server/internal: copy bmizerany/ollama-go to internal package (#9294)

This commit copies (without history) the bmizerany/ollama-go repository
with the intention of integrating it into the ollama as a replacement
for the pushing, and pulling of models, and management of the cache they
are pushed and pulled from.

New homes for these packages will be determined as they are integrated
and we have a better understanding of proper package boundaries.
This commit is contained in:
Blake Mizerany
2025-02-24 22:39:44 -08:00
committed by GitHub
parent 0b7e1676eb
commit 348b3e0983
29 changed files with 4974 additions and 6 deletions

View File

@@ -0,0 +1,220 @@
// safetensors provides a reader for the safetensor directories and files.
package safetensors
import (
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"slices"
"strconv"
"strings"
)
// Tensor represents a single tensor in a safetensors file.
//
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
//
// It is not safe for use across multiple goroutines.
type Tensor struct {
name string
dataType string
shape []int64
fsys fs.FS
fname string // entry name in fsys
offset int64
size int64
}
type Model struct {
fsys fs.FS
}
func Read(fsys fs.FS) (*Model, error) {
return &Model{fsys: fsys}, nil
}
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
return func(yield func(*Tensor, error) bool) {
entries, err := fs.Glob(m.fsys, "*.safetensors")
if err != nil {
yield(nil, err)
return
}
for _, e := range entries {
tt, err := m.readTensors(e)
if err != nil {
yield(nil, err)
return
}
for _, t := range tt {
if !yield(t, nil) {
return
}
}
}
}
}
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
f, err := m.fsys.Open(fname)
if err != nil {
return nil, err
}
defer f.Close()
finfo, err := f.Stat()
if err != nil {
return nil, err
}
headerSize, err := readInt64(f)
if err != nil {
return nil, err
}
data := make([]byte, headerSize)
_, err = io.ReadFull(f, data)
if err != nil {
return nil, err
}
var raws map[string]json.RawMessage
if err := json.Unmarshal(data, &raws); err != nil {
return nil, err
}
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
// own and if it needs to follow each tensor. Currently, I (bmizerany)
// am only seeing them show up with one entry for file type which is
// always "pt".
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if !strings.HasPrefix(name, "model.layer") {
continue
}
var v struct {
DataType string `json:"dtype"`
Shape []int64 `json:"shape"`
Offsets []int64 `json:"data_offsets"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
}
if len(v.Offsets) != 2 {
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
}
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin, end := v.Offsets[0], v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
tt = append(tt, &Tensor{
name: name,
dataType: v.DataType,
shape: v.Shape,
fsys: m.fsys,
fname: fname,
offset: begin,
size: end - begin,
})
}
return tt, nil
}
func checkBeginEnd(size, begin, end int64) error {
if begin < 0 {
return fmt.Errorf("begin must not be negative: %d", begin)
}
if end < 0 {
return fmt.Errorf("end must not be negative: %d", end)
}
if end < begin {
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
}
if end > size {
return fmt.Errorf("end must be <= size: %d > %d", end, size)
}
return nil
}
func readInt64(r io.Reader) (int64, error) {
var v uint64
var buf [8]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return 0, err
}
for i := range buf {
v |= uint64(buf[i]) << (8 * i)
}
return int64(v), nil
}
type Shape []int64
func (s Shape) String() string {
var b strings.Builder
b.WriteByte('[')
for i, v := range s {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(strconv.FormatInt(v, 10))
}
b.WriteByte(']')
return b.String()
}
func (t *Tensor) Name() string { return t.name }
func (t *Tensor) DataType() string { return t.dataType }
func (t *Tensor) Size() int64 { return t.size }
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
func (t *Tensor) Reader() (io.ReadCloser, error) {
f, err := t.fsys.Open(t.fname)
if err != nil {
return nil, err
}
r := newSectionReader(f, t.offset, t.size)
rc := struct {
io.Reader
io.Closer
}{r, f}
return rc, nil
}
// newSectionReader returns a new io.Reader that reads from r starting at
// offset. It is a convenience function for creating a io.SectionReader when r
// may not be an io.ReaderAt.
//
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
// beginning of the file.
//
// If r is an io.Seeker,
// or slow path. The slow path is used when r does not implement io.ReaderAt,
// in which case it must discard the data it reads.
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
if r, ok := r.(io.ReaderAt); ok {
return io.NewSectionReader(r, offset, n)
}
if r, ok := r.(io.ReadSeeker); ok {
r.Seek(offset, io.SeekStart)
return io.LimitReader(r, n)
}
// Discard to offset and return a limited reader.
_, err := io.CopyN(io.Discard, r, offset)
if err != nil {
return nil
}
return io.LimitReader(r, n)
}

View File

@@ -0,0 +1,366 @@
package main
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"mime"
"net/http"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
"golang.org/x/sync/errgroup"
)
var stdout io.Writer = os.Stdout
const usage = `Opp is a tool for pushing and pulling Ollama models.
Usage:
opp [flags] <push|pull|import>
Commands:
push Upload a model to the Ollama server.
pull Download a model from the Ollama server.
import Import a model from a local safetensor directory.
Examples:
# Pull a model from the Ollama server.
opp pull library/llama3.2:latest
# Push a model to the Ollama server.
opp push username/my_model:8b
# Import a model from a local safetensor directory.
opp import /path/to/safetensor
Envionment Variables:
OLLAMA_MODELS
The directory where models are pushed and pulled from
(default ~/.ollama/models).
`
func main() {
flag.Usage = func() {
fmt.Fprint(os.Stderr, usage)
}
flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.RegistryFromEnv()
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
err = func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
return cmdPull(ctx, rc, c)
case "push":
return cmdPush(ctx, rc, c)
case "import":
return cmdImport(ctx, c)
default:
if cmd == "" {
flag.Usage()
} else {
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
}
os.Exit(2)
return errors.New("unreachable")
}
}()
if err != nil {
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
os.Exit(1)
}
}
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
os.Exit(1)
}
tr := http.DefaultTransport.(*http.Transport).Clone()
// TODO(bmizerany): configure transport?
rc.HTTPClient = &http.Client{Transport: tr}
var mu sync.Mutex
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
var pb bytes.Buffer
printProgress := func() {
pb.Reset()
mu.Lock()
for d, s := range p {
// Write progress to a buffer first to avoid blocking
// on stdout while holding the lock.
stamp := time.Now().Format("2006/01/02 15:04:05")
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
if s[0] == s[1] {
delete(p, d)
}
}
mu.Unlock()
io.Copy(stdout, &pb)
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if err != nil && !errors.Is(err, ollama.ErrCached) {
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
return
}
mu.Lock()
p[l.Digest] = [2]int64{l.Size, n}
mu.Unlock()
},
})
errc := make(chan error)
go func() {
errc <- rc.Pull(ctx, c, model)
}()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-t.C:
printProgress()
case err := <-errc:
printProgress()
return err
}
}
}
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
flag.PrintDefaults()
}
flag.Parse(args)
model := flag.Arg(0)
if model == "" {
return fmt.Errorf("missing model argument")
}
from := cmp.Or(*flagFrom, model)
m, err := ollama.ResolveLocal(c, from)
if err != nil {
return err
}
ctx = ollama.WithTrace(ctx, &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
switch {
case errors.Is(err, ollama.ErrCached):
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
case err != nil:
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
case n == 0:
l := m.Layer(l.Digest)
mt, p, _ := mime.ParseMediaType(l.MediaType)
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
switch mt {
case "tensor":
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
default:
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
}
}
},
})
return rc.Push(ctx, c, model, &ollama.PushParams{
From: from,
})
}
type trackingReader struct {
io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.n.Add(int64(n))
return n, err
}
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("import", flag.ExitOnError)
flagAs := flag.String("as", "", "Import using the provided name.")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
flag.PrintDefaults()
}
flag.Parse(args)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
m, err := safetensors.Read(os.DirFS(dir))
if err != nil {
return err
}
var total int64
var tt []*safetensors.Tensor
for t, err := range m.Tensors() {
if err != nil {
return err
}
tt = append(tt, t)
total += t.Size()
}
var n atomic.Int64
done := make(chan error)
go func() {
layers := make([]*ollama.Layer, len(tt))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
var ctxErr error
for i, t := range tt {
if ctx.Err() != nil {
// The context may cancel AFTER we exit the
// loop, and so if we use ctx.Err() after the
// loop we may report it as the error that
// broke the loop, when it was not. This can
// manifest as a false-negative, leading the
// user to think their import failed when it
// did not, so capture it if and only if we
// exit the loop because of a ctx.Err() and
// report it.
ctxErr = ctx.Err()
break
}
g.Go(func() (err error) {
rc, err := t.Reader()
if err != nil {
return err
}
defer rc.Close()
tr := &trackingReader{rc, &n}
d, err := c.Import(tr, t.Size())
if err != nil {
return err
}
if err := rc.Close(); err != nil {
return err
}
layers[i] = &ollama.Layer{
Digest: d,
Size: t.Size(),
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
"name": t.Name(),
"dtype": t.DataType(),
"shape": t.Shape().String(),
}),
}
return nil
})
}
done <- func() error {
if err := errors.Join(g.Wait(), ctxErr); err != nil {
return err
}
m := &ollama.Manifest{Layers: layers}
data, err := json.MarshalIndent(m, "", " ")
if err != nil {
return err
}
d := blob.DigestFromBytes(data)
err = blob.PutBytes(c, d, data)
if err != nil {
return err
}
return c.Link(*flagAs, d)
}()
}()
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
csiHideCursor(stdout)
defer csiShowCursor(stdout)
csiSavePos(stdout)
writeProgress := func() {
csiRestorePos(stdout)
nn := n.Load()
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
formatNatural(nn),
formatNatural(total),
nn*100/total,
ansiClearToEndOfLine,
)
}
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
writeProgress()
case err := <-done:
writeProgress()
return err
}
}
}
func formatNatural(n int64) string {
switch {
case n < 1024:
return fmt.Sprintf("%d B", n)
case n < 1024*1024:
return fmt.Sprintf("%.1f KB", float64(n)/1024)
case n < 1024*1024*1024:
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
default:
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
}
}
const ansiClearToEndOfLine = "\033[K"
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }

View File

@@ -0,0 +1,11 @@
package main
import (
"fmt"
"os"
)
func main() {
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
os.Exit(1)
}

View File

@@ -0,0 +1,107 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/chunks"
"golang.org/x/sync/errgroup"
)
func BenchmarkDownload(b *testing.B) {
run := func(fileSize, chunkSize int64) {
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
}
run(100<<20, 8<<20)
run(100<<20, 16<<20)
run(100<<20, 32<<20)
run(100<<20, 64<<20)
run(100<<20, 128<<20) // 1 chunk
}
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := c.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
return err
}
var sleepTime atomic.Int64
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
client := &http.Client{
Transport: func() http.RoundTripper {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DisableKeepAlives = true
return tr
}(),
}
defer client.CloseIdleConnections()
// warm up the client
run(context.Background(), client, chunks.New(0, 1<<20))
b.SetBytes(fileSize)
b.ReportAllocs()
// Give our CDN a min to breathe between benchmarks.
time.Sleep(time.Duration(sleepTime.Swap(3)))
for b.Loop() {
g, ctx := errgroup.WithContext(b.Context())
g.SetLimit(runtime.GOMAXPROCS(0))
for chunk := range chunks.Of(fileSize, chunkSize) {
g.Go(func() error { return run(ctx, client, chunk) })
}
if err := g.Wait(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkWrite(b *testing.B) {
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
}
func benchmarkWrite(b *testing.B, chunkSize int) {
b.ReportAllocs()
dir := b.TempDir()
f, err := os.Create(filepath.Join(dir, "write-single"))
if err != nil {
b.Fatal(err)
}
defer f.Close()
data := make([]byte, chunkSize)
b.SetBytes(int64(chunkSize))
r := bytes.NewReader(data)
for b.Loop() {
r.Reset(data)
_, err := io.Copy(f, r)
if err != nil {
b.Fatal(err)
}
}
}