convert gemma2

This commit is contained in:
Michael Yang
2024-06-28 13:27:05 -07:00
parent beb49eef65
commit 3546bbd08c
13 changed files with 132 additions and 46 deletions

View File

@@ -96,14 +96,13 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
func (p *llama) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor
for _, t := range ts {
name := p.tensorName(t.Name())
if strings.HasSuffix(name, "attn_q.weight") ||
strings.HasSuffix(name, "attn_k.weight") {
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
out = append(out, llm.Tensor{
Name: name,
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
@@ -113,8 +112,8 @@ func (p *llama) Tensors(ts []Tensor) []llm.Tensor {
return out
}
func (p *llama) tensorName(n string) string {
return strings.NewReplacer(
func (p *llama) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
@@ -128,9 +127,7 @@ func (p *llama) tensorName(n string) string {
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
// mixtral
"block_sparse_moe.gate", "ffn_gate_inp",
).Replace(n)
}
}
func (p *llama) repack(name string, data []float32, shape []uint64) ([]float32, error) {
@@ -140,9 +137,9 @@ func (p *llama) repack(name string, data []float32, shape []uint64) ([]float32,
}
var heads uint32
if strings.HasSuffix(name, "q_proj.weight") {
if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "k_proj.weight") {
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)