ollamarunner: Multi-modal worst case graph

We currently preallocate compute graph memory for the worst case
batch of text tokens. This adds support for doing the same for
images.

Note that image models are more complicated than text models in
how they process their inputs so there may be cases where this
approach isn't completely generic for all models. It covers all
currently supported models though.
This commit is contained in:
Jesse Gross
2025-04-07 13:59:11 -07:00
committed by Jesse Gross
parent 3c14461d5d
commit fe623c2cf4
2 changed files with 88 additions and 14 deletions

View File

@@ -48,12 +48,12 @@ func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
// getMultimodal takes a source set of tensors (which may contain a whole or
// parts of one or more images) and returns the equivalent that can be used in
// the current context
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) {
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
out := make([]input.Multimodal, len(in))
for i := range out {
if in[i].Tensor != nil {
var err error
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor)
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
if err != nil {
return nil, err
}
@@ -65,7 +65,7 @@ func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []
return out, nil
}
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) {
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
entry := m[in]
if entry.data == nil {
@@ -83,19 +83,32 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
return nil, nil
}
computeCtx.Forward(tensors...).Compute(tensors...)
computeCtx.Forward(tensors...)
entry.data = make([][]float32, len(entry.mm))
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
if !reserve {
computeCtx.Compute(tensors...)
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
}
}
} else {
err := computeCtx.Reserve()
if err != nil {
return nil, err
}
}
}
for i, t := range entry.mm {
if in == t.Tensor {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
} else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
}
}
}