fix conversion for f16 or f32 inputs

This commit is contained in:
Michael Yang
2024-05-17 12:11:49 -07:00
parent bbbd9f20f3
commit 34d5ef29b3
7 changed files with 152 additions and 294 deletions

View File

@@ -27,9 +27,10 @@ type safetensorWriterTo struct {
bo ByteOrder
filename string
dtype string
start, end, padding uint64
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
repacker func(string, []float32, []uint64) ([]float32, error)
}
type tensorMetaData struct {
@@ -150,6 +151,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
params: params,
bo: params.ByteOrder,
filename: fn,
dtype: data.Type,
start: uint64(data.Offsets[0]),
end: uint64(data.Offsets[1]),
padding: 8 + jsonSize,
@@ -235,51 +237,54 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
return 0, err
}
// use the handler if one is present
if r.handler != nil {
return 0, r.handler(w, r, f)
}
remaining := r.end - r.start
bufSize := uint64(10240)
var finished bool
for {
data := make([]byte, min(bufSize, remaining))
b, err := io.ReadFull(f, data)
remaining -= uint64(b)
if err == io.EOF || remaining <= 0 {
finished = true
} else if err != nil {
var f32s []float32
switch r.dtype {
case "F32":
f32s = make([]float32, (r.end-r.start)/4)
if err = binary.Read(f, r.bo, f32s); err != nil {
return 0, err
}
case "F16":
bts := make([]uint16, (r.end-r.start)/2)
if err = binary.Read(f, r.bo, bts); err != nil {
return 0, err
}
// convert bfloat16 -> ieee float32
tDataF32 := bfloat16.DecodeFloat32(data)
switch r.t.Kind {
case 0:
if err := binary.Write(w, r.bo, tDataF32); err != nil {
return 0, err
}
case 1:
// convert float32 -> float16
tempBuf := make([]uint16, len(data)/2)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}
if err := binary.Write(w, r.bo, tempBuf); err != nil {
return 0, err
}
for _, b := range bts {
f32s = append(f32s, float16.Frombits(b).Float32())
}
if finished {
break
case "BF16":
bts := make([]byte, r.end-r.start)
if err = binary.Read(f, r.bo, bts); err != nil {
return 0, err
}
f32s = bfloat16.DecodeFloat32(bts)
default:
return 0, fmt.Errorf("unknown data type: %s", r.dtype)
}
if r.repacker != nil {
f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
if err != nil {
return 0, err
}
}
return 0, nil
switch r.t.Kind {
case 0:
return 0, binary.Write(w, r.bo, f32s)
case 1:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, r.bo, f16s)
default:
return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
}
}
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {