mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-17 03:07:00 +00:00
Also, require the -as flag to be set when importing a model. This prevents the confusing error message "invalid name". Also, allow short names to be used when importing a model and auto-complete the name with the default mask.
225 lines
5.2 KiB
Go
225 lines
5.2 KiB
Go
// safetensors provides a reader for the safetensor directories and files.
|
|
package safetensors
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"iter"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// Tensor represents a single tensor in a safetensors file.
|
|
//
|
|
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
|
|
//
|
|
// It is not safe for use across multiple goroutines.
|
|
type Tensor struct {
|
|
name string
|
|
dataType string
|
|
shape []int64
|
|
|
|
fsys fs.FS
|
|
fname string // entry name in fsys
|
|
offset int64
|
|
size int64
|
|
}
|
|
|
|
type Model struct {
|
|
fsys fs.FS
|
|
}
|
|
|
|
func Read(fsys fs.FS) (*Model, error) {
|
|
return &Model{fsys: fsys}, nil
|
|
}
|
|
|
|
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
|
|
return func(yield func(*Tensor, error) bool) {
|
|
entries, err := fs.Glob(m.fsys, "*.safetensors")
|
|
if err != nil {
|
|
yield(nil, err)
|
|
return
|
|
}
|
|
for _, e := range entries {
|
|
tt, err := m.readTensors(e)
|
|
if err != nil {
|
|
yield(nil, err)
|
|
return
|
|
}
|
|
for _, t := range tt {
|
|
if !yield(t, nil) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
|
f, err := m.fsys.Open(fname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
|
|
finfo, err := f.Stat()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
headerSize, err := readInt64(f)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
data := make([]byte, headerSize)
|
|
_, err = io.ReadFull(f, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var raws map[string]json.RawMessage
|
|
if err := json.Unmarshal(data, &raws); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
|
|
|
// TODO(bmizerany): do something with metadata? This could be another
|
|
// header read if needed. We also need to figure out if the metadata is
|
|
// present in only one .safetensors file or if each file may have their
|
|
// own and if it needs to follow each tensor. Currently, I (bmizerany)
|
|
// am only seeing them show up with one entry for file type which is
|
|
// always "pt".
|
|
|
|
tt := make([]*Tensor, 0, len(raws))
|
|
for name, raw := range raws {
|
|
if name == "__metadata__" {
|
|
// TODO(bmizerany): do something with metadata?
|
|
continue
|
|
}
|
|
var v struct {
|
|
DataType string `json:"dtype"`
|
|
Shape []int64 `json:"shape"`
|
|
Offsets []int64 `json:"data_offsets"`
|
|
}
|
|
if err := json.Unmarshal(raw, &v); err != nil {
|
|
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
|
|
}
|
|
if len(v.Offsets) != 2 {
|
|
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
|
|
}
|
|
|
|
// TODO(bmizerany): after collecting, validate all offests make
|
|
// tensors contiguous?
|
|
begin := endOfHeader + v.Offsets[0]
|
|
end := endOfHeader + v.Offsets[1]
|
|
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
|
|
tt = append(tt, &Tensor{
|
|
name: name,
|
|
dataType: v.DataType,
|
|
shape: v.Shape,
|
|
fsys: m.fsys,
|
|
fname: fname,
|
|
offset: begin,
|
|
size: end - begin,
|
|
})
|
|
}
|
|
return tt, nil
|
|
}
|
|
|
|
func checkBeginEnd(size, begin, end int64) error {
|
|
if begin < 0 {
|
|
return fmt.Errorf("begin must not be negative: %d", begin)
|
|
}
|
|
if end < 0 {
|
|
return fmt.Errorf("end must not be negative: %d", end)
|
|
}
|
|
if end < begin {
|
|
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
|
|
}
|
|
if end > size {
|
|
return fmt.Errorf("end must be <= size: %d > %d", end, size)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readInt64(r io.Reader) (int64, error) {
|
|
var v uint64
|
|
var buf [8]byte
|
|
if _, err := io.ReadFull(r, buf[:]); err != nil {
|
|
return 0, err
|
|
}
|
|
for i := range buf {
|
|
v |= uint64(buf[i]) << (8 * i)
|
|
}
|
|
return int64(v), nil
|
|
}
|
|
|
|
type Shape []int64
|
|
|
|
func (s Shape) String() string {
|
|
var b strings.Builder
|
|
b.WriteByte('[')
|
|
for i, v := range s {
|
|
if i > 0 {
|
|
b.WriteByte(',')
|
|
}
|
|
b.WriteString(strconv.FormatInt(v, 10))
|
|
}
|
|
b.WriteByte(']')
|
|
return b.String()
|
|
}
|
|
|
|
func (t *Tensor) Name() string { return t.name }
|
|
func (t *Tensor) DataType() string { return t.dataType }
|
|
func (t *Tensor) Size() int64 { return t.size }
|
|
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
|
|
|
|
func (t *Tensor) Reader() (io.ReadCloser, error) {
|
|
f, err := t.fsys.Open(t.fname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
r := newSectionReader(f, t.offset, t.size)
|
|
rc := struct {
|
|
io.Reader
|
|
io.Closer
|
|
}{r, f}
|
|
return rc, nil
|
|
}
|
|
|
|
// newSectionReader returns a new io.Reader that reads from r starting at
|
|
// offset. It is a convenience function for creating a io.SectionReader when r
|
|
// may not be an io.ReaderAt.
|
|
//
|
|
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
|
|
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
|
|
// beginning of the file.
|
|
//
|
|
// If r is an io.Seeker,
|
|
// or slow path. The slow path is used when r does not implement io.ReaderAt,
|
|
// in which case it must discard the data it reads.
|
|
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
|
|
if r, ok := r.(io.ReaderAt); ok {
|
|
return io.NewSectionReader(r, offset, n)
|
|
}
|
|
if r, ok := r.(io.ReadSeeker); ok {
|
|
r.Seek(offset, io.SeekStart)
|
|
return io.LimitReader(r, n)
|
|
}
|
|
// Discard to offset and return a limited reader.
|
|
_, err := io.CopyN(io.Discard, r, offset)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return io.LimitReader(r, n)
|
|
}
|