convert gemma2

This commit is contained in:
Michael Yang
2024-06-28 13:27:05 -07:00
parent beb49eef65
commit 3546bbd08c
13 changed files with 132 additions and 46 deletions

View File

@@ -74,8 +74,7 @@ func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
out := make([]llm.Tensor, 0, len(ts)+2)
for _, t := range ts {
name := p.tensorName(t.Name())
if strings.HasPrefix(name, "blk.0.") {
if strings.HasPrefix(t.Name(), "blk.0.") {
addRopeFactors.Do(func() {
out = append(out, llm.Tensor{
Name: "rope_factors_long.weight",
@@ -92,7 +91,7 @@ func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
}
out = append(out, llm.Tensor{
Name: name,
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
@@ -102,8 +101,8 @@ func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
return out
}
func (p *phi3) tensorName(n string) string {
return strings.NewReplacer(
func (p *phi3) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
@@ -114,7 +113,7 @@ func (p *phi3) tensorName(n string) string {
"mlp.down_proj", "ffn_down",
"mlp.gate_up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
).Replace(n)
}
}
type ropeFactor []float32