return more info in generate response

This commit is contained in:
Michael Yang
2023-07-12 18:18:06 -07:00
parent 31590284a7
commit 05e08d2310
4 changed files with 116 additions and 31 deletions

View File

@@ -79,9 +79,11 @@ llama_token llama_sample(
import "C"
import (
"errors"
"fmt"
"io"
"os"
"strings"
"time"
"unsafe"
"github.com/jmorganca/ollama/api"
@@ -147,7 +149,7 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx)
}
func (llm *llama) Predict(prompt string, fn func(string)) error {
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil {
return llm.generate(tokens, fn)
}
@@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return sb.String()
}
func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
var opts C.struct_llama_sample_options
opts.repeat_penalty = C.float(llm.RepeatPenalty)
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
@@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
opts.mirostat_tau = C.float(llm.MirostatTau)
opts.mirostat_eta = C.float(llm.MirostatEta)
pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN}
output := deque[C.llama_token]{capacity: llm.NumCtx}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval")
}
token, err := llm.sample(pastTokens, &opts)
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
token, err := llm.sample(output, &opts)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
fn(llm.detokenize(token))
// call the callback
fn(api.GenerateResponse{
Response: llm.detokenize(token),
})
tokens = []C.llama_token{token}
output.PushLeft(token)
pastTokens.PushLeft(token)
input = []C.llama_token{token}
}
dur := func(ms float64) time.Duration {
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return d
}
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),
EvalDuration: dur(float64(timings.t_eval_ms)),
})
return nil
}
func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
numVocab := int(C.llama_n_vocab(llm.ctx))
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
candidates := make([]C.struct_llama_token_data, 0, numVocab)
for i := 0; i < numVocab; i++ {
candidates = append(candidates, C.llama_token_data{
candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
for i := 0; i < candidates.Cap(); i++ {
candidates.PushLeft(C.struct_llama_token_data{
id: C.int(i),
logit: logits[i],
p: 0,
@@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
token := C.llama_sample(
llm.ctx,
unsafe.SliceData(candidates), C.ulong(len(candidates)),
unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()),
unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
unsafe.SliceData(output.Data()), C.ulong(output.Len()),
opts)
if token != C.llama_token_eos() {
return token, nil