set non-causal attention

This commit is contained in:
Michael Yang
2025-03-07 13:52:45 -08:00
parent 631fecc6d9
commit 0df1800436
6 changed files with 57 additions and 25 deletions

View File

@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
@@ -165,12 +166,15 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
if embeddings == nil {
embeddings = m.TokenEmbedding.Forward(ctx, inputs)
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
}
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true