convert: only extract large files

This commit is contained in:
Michael Yang
2024-06-29 16:53:59 -07:00
parent 781fc2d576
commit eafc607abb
10 changed files with 120 additions and 200 deletions

View File

@@ -5,9 +5,8 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"path/filepath"
"github.com/ollama/ollama/llm"
)
@@ -67,8 +66,8 @@ type Converter interface {
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func Convert(path string, ws io.WriteSeeker) error {
bts, err := os.ReadFile(filepath.Join(path, "config.json"))
func Convert(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
}
@@ -98,7 +97,7 @@ func Convert(path string, ws io.WriteSeeker) error {
return err
}
t, err := parseTokenizer(path, conv.specialTokenTypes())
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
if err != nil {
return err
}
@@ -114,7 +113,7 @@ func Convert(path string, ws io.WriteSeeker) error {
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
ts, err := parseTensors(path)
ts, err := parseTensors(fsys)
if err != nil {
return err
}

View File

@@ -6,6 +6,7 @@ import (
"flag"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"os"
@@ -17,7 +18,7 @@ import (
"golang.org/x/exp/maps"
)
func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "f16")
@@ -26,7 +27,7 @@ func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
}
defer f.Close()
if err := Convert(d, f); err != nil {
if err := Convert(fsys, f); err != nil {
t.Fatal(err)
}
@@ -76,7 +77,7 @@ func TestConvertFull(t *testing.T) {
t.Skipf("%s not found", p)
}
f, kv, tensors := convertFull(t, p)
f, kv, tensors := convertFull(t, os.DirFS(p))
actual := make(map[string]string)
for k, v := range kv {
if s, ok := v.(json.Marshaler); !ok {

58
convert/fs.go Normal file
View File

@@ -0,0 +1,58 @@
package convert
import (
"archive/zip"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
)
type ZipReader struct {
r *zip.Reader
p string
// limit is the maximum size of a file that can be read directly
// from the zip archive. Files larger than this size will be extracted
limit int64
}
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
return &ZipReader{r, p, limit}
}
func (z *ZipReader) Open(name string) (fs.File, error) {
r, err := z.r.Open(name)
if err != nil {
return nil, err
}
defer r.Close()
if fi, err := r.Stat(); err != nil {
return nil, err
} else if fi.Size() < z.limit {
return r, nil
}
if !filepath.IsLocal(name) {
return nil, zip.ErrInsecurePath
}
n := filepath.Join(z.p, name)
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
w, err := os.Create(n)
if err != nil {
return nil, err
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return os.Open(n)
}

View File

@@ -3,7 +3,7 @@ package convert
import (
"errors"
"io"
"path/filepath"
"io/fs"
"strings"
)
@@ -55,8 +55,8 @@ func (t *tensorBase) SetRepacker(fn repacker) {
type repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(d string) ([]Tensor, error) {
patterns := map[string]func(...string) ([]Tensor, error){
func parseTensors(fsys fs.FS) ([]Tensor, error) {
patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
"model-*-of-*.safetensors": parseSafetensors,
"model.safetensors": parseSafetensors,
"pytorch_model-*-of-*.bin": parseTorch,
@@ -65,13 +65,13 @@ func parseTensors(d string) ([]Tensor, error) {
}
for pattern, parseFn := range patterns {
matches, err := filepath.Glob(filepath.Join(d, pattern))
matches, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, err
}
if len(matches) > 0 {
return parseFn(matches...)
return parseFn(fsys, matches...)
}
}

View File

@@ -6,7 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"io/fs"
"slices"
"github.com/d4l3k/go-bfloat16"
@@ -20,10 +20,10 @@ type safetensorMetadata struct {
Offsets []int64 `json:"data_offsets"`
}
func parseSafetensors(ps ...string) ([]Tensor, error) {
func parseSafetensors(fsys fs.FS, ps ...string) ([]Tensor, error) {
var ts []Tensor
for _, p := range ps {
f, err := os.Open(p)
f, err := fsys.Open(p)
if err != nil {
return nil, err
}
@@ -50,6 +50,7 @@ func parseSafetensors(ps ...string) ([]Tensor, error) {
for _, key := range keys {
if value := headers[key]; value.Type != "" {
ts = append(ts, safetensor{
fs: fsys,
path: p,
dtype: value.Type,
offset: safetensorsPad(n, value.Offsets[0]),
@@ -72,6 +73,7 @@ func safetensorsPad(n, offset int64) int64 {
}
type safetensor struct {
fs fs.FS
path string
dtype string
offset int64
@@ -80,14 +82,20 @@ type safetensor struct {
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := os.Open(st.path)
f, err := st.fs.Open(st.path)
if err != nil {
return 0, err
}
defer f.Close()
if _, err = f.Seek(st.offset, io.SeekStart); err != nil {
return 0, err
if seeker, ok := f.(io.Seeker); ok {
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
return 0, err
}
} else {
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
return 0, err
}
}
var f32s []float32

View File

@@ -2,12 +2,13 @@ package convert
import (
"io"
"io/fs"
"github.com/nlpodyssey/gopickle/pytorch"
"github.com/nlpodyssey/gopickle/types"
)
func parseTorch(ps ...string) ([]Tensor, error) {
func parseTorch(fsys fs.FS, ps ...string) ([]Tensor, error) {
var ts []Tensor
for _, p := range ps {
pt, err := pytorch.Load(p)

View File

@@ -7,9 +7,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"slices"
)
@@ -32,8 +32,8 @@ type Tokenizer struct {
Template string
}
func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
v, err := parseVocabulary(d)
func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
v, err := parseVocabulary(fsys)
if err != nil {
return nil, err
}
@@ -44,7 +44,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
}
addedTokens := make(map[string]token)
if f, err := os.Open(filepath.Join(d, "tokenizer.json")); errors.Is(err, os.ErrNotExist) {
if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
@@ -87,7 +87,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
}
}
if f, err := os.Open(filepath.Join(d, "tokenizer_config.json")); errors.Is(err, os.ErrNotExist) {
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
@@ -172,8 +172,8 @@ type Vocabulary struct {
Types []int32
}
func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
f, err := os.Open(filepath.Join(p, "tokenizer.json"))
func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
f, err := fsys.Open("tokenizer.json")
if err != nil {
return nil, err
}
@@ -219,20 +219,20 @@ func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
return &v, nil
}
func parseVocabulary(d string) (*Vocabulary, error) {
patterns := map[string]func(string) (*Vocabulary, error){
func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
patterns := map[string]func(fs.FS) (*Vocabulary, error){
"tokenizer.model": parseSentencePiece,
"tokenizer.json": parseVocabularyFromTokenizer,
}
for pattern, parseFn := range patterns {
if _, err := os.Stat(filepath.Join(d, pattern)); errors.Is(err, os.ErrNotExist) {
if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return nil, err
}
return parseFn(d)
return parseFn(fsys)
}
return nil, errors.New("unknown tensor format")

View File

@@ -5,8 +5,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"google.golang.org/protobuf/proto"
@@ -14,8 +14,8 @@ import (
"github.com/ollama/ollama/convert/sentencepiece"
)
func parseSentencePiece(d string) (*Vocabulary, error) {
bts, err := os.ReadFile(filepath.Join(d, "tokenizer.model"))
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
bts, err := fs.ReadFile(fsys, "tokenizer.model")
if err != nil {
return nil, err
}
@@ -41,7 +41,7 @@ func parseSentencePiece(d string) (*Vocabulary, error) {
}
}
f, err := os.Open(filepath.Join(d, "added_tokens.json"))
f, err := fsys.Open("added_tokens.json")
if errors.Is(err, os.ErrNotExist) {
return &v, nil
} else if err != nil {