refactor readseeker

This commit is contained in:
Michael Yang
2024-03-09 12:28:36 -08:00
parent f878e91070
commit 0085297928
3 changed files with 72 additions and 70 deletions

View File

@@ -42,18 +42,18 @@ func (c *ContainerGGUF) Name() string {
return "gguf"
}
func (c *ContainerGGUF) Decode(rso *readSeekOffset) (model, error) {
binary.Read(rso, c.ByteOrder, &c.Version)
func (c *ContainerGGUF) Decode(rs io.ReadSeeker) (model, error) {
binary.Read(rs, c.ByteOrder, &c.Version)
switch c.Version {
case 1:
binary.Read(rso, c.ByteOrder, &c.V1)
binary.Read(rs, c.ByteOrder, &c.V1)
default:
binary.Read(rso, c.ByteOrder, &c.V2)
binary.Read(rs, c.ByteOrder, &c.V2)
}
model := NewGGUFModel(c)
if err := model.Decode(rso); err != nil {
if err := model.Decode(rs); err != nil {
return nil, err
}
@@ -633,49 +633,49 @@ func (llm *GGUFModel) writeString(f *os.File, s string) error {
return nil
}
func (llm *GGUFModel) Decode(rso *readSeekOffset) error {
func (llm *GGUFModel) Decode(rs io.ReadSeeker) error {
// decode key-values
for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := llm.readString(rso)
k, err := llm.readString(rs)
if err != nil {
return err
}
vtype := llm.readU32(rso)
vtype := llm.readU32(rs)
var v any
switch vtype {
case GGUFTypeUint8:
v = llm.readU8(rso)
v = llm.readU8(rs)
case GGUFTypeInt8:
v = llm.readI8(rso)
v = llm.readI8(rs)
case GGUFTypeUint16:
v = llm.readU16(rso)
v = llm.readU16(rs)
case GGUFTypeInt16:
v = llm.readI16(rso)
v = llm.readI16(rs)
case GGUFTypeUint32:
v = llm.readU32(rso)
v = llm.readU32(rs)
case GGUFTypeInt32:
v = llm.readI32(rso)
v = llm.readI32(rs)
case GGUFTypeUint64:
v = llm.readU64(rso)
v = llm.readU64(rs)
case GGUFTypeInt64:
v = llm.readI64(rso)
v = llm.readI64(rs)
case GGUFTypeFloat32:
v = llm.readF32(rso)
v = llm.readF32(rs)
case GGUFTypeFloat64:
v = llm.readF64(rso)
v = llm.readF64(rs)
case GGUFTypeBool:
v = llm.readBool(rso)
v = llm.readBool(rs)
case GGUFTypeString:
s, err := llm.readString(rso)
s, err := llm.readString(rs)
if err != nil {
return err
}
v = s
case GGUFTypeArray:
a, err := llm.readArray(rso)
a, err := llm.readArray(rs)
if err != nil {
return err
}
@@ -690,23 +690,23 @@ func (llm *GGUFModel) Decode(rso *readSeekOffset) error {
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
name, err := llm.readString(rso)
name, err := llm.readString(rs)
if err != nil {
return err
}
// dims is the number of dimensions in the tensor
dims := llm.readU32(rso)
dims := llm.readU32(rs)
shape := [4]uint64{1, 1, 1, 1}
for i := 0; uint32(i) < dims; i++ {
shape[i] = llm.readU64(rso)
shape[i] = llm.readU64(rs)
}
tensor := Tensor{
Name: name,
Kind: llm.readU32(rso),
Offset: llm.readU64(rso),
Kind: llm.readU32(rs),
Offset: llm.readU64(rs),
Shape: shape[:],
}
@@ -719,10 +719,20 @@ func (llm *GGUFModel) Decode(rso *readSeekOffset) error {
alignment = 32
}
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if _, err := rs.Seek(int64(alignment)-offset%int64(alignment), io.SeekCurrent); err != nil {
return err
}
for _, tensor := range llm.Tensors {
padded := (int64(tensor.Size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent)
if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
return err
}
}
return nil