common stream producer

This commit is contained in:
Michael Yang
2023-07-11 11:54:22 -07:00
parent 62620914e9
commit 2a66a1164a
2 changed files with 61 additions and 85 deletions

View File

@@ -8,8 +8,6 @@ import (
"os"
"path"
"strconv"
"github.com/jmorganca/ollama/api"
)
const directoryURL = "https://ollama.ai/api/models"
@@ -36,14 +34,6 @@ func (m *Model) FullName() string {
return path.Join(home, ".ollama", "models", m.Name+".bin")
}
func pull(model string, progressCh chan<- api.PullProgress) error {
remote, err := getRemote(model)
if err != nil {
return fmt.Errorf("failed to pull model: %w", err)
}
return saveModel(remote, progressCh)
}
func getRemote(model string) (*Model, error) {
// resolve the model download from our directory
resp, err := http.Get(directoryURL)
@@ -68,7 +58,7 @@ func getRemote(model string) (*Model, error) {
return nil, fmt.Errorf("model not found in directory: %s", model)
}
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
func saveModel(model *Model, fn func(total, completed int64)) error {
// this models cache directory is created by the server on startup
client := &http.Client{}
@@ -98,11 +88,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
// already downloaded
progressCh <- api.PullProgress{
Total: alreadyDownloaded,
Completed: alreadyDownloaded,
Percent: 100,
}
fn(alreadyDownloaded, alreadyDownloaded)
return nil
}
@@ -136,19 +122,9 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
totalBytes += int64(n)
// send progress updates
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalBytes,
Percent: float64(totalBytes) / float64(totalSize) * 100,
}
}
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalSize,
Percent: 100,
fn(totalSize, totalBytes)
}
fn(totalSize, totalSize)
return nil
}