Add unit test of API routes (#1528)

This commit is contained in:
Patrick Devine
2023-12-14 16:47:40 -08:00
committed by GitHub
parent 6e16098a60
commit 630518f0d9
4 changed files with 122 additions and 30 deletions

View File

@@ -32,6 +32,10 @@ import (
var mode string = gin.DebugMode
type Server struct {
WorkDir string
}
func init() {
switch mode {
case gin.DebugMode:
@@ -800,27 +804,27 @@ var defaultAllowOrigins = []string{
"0.0.0.0",
}
func Serve(ln net.Listener, allowOrigins []string) error {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
func NewServer() (*Server, error) {
workDir, err := os.MkdirTemp("", "ollama")
if err != nil {
return nil, err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
return &Server{
WorkDir: workDir,
}, nil
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
func (s *Server) GenerateRoutes() http.Handler {
var origins []string
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
origins = strings.Split(o, ",")
}
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowOrigins = allowOrigins
config.AllowOrigins = origins
for _, allowOrigin := range defaultAllowOrigins {
config.AllowOrigins = append(config.AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
@@ -830,17 +834,11 @@ func Serve(ln net.Listener, allowOrigins []string) error {
)
}
workDir, err := os.MkdirTemp("", "ollama")
if err != nil {
return err
}
defer os.RemoveAll(workDir)
r := gin.Default()
r.Use(
cors.New(config),
func(c *gin.Context) {
c.Set("workDir", workDir)
c.Set("workDir", s.WorkDir)
c.Next()
},
)
@@ -868,8 +866,34 @@ func Serve(ln net.Listener, allowOrigins []string) error {
})
}
return r
}
func Serve(ln net.Listener) error {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}
s, err := NewServer()
if err != nil {
return err
}
r := s.GenerateRoutes()
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
s := &http.Server{
srvr := &http.Server{
Handler: r,
}
@@ -881,7 +905,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if loaded.runner != nil {
loaded.runner.Close()
}
os.RemoveAll(workDir)
os.RemoveAll(s.WorkDir)
os.Exit(0)
}()
@@ -892,7 +916,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
}
}
return s.Serve(ln)
return srvr.Serve(ln)
}
func waitForStream(c *gin.Context, ch chan interface{}) {