mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
ml: Abstract attention out of model definitions
There are two benefits to doing this: - Provide a library function that models can use, reducing code for each model implementation - Enables a single place to drop in optimized implementations of attention based on the backend or other factors. One is provided for GGML. On CUDA this improves token generation rate by about 3%. It does not have a significant effect on Metal. Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
@@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.MulmatFullPrec(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
||||
@@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
@@ -112,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
var key, value ml.Tensor
|
||||
var key, value, mask ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
@@ -125,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, _ = cache.Get(ctx)
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
|
||||
Reference in New Issue
Block a user