fix falcon decode

get model and file type from bin file
This commit is contained in:
Michael Yang
2023-09-12 10:01:20 -07:00
parent f221637053
commit 7dee25a07f
5 changed files with 123 additions and 158 deletions

View File

@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"log"
"path"
"sync"
)
@@ -87,38 +86,37 @@ func (llm *ggufModel) NumKV() uint64 {
return llm.V2.NumKV
}
func (llm *ggufModel) ModelFamily() ModelFamily {
func (llm *ggufModel) ModelFamily() string {
t, ok := llm.kv["general.architecture"].(string)
if ok {
return ModelFamily(t)
return t
}
log.Printf("unknown model family: %T", t)
return ModelFamilyUnknown
return "unknown"
}
func (llm *ggufModel) ModelType() ModelType {
func (llm *ggufModel) ModelType() string {
switch llm.ModelFamily() {
case ModelFamilyLlama:
blocks, ok := llm.kv["llama.block_count"].(uint32)
if ok {
return ModelType(blocks)
case "llama":
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
return llamaModelType(blocks)
}
case "falcon":
if blocks, ok := llm.kv["falcon.block_count"].(uint32); ok {
return falconModelType(blocks)
}
}
return ModelType7B
return "Unknown"
}
func (llm *ggufModel) FileType() FileType {
switch llm.ModelFamily() {
case ModelFamilyLlama:
t, ok := llm.kv["general.file_type"].(uint32)
if ok {
return llamaFileType(t)
}
func (llm *ggufModel) FileType() string {
t, ok := llm.kv["general.file_type"].(uint32)
if ok {
return fileType(t)
}
return llamaFileTypeF16
return "Unknown"
}
func (llm *ggufModel) Decode(r io.Reader) error {