- update chat docs
- add messages chat endpoint
- remove deprecated context and template generate parameters from docs
- context and template are still supported for the time being and will continue to work as expected
- add partial response to chat history
This commit is contained in:
Bruce MacDonald
2023-12-04 18:01:06 -05:00
committed by GitHub
parent 0cca1486dd
commit 7a0899d62d
9 changed files with 667 additions and 256 deletions

View File

@@ -531,21 +531,31 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
}
type PredictRequest struct {
Model string
Prompt string
Format string
CheckpointStart time.Time
CheckpointLoaded time.Time
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
type PredictResponse struct {
Model string
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Context []int
}
func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error {
request := map[string]any{
"prompt": nextContext.String(),
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@@ -567,7 +577,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop,
}
if format == "json" {
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
@@ -624,25 +634,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
}
if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
fn(PredictResponse{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
}
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
fn(PredictResponse{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Predict(context.Context, PredictRequest, func(PredictResponse)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)