Runner for Ollama engine

This provides integration with the new Ollama engine
(5824541 next ollama runner (#7913)) and the rest of the Ollama
infrastructure such as the runner and Ollama server.

In addition, it also builds out the KV cache infrastructure to
support requirements of how Ollama runs models such as:
 - Parallel processing
 - Memory management for defragmentation and shifting
 - Multi-modal modals

Both old and new engines continue to be supported. By default, only
the old engine is used. To enable the new engine:

Start the server with the OLLAMA_NEW_ENGINE environment variable set:
OLLAMA_NEW_ENGINE=1 ./ollama serve

Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M:
./ollama run jessegross/llama3.1
This commit is contained in:
Jesse Gross
2024-12-17 19:59:41 -08:00
committed by Jesse Gross
parent 6945617af5
commit ed443a0393
31 changed files with 2952 additions and 244 deletions

246
runner/llamarunner/cache.go Normal file
View File

@@ -0,0 +1,246 @@
package llamarunner
import (
"errors"
"fmt"
"log/slog"
"reflect"
"time"
"github.com/ollama/ollama/llama"
)
type InputCache struct {
// context window size (per slot)
numCtx int
// individual KV caches
slots []InputCacheSlot
// optimize cache eviction for multiple users
multiUserCache bool
lc *llama.Context
}
func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/numSlots < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
slots[i] = InputCacheSlot{
Id: i,
Inputs: make([]input, 0),
}
}
return &InputCache{
numCtx: kvSize / numSlots,
slots: slots,
multiUserCache: multiUserCache,
lc: lc,
}, nil
}
// Locking: Operations on InputCacheSlot (including finding one
// through LoadCacheSlot) require a lock to be be held that serializes
// these operations with each other and llama.Decode
type InputCacheSlot struct {
// Index in the KV cache
Id int
// Inputs that are stored in the KV cache
Inputs []input
// is this cache actively being processed as part of a sequence?
InUse bool
// last time this cache was used (as of start of processing)
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
var slot *InputCacheSlot
var numPast int
var err error
// In single-user scenarios, the longest cache slot works fine for getting good input
// cache hit rates and it reuses the same VRAM over and over again, which is good for
// GPU performance in situations where we miss the input cache.
// For multiple users, the "best" cache slot produces better input cache hit rates
// at the cost of worse performance when we miss the input cache (because it causes
// GPU L2 cache misses due to spreading out accesses across VRAM).
if !c.multiUserCache {
slot, numPast, err = c.findLongestCacheSlot(prompt)
} else {
slot, numPast, err = c.findBestCacheSlot(prompt)
}
if err != nil {
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()
if numPast == len(prompt) {
// Leave one input to sample so we can get a response
numPast--
}
if !c.lc.KvCacheSeqRm(slot.Id, numPast, -1) {
// Some models don't support partial erasure
c.lc.KvCacheSeqRm(slot.Id, 0, -1)
numPast = 0
}
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", len(prompt)-numPast)
prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
longest := -1
var longestSlot *InputCacheSlot
for i, s := range c.slots {
if s.InUse {
continue
}
count := countCommonPrefix(s.Inputs, prompt)
if count > longest {
longest = count
longestSlot = &c.slots[i]
}
}
if longestSlot == nil {
return nil, 0, errors.New("no available cache slots")
}
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
longest := -1
var longestSlot *InputCacheSlot
for i, s := range c.slots {
count := countCommonPrefix(s.Inputs, prompt)
if count > longest {
longest = count
longestSlot = &c.slots[i]
}
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
oldest = s.lastUsed
oldestSlot = &c.slots[i]
}
}
if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
return longestSlot, longest, nil
}
if oldestSlot.InUse {
return nil, 0, errors.New("no available cache slots")
}
if len(oldestSlot.Inputs) != 0 {
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
"used", oldestSlot.lastUsed)
}
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
// This is only nil for unit tests
if c.lc != nil {
c.lc.KvCacheSeqRm(oldestSlot.Id, 0, -1)
c.lc.KvCacheSeqCp(longestSlot.Id, oldestSlot.Id, 0, longest)
}
}
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input, b []input) int {
var count int
for i := range a {
if i >= len(b) {
break
}
if !reflect.DeepEqual(a[i], b[i]) {
break
}
count++
}
return count
}
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
if numKeep >= c.numCtx {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
if discard <= 0 {
return nil
}
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
for i := numKeep + discard; i < len(slot.Inputs); i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
return nil
}

View File

