image processing for llama3.2 (#6963)

Co-authored-by: jmorganca <jmorganca@gmail.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
Patrick Devine
2024-10-18 16:12:35 -07:00
committed by GitHub
parent bf4018b9ec
commit c7cb0f0602
35 changed files with 3851 additions and 203 deletions

240
server/imageproc/images.go Normal file
View File

@@ -0,0 +1,240 @@
package imageproc
import (
"bytes"
"fmt"
"image"
"image/color"
_ "image/jpeg"
_ "image/png"
"math"
"slices"
"golang.org/x/image/draw"
)
func GetSupportedAspectRatios(maxTiles int) []image.Point {
ratios := []image.Point{}
for w := range maxTiles {
for h := range maxTiles {
if (w+1)*(h+1) <= maxTiles {
ratios = append(ratios, image.Point{w + 1, h + 1})
}
}
}
return ratios
}
func clip(a, a_min, a_max int) int {
if a < a_min {
return a_min
} else if a > a_max {
return a_max
}
return a
}
func getImageSizeFitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point {
targetWidth := clip(imageSize.X, tileSize, canvasSize.X)
targetHeight := clip(imageSize.Y, tileSize, canvasSize.Y)
scaleWidth := float64(targetWidth) / float64(imageSize.X)
scaleHeight := float64(targetHeight) / float64(imageSize.Y)
var w, h int
if scaleWidth < scaleHeight {
w = targetWidth
h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight)
} else {
w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth)
h = targetHeight
}
return image.Point{w, h}
}
func getOptimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point {
possibleTileArrangements := GetSupportedAspectRatios(maxImageTiles)
possibleCanvasSizes := []image.Point{}
for _, pta := range possibleTileArrangements {
possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize})
}
scales := []float64{}
for _, pcs := range possibleCanvasSizes {
scaleHeight := float64(pcs.Y) / float64(imageSize.Y)
scaleWidth := float64(pcs.X) / float64(imageSize.X)
if scaleWidth > scaleHeight {
scales = append(scales, scaleHeight)
} else {
scales = append(scales, scaleWidth)
}
}
var minUpscale float64
var maxDownscale float64
var upscale bool
for _, s := range scales {
if s > 1.0 {
upscale = true
if minUpscale == 0 {
minUpscale = s
} else {
minUpscale = math.Min(minUpscale, s)
}
} else {
maxDownscale = math.Max(maxDownscale, s)
}
}
selectedScale := maxDownscale
if upscale {
selectedScale = minUpscale
}
var selectedCanvas image.Point
for n, pcs := range possibleCanvasSizes {
if scales[n] == selectedScale {
// choose the smallest possible canvas
if selectedCanvas.X == 0 && selectedCanvas.Y == 0 {
selectedCanvas = pcs
} else if pcs.X*pcs.Y < selectedCanvas.X*selectedCanvas.Y {
selectedCanvas = pcs
}
}
}
return selectedCanvas
}
func splitToTiles(img image.Image, numTilesSize image.Point) []image.Image {
b := img.Bounds()
width := b.Max.X - b.Min.X
height := b.Max.Y - b.Min.Y
tileHeight := height / numTilesSize.Y
tileWidth := width / numTilesSize.X
images := []image.Image{}
for h := range numTilesSize.Y {
for w := range numTilesSize.X {
rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1))
images = append(images, img.(interface {
SubImage(image.Rectangle) image.Image
}).SubImage(rect))
}
}
return images
}
// remove the "alpha" channel by drawing over a prefilled image
func compositeImage(img image.Image) image.Image {
dst := image.NewRGBA(img.Bounds())
white := color.RGBA{255, 255, 255, 255}
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
return dst
}
func ResizeImage(img image.Image, format string, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) {
if format == "png" {
img = compositeImage(img)
}
b := img.Bounds()
tileSize := outputSize.Y
canvasSize := getOptimalTiledCanvas(b.Max, maxImageTiles, tileSize)
aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize}
newSize := getImageSizeFitToCanvas(b.Max, canvasSize, tileSize)
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
// scaling choices:
// NearestNeighbor fast, blocky output
// ApproxBiLinear fast, medium quality
// BiLinear slow, high quality
// CatmullRom very slow, very high quality
draw.BiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil)
return dst, aspectRatio
}
func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image {
paddedSize := image.Point{
X: outputSize.X * aspectRatio.X,
Y: outputSize.Y * aspectRatio.Y,
}
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
return dst
}
func PackImages(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 {
subImages := splitToTiles(img, aspectRatio)
var pixelVals []float32
for _, subImg := range subImages {
bounds := subImg.Bounds()
var rVals, gVals, bVals []float32
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := subImg.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
}
return pixelVals
}
func Preprocess(imageData []byte) ([]float32, int, error) {
// todo: need guard in here for bad image data
// mllama values
outputSize := image.Point{560, 560}
maxTiles := 4
// clip values
mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
std := [3]float32{0.26862954, 0.26130258, 0.27577711}
img, format, err := image.Decode(bytes.NewReader(imageData))
if err != nil {
return nil, 0, fmt.Errorf("failed to decode image: %w", err)
}
newImage, aspectRatio := ResizeImage(img, format, outputSize, maxTiles)
newImage = PadImage(newImage, outputSize, aspectRatio)
data := PackImages(newImage, aspectRatio, mean, std)
aspectRatioIndex := slices.Index(GetSupportedAspectRatios(maxTiles), aspectRatio) + 1
return data, aspectRatioIndex, nil
}

View File

@@ -0,0 +1,344 @@
package imageproc
import (
"bytes"
"image"
"image/png"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestAspectRatios(t *testing.T) {
type aspectCase struct {
MaxTiles int
Expected []image.Point
}
cases := []aspectCase{
{
MaxTiles: 1,
Expected: []image.Point{{1, 1}},
},
{
MaxTiles: 2,
Expected: []image.Point{{1, 1}, {1, 2}, {2, 1}},
},
{
MaxTiles: 3,
Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {2, 1}, {3, 1}},
},
{
MaxTiles: 4,
Expected: []image.Point{{1, 1}, {1, 2}, {1, 3}, {1, 4}, {2, 1}, {2, 2}, {3, 1}, {4, 1}},
},
}
for _, c := range cases {
actual := GetSupportedAspectRatios(c.MaxTiles)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestGetImageSizeFitToCanvas(t *testing.T) {
type imageSizeCase struct {
ImageRect image.Point
CanvasRect image.Point
TileSize int
Expected image.Point
}
cases := []imageSizeCase{
{
ImageRect: image.Point{400, 400},
CanvasRect: image.Point{640, 480},
TileSize: 200,
Expected: image.Point{400, 400},
},
{
ImageRect: image.Point{1024, 768},
CanvasRect: image.Point{640, 480},
TileSize: 200,
Expected: image.Point{640, 480},
},
{
ImageRect: image.Point{500, 500},
CanvasRect: image.Point{1000, 1000},
TileSize: 750,
Expected: image.Point{750, 750},
},
{
ImageRect: image.Point{500, 1000},
CanvasRect: image.Point{2000, 2000},
TileSize: 2000,
Expected: image.Point{1000, 2000},
},
{
ImageRect: image.Point{4000, 3000},
CanvasRect: image.Point{2000, 1000},
TileSize: 1000,
Expected: image.Point{1333, 1000},
},
{
ImageRect: image.Point{667, 1000},
CanvasRect: image.Point{1000, 1000},
TileSize: 560,
Expected: image.Point{667, 1000},
},
}
for _, c := range cases {
actual := getImageSizeFitToCanvas(c.ImageRect, c.CanvasRect, c.TileSize)
if actual != c.Expected {
t.Errorf("incorrect image rect: '%#v'. expected: '%#v'", actual, c.Expected)
}
}
}
func TestGetOptimalTiledCanvas(t *testing.T) {
type tiledCanvasSizeCase struct {
ImageSize image.Point
MaxImageTiles int
TileSize int
Expected image.Point
}
cases := []tiledCanvasSizeCase{
{
ImageSize: image.Point{1024, 768},
MaxImageTiles: 4,
TileSize: 1000,
Expected: image.Point{2000, 1000},
},
{
ImageSize: image.Point{1024, 768},
MaxImageTiles: 4,
TileSize: 560,
Expected: image.Point{1120, 1120},
},
}
for _, c := range cases {
actual := getOptimalTiledCanvas(c.ImageSize, c.MaxImageTiles, c.TileSize)
if actual != c.Expected {
t.Errorf("incorrect tiled canvas: '%#v'. expected: '%#v'", actual, c.Expected)
}
}
}
func TestSplitToTiles(t *testing.T) {
type splitCase struct {
TestImage image.Image
NumTilesSize image.Point
Expected []image.Image
}
cases := []splitCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
NumTilesSize: image.Point{1, 1},
Expected: []image.Image{image.NewRGBA(image.Rect(0, 0, 1024, 768))},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 500)),
NumTilesSize: image.Point{2, 1},
Expected: []image.Image{
image.NewRGBA(image.Rect(0, 0, 500, 500)),
image.NewRGBA(image.Rect(500, 0, 1000, 500)),
},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 1000)),
NumTilesSize: image.Point{2, 2},
Expected: []image.Image{
image.NewRGBA(image.Rect(0, 0, 500, 500)),
image.NewRGBA(image.Rect(500, 0, 1000, 500)),
image.NewRGBA(image.Rect(0, 500, 500, 1000)),
image.NewRGBA(image.Rect(500, 500, 1000, 1000)),
},
},
}
for _, c := range cases {
actual := splitToTiles(c.TestImage, c.NumTilesSize)
if len(actual) != len(c.Expected) {
t.Errorf("incorrect number of images '%d': expected: '%d'", len(actual), len(c.Expected))
}
for i := range actual {
if actual[i].Bounds() != c.Expected[i].Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual[i].Bounds(), c.Expected[i].Bounds())
}
}
}
}
func TestResize(t *testing.T) {
type resizeCase struct {
TestImage image.Image
OutputSize image.Point
MaxImageTiles int
ExpectedImage image.Image
ExpectedAspectRatio image.Point
}
cases := []resizeCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 200, 200)),
OutputSize: image.Point{100, 100},
MaxImageTiles: 1,
ExpectedImage: image.NewRGBA(image.Rect(0, 0, 100, 100)),
ExpectedAspectRatio: image.Point{1, 1},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 200, 200)),
OutputSize: image.Point{100, 100},
MaxImageTiles: 2,
ExpectedImage: image.NewRGBA(image.Rect(0, 0, 100, 100)),
ExpectedAspectRatio: image.Point{1, 1},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
OutputSize: image.Point{560, 560},
MaxImageTiles: 4,
ExpectedImage: image.NewRGBA(image.Rect(0, 0, 560, 560)),
ExpectedAspectRatio: image.Point{1, 1},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2560, 1920)),
OutputSize: image.Point{560, 560},
MaxImageTiles: 4,
ExpectedImage: image.NewRGBA(image.Rect(0, 0, 1120, 840)),
ExpectedAspectRatio: image.Point{2, 2},
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
OutputSize: image.Point{560, 560},
MaxImageTiles: 4,
ExpectedImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
ExpectedAspectRatio: image.Point{2, 2},
},
}
for _, c := range cases {
actualImage, actualAspectRatio := ResizeImage(c.TestImage, "png", c.OutputSize, c.MaxImageTiles)
if actualImage.Bounds() != c.ExpectedImage.Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actualImage.Bounds(), c.ExpectedImage.Bounds())
}
if actualAspectRatio != c.ExpectedAspectRatio {
t.Errorf("aspect ratio incorrect: '%#v': expected: '%#v'", actualAspectRatio, c.ExpectedAspectRatio)
}
}
}
func TestPad(t *testing.T) {
type padCase struct {
TestImage image.Image
OutputSize image.Point
AspectRatio image.Point
Expected image.Image
}
cases := []padCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1000, 667)),
OutputSize: image.Point{560, 560},
AspectRatio: image.Point{2, 2},
Expected: image.NewRGBA(image.Rect(0, 0, 1120, 1120)),
},
}
for _, c := range cases {
actual := PadImage(c.TestImage, c.OutputSize, c.AspectRatio)
if actual.Bounds() != c.Expected.Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
}
}
}
func TestPackImages(t *testing.T) {
type packCase struct {
TestImage image.Image
AspectRatio image.Point
ExpectedVals int
}
mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
std := [3]float32{0.26862954, 0.26130258, 0.27577711}
cases := []packCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1120, 1120)),
AspectRatio: image.Point{2, 2},
ExpectedVals: 2 * 2 * 3 * 560 * 560,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 560, 560)),
AspectRatio: image.Point{1, 1},
ExpectedVals: 1 * 1 * 3 * 560 * 560,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1120, 560)),
AspectRatio: image.Point{1, 2},
ExpectedVals: 1 * 2 * 3 * 560 * 560,
},
}
for _, c := range cases {
actualVals := PackImages(c.TestImage, c.AspectRatio, mean, std)
if len(actualVals) != c.ExpectedVals {
t.Errorf("packed image size incorrect: '%d': expected: '%d'", len(actualVals), c.ExpectedVals)
}
}
}
func TestPreprocess(t *testing.T) {
type preprocessCase struct {
TestImage image.Image
ExpectedVals int
ExpectedAspectRatioID int
}
cases := []preprocessCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
ExpectedVals: 0,
ExpectedAspectRatioID: 1,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
ExpectedVals: 0,
ExpectedAspectRatioID: 6,
},
}
for _, c := range cases {
var buf bytes.Buffer
err := png.Encode(&buf, c.TestImage)
if err != nil {
t.Fatal(err)
}
imgData, aspectRatioID, err := Preprocess(buf.Bytes())
if err != nil {
t.Fatalf("error processing: %q", err)
}
if len(imgData) == 0 {
t.Errorf("no image data returned")
}
if aspectRatioID != c.ExpectedAspectRatioID {
t.Errorf("aspect ratio incorrect: '%d': expected: '%d'", aspectRatioID, c.ExpectedAspectRatioID)
}
}
}

