update convert test to check result data

This commit is contained in:
Michael Yang
2024-06-03 09:49:13 -07:00
parent c4c84b7a0d
commit 6b252918fb
8 changed files with 924 additions and 37 deletions

View File

@@ -1,29 +1,36 @@
//go:build slow
package convert
import (
"crypto/sha256"
"encoding/json"
"flag"
"fmt"
"io"
"log/slog"
"math"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/llm"
"golang.org/x/exp/maps"
)
func convertFull(t *testing.T, p string) (llm.KV, llm.Tensors) {
func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
t.Helper()
mf, err := GetModelFormat(p)
mf, err := GetModelFormat(d)
if err != nil {
t.Fatal(err)
}
params, err := mf.GetParams(p)
params, err := mf.GetParams(d)
if err != nil {
t.Fatal(err)
}
arch, err := mf.GetModelArch("", p, params)
arch, err := mf.GetModelArch("", d, params)
if err != nil {
t.Fatal(err)
}
@@ -50,53 +57,91 @@ func convertFull(t *testing.T, p string) (llm.KV, llm.Tensors) {
if err != nil {
t.Fatal(err)
}
defer r.Close()
t.Cleanup(func() { r.Close() })
m, _, err := llm.DecodeGGML(r)
m, _, err := llm.DecodeGGML(r, math.MaxInt)
if err != nil {
t.Fatal(err)
}
return m.KV(), m.Tensors()
if _, err := r.Seek(0, io.SeekStart); err != nil {
t.Fatal(err)
}
return r, m.KV(), m.Tensors()
}
func TestMain(m *testing.M) {
var level slog.Level
flag.TextVar(&level, "level", slog.LevelInfo, "log level")
flag.Parse()
slog.SetLogLoggerLevel(level)
os.Exit(m.Run())
}
func TestConvertFull(t *testing.T) {
cases := []struct {
path string
arch string
tensors int
layers int
}{
{"Meta-Llama-3-8B-Instruct", "llama", 291, 35},
{"Mistral-7B-Instruct-v0.2", "llama", 291, 35},
{"Mixtral-8x7B-Instruct-v0.1", "llama", 291, 35},
{"gemma-2b-it", "gemma", 164, 20},
cases := []string{
"Meta-Llama-3-8B-Instruct",
"Mistral-7B-Instruct-v0.2",
"Mixtral-8x7B-Instruct-v0.1",
"gemma-2b-it",
}
for _, tt := range cases {
t.Run(tt.path, func(t *testing.T) {
p := filepath.Join("testdata", tt.path)
if _, err := os.Stat(p); err != nil {
for i := range cases {
tt := cases[i]
t.Run(tt, func(t *testing.T) {
t.Parallel()
p := filepath.Join("testdata", tt)
if testing.Short() {
t.Skip("skipping in short mode")
} else if _, err := os.Stat(p); err != nil {
t.Skipf("%s not found", p)
}
kv, tensors := convertFull(t, p)
f, kv, tensors := convertFull(t, p)
actual := make(map[string]string)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {
actual[k] = fmt.Sprintf("%v", v)
} else {
bts, err := json.Marshal(s)
if err != nil {
t.Fatal(err)
}
if kv.Architecture() != tt.arch {
t.Fatalf("expected llama, got %s", kv.Architecture())
actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
}
}
if kv.FileType().String() != "F16" {
t.Fatalf("expected F16, got %s", kv.FileType())
for _, tensor := range tensors.Items {
sha256sum := sha256.New()
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
if _, err := io.Copy(sha256sum, sr); err != nil {
t.Fatal(err)
}
actual[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
}
if len(tensors) != tt.tensors {
t.Fatalf("expected %d tensors, got %d", tt.tensors, len(tensors))
expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
if err != nil {
t.Fatal(err)
}
layers := tensors.Layers()
if len(layers) != tt.layers {
t.Fatalf("expected %d layers, got %d", tt.layers, len(layers))
var expect map[string]string
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
t.Fatal(err)
}
keys := maps.Keys(expect)
slices.Sort(keys)
for _, k := range keys {
if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k)
} else if v != expect[k] {
t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
}
}
})
}