kvcache: create cache ctx per layer

each cache layer creates and maintains its own context instead of using
a large context for all layers
This commit is contained in:
Michael Yang
2025-02-25 12:57:49 -08:00
parent bfce55db3d
commit 764e199d67
4 changed files with 68 additions and 46 deletions

View File

@@ -55,8 +55,8 @@ type Causal struct {
shiftFn shiftFn
backend ml.Backend
cacheCtx ml.Context
keys, values []ml.Tensor
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
}
type cacheCell struct {
@@ -70,11 +70,23 @@ type cellRange struct {
}
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
return &Causal{
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{windowSize: windowSize, shiftFn: shift}
return &Causal{
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
@@ -103,7 +115,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.cacheCtx = backend.NewContext()
}
func (c *Causal) SetConfig(config ml.CacheConfig) {
@@ -115,7 +126,9 @@ func (c *Causal) SetConfig(config ml.CacheConfig) {
}
func (c *Causal) Close() {
c.cacheCtx.Close()
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
@@ -239,13 +252,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i := range c.keys {
if c.keys[i] == nil {
for i, key := range c.keys {
if key == nil {
continue
}
key := c.keys[i]
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
@@ -305,7 +316,7 @@ func (c *Causal) defrag() {
layers++
}
maxMoves := ctx.MaxTensors() / (6 * layers)
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -377,11 +388,6 @@ func (c *Causal) defrag() {
}
func (c *Causal) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer
}
@@ -433,13 +439,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContext()
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}

View File

@@ -35,13 +35,17 @@ type EncoderCache struct {
encoderPos int32
// ** cache data storage **
cacheCtx ml.Context
keys, values []ml.Tensor
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
}
func NewEncoderCache() *EncoderCache {
return &EncoderCache{}
return &EncoderCache{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
@@ -57,7 +61,7 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.cacheCtx = backend.NewContext()
c.backend = backend
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
@@ -69,7 +73,9 @@ func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
}
func (c *EncoderCache) Close() {
c.cacheCtx.Close()
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
@@ -80,11 +86,6 @@ func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []in
}
func (c *EncoderCache) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer
}
@@ -104,9 +105,16 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContext()
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
}
if _, ok := c.values[c.curLayer]; !ok {
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
}
ctx.Forward(