Convert the REPL to use /api/chat for interactive responses (#1936)

This commit is contained in:
Patrick Devine
2024-01-12 12:05:52 -08:00
committed by GitHub
parent 40a0a90a88
commit 565f8a3c44
2 changed files with 155 additions and 72 deletions

View File

@@ -1,7 +1,6 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
@@ -43,16 +42,16 @@ func modelIsMultiModal(cmd *cobra.Command, name string) bool {
return slices.Contains(resp.Details.Families, "clip")
}
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model)
// load the model
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
Images: []ImageData{},
loadOpts := runOptions{
Model: opts.Model,
Prompt: "",
Messages: []api.Message{},
}
if err := generate(cmd, loadOpts); err != nil {
if _, err := chat(cmd, loadOpts); err != nil {
return err
}
@@ -141,6 +140,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
var sb strings.Builder
var multiline MultilineState
opts.Messages = make([]api.Message, 0)
for {
line, err := scanner.Readline()
@@ -409,22 +409,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
if sb.Len() > 0 && multiline == MultilineNone {
opts.Prompt = sb.String()
newMessage := api.Message{Role: "user", Content: sb.String()}
if multiModal {
newPrompt, images, err := extractFileData(sb.String())
msg, images, err := extractFileData(sb.String())
if err != nil {
return err
}
opts.Prompt = newPrompt
newMessage.Content = msg
// reset the context if we find another image
if len(images) > 0 {
opts.Images = images
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
cmd.SetContext(ctx)
newMessage.Images = append(newMessage.Images, images...)
// reset the context for the new image
opts.Messages = []api.Message{}
} else {
if len(opts.Messages) > 1 {
newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
}
}
if len(opts.Images) == 0 {
if len(newMessage.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
fmt.Println()
sb.Reset()
@@ -432,9 +436,18 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
}
if err := generate(cmd, opts); err != nil {
if opts.System != "" {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
sb.Reset()
}
@@ -476,9 +489,9 @@ func extractFileNames(input string) []string {
return re.FindAllString(input, -1)
}
func extractFileData(input string) (string, []ImageData, error) {
func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input)
var imgs []ImageData
var imgs []api.ImageData
for _, fp := range filePaths {
nfp := normalizeFilePath(fp)