Revert "chat api (#991)" while context variable is fixed

This reverts commit 7a0899d62d.
This commit is contained in:
Jeffrey Morgan
2023-12-04 21:16:27 -08:00
parent f1ef3f9947
commit 00d06619a1
8 changed files with 144 additions and 559 deletions

View File

@@ -531,31 +531,21 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
type PredictRequest struct {
Model string
Prompt string
Format string
CheckpointStart time.Time
CheckpointLoaded time.Time
}
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
}
type PredictResponse struct {
Model string
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Context []int
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error {
request := map[string]any{
"prompt": predict.Prompt,
"prompt": nextContext.String(),
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@@ -577,7 +567,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P
"stop": llm.Stop,
}
if predict.Format == "json" {
if format == "json" {
request["grammar"] = jsonGrammar
}
@@ -634,25 +624,25 @@ func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(P
}
if p.Content != "" {
fn(PredictResponse{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
}
if p.Stop {
fn(PredictResponse{
Model: predict.Model,
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, PredictRequest, func(PredictResponse)) error
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)