mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 15:57:04 +00:00
Prior to performing attention, we need to permute query, key and value. Currently we call Contiguous after each of these permutations, which is correct but expensive. Avoiding the 3 calls to Contiguous increases performance by over 20%. The permutations of query and key do not violate the continuity rules for mulmat and the Contiguous call can be simply removed. Value requires a different permutation and does require Contiguous. However, we can use the copy into the cache as a way to perform this without further overhead. To support this and avoid unexpected tensor shapes that are seen by models, we need tighter integration between attention, cache and backend. Future optimization will also likely need this structure - for example, flash attention has special padding requirements in the cache and other backends may have their own needs. This further contains the operations that go into attention so that these and other optimizations can be handled transparently. Models that have special requirements for attention can still implement their own version of it.
129 lines
3.0 KiB
Go
129 lines
3.0 KiB
Go
package kvcache
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
// Encoder cache stores K and V tensors that are position independent
|
|
//
|
|
// The tensors can be of any shape and will be returned as they were stored
|
|
// The mask is currently always nil
|
|
//
|
|
// Not currently safe for multiple sequences
|
|
type EncoderCache struct {
|
|
// config controls mostly backend-specific optimizations
|
|
config *ml.CacheConfig
|
|
|
|
// ** current forward pass **
|
|
|
|
// the active layer for Get and Put
|
|
curLayer int
|
|
|
|
// if something is stored during this pass, this
|
|
// will be the position (but there is no guarantee
|
|
// anything will be stored)
|
|
curPos int32
|
|
|
|
// ** cache metadata **
|
|
|
|
// was something stored in the cache?
|
|
encoderCached bool
|
|
|
|
// position of the cached data
|
|
encoderPos int32
|
|
|
|
// ** cache data storage **
|
|
|
|
cacheCtx ml.Context
|
|
keys, values []ml.Tensor
|
|
}
|
|
|
|
func NewEncoderCache() *EncoderCache {
|
|
return &EncoderCache{}
|
|
}
|
|
|
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
if c.config == nil {
|
|
var config ml.CacheConfig
|
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
config = cc.CacheConfig()
|
|
}
|
|
c.config = &config
|
|
}
|
|
|
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
|
}
|
|
|
|
c.cacheCtx = backend.NewContext()
|
|
}
|
|
|
|
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
|
if c.config != nil {
|
|
panic("config cannot be changed after being previously set, either by the model or backend")
|
|
}
|
|
|
|
c.config = &config
|
|
}
|
|
|
|
func (c *EncoderCache) Close() {
|
|
c.cacheCtx.Close()
|
|
}
|
|
|
|
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
|
// The image is always in the first position
|
|
c.curPos = positions[0]
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *EncoderCache) SetLayer(layer int) {
|
|
if layer >= len(c.keys) {
|
|
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
|
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
|
}
|
|
|
|
c.curLayer = layer
|
|
}
|
|
|
|
func (c *EncoderCache) EncoderCached() bool {
|
|
return c.encoderCached
|
|
}
|
|
|
|
func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
return c.keys[c.curLayer], c.values[c.curLayer], nil
|
|
}
|
|
|
|
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
c.encoderPos = c.curPos
|
|
c.encoderCached = true
|
|
|
|
if c.config.PermutedV {
|
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
|
}
|
|
|
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
|
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
|
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
|
}
|
|
|
|
ctx.Forward(
|
|
key.Copy(ctx, c.keys[c.curLayer]),
|
|
value.Copy(ctx, c.values[c.curLayer]),
|
|
)
|
|
}
|
|
|
|
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|
panic("encoder cache does not support multiple sequences")
|
|
}
|
|
|
|
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
|
c.encoderCached = false
|
|
}
|
|
|
|
return nil
|
|
}
|