refactor readseeker

This commit is contained in:
Michael Yang
2024-03-09 12:28:36 -08:00
parent f878e91070
commit 0085297928
3 changed files with 72 additions and 70 deletions

View File

@@ -103,7 +103,7 @@ type model interface {
type container interface {
Name() string
Decode(*readSeekOffset) (model, error)
Decode(io.ReadSeeker) (model, error)
}
const (
@@ -122,11 +122,9 @@ const (
var ErrUnsupportedFormat = errors.New("unsupported model format")
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
ro := readSeekOffset{ReadSeeker: r}
func DecodeGGML(rs io.ReadSeeker) (*GGML, error) {
var magic uint32
if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil {
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, err
}
@@ -144,38 +142,22 @@ func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
return nil, errors.New("invalid file magic")
}
model, err := c.Decode(&ro)
model, err := c.Decode(rs)
if errors.Is(err, io.EOF) {
// noop
} else if err != nil {
return nil, err
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
// final model type
return &GGML{
container: c,
model: model,
Size: ro.offset,
Size: offset,
}, nil
}
type readSeekOffset struct {
io.ReadSeeker
offset int64
}
func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) {
offset, err := rso.ReadSeeker.Seek(offset, whence)
if err != nil {
return 0, err
}
rso.offset = offset
return offset, nil
}
func (rso *readSeekOffset) Read(p []byte) (int, error) {
n, err := rso.ReadSeeker.Read(p)
rso.offset += int64(n)
return n, err
}