mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
ml: Enable support for flash attention
The GGML flash attention kernel has specific requirements for padding and permutation. This adds support to the KV cache for conforming to these requirements so that flash attention can be enabled. Flash attention can be used in the same situations as the llama engine and is enabled by the user in the same way.
This commit is contained in:
@@ -46,6 +46,14 @@ type CacheConfig struct {
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
@@ -61,6 +69,9 @@ type BackendParams struct {
|
||||
|
||||
// TensorSplit is the fraction of the model to offload to each GPU
|
||||
TensorSplit []float32
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||
|
||||
@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
|
||||
})
|
||||
|
||||
type Backend struct {
|
||||
flashAttention bool
|
||||
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
flashAttention: params.FlashAttention,
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
@@ -248,7 +251,11 @@ func (b *Backend) NewContext() ml.Context {
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
||||
if b.flashAttention {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
@@ -705,14 +712,22 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
if t.b.flashAttention {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
|
||||
Reference in New Issue
Block a user