ml: Enable support for flash attention

The GGML flash attention kernel has specific requirements for
padding and permutation. This adds support to the KV cache
for conforming to these requirements so that flash attention
can be enabled.

Flash attention can be used in the same situations as the llama
engine and is enabled by the user in the same way.
This commit is contained in:
Jesse Gross
2025-02-25 17:24:36 -08:00
committed by Jesse Gross
parent ee141cc821
commit 21aa666a1e
4 changed files with 73 additions and 21 deletions

View File

@@ -90,6 +90,14 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
@@ -192,13 +200,14 @@ func roundUp(length, pad int) int {
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do mask padding, which is required for flash attention
// Align and pad the cache range as required by the backend
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*length)
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
@@ -209,7 +218,24 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
}
}
return ctx.FromFloatSlice(mask, length, c.curBatchSize)
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {