trim images

This commit is contained in:
Michael Yang
2024-01-31 17:39:38 -08:00
parent b4e11be8ef
commit 8450bf66e6
3 changed files with 41 additions and 20 deletions

View File

@@ -312,11 +312,16 @@ func GenerateHandler(c *gin.Context) {
ch <- resp
}
images := make(map[int]api.ImageData)
for i := range req.Images {
images[i] = req.Images[i]
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: req.Images,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1143,7 +1148,8 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1186,7 +1192,7 @@ func ChatHandler(c *gin.Context) {
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: chat.CurrentImages,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -1233,25 +1239,27 @@ type promptInfo struct {
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, map[int]api.ImageData, error) {
if len(chat.Prompts) == 0 {
return "", nil
return "", nil, nil
}
var promptsToAdd []promptInfo
var totalTokenLength int
var systemPromptIncluded bool
images := make(map[int]api.ImageData)
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- {
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
if err != nil {
return "", err
return "", nil, err
}
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
if err != nil {
return "", err
return "", nil, err
}
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
@@ -1261,6 +1269,10 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
totalTokenLength += len(encodedTokens)
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
for _, image := range chat.Prompts[i].Images {
images[image.Rank] = image.ImageData
}
}
// ensure the system prompt is included, if not already
@@ -1268,7 +1280,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
var err error
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
if err != nil {
return "", err
return "", nil, err
}
}
@@ -1279,11 +1291,11 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string
for i, prompt := range promptsToAdd {
promptText, err := promptString(model, prompt.vars, i == 0)
if err != nil {
return "", err
return "", nil, err
}
result = promptText + result
}
return result, nil
return result, images, nil
}
// promptString applies the model template to the prompt