@@ -0,0 +1,292 @@
package llamarunner
import (
"testing"
"time"
)
func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []input
t2 []input
expected int
}{
{
name: "Equal",
t1: []input{{token: 1}, {token: 2}, {token: 3}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []input{{token: 1}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 1,
},
{
name: "Embeddings Prefix",
t1: []input{{embed: []float32{0.1, 0.2, 0.3}}},
t2: []input{{embed: []float32{0.1, 0.2, 0.3}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
expected: 1,
},
{
name: "Embeddings Prefix Partial",
t1: []input{{embed: []float32{0.1, 0.2, 0.3}}},
t2: []input{{embed: []float32{0.1, 0.2}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
expected: 0,
},
{
name: "Mixed",
t1: []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}},
t2: []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}, {token: 5}},
expected: 2,
},
{
name: "Empty",
t1: []input{},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []input{},
t2: []input{},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := countCommonPrefix(tt.t1, tt.t2)
if result != tt.expected {
t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
}
})
}
}
func TestFindCacheSlot(t *testing.T) {
type expected struct {
result int
len int
}
tests := []struct {
name string
cache InputCache
prompt []input
longest expected
best expected
}{
{
name: "Empty",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
{
name: "Extend",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
{
name: "New",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
{
name: "Fork",
cache: InputCache{
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []input{{token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
{
name: "Evict",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 2}, {token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
{
name: "In use",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
}
for _, tt := range tests {
t.Run("Longest-"+tt.name, func(t *testing.T) {
result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
if err != nil {
t.Errorf("findLongestCacheSlot: err %v", err)
} else if result.Id != tt.longest.result || resultLen != tt.longest.len {
t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
result.Id, tt.longest.result, resultLen, tt.longest.len)
}
})
}
for _, tt := range tests {
t.Run("Best-"+tt.name, func(t *testing.T) {
result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
if err != nil {
t.Errorf("findBestCacheSlot: err %v", err)
} else if result.Id != tt.best.result || resultLen != tt.best.len {
t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
result.Id, tt.best.result, resultLen, tt.best.len)
}
})
}
}
func TestShiftDiscard(t *testing.T) {
tests := []struct {
name string
numCtx int
numKeep int
inputLen int
expected int
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
}
})
}
}

183
runner/llamarunner/image.go Normal file
View File

@@ -0,0 +1,183 @@
package llamarunner
import (
"errors"
"fmt"
"hash/maphash"
"log/slog"
"slices"
"sync"
"time"
"github.com/ollama/ollama/llama"
)
const imageCacheSize = 4
type ImageContext struct {
// mu is required to be held when generating embeddings or accessing the cache
mu sync.Mutex
clip *llama.ClipContext
mllama *llama.MllamaContext
// cache of images to embeddings
images []imageCache
imageHash maphash.Hash
}
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
arch, err := llama.GetModelArch(modelPath)
if err != nil {
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
}
var c ImageContext
if arch == "clip" {
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
} else if arch == "mllama" {
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
} else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
}
if err != nil {
return nil, err
}
c.images = make([]imageCache, imageCacheSize)
return &c, nil
}
func (c *ImageContext) Free(modelPath string) {
if c == nil {
return
}
if c.clip != nil {
c.clip.Free()
}
if c.mllama != nil {
c.mllama.Free()
}
}
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) ([][]float32, error) {
if c == nil {
return nil, nil
}
if len(data) <= 0 {
return nil, errors.New("received zero length image")
}
hash := c.hashImage(data)
c.mu.Lock()
defer c.mu.Unlock()
embed, err := c.findImage(hash)
if err != nil {
if c.mllama != nil {
embed, err = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
if err != nil {
return nil, err
}
} else if c.clip != nil {
embed, err = c.clip.NewEmbed(llamaContext, data)
if err != nil {
return nil, err
}
} else {
return nil, errors.New("received image but vision model not loaded")
}
c.addImage(hash, embed)
}
return embed, nil
}
func (c *ImageContext) BatchSize(configuredBatchSize int) int {
// If images are not supported, we don't need to allocate embedding batches
if c == nil {
return 0
}
// Mllama maps an image to 1 embedding token (llava creates many tokens)
// and doesn't support more than a single image per request.
// The embeddings are large (100 MB), so allocating a big batch can fail
// on some systems
if c.mllama != nil {
return 1
}
return configuredBatchSize
}
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
if c != nil && c.mllama != nil {
return c.mllama.EmbedSize(llamaContext)
} else {
return llamaContext.Model().NEmbd()
}
}
func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
if c == nil || c.mllama == nil {
return false
}
return slices.ContainsFunc(inputs, func(input input) bool {
return input.embed != nil
})
}
type imageCache struct {
key uint64
val [][]float32
lastUsed time.Time
}
func (c *ImageContext) hashImage(image []byte) uint64 {
c.imageHash.Reset()
_, _ = c.imageHash.Write(image)
return c.imageHash.Sum64()
}
var errImageNotFound = errors.New("image not found in cache")
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
for i := range c.images {
if c.images[i].key == hash {
slog.Debug("loading image embeddings from cache", "entry", i)
c.images[i].lastUsed = time.Now()
return c.images[i].val, nil
}
}
return nil, errImageNotFound
}
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
best := time.Now()
var bestImage int
for i := range c.images {
if c.images[i].key == hash {
bestImage = i
break
}
if c.images[i].lastUsed.Compare(best) < 0 {
best = c.images[i].lastUsed
bestImage = i
}
}
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
c.images[bestImage].key = hash
c.images[bestImage].val = embed
c.images[bestImage].lastUsed = time.Now()
}

View File

@@ -0,0 +1,80 @@
package llamarunner
import (
"reflect"
"testing"
)
func TestImageCache(t *testing.T) {
cache := ImageContext{images: make([]imageCache, 4)}
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
// Empty cache
result, err := cache.findImage(0x5adb61d31933a946)
if err != errImageNotFound {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}
// Insert A
cache.addImage(0x5adb61d31933a946, valA)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Insert B
cache.addImage(0x011551369a34a901, valB)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Replace B with C
cache.addImage(0x011551369a34a901, valC)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Evict A
cache.addImage(0x756b218a517e7353, valB)
cache.addImage(0x75e5e8d35d7e3967, valD)
cache.addImage(0xd96f7f268ca0646e, valE)
result, err = cache.findImage(0x5adb61d31933a946)
if reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x756b218a517e7353)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x75e5e8d35d7e3967)
if !reflect.DeepEqual(result, valD) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0xd96f7f268ca0646e)
if !reflect.DeepEqual(result, valE) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
}

1001
runner/llamarunner/runner.go Normal file

File diff suppressed because it is too large Load Diff