This commit is contained in:
Michael Yang
2023-07-06 14:05:55 -07:00
parent 3d6009aae3
commit c4b9e84945
6 changed files with 75 additions and 152 deletions

View File

@@ -3,15 +3,13 @@ package cmd
import (
"bufio"
"context"
"encoding/json"
"fmt"
"log"
"net"
"os"
"path"
"sync"
"github.com/gosuri/uiprogress"
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
"golang.org/x/term"
@@ -28,43 +26,33 @@ func cacheDir() string {
return path.Join(home, ".ollama")
}
func bytesToGB(bytes int) float64 {
return float64(bytes) / float64(1<<30)
func RunRun(cmd *cobra.Command, args []string) error {
if err := pull(args[0]); err != nil {
return err
}
fmt.Println("Up to date.")
return RunGenerate(cmd, args)
}
func RunRun(cmd *cobra.Command, args []string) error {
func pull(model string) error {
client, err := NewAPIClient()
if err != nil {
return err
}
pr := api.PullRequest{
Model: args[0],
}
var bar *uiprogress.Bar
mutex := &sync.Mutex{}
var progressData api.PullProgress
pullCallback := func(progress api.PullProgress) {
mutex.Lock()
progressData = progress
if bar == nil {
uiprogress.Start()
bar = uiprogress.AddBar(int(progress.Total))
bar.PrependFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
})
bar.AppendFunc(func(b *uiprogress.Bar) string {
return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
})
}
bar.Set(int(progress.Completed))
mutex.Unlock()
}
if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
return err
}
fmt.Println("Up to date.")
return RunGenerate(cmd, args)
var bar *progressbar.ProgressBar
return client.Pull(
context.Background(),
&api.PullRequest{Model: model},
func(progress api.PullProgress) error {
if bar == nil {
bar = progressbar.DefaultBytes(progress.Total)
}
return bar.Set64(progress.Completed)
},
)
}
func RunGenerate(_ *cobra.Command, args []string) error {
@@ -86,13 +74,9 @@ func generate(model string, prompts ...string) error {
}
for _, prompt := range prompts {
client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(bts []byte) {
var resp api.GenerateResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return
}
client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
fmt.Print(resp.Response)
return nil
})
}