image processing

Co-authored-by: Patrick Devine <patrick@infrahq.com>
This commit is contained in:
Michael Yang
2025-04-16 15:25:34 -07:00
committed by Michael Yang
parent f0c66e6dea
commit 178761aef3
3 changed files with 494 additions and 4 deletions

View File

@@ -15,6 +15,7 @@ import (
type Model struct {
model.Base
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v,vision"`
*Projector `gguf:"mm"`
@@ -43,8 +44,9 @@ func New(c fs.Config) (model.Model, error) {
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(
@@ -66,21 +68,42 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err
}
f32s, aspectRatio, err := m.ProcessImage(ctx, img)
pixelsLocal, pixelsGlobal, size, err := m.ProcessImage(img)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s))
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
if err != nil {
return nil, err
}
ratioW, ratioH := int(size.X/m.imageSize), int(size.Y/m.imageSize)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, size.Y/ratioH, m.numChannels, ratioH*ratioW)
pixelValues := tilesLocal
if len(pixelsGlobal) > 0 {
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
if err != nil {
return nil, err
}
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
return m.Projector.Forward(ctx, visionOutputs), nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {