llm: introduce k/v context quantization (vRAM improvements) (#6279)

This commit is contained in:
Sam
2024-12-04 10:57:19 +11:00
committed by GitHub
parent 2b82c5a8a1
commit 1bdab9fdb1
10 changed files with 147 additions and 21 deletions

View File

@@ -360,7 +360,7 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil
}
func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffload uint64) {
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
embedding := llm.KV().EmbeddingLength()
heads := llm.KV().HeadCount()
headsKV := llm.KV().HeadCountKV()
@@ -372,7 +372,8 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
layers := llm.Tensors().Layers()
kv = 2 * context * llm.KV().BlockCount() * (embeddingHeadsK + embeddingHeadsV) * headsKV
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
switch llm.KV().Architecture() {
case "llama":
@@ -527,3 +528,34 @@ func (llm GGML) GraphSize(context, batch uint64) (kv, partialOffload, fullOffloa
return
}
// SupportsKVCacheType checks if the requested cache type is supported
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
return slices.Contains(validKVCacheTypes, cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
func (ggml GGML) SupportsFlashAttention() bool {
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
if isEmbedding {
return false
}
// Check head counts match and are non-zero
headCountK := ggml.KV().EmbeddingHeadCountK()
headCountV := ggml.KV().EmbeddingHeadCountV()
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType {
case "q8_0":
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
default:
return 2 // f16 (default)
}
}