ml: Enable support for flash attention

The GGML flash attention kernel has specific requirements for
padding and permutation. This adds support to the KV cache
for conforming to these requirements so that flash attention
can be enabled.

Flash attention can be used in the same situations as the llama
engine and is enabled by the user in the same way.
This commit is contained in:
Jesse Gross
2025-02-25 17:24:36 -08:00
committed by Jesse Gross
parent ee141cc821
commit 21aa666a1e
4 changed files with 73 additions and 21 deletions

View File

@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
})
type Backend struct {
flashAttention bool
meta *fs.GGML
cpus, gpus []Context
tensors map[string]*Context
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
return &Backend{
meta: meta,
cpus: cpus,
gpus: gpus,
flashAttention: params.FlashAttention,
meta: meta,
cpus: cpus,
gpus: gpus,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
@@ -248,7 +251,11 @@ func (b *Backend) NewContext() ml.Context {
}
func (b *Backend) CacheConfig() ml.CacheConfig {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
if b.flashAttention {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
}
}
type Context struct {
@@ -705,14 +712,22 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
if t.b.flashAttention {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (b *Backend) SystemInfo() string {