update graph size estimate

This commit is contained in:
Michael Yang
2024-04-02 11:15:14 -07:00
parent cd135317d2
commit 12e923e158
2 changed files with 52 additions and 4 deletions

View File

@@ -303,3 +303,50 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
model: model,
}, offset, nil
}
func (llm GGML) GraphSize(context, batch int) (int64, bool) {
embeddingLength := llm.KV().EmbeddingLength()
headCount := llm.KV().HeadCount()
headCountKV := llm.KV().HeadCountKV()
vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
var attnQKVWeight1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
attnQKVWeight1 = t.Shape[1]
break
}
}
var ffnGate1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
ffnGate1 = t.Shape[1]
break
}
}
switch llm.KV().Architecture() {
case "gemma":
return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
case "phi2":
return max(
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
), true
case "qwen2":
return max(
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
), true
case "llama":
if ffnGate1 > 0 {
// moe
return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
}
return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
}
return 0, false
}