llm: remove internal subprocess req and resp types (#9324)

This commit refactors the LLM subsystem by removing internal subprocess
request and response types. It consolidates duplicate type definitions
across the codebase, moving them to centralized locations. The change also
standardizes interfaces between components, simplifies the ServerStatusResp
struct, and moves the ParseDurationMs function to a common package. This
cleanup reduces code duplication between different runner implementations
(llamarunner and ollamarunner).
This commit is contained in:
Bruce MacDonald
2025-03-14 15:21:53 -07:00
committed by GitHub
parent 4e320b8b90
commit 3892c3a703
4 changed files with 125 additions and 354 deletions

View File

@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runner/common"
)
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
var inputs []input
var parts []string
var matches [][]string
@@ -229,7 +230,7 @@ type Server struct {
image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
status llm.ServerStatus
// current progress on loading the model
progress float32
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var samplingParams llama.SamplingParams
samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
// Extract options from the CompletionRequest
samplingParams := llama.SamplingParams{
TopK: req.Options.TopK,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TypicalP: req.Options.TypicalP,
Temp: req.Options.Temperature,
RepeatLastN: req.Options.RepeatLastN,
PenaltyRepeat: req.Options.RepeatPenalty,
PenaltyFreq: req.Options.FrequencyPenalty,
PenaltyPresent: req.Options.PresencePenalty,
Mirostat: req.Options.Mirostat,
MirostatTau: req.Options.MirostatTau,
MirostatEta: req.Options.MirostatEta,
Seed: uint32(req.Options.Seed),
Grammar: req.Grammar,
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numDecoded,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
panic(err)
}
s.status = ServerStatusReady
s.status = llm.ServerStatusReady
s.ready.Done()
}
@@ -937,7 +852,7 @@ func Execute(args []string) error {
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel,
status: llm.ServerStatusLoadingModel,
}
var tensorSplitFloats []float32