refactor convert

This commit is contained in:
Michael Yang
2024-05-31 20:00:49 -07:00
parent 6b252918fb
commit 5e9db9fb0b
24 changed files with 1514 additions and 1494 deletions

46
convert/reader_torch.go Normal file
View File

@@ -0,0 +1,46 @@
package convert
import (
"io"
"github.com/nlpodyssey/gopickle/pytorch"
"github.com/nlpodyssey/gopickle/types"
)
func parseTorch(ps ...string) ([]Tensor, error) {
var ts []Tensor
for _, p := range ps {
pt, err := pytorch.Load(p)
if err != nil {
return nil, err
}
for _, k := range pt.(*types.Dict).Keys() {
t := pt.(*types.Dict).MustGet(k)
var shape []uint64
for dim := range t.(*pytorch.Tensor).Size {
shape = append(shape, uint64(dim))
}
ts = append(ts, torch{
storage: t.(*pytorch.Tensor).Source,
tensorBase: &tensorBase{
name: k.(string),
shape: shape,
},
})
}
}
return ts, nil
}
type torch struct {
storage pytorch.StorageInterface
*tensorBase
}
func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}