mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 08:17:03 +00:00
ollamarunner: Multi-modal worst case graph
We currently preallocate compute graph memory for the worst case batch of text tokens. This adds support for doing the same for images. Note that image models are more complicated than text models in how they process their inputs so there may be cases where this approach isn't completely generic for all models. It covers all currently supported models though.
This commit is contained in:
@@ -48,12 +48,12 @@ func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
|
||||
// getMultimodal takes a source set of tensors (which may contain a whole or
|
||||
// parts of one or more images) and returns the equivalent that can be used in
|
||||
// the current context
|
||||
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) {
|
||||
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
|
||||
out := make([]input.Multimodal, len(in))
|
||||
for i := range out {
|
||||
if in[i].Tensor != nil {
|
||||
var err error
|
||||
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor)
|
||||
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) {
|
||||
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
|
||||
entry := m[in]
|
||||
|
||||
if entry.data == nil {
|
||||
@@ -83,19 +83,32 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
computeCtx.Forward(tensors...).Compute(tensors...)
|
||||
|
||||
computeCtx.Forward(tensors...)
|
||||
entry.data = make([][]float32, len(entry.mm))
|
||||
for i, t := range entry.mm {
|
||||
if t.Tensor != nil {
|
||||
entry.data[i] = t.Tensor.Floats()
|
||||
|
||||
if !reserve {
|
||||
computeCtx.Compute(tensors...)
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if t.Tensor != nil {
|
||||
entry.data[i] = t.Tensor.Floats()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := computeCtx.Reserve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if in == t.Tensor {
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
||||
if !reserve {
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
||||
} else {
|
||||
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user