refactor model parsing

This commit is contained in:
Michael Yang
2024-03-13 11:03:56 -07:00
parent 011bb67351
commit d338d70492
5 changed files with 131 additions and 197 deletions

View File

@@ -6,8 +6,6 @@ import (
"fmt"
"io"
"strings"
"github.com/ollama/ollama/format"
)
type containerGGUF struct {
@@ -90,8 +88,8 @@ const (
type gguf struct {
*containerGGUF
KV
Tensors []Tensor
kv KV
tensors []*Tensor
parameters uint64
}
@@ -99,7 +97,7 @@ type gguf struct {
func newGGUF(container *containerGGUF) *gguf {
return &gguf{
containerGGUF: container,
KV: make(KV),
kv: make(KV),
}
}
@@ -107,6 +105,14 @@ func NewGGUFV3(bo binary.ByteOrder) *gguf {
return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
}
func (llm *gguf) KV() KV {
return llm.kv
}
func (llm *gguf) Tensors() []*Tensor {
return llm.tensors
}
func (llm *gguf) numTensor() uint64 {
switch llm.Version {
case 1:
@@ -129,30 +135,6 @@ func (llm *gguf) numKV() uint64 {
}
}
func (llm *gguf) ModelFamily() string {
if t, ok := llm.KV["general.architecture"].(string); ok {
return t
}
return "unknown"
}
func (llm *gguf) ModelType() string {
if llm.parameters > 0 {
return format.HumanNumber(llm.parameters)
}
return "unknown"
}
func (llm *gguf) FileType() string {
if t, ok := llm.KV["general.file_type"].(uint32); ok {
return fileType(t)
}
return "unknown"
}
func (llm *gguf) Decode(rs io.ReadSeeker) error {
// decode key-values
for i := 0; uint64(i) < llm.numKV(); i++ {
@@ -202,7 +184,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
return err
}
llm.KV[k] = v
llm.kv[k] = v
}
// decode tensors
@@ -243,11 +225,14 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
Shape: shape[:],
}
llm.Tensors = append(llm.Tensors, tensor)
llm.tensors = append(llm.tensors, &tensor)
llm.parameters += tensor.parameters()
}
alignment, ok := llm.KV["general.alignment"].(uint32)
// patch KV with parameter count
llm.kv["general.parameter_count"] = llm.parameters
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
@@ -262,7 +247,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
return err
}
for _, tensor := range llm.Tensors {
for _, tensor := range llm.tensors {
padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
return err
@@ -272,60 +257,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
return nil
}
func (llm *gguf) NumLayers() uint32 {
value, exists := llm.KV[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *gguf) NumHead() uint32 {
value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *gguf) NumEmbed() uint32 {
value, exists := llm.KV[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *gguf) NumHeadKv() uint32 {
value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *gguf) NumCtx() uint32 {
value, exists := llm.KV[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *gguf) NumGQA() uint32 {
numHeadKv := llm.NumHeadKv()
if numHeadKv == 0 {
return 0
}
return llm.NumHead() / numHeadKv
}
func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
var t T
err := binary.Read(r, llm.ByteOrder, &t)