display pull progress

This commit is contained in:
Bruce MacDonald
2023-07-06 14:18:40 -04:00
committed by Jeffrey Morgan
parent 580fe8951c
commit 7cf5905063
7 changed files with 81 additions and 19 deletions

View File

@@ -9,15 +9,14 @@ import (
"os"
"path"
"strconv"
"github.com/jmorganca/ollama/api"
)
// const directoryURL = "https://ollama.ai/api/models"
// TODO
const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
type directoryCtxKey string
var dirCtx directoryCtxKey = "directory"
type Model struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
@@ -31,7 +30,7 @@ type Model struct {
License string `json:"license"`
}
func pull(model string, progressCh chan<- string) error {
func pull(model string, progressCh chan<- api.PullProgress) error {
remote, err := getRemote(model)
if err != nil {
return fmt.Errorf("failed to pull model: %w", err)
@@ -64,7 +63,7 @@ func getRemote(model string) (*Model, error) {
return nil, fmt.Errorf("model not found in directory: %s", model)
}
func saveModel(model *Model, progressCh chan<- string) error {
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
// this models cache directory is created by the server on startup
home, err := os.UserHomeDir()
if err != nil {
@@ -130,11 +129,18 @@ func saveModel(model *Model, progressCh chan<- string) error {
totalBytes += n
// send progress updates
progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalBytes,
Percent: float64(totalBytes) / float64(totalSize) * 100,
}
}
// send completion message
progressCh <- "Download complete!"
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalSize,
Percent: 100,
}
return nil
}

View File

@@ -107,7 +107,7 @@ func Serve(ln net.Listener) error {
return
}
progressCh := make(chan string)
progressCh := make(chan api.PullProgress)
go func() {
defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil {