attention: Remove unnecessary contiguous operations

Prior to performing attention, we need to permute query, key
and value. Currently we call Contiguous after each of these
permutations, which is correct but expensive. Avoiding the
3 calls to Contiguous increases performance by over 20%.

The permutations of query and key do not violate the continuity
rules for mulmat and the Contiguous call can be simply removed.

Value requires a different permutation and does require Contiguous.
However, we can use the copy into the cache as a way to perform this
without further overhead.

To support this and avoid unexpected tensor shapes that are seen by
models, we need tighter integration between attention, cache
and backend. Future optimization will also likely need this structure
 - for example, flash attention has special padding requirements in
the cache and other backends may have their own needs.

This further contains the operations that go into attention so that
these and other optimizations can be handled transparently. Models
that have special requirements for attention can still implement
their own version of it.
This commit is contained in:
Jesse Gross
2025-02-22 21:34:10 -08:00
committed by Jesse Gross
parent 96a97adf9b
commit 854a9195f3
10 changed files with 270 additions and 86 deletions

View File

@@ -1,6 +1,8 @@
package kvcache
import (
"fmt"
"github.com/ollama/ollama/ml"
)
@@ -11,6 +13,9 @@ import (
//
// Not currently safe for multiple sequences
type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.cacheCtx = backend.NewContext()
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() {
c.cacheCtx.Close()
}
@@ -75,6 +100,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)