chat api endpoint (#1392)

This commit is contained in:
Bruce MacDonald
2023-12-05 14:57:33 -05:00
committed by GitHub
parent 00d06619a1
commit 195e3d9dbd
9 changed files with 550 additions and 132 deletions

View File

@@ -531,21 +531,30 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
}
type PredictOpts struct {
Model string
Prompt string
Format string
CheckpointStart time.Time
CheckpointLoaded time.Time
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
type PredictResult struct {
Model string
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
request := map[string]any{
"prompt": nextContext.String(),
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@@ -567,7 +576,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop,
}
if format == "json" {
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
@@ -624,25 +633,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
}
if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
fn(PredictResult{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
}
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
fn(PredictResult{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}