Revert "chat api (#991)" while context variable is fixed

This reverts commit 7a0899d62d.
This commit is contained in:
Jeffrey Morgan
2023-12-04 21:16:27 -08:00
parent f1ef3f9947
commit 00d06619a1
8 changed files with 144 additions and 559 deletions

View File

@@ -47,82 +47,37 @@ type Model struct {
Options map[string]interface{}
}
type PromptVars struct {
System string
Prompt string
Response string
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
}
func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
tmpl, err := template.New("").Parse(m.Template)
tmpl, err := template.New("").Parse(t)
if err != nil {
return "", err
}
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
var vars struct {
First bool
System string
Prompt string
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
}
var sb strings.Builder
if err := tmpl.Execute(&sb, p); err != nil {
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
return prompt.String(), nil
}
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
// build the prompt from the list of messages
var prompt strings.Builder
currentVars := PromptVars{}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch msg.Role {
case "system":
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.Prompt = msg.Content
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
return prompt.String(), nil
return sb.String(), nil
}
type ManifestV2 struct {