runner.go: Better abstract vision model integration

-Update mllama to take the cross attention state as embeddings in
a batch, more similar to how Llava handles it. This improves
integration with the input cache.
-Pass locations in a prompt for embeddings using tags similar to Llava.
-Abstract interface to vision models so the main runner accesses Clip
and Mllama similarly

Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
Jesse Gross
2024-10-11 15:34:01 -07:00
committed by Jesse Gross
parent 712e99d477
commit c826e57475
13 changed files with 534 additions and 454 deletions

View File

@@ -75,11 +75,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
currMsgIdx := n
if isMllama {
lastMsgIdx := len(msgs) - 1
for i := lastMsgIdx; i >= currMsgIdx; i-- {
if len(msgs[i].Images) > 0 {
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
imgPrompt := ""
prompt := msg.Content
for _, i := range msg.Images {
var imgData llm.ImageData
if isMllama {
data, aspectRatioID, err := imageproc.Preprocess(i)
if err != nil {
return "", nil, err
}
@@ -90,37 +95,30 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, err
}
imgData := llm.ImageData{
imgData = llm.ImageData{
ID: len(images),
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
images = append(images, imgData)
break
}
}
} else {
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
imgPrompt = "<|image|>"
} else {
imgData = llm.ImageData{
ID: len(images),
Data: i,
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
imgPrompt = " "
}
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
}
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + imgPrompt + prompt)
}
// truncate any messages that do not fit into the context window

View File

@@ -249,7 +249,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "<|image|>How many hotdogs are in this image? ",
prompt: "[img-0]<|image|>How many hotdogs are in this image? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
@@ -264,7 +264,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
@@ -279,8 +279,8 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf2},
prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf, imgBuf2},
aspectRatioID: 1,
},
},
@@ -294,7 +294,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Which ones have mustard?"},
},
expect: expect{
prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},

View File

@@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
images[i] = llm.ImageData{Data: buf.Bytes(), AspectRatioID: aspectRatioID}
images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: aspectRatioID}
} else {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
@@ -239,11 +239,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
for _, i := range images {
imgPrompt := ""
if isMllama {
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
} else {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
imgPrompt = "<|image|>"
}
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})