This commit is contained in:
Michael Yang
2024-06-20 11:00:08 -07:00
parent 269ed6e6a2
commit 2c3fe1fd97
5 changed files with 224 additions and 113 deletions

View File

@@ -11,8 +11,13 @@ import (
"github.com/ollama/ollama/template"
)
func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// extract system messages which should always be included
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
@@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
return false
})
if len(system) == 0 && r.model.System != "" {
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: r.model.System})
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
var b bytes.Buffer
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err
}
s, err := r.llama.Tokenize(ctx, b.String())
s, err := tokenize(ctx, b.String())
if err != nil {
return "", nil, err
}
c := len(s)
if r.model.ProjectorPaths != nil {
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// TODO: get image embedding length from project metadata
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
}
}
if c > r.NumCtx {
if c > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
@@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err
}

View File

@@ -7,15 +7,10 @@ import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template"
)
type mock struct {
llm.LlamaServer
}
func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) {
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
@@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages",
name: "truncate messages",
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages with image",
name: "truncate messages with image",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages with images",
name: "truncate messages with images",
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "messages with images",
name: "messages with images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "message with image tag",
name: "message with image tag",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
@@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "messages with interleaved images",
name: "messages with interleaved images",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate message with interleaved images",
name: "truncate message with interleaved images",
limit: 1024,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "message with system prompt",
name: "message with system prompt",
limit: 2048,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
@@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) {
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
r := runnerRef{
llama: mock{},
model: &Model{Template: tmpl, ProjectorPaths: []string{"vision"}},
Options: &api.Options{},
}
r.NumCtx = tt.limit
prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs)
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil {
t.Fatal(err)
}

View File

@@ -54,6 +54,8 @@ func init() {
gin.SetMode(mode)
}
var errRequired = errors.New("is required")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
if name == "" {
return nil, errors.New("model is required")
return nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
@@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, err)
handleScheduleError(c, req.Model, err)
return
}
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
@@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
}
if req.Prompt != "" {
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
if len(msgs) == 0 {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := r.model.Template
if req.Template != "" {
@@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, err)
handleScheduleError(c, req.Model, err)
return
}
@@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
} else if err != nil {
handleScheduleError(c, err)
handleScheduleError(c, req.Model, err)
return
}
@@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, err error) {
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}