fix embeddings invalid values

This commit is contained in:
Bruce MacDonald
2023-08-09 16:13:24 -04:00
parent 9738ef85db
commit 984c9c628c
2 changed files with 9 additions and 39 deletions

View File

@@ -94,7 +94,6 @@ import (
"io"
"log"
"os"
"reflect"
"strings"
"sync"
"unicode/utf8"
@@ -421,27 +420,20 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread))
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}
n := int(C.llama_n_embd(llm.ctx))
n := C.llama_n_embd(llm.ctx)
if n <= 0 {
return nil, errors.New("llama: no embeddings generated")
}
cEmbeddings := unsafe.Slice(C.llama_get_embeddings(llm.ctx), n)
embedPtr := C.llama_get_embeddings(llm.ctx)
if embedPtr == nil {
return nil, errors.New("llama: embedding retrieval failed")
embeddings := make([]float64, len(cEmbeddings))
for i, v := range cEmbeddings {
embeddings[i] = float64(v)
}
header := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(embedPtr)),
Len: n,
Cap: n,
}
embedSlice := *(*[]float64)(unsafe.Pointer(&header))
return embedSlice, nil
return embeddings, nil
}