diff --git a/convert/convert_mixtral.go b/convert/convert_mixtral.go index 17580ff8..7d60146b 100644 --- a/convert/convert_mixtral.go +++ b/convert/convert_mixtral.go @@ -2,9 +2,6 @@ package convert import ( "fmt" - "io" - "slices" - "strings" "github.com/ollama/ollama/fs/ggml" ) @@ -30,65 +27,38 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV { } func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor { - oldnew := []string{ - "model.layers", "blk", - "w1", "ffn_gate_exps", - "w2", "ffn_down_exps", - "w3", "ffn_up_exps", - } - - for i := range p.NumLocalExperts { - oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".") - } - - // group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor - namer := strings.NewReplacer(oldnew...) - experts := make(map[string]experts) - - // merge experts into a single tensor while removing them from ts - ts = slices.DeleteFunc(ts, func(t Tensor) bool { - if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") { - return false - } - - name := namer.Replace(t.Name()) - experts[name] = append(experts[name], t) - return true - }) - - var out []*ggml.Tensor - for n, e := range experts { - // TODO(mxyng): sanity check experts - out = append(out, &ggml.Tensor{ - Name: n, - Kind: e[0].Kind(), - Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...), - WriterTo: e, + merges := make([]merge, 0, p.NumHiddenLayers*6) + for i := range p.NumHiddenLayers { + merges = append(merges, merge{ + fmt.Sprintf("blk.%d.*.w1.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + }, merge{ + fmt.Sprintf("blk.%d.*.w1.bias", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i), + }, merge{ + fmt.Sprintf("blk.%d.*.w2.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + }, merge{ + fmt.Sprintf("blk.%d.*.w2.bias", i), + fmt.Sprintf("blk.%d.ffn_up_exps.bias", i), + }, merge{ + fmt.Sprintf("blk.%d.*.w3.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + }, merge{ + fmt.Sprintf("blk.%d.*.w3.bias", i), + fmt.Sprintf("blk.%d.ffn_down_exps.bias", i), }) } + out, ts := mergeTensors(ts, merges...) return append(out, p.llamaModel.Tensors(ts)...) } func (p *mixtralModel) Replacements() []string { return append( p.llamaModel.Replacements(), + "model.layers", "blk", "block_sparse_moe.gate", "ffn_gate_inp", + "block_sparse_moe.experts.", ".", ) } - -type experts []Tensor - -func (e experts) WriteTo(w io.Writer) (int64, error) { - // TODO(mxyng): experts _should_ be numerically sorted by expert but this should check - for _, t := range e { - // the canonical merged experts tensor stacks all experts along a new, 0 axis, - // e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers - // this accomplishes the same thing by writing each expert tensor in sequence - if _, err := t.WriteTo(w); err != nil { - return 0, err - } - } - - return 0, nil -} diff --git a/convert/tensor.go b/convert/tensor.go index 9d6919e3..c9565ed4 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -2,7 +2,9 @@ package convert import ( "cmp" + "io" "iter" + "path" "slices" "strings" @@ -74,3 +76,54 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] { } } } + +type merge struct { + pattern, name string +} + +// mergeTensors merges tensors that match a given pattern into a single tensor. +func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) { + var matched []Tensor + for i := range merges { + matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool { + matched, _ := path.Match(merges[i].pattern, t.Name()) + return matched + }) + + if len(matched) > 0 { + out = append(out, &ggml.Tensor{ + Name: merges[i].name, + Kind: matched[0].Kind(), + Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...), + WriterTo: mergeGroup(matched), + }) + } + } + + return out, unmatched +} + +// slicesSplitFunc splits a slice into two slices based on a predicate function. +func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) { + for _, e := range s { + if fn(e) { + matched = append(matched, e) + } else { + unmatched = append(unmatched, e) + } + } + + return matched, unmatched +} + +type mergeGroup []Tensor + +func (g mergeGroup) WriteTo(w io.Writer) (int64, error) { + for _, t := range g { + if _, err := t.WriteTo(w); err != nil { + return 0, err + } + } + + return 0, nil +} diff --git a/convert/tensor_test.go b/convert/tensor_test.go index ea12d0f5..0b2db5ba 100644 --- a/convert/tensor_test.go +++ b/convert/tensor_test.go @@ -9,6 +9,8 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/fs/ggml" "github.com/pdevine/tensor" ) @@ -302,3 +304,99 @@ func TestSplitDim(t *testing.T) { } }) } + +func TestMerge(t *testing.T) { + unmatched := []Tensor{ + &fakeTensor{ + name: "a.0.b", + shape: []uint64{5, 2}, + data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, + }, + &fakeTensor{ + name: "a.1.b", + shape: []uint64{5, 2}, + data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29}, + }, + &fakeTensor{ + name: "c.0.d", + shape: []uint64{5, 2}, + data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39}, + }, + &fakeTensor{ + name: "c.1.d", + shape: []uint64{5, 2}, + data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49}, + }, + &fakeTensor{ + name: "e.0.f", + shape: []uint64{5, 2}, + data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59}, + }, + } + + checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) { + for i := range n { + got := matched[i] + if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" { + t.Errorf("unexpected (-want +got):\n%s", diff) + } + + var b bytes.Buffer + if _, err := got.WriteTo(&b); err != nil { + t.Fatal(err) + } + + f32s := make([]float32, 20) + if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { + t.Fatal(err) + } + + offset := 10 + (i * 20) + want := make([]float32, 20) + for j := range 20 { + want[j] = float32(offset + j) + } + + if diff := cmp.Diff(want, f32s); diff != "" { + t.Errorf("unexpected data (-want +got):\n%s", diff) + } + } + } + + t.Run("single merge", func(t *testing.T) { + matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}) + if len(unmatched) != 3 { + t.Error("expected 3 remaining tensors, got", len(unmatched)) + } + + if len(matched) != 1 { + t.Error("expected 1 merged tensor, got", len(matched)) + } + + checkMatched(t, 1, matched) + }) + + t.Run("multiple merges", func(t *testing.T) { + matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"}) + if len(unmatched) != 1 { + t.Error("expected 1 remaining tensors, got", len(unmatched)) + } + + if len(matched) != 2 { + t.Error("expected 2 merged tensor, got", len(matched)) + } + + checkMatched(t, 2, matched) + }) + + t.Run("no match", func(t *testing.T) { + matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"}) + if len(unmatched) != 5 { + t.Error("expected 5 remaining tensors, got", len(unmatched)) + } + + if len(matched) != 0 { + t.Error("expected no merged tensors, got", len(matched)) + } + }) +}