This commit is contained in:
Michael Yang
2024-06-20 11:00:08 -07:00
parent 269ed6e6a2
commit 2c3fe1fd97
5 changed files with 224 additions and 113 deletions

View File

@@ -11,8 +11,13 @@ import (
"github.com/ollama/ollama/template"
)
func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// extract system messages which should always be included
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
@@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
return false
})
if len(system) == 0 && r.model.System != "" {
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: r.model.System})
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
var b bytes.Buffer
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err
}
s, err := r.llama.Tokenize(ctx, b.String())
s, err := tokenize(ctx, b.String())
if err != nil {
return "", nil, err
}
c := len(s)
if r.model.ProjectorPaths != nil {
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// TODO: get image embedding length from project metadata
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
}
}
if c > r.NumCtx {
if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
@@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err
}