gguf: fix write order (#11068)

* ggml: test write gguf order
* ggml: fix write tensor order
This commit is contained in:
Michael Yang
2025-06-16 10:42:32 -07:00
committed by GitHub
parent 502028968d
commit a6fbfc880c
2 changed files with 68 additions and 54 deletions

View File

@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err return err
} }
keys := slices.Collect(maps.Keys(kv)) for _, key := range slices.Sorted(maps.Keys(kv)) {
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil { if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err return err
} }
} }
slices.SortStableFunc(ts, func(a, b *Tensor) int { slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 { if i, j := a.block(), b.block(); i > 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j) return cmp.Compare(i, j)
} }
return cmp.Compare(a.Name, b.Name)
}) })
var s uint64 var s uint64

View File

@@ -2,62 +2,82 @@ package ggml
import ( import (
"bytes" "bytes"
"math/rand/v2"
"os" "os"
"slices" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
func TestWriteGGUF(t *testing.T) { func TestWriteGGUF(t *testing.T) {
w, err := os.CreateTemp(t.TempDir(), "*.bin") r := rand.New(rand.NewPCG(0, 0))
if err != nil { for range 8 {
t.Fatal(err) t.Run("shuffle", func(t *testing.T) {
} t.Parallel()
defer w.Close()
if err := WriteGGUF(w, KV{ ts := []*Tensor{
"general.alignment": uint32(16), {Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
}, []*Tensor{ {Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, {Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, {Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, {Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, {Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, {Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))}, {Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}); err != nil { {Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
t.Fatal(err) }
}
r, err := os.Open(w.Name()) r.Shuffle(len(ts), func(i, j int) {
if err != nil { ts[i], ts[j] = ts[j], ts[i]
t.Fatal(err) })
}
defer r.Close()
ff, err := Decode(r, 0) w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer w.Close()
if diff := cmp.Diff(ff.KV(), KV{ if err := WriteGGUF(w, KV{
"general.alignment": uint32(16), "general.alignment": uint32(16),
"general.parameter_count": uint64(36), }, ts); err != nil {
}); diff != "" { t.Fatal(err)
t.Errorf("Mismatch (-want +got):\n%s", diff) }
}
if diff := cmp.Diff(ff.Tensors(), Tensors{ r, err := os.Open(w.Name())
Offset: 336, if err != nil {
items: []*Tensor{ t.Fatal(err)
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}}, }
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}}, defer r.Close()
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}}, ff, err := Decode(r, 0)
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}}, if err != nil {
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}}, t.Fatal(err)
}, }
}, cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff) if diff := cmp.Diff(KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 608,
items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
} }
} }