add gemma vision encoder

This commit is contained in:
Michael Yang
2025-03-06 12:16:54 -08:00
parent 5f74d1fd47
commit 4b037a97dc
10 changed files with 337 additions and 34 deletions

View File

@@ -160,9 +160,12 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
if embeddings == nil {
embeddings = m.TokenEmbedding.Forward(ctx, inputs)
}
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true