add fixes for llama

This commit is contained in:
Patrick Devine
2024-05-08 16:07:46 -07:00
committed by Michael Yang
parent c8cf0d94ed
commit d355d2020f
5 changed files with 55 additions and 24 deletions

View File

@@ -23,12 +23,24 @@ type LlamaModel struct {
}
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
data := r.storage.(*pytorch.HalfStorage).Data
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
var tData []uint16
switch r.storage.(type) {
case *pytorch.HalfStorage:
data := r.storage.(*pytorch.HalfStorage).Data
tData = make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
case *pytorch.BFloat16Storage:
data := r.storage.(*pytorch.BFloat16Storage).Data
tData = make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
default:
return fmt.Errorf("unknown storage type for torch")
}
var err error
@@ -44,8 +56,6 @@ func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
return fmt.Errorf("unknown layer type")
}
slog.Debug(fmt.Sprintf("heads = %d", heads))
tData, err = llamaRepack(tData, int(heads), r.t.Shape)
if err != nil {
return err
@@ -106,7 +116,6 @@ func (m *LlamaModel) GetTensors() error {
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
switch m.Format.(type) {
case *TorchFormat:
wt := l.WriterTo.(torchWriterTo)
@@ -182,10 +191,8 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
//"general.file_type": uint32(1),
"general.file_type": uint32(2),
//"tokenizer.ggml.model": "llama",
"tokenizer.ggml.model": "gpt2",
"general.file_type": uint32(2),
"tokenizer.ggml.model": "gpt2",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.token_type": m.Vocab.Types,
@@ -193,8 +200,6 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
//"tokenizer.ggml.add_bos_token": true,
//"tokenizer.ggml.add_eos_token": false,
}
if len(m.Vocab.Merges) > 0 {