This commit is contained in:
Michael Yang
2024-05-15 14:55:57 -07:00
parent 547132e820
commit bbbd9f20f3
9 changed files with 35 additions and 51 deletions

View File

@@ -34,18 +34,13 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
slog.Debug("getting torch tensors")
var files []string
var err error
files, err = filepath.Glob(filepath.Join(dirpath, "consolidated.*.pth"))
if err != nil {
files, err = filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
if err != nil {
slog.Error("didn't find any torch files")
return nil, err
}
if pt, _ := filepath.Glob(filepath.Join(dirpath, "consolidated*.pth")); len(pt) > 0 {
files = append(files, pt...)
} else if pt, _ := filepath.Glob(filepath.Join(dirpath, "pytorch_model*.pth")); len(pt) > 0 {
files = append(files, pt...)
}
var offset uint64
var tensors []llm.Tensor
for _, fn := range files {
m, err := pytorch.Load(fn)