View File

@@ -194,7 +194,9 @@ func parseFromFile(ctx context.Context, command string, baseLayers []*layerGGML,
mediatype := "application/vnd.ollama.image.model"
if ggml.Name() == "ggla" || ggml.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if ggml.KV().Architecture() == "clip" {
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.vision.block_count", ggml.KV().Architecture())]; ok || ggml.KV().Kind() == "projector" {
mediatype = "application/vnd.ollama.image.projector"
}

View File

@@ -3,24 +3,42 @@ package server
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/imageproc"
"github.com/ollama/ollama/template"
)
type tokenizeFunc func(context.Context, string) ([]int, error)
var errTooManyImages = errors.New("vision model only supports a single image per message")
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// always include the last message
isMllama := checkMllamaModelFamily(m)
n := len(msgs) - 1
// in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- {
for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
// always include the last message
if i == n {
continue
}
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
@@ -38,16 +56,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, err
}
c := len(s)
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// images are represented as 768 sized embeddings
// TODO: get embedding length from project metadata
c += 768 * len(m.Images)
ctxLen += 768 * len(m.Images)
}
}
if c > opts.NumCtx {
if ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {
@@ -55,20 +73,70 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
return "", nil, err
currMsgIdx := n
if isMllama {
lastMsgIdx := len(msgs) - 1
for i := lastMsgIdx; i >= currMsgIdx; i-- {
if len(msgs[i].Images) > 0 {
data, aspectRatioID, err := imageproc.Preprocess(msgs[i].Images[0])
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
imgData := llm.ImageData{
Data: buf.Bytes(),
AspectRatioID: aspectRatioID,
}
msgs[i].Content = strings.TrimSpace("<|image|>" + msgs[i].Content)
images = append(images, imgData)
break
}
}
} else {
for cnt, msg := range msgs[currMsgIdx:] {
prefix := ""
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
ID: len(images),
Data: i,
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
images = append(images, imgData)
}
msgs[currMsgIdx+cnt].Content = strings.TrimSpace(prefix + " " + prompt)
}
}
for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
return "", nil, err
}
return b.String(), images, nil
}
func checkMllamaModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "mllama" {
return true
}
}
return false
}

