gguf: fix write order (#11068)

* ggml: test write gguf order
* ggml: fix write tensor order
This commit is contained in:
Michael Yang
2025-06-16 10:42:32 -07:00
committed by GitHub
parent 502028968d
commit a6fbfc880c
2 changed files with 68 additions and 54 deletions

View File

@@ -527,23 +527,17 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err return err
} }
keys := slices.Collect(maps.Keys(kv)) for _, key := range slices.Sorted(maps.Keys(kv)) {
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil { if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err return err
} }
} }
slices.SortStableFunc(ts, func(a, b *Tensor) int { slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 { if i, j := a.block(), b.block(); i > 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j) return cmp.Compare(i, j)
} }
return cmp.Compare(a.Name, b.Name)
}) })
var s uint64 var s uint64

View File

@@ -2,15 +2,37 @@ package ggml
import ( import (
"bytes" "bytes"
"math/rand/v2"
"os" "os"
"slices" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
func TestWriteGGUF(t *testing.T) { func TestWriteGGUF(t *testing.T) {
w, err := os.CreateTemp(t.TempDir(), "*.bin") r := rand.New(rand.NewPCG(0, 0))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}
r.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -18,14 +40,7 @@ func TestWriteGGUF(t *testing.T) {
if err := WriteGGUF(w, KV{ if err := WriteGGUF(w, KV{
"general.alignment": uint32(16), "general.alignment": uint32(16),
}, []*Tensor{ }, ts); err != nil {
{Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -40,24 +55,29 @@ func TestWriteGGUF(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if diff := cmp.Diff(ff.KV(), KV{ if diff := cmp.Diff(KV{
"general.alignment": uint32(16), "general.alignment": uint32(16),
"general.parameter_count": uint64(36), "general.parameter_count": uint64(54),
}); diff != "" { }, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff) t.Errorf("Mismatch (-want +got):\n%s", diff)
} }
if diff := cmp.Diff(ff.Tensors(), Tensors{ if diff := cmp.Diff(Tensors{
Offset: 336, Offset: 608,
items: []*Tensor{ items: []*Tensor{
{Name: "test.0", Offset: 0, Shape: []uint64{2, 3}}, {Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "test.1", Offset: 32, Shape: []uint64{2, 3}}, {Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}}, {Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "test.3", Offset: 96, Shape: []uint64{2, 3}}, {Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "test.4", Offset: 128, Shape: []uint64{2, 3}}, {Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "test.5", Offset: 160, Shape: []uint64{2, 3}}, {Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
}, },
}, cmp.AllowUnexported(Tensors{})); diff != "" { }, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff) t.Errorf("Mismatch (-want +got):\n%s", diff)
} }
})
}
} }