preserve last system message from modelfile (#2289)

This commit is contained in:
Bruce MacDonald
2024-01-31 21:45:01 -05:00
committed by GitHub
parent 583950c828
commit a896079705
2 changed files with 66 additions and 17 deletions

View File

@@ -256,15 +256,17 @@ func chatHistoryEqual(a, b ChatHistory) bool {
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want ChatHistory
wantErr string
name string
model Model
msgs []api.Message
want ChatHistory
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
name: "Single Message",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "system",
@@ -287,8 +289,10 @@ func TestChat(t *testing.T) {
},
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
name: "Message History",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "system",
@@ -323,8 +327,10 @@ func TestChat(t *testing.T) {
},
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
name: "Assistant Only",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{
{
Role: "assistant",
@@ -340,6 +346,51 @@ func TestChat(t *testing.T) {
},
},
},
{
name: "Last system message is preserved from modelfile",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "user",
Content: "hi",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Mojo Jojo.",
Prompt: "hi",
First: true,
},
},
LastSystem: "You are Mojo Jojo.",
},
},
{
name: "Last system message is preserved from messages",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are Professor Utonium.",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Professor Utonium.",
First: true,
},
},
LastSystem: "You are Professor Utonium.",
},
},
{
name: "Invalid Role",
msgs: []api.Message{
@@ -353,11 +404,8 @@ func TestChat(t *testing.T) {
}
for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, err := m.ChatPrompts(tt.msgs)
got, err := tt.model.ChatPrompts(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")