gemma2 impl

This commit is contained in:
Patrick Devine
2025-02-07 15:58:15 -08:00
committed by Michael Yang
parent 4dcf80167a
commit 5f74d1fd47
18 changed files with 1057 additions and 24 deletions

View File

@@ -20,14 +20,15 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
return key, nil