Disable causal attention based on batch index

Currently we are using positions, which are relative to a
sequence and may not be unique.
This commit is contained in:
Jesse Gross
2025-03-10 17:17:19 -07:00
committed by Michael Yang
parent 475005504e
commit a8e83a7654
2 changed files with 10 additions and 12 deletions

View File

@@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int32
var except []int
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
@@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu
length = 1
}
except = append(except, positions[imageDst])
except = append(except, imageDst)
}
if embedding != nil {
@@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)