Move quantization to new backend (#10363)

* Move quantization logic to GGML via new backend

This moves the model aware logic to Go code and calls GGMLs quantization code for model creation.

* Remove "add model quantizations"

This is no longer needed now that quantization is implemented in Go+GGML code directly.
This commit is contained in:
Daniel Hiltgen
2025-05-06 11:20:48 -07:00
committed by GitHub
parent 95e744beeb
commit 424810450f
39 changed files with 1854 additions and 440 deletions

View File

@@ -36,12 +36,12 @@ func (kv KV) ParameterCount() uint64 {
return keyValue(kv, "general.parameter_count", uint64(0))
}
func (kv KV) FileType() fileType {
func (kv KV) FileType() FileType {
if t := kv.Uint("general.file_type"); t > 0 {
return fileType(t)
return FileType(t)
}
return fileTypeUnknown
return FileTypeUnknown
}
func (kv KV) BlockCount() uint64 {
@@ -226,7 +226,11 @@ func (t Tensor) block() (n int) {
}
func (t Tensor) blockSize() uint64 {
switch t.Kind {
return (TensorType)(t.Kind).BlockSize()
}
func (t TensorType) BlockSize() uint64 {
switch t {
case
0, // F32
1, // F16
@@ -252,73 +256,77 @@ func (t Tensor) blockSize() uint64 {
}
func (t Tensor) typeSize() uint64 {
blockSize := t.blockSize()
return TensorType(t.Kind).TypeSize()
}
switch t.Kind {
case 0: // FP32
func (t TensorType) TypeSize() uint64 {
blockSize := t.BlockSize()
switch t {
case TensorTypeF32:
return 4
case 1: // FP16
case TensorTypeF16:
return 2
case 2: // Q4_0
case TensorTypeQ4_0:
return 2 + blockSize/2
case 3: // Q4_1
case TensorTypeQ4_1:
return 2 + 2 + blockSize/2
case 6: // Q5_0
case TensorTypeQ5_0:
return 2 + 4 + blockSize/2
case 7: // Q5_1
case TensorTypeQ5_1:
return 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
case TensorTypeQ8_0:
return 2 + blockSize
case 9: // Q8_1
case TensorTypeQ8_1:
return 2 + 2 + blockSize
case 10: // Q2_K
case TensorTypeQ2_K:
return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
case TensorTypeQ3_K:
return blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
case TensorTypeQ4_K:
return 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
case TensorTypeQ5_K:
return 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
case TensorTypeQ6_K:
return blockSize/2 + blockSize/4 + blockSize/16 + 2
case 15: // Q8_K
case TensorTypeQ8_K:
return 4 + blockSize + 2*blockSize/16
case 16: // IQ2_XXS
case tensorTypeIQ2_XXS:
return 2 + 2*blockSize/8
case 17: // IQ2_XS
case tensorTypeIQ2_XS:
return 2 + 2*blockSize/8 + blockSize/32
case 18: // IQ3_XXS
case tensorTypeIQ3_XXS:
return 2 + blockSize/4 + blockSize/8
case 19: // IQ1_S
case tensorTypeIQ1_S:
return 2 + blockSize/8 + blockSize/16
case 20: // IQ4_NL
case tensorTypeIQ4_NL:
return 2 + blockSize/2
case 21: // IQ3_S
case tensorTypeIQ3_S:
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
case 22: // IQ2_S
case tensorTypeIQ2_S:
return 2 + blockSize/4 + blockSize/16
case 23: // IQ4_XS
case tensorTypeIQ4_XS:
return 2 + 2 + blockSize/2 + blockSize/64
case 24: // I8
case TensorTypeI8:
return 1
case 25: // I16
case TensorTypeI16:
return 2
case 26: // I32
case TensorTypeI32:
return 4
case 27: // I64
case TensorTypeI64:
return 8
case 28: // F64
case TensorTypeF64:
return 8
case 29: // IQ1_M
case tensorTypeIQ1_M:
return blockSize/8 + blockSize/16 + blockSize/32
case 30: // BF16
case TensorTypeBF16:
return 2
default:
return 0
}
}
func (t Tensor) parameters() uint64 {
func (t Tensor) Elements() uint64 {
var count uint64 = 1
for _, n := range t.Shape {
count *= n
@@ -327,11 +335,11 @@ func (t Tensor) parameters() uint64 {
}
func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
return t.Elements() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return fileType(t.Kind).String()
return TensorType(t.Kind).String()
}
type container interface {
@@ -480,7 +488,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
var ropeFreqsCount uint64
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
ropeFreqsCount = ropeFreqsWeights.parameters()
ropeFreqsCount = ropeFreqsWeights.Elements()
}
}