mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 00:07:07 +00:00
partial offloading
This commit is contained in:
77
llm/ggml.go
77
llm/ggml.go
@@ -324,45 +324,52 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch int) (int64, bool) {
|
||||
embeddingLength := llm.KV().EmbeddingLength()
|
||||
headCount := llm.KV().HeadCount()
|
||||
headCountKV := llm.KV().HeadCountKV()
|
||||
vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
var attnQKVWeight1 uint64 = 0
|
||||
if t, ok := layers["0"]["attn_qkv.weight"]; ok && len(t.Shape) > 2 {
|
||||
attnQKVWeight1 = t.Shape[1]
|
||||
}
|
||||
|
||||
var ffnGate0Weight1 uint64 = 0
|
||||
if t, ok := layers["0"]["ffn_gate.0.weight"]; ok && len(t.Shape) > 2 {
|
||||
ffnGate0Weight1 = t.Shape[1]
|
||||
}
|
||||
func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "gemma", "command-r":
|
||||
return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
|
||||
case "phi2":
|
||||
return max(
|
||||
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
|
||||
4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
|
||||
), true
|
||||
case "qwen2":
|
||||
return max(
|
||||
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
|
||||
4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
|
||||
), true
|
||||
case "llama":
|
||||
if ffnGate0Weight1 > 0 {
|
||||
// moe
|
||||
return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate0Weight1), true
|
||||
}
|
||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||
|
||||
return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+4*embedding+context*(1+heads)),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+ 4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+2*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
||||
)
|
||||
case "phi2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+4*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
|
||||
}
|
||||
|
||||
return 0, false
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user