View File

@@ -3,6 +3,8 @@ package server
import (
"bytes"
"context"
"image"
"image/png"
"testing"
"github.com/google/go-cmp/cmp"
@@ -13,18 +15,53 @@ import (
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
images [][]byte
prompt string
images [][]byte
aspectRatioID int
error error
}
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
createImg := func(width, height int) ([]byte, error) {
img := image.NewRGBA(image.Rect(0, 0, 5, 5))
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
imgBuf, err := createImg(5, 5)
if err != nil {
t.Fatal(err)
}
imgBuf2, err := createImg(6, 6)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
model Model
limit int
msgs []api.Message
expect
}{
{
name: "messages",
model: visionModel,
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -37,6 +74,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "truncate messages",
model: visionModel,
limit: 1,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -49,6 +87,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "truncate messages with image",
model: visionModel,
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -64,6 +103,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "truncate messages with images",
model: visionModel,
limit: 64,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -79,6 +119,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "messages with images",
model: visionModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -95,6 +136,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "message with image tag",
model: visionModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
@@ -111,6 +153,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "messages with interleaved images",
model: visionModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -129,6 +172,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "truncate message with interleaved images",
model: visionModel,
limit: 1024,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -146,6 +190,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "message with system prompt",
model: visionModel,
limit: 2048,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
@@ -159,6 +204,7 @@ func TestChatPrompt(t *testing.T) {
},
{
name: "out of order system",
model: visionModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
@@ -170,23 +216,113 @@ func TestChatPrompt(t *testing.T) {
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
},
}
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
{
name: "multiple images same prompt",
model: visionModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
},
expect: expect{
prompt: "[img-0][img-1] Compare these two pictures of hotdogs ",
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
},
},
{
name: "messages with mllama (no images)",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "messages with mllama single prompt",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "<|image|>How many hotdogs are in this image? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "multiple messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? <|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf2},
aspectRatioID: 1,
},
},
{
name: "earlier image with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "There are four hotdogs."},
{Role: "user", Content: "Which ones have mustard?"},
},
expect: expect{
prompt: "<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "too many images with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
},
expect: expect{
error: errTooManyImages,
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
if err != nil {
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
@@ -202,8 +338,14 @@ func TestChatPrompt(t *testing.T) {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i])
if len(model.Config.ModelFamilies) == 0 {
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
}
} else {
if images[i].AspectRatioID != tt.aspectRatioID {
t.Errorf("expected aspect ratio %d, got %d", tt.aspectRatioID, images[i].AspectRatioID)
}
}
}
})

View File

@@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.GenerateResponse{
@@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
// load the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
isMllama := checkMllamaModelFamily(model)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
@@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
if isMllama {
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
} else {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})

View File

@@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) {
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})

View File

@@ -562,7 +562,7 @@ func TestShow(t *testing.T) {
Modelfile: fmt.Sprintf(
"FROM %s\nFROM %s",
createBinFile(t, llm.KV{"general.architecture": "test"}, nil),
createBinFile(t, llm.KV{"general.architecture": "clip"}, nil),
createBinFile(t, llm.KV{"general.type": "projector", "general.architecture": "clip"}, nil),
),
})