Support multiple variants for a given llm lib type

In some cases we may want multiple variants for a given GPU type or CPU.
This adds logic to have an optional Variant which we can use to select
an optimal library, but also allows us to try multiple variants in case
some fail to load.

This can be useful for scenarios such as ROCm v5 vs v6 incompatibility
or potentially CPU features.
This commit is contained in:
Daniel Hiltgen
2024-01-05 12:13:08 -08:00
parent b24e8d17b2
commit 8da7bef05f
16 changed files with 428 additions and 212 deletions

View File

@@ -11,14 +11,9 @@ package llm
import "C"
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"path/filepath"
"strings"
"sync"
"unsafe"
@@ -34,8 +29,6 @@ type shimExtServer struct {
var shimMutex sync.Mutex
var llm *shimExtServer
const pathComponentCount = 6
func (llm *shimExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_init(llm.s, sparams, err)
}
@@ -112,82 +105,3 @@ func (llm *shimExtServer) Embedding(ctx context.Context, input string) ([]float6
func (llm *shimExtServer) Close() {
close(llm)
}
func nativeInit(workdir string) error {
libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/lib/*")
if err != nil {
if err == payloadMissing {
log.Printf("%s", payloadMissing)
return nil
}
return err
}
for _, lib := range libs {
// The last dir component is the variant name
variant := filepath.Base(filepath.Dir(lib))
AvailableShims[variant] = lib
}
if err := verifyDriverAccess(); err != nil {
return err
}
// Report which dynamic libraries we have loaded to assist troubleshooting
variants := make([]string, len(AvailableShims))
i := 0
for variant := range AvailableShims {
variants[i] = variant
i++
}
log.Printf("Dynamic LLM variants %v", variants)
return nil
}
func extractDynamicLibs(workDir, glob string) ([]string, error) {
files, err := fs.Glob(libEmbed, glob)
if err != nil || len(files) == 0 {
return nil, payloadMissing
}
libs := []string{}
for _, file := range files {
pathComps := strings.Split(file, "/")
if len(pathComps) != pathComponentCount {
log.Printf("unexpected payload components: %v", pathComps)
continue
}
// llama.cpp/build/$OS/$VARIANT/lib/$LIBRARY
// Include the variant in the path to avoid conflicts between multiple server libs
targetDir := filepath.Join(workDir, pathComps[pathComponentCount-3])
srcFile, err := libEmbed.Open(file)
if err != nil {
return nil, fmt.Errorf("read payload %s: %v", file, err)
}
defer srcFile.Close()
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, fmt.Errorf("create payload temp dir %s: %v", workDir, err)
}
destFile := filepath.Join(targetDir, filepath.Base(file))
if strings.Contains(destFile, "server") {
libs = append(libs, destFile)
}
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return nil, fmt.Errorf("write payload %s: %v", file, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
return nil, fmt.Errorf("copy payload %s: %v", file, err)
}
case err != nil:
return nil, fmt.Errorf("stat payload %s: %v", file, err)
}
}
return libs, nil
}