fix drift from main

This commit is contained in:
Jesse Gross
2025-03-07 09:30:10 -08:00
committed by Michael Yang
parent 4b037a97dc
commit 4346c2409d
5 changed files with 49 additions and 22 deletions

View File

@@ -66,9 +66,6 @@ func newTextModel(c ml.Config) *TextModel {
},
}
slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m
}
@@ -145,12 +142,20 @@ type TextLayer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@@ -181,7 +186,13 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, cache, m.TextOptions)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@@ -190,7 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs)
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
}

View File

@@ -53,7 +53,7 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`