mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-12 00:37:04 +00:00
ollamarunner: Multi-modal worst case graph
We currently preallocate compute graph memory for the worst case batch of text tokens. This adds support for doing the same for images. Note that image models are more complicated than text models in how they process their inputs so there may be cases where this approach isn't completely generic for all models. It covers all currently supported models though.
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -20,6 +22,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -444,7 +447,7 @@ func (s *Server) processBatch() error {
|
||||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal)
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -732,12 +735,71 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
// - Encode a 2048x2048 image. This assumes that a single image of this
|
||||
// size is sufficient to trigger the worst case. This is currently true
|
||||
// because for existing models, only a single image fits in a batch.
|
||||
// - Add the embedding to a full batch of tokens - this is necessary because
|
||||
// the model may be looking for non-image data, such as <image> tags.
|
||||
// - Run PostTokenize to execute any transformations between generated
|
||||
// embeddings and what the forward pass expects.
|
||||
// - The result may now be larger than a batch (images may not fit in a
|
||||
// single batch), so trim based on what will fit and must be grouped together.
|
||||
// - Fill out the rest of the space with text tokens.
|
||||
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
|
||||
mmCtx := s.model.Backend().NewContext()
|
||||
defer mmCtx.Close()
|
||||
|
||||
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
|
||||
var buf bytes.Buffer
|
||||
bmp.Encode(&buf, img)
|
||||
|
||||
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
|
||||
mmStore.addMultimodal(inputs[0].Multimodal)
|
||||
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, inp := range inputs {
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > s.batchSize {
|
||||
inputs = inputs[i:min(i+minBatch, len(inputs))]
|
||||
break
|
||||
} else if i+minBatch > s.batchSize {
|
||||
inputs = inputs[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var batch input.Batch
|
||||
|
||||
inputs := make([]int32, s.batchSize)
|
||||
batchInputs := make([]int32, len(inputs))
|
||||
batch.Positions = make([]int32, len(inputs))
|
||||
batch.Sequences = make([]int, len(inputs))
|
||||
for i := range inputs {
|
||||
for i, inp := range inputs {
|
||||
batchInputs[i] = inp.Token
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
|
||||
}
|
||||
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
@@ -746,8 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user