use non-causal mask only for image positions

This commit is contained in:
Michael Yang
2025-03-10 15:38:58 -07:00
parent 9d2a20a763
commit e95278932b
2 changed files with 18 additions and 10 deletions

View File

@@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, false)
defer causal.SetCausal(ctx, true)
except := make([]int32, visionOutputs.Dim(1))
for i := 0; i < visionOutputs.Dim(1); i++ {
except[i] = int32(offset + i)
}
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}