fix: relay request opts to loaded llm prediction (#1761)

This commit is contained in:
Bruce MacDonald
2024-01-03 12:01:42 -05:00
committed by GitHub
parent 05face44ef
commit 0b3118e0af
5 changed files with 106 additions and 71 deletions

View File

@@ -153,7 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
return server, nil
}
func predict(llm extServer, opts api.Options, ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
var imageData []ImageData
@@ -167,23 +167,23 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
"n_predict": opts.NumPredict,
"n_keep": opts.NumKeep,
"temperature": opts.Temperature,
"top_k": opts.TopK,
"top_p": opts.TopP,
"tfs_z": opts.TFSZ,
"typical_p": opts.TypicalP,
"repeat_last_n": opts.RepeatLastN,
"repeat_penalty": opts.RepeatPenalty,
"presence_penalty": opts.PresencePenalty,
"frequency_penalty": opts.FrequencyPenalty,
"mirostat": opts.Mirostat,
"mirostat_tau": opts.MirostatTau,
"mirostat_eta": opts.MirostatEta,
"penalize_nl": opts.PenalizeNewline,
"seed": opts.Seed,
"stop": opts.Stop,
"n_predict": predict.Options.NumPredict,
"n_keep": predict.Options.NumKeep,
"temperature": predict.Options.Temperature,
"top_k": predict.Options.TopK,
"top_p": predict.Options.TopP,
"tfs_z": predict.Options.TFSZ,
"typical_p": predict.Options.TypicalP,
"repeat_last_n": predict.Options.RepeatLastN,
"repeat_penalty": predict.Options.RepeatPenalty,
"presence_penalty": predict.Options.PresencePenalty,
"frequency_penalty": predict.Options.FrequencyPenalty,
"mirostat": predict.Options.Mirostat,
"mirostat_tau": predict.Options.MirostatTau,
"mirostat_eta": predict.Options.MirostatEta,
"penalize_nl": predict.Options.PenalizeNewline,
"seed": predict.Options.Seed,
"stop": predict.Options.Stop,
"image_data": imageData,
"cache_prompt": true,
}

View File

@@ -60,7 +60,7 @@ func newDefaultExtServer(model string, adapters, projectors []string, numLayers
}
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(llm, llm.Options, ctx, pred, fn)
return predict(ctx, llm, pred, fn)
}
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {

View File

@@ -166,9 +166,10 @@ const maxRetries = 3
const retryDelay = 1 * time.Second
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
Prompt string
Format string
Images []api.ImageData
Options api.Options
}
type PredictResult struct {

View File

@@ -92,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
}
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(llm, llm.options, ctx, pred, fn)
return predict(ctx, llm, pred, fn)
}
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {