mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-12 00:37:04 +00:00
The current splitDim function only operates on tensors that are split evenly which isn't always the case, e.g. a QKV tensor. This change allows the function to be used for arbitrary splits
305 lines
6.3 KiB
Go
305 lines
6.3 KiB
Go
package convert
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"io"
|
|
"iter"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/pdevine/tensor"
|
|
)
|
|
|
|
type fakeTensor struct {
|
|
name string
|
|
shape []uint64
|
|
data []float32
|
|
|
|
repacker Repacker
|
|
}
|
|
|
|
func (f fakeTensor) Name() string {
|
|
return f.name
|
|
}
|
|
|
|
func (f fakeTensor) Shape() []uint64 {
|
|
return f.shape
|
|
}
|
|
|
|
func (f fakeTensor) Kind() uint32 {
|
|
return 0
|
|
}
|
|
|
|
func (f *fakeTensor) SetRepacker(fn Repacker) {
|
|
f.repacker = fn
|
|
}
|
|
|
|
func (f fakeTensor) Clone() Tensor {
|
|
return &fakeTensor{
|
|
name: f.name,
|
|
shape: slices.Clone(f.shape),
|
|
data: slices.Clone(f.data),
|
|
repacker: f.repacker,
|
|
}
|
|
}
|
|
|
|
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
|
|
data := f.data
|
|
if f.repacker != nil {
|
|
data, err = f.repacker(f.name, data, f.shape)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return int64(len(data) * 4), nil
|
|
}
|
|
|
|
func mul(shape []uint64) int {
|
|
n := 1
|
|
for _, dim := range shape {
|
|
n *= int(dim)
|
|
}
|
|
return n
|
|
}
|
|
|
|
func TestSplitDim(t *testing.T) {
|
|
r := fakeTensor{
|
|
name: "a.b",
|
|
shape: []uint64{3, 4},
|
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
|
}
|
|
|
|
t.Run("no split", func(t *testing.T) {
|
|
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
|
if tt.Name != "x.b" {
|
|
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
|
|
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
|
|
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("even split", func(t *testing.T) {
|
|
next, stop := iter.Pull(splitDim(&r, 1,
|
|
split{Replacer: strings.NewReplacer("a", "x")},
|
|
split{Replacer: strings.NewReplacer("b", "y")},
|
|
))
|
|
defer stop()
|
|
|
|
{
|
|
tt, ok := next()
|
|
if !ok {
|
|
t.Fatal("expected at least one split")
|
|
}
|
|
|
|
if tt.Name != "x.b" {
|
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
|
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
|
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
|
}
|
|
}
|
|
|
|
{
|
|
tt, ok := next()
|
|
if !ok {
|
|
t.Fatal("expected at least one split")
|
|
}
|
|
|
|
if tt.Name != "a.y" {
|
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
|
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
|
|
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("uneven split", func(t *testing.T) {
|
|
next, stop := iter.Pull(splitDim(&r, 0,
|
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
|
))
|
|
defer stop()
|
|
|
|
{
|
|
tt, ok := next()
|
|
if !ok {
|
|
t.Fatal("expected at least one split")
|
|
}
|
|
|
|
if tt.Name != "x.b" {
|
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
|
|
t.Fatal("expected shape [2, 4], got", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
|
|
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
|
|
}
|
|
}
|
|
|
|
{
|
|
tt, ok := next()
|
|
if !ok {
|
|
t.Fatal("expected at least one split")
|
|
}
|
|
|
|
if tt.Name != "a.y" {
|
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
|
|
t.Fatal("expected shape [1, 4], got", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
|
|
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("split with transpose", func(t *testing.T) {
|
|
next, stop := iter.Pull(splitDim(&r, 1,
|
|
split{Replacer: strings.NewReplacer("a", "x")},
|
|
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
|
return tensor.Transpose(tt, 1, 0)
|
|
}},
|
|
))
|
|
defer stop()
|
|
|
|
{
|
|
tt, ok := next()
|
|
if !ok {
|
|
t.Fatal("expected at least one split")
|
|
}
|
|
|
|
if tt.Name != "x.b" {
|
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
|
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
|
|
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
|
|
}
|
|
}
|
|
|
|
{
|
|
tt, ok := next()
|
|
if !ok {
|
|
t.Fatal("expected at least one split")
|
|
}
|
|
|
|
if tt.Name != "a.y" {
|
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
|
}
|
|
|
|
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
|
|
t.Fatal("expected shape [3, 2], got", tt.Shape)
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if _, err := tt.WriteTo(&b); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f32s := make([]float32, mul(tt.Shape))
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
|
|
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
|
|
}
|
|
}
|
|
})
|
|
}
|