fix conversion for f16 or f32 inputs

This commit is contained in:
Michael Yang
2024-05-17 12:11:49 -07:00
parent bbbd9f20f3
commit 34d5ef29b3
7 changed files with 152 additions and 294 deletions

View File

@@ -24,8 +24,8 @@ type torchWriterTo struct {
params *Params
bo ByteOrder
storage pytorch.StorageInterface
handler func(w io.Writer, r torchWriterTo) error
storage pytorch.StorageInterface
repacker func(string, []float32, []uint64) ([]float32, error)
}
type TorchFormat struct{}
@@ -230,59 +230,38 @@ func (m *TorchFormat) GetLayerName(n string) (string, error) {
}
func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
// use the handler if one is present
if r.handler != nil {
return 0, r.handler(w, r)
}
switch storage := r.storage.(type) {
var f32s []float32
switch s := r.storage.(type) {
case *pytorch.FloatStorage:
slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name))
return 0, nil
f32s = s.Data
case *pytorch.HalfStorage:
switch r.t.Kind {
case 0:
data := r.storage.(*pytorch.HalfStorage).Data
slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data)))
if err := binary.Write(w, r.bo, data); err != nil {
return 0, err
}
case 1:
data := r.storage.(*pytorch.HalfStorage).Data
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData)))
if err := binary.Write(w, r.bo, tData); err != nil {
return 0, err
}
}
f32s = s.Data
case *pytorch.BFloat16Storage:
data := r.storage.(*pytorch.BFloat16Storage).Data
switch r.t.Kind {
case 0:
if err = binary.Write(w, r.bo, data); err != nil {
return 0, err
}
case 1:
tData := make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
if err = binary.Write(w, r.bo, tData); err != nil {
return 0, err
}
default:
return 0, fmt.Errorf("unknown storage kind: %d", r.t.Kind)
}
f32s = s.Data
default:
return 0, fmt.Errorf("unknown storage type: %T", storage)
return 0, fmt.Errorf("unknown data type: %T", s)
}
return 0, nil
if r.repacker != nil {
f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
if err != nil {
return 0, err
}
}
switch r.t.Kind {
case 0:
return 0, binary.Write(w, r.bo, f32s)
case 1:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, r.bo, f16s)
default:
return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
}
}
func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {