refactor tensor query

This commit is contained in:
Michael Yang
2024-04-03 15:00:31 -07:00
parent c5c451ca3b
commit 8b2c10061c
4 changed files with 54 additions and 42 deletions

View File

@@ -13,16 +13,6 @@ type GGML struct {
model
}
func (ggml *GGML) LayerSize(prefix string) (n int64) {
for _, t := range ggml.Tensors() {
if strings.HasPrefix(t.Name, prefix) {
n += int64(t.size())
}
}
return
}
const (
fileTypeF32 uint32 = iota
fileTypeF16
@@ -101,7 +91,7 @@ func fileType(fileType uint32) string {
type model interface {
KV() KV
Tensors() []*Tensor
Tensors() Tensors
}
type KV map[string]any
@@ -167,6 +157,36 @@ func (kv KV) ContextLength() uint64 {
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
}
type Tensors []*Tensor
func (ts Tensors) Layers() map[string]Layer {
layers := make(map[string]Layer)
for _, t := range ts {
parts := strings.Split(t.Name, ".")
if parts[0] == "blk" {
parts = parts[1:]
}
if _, ok := layers[parts[0]]; !ok {
layers[parts[0]] = make(Layer)
}
layers[parts[0]][strings.Join(parts[1:], ".")] = t
}
return layers
}
type Layer map[string]*Tensor
func (l Layer) size() (size uint64) {
for _, t := range l {
size += t.size()
}
return size
}
type Tensor struct {
Name string `json:"name"`
Kind uint32 `json:"kind"`
@@ -310,20 +330,16 @@ func (llm GGML) GraphSize(context, batch int) (int64, bool) {
headCountKV := llm.KV().HeadCountKV()
vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
layers := llm.Tensors().Layers()
var attnQKVWeight1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
attnQKVWeight1 = t.Shape[1]
break
}
if t, ok := layers["0"]["attn_qkv.weight"]; ok && len(t.Shape) > 2 {
attnQKVWeight1 = t.Shape[1]
}
var ffnGate1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
ffnGate1 = t.Shape[1]
break
}
var ffnGate0Weight1 uint64 = 0
if t, ok := layers["0"]["ffn_gate.0.weight"]; ok && len(t.Shape) > 2 {
ffnGate0Weight1 = t.Shape[1]
}
switch llm.KV().Architecture() {
@@ -340,11 +356,11 @@ func (llm GGML) GraphSize(context, batch int) (int64, bool) {
4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
), true
case "llama":
if ffnGate1 > 0 {
if ffnGate0Weight1 > 0 {
// moe
return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate0Weight1), true
}
return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
}