Sync with upstream ollama/ollama and restore Tesla K80 (compute 3.7) support

This commit represents a complete rework after pulling the latest changes from
official ollama/ollama repository and re-applying Tesla K80 compatibility patches.

## Key Changes

### CUDA Compute Capability 3.7 Support (Tesla K80)
- Added sm_37 (compute 3.7) to CMAKE_CUDA_ARCHITECTURES in CMakeLists.txt
- Updated CMakePresets.json to include compute 3.7 in "CUDA 11" preset
- Using 37-virtual (PTX with JIT compilation) for maximum compatibility

### Legacy Toolchain Compatibility
- **NVIDIA Driver**: 470.256.02 (last version supporting Kepler/K80)
- **CUDA Version**: 11.4.4 (last CUDA 11.x supporting compute 3.7)
- **GCC Version**: 10.5.0 (required by CUDA 11.4 host_config.h)

### CPU Architecture Trade-offs
Due to GCC 10.5 limitation, sacrificed newer CPU optimizations:
- Alderlake CPU variant enabled WITHOUT AVX_VNNI (requires GCC 11+)
- Still supports: SSE4.2, AVX, F16C, AVX2, BMI2, FMA
- Performance impact: ~3-7% on newer CPUs (acceptable for K80 compatibility)

### Build System Updates
- Modified ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt for compute 3.7
- Added -Wno-deprecated-gpu-targets flag to suppress warnings
- Updated ml/backend/ggml/ggml/src/CMakeLists.txt for Alderlake without AVX_VNNI

### Upstream Sync
Merged latest llama.cpp changes including:
- Enhanced KV cache management with ISWA and hybrid memory support
- Improved multi-modal support (mtmd framework)
- New model architectures (Gemma3, Llama4, Qwen3, etc.)
- GPU backend improvements for CUDA, Metal, and ROCm
- Updated quantization support and GGUF format handling

### Documentation
- Updated CLAUDE.md with comprehensive build instructions
- Documented toolchain constraints and CPU architecture trade-offs
- Removed outdated CI/CD workflows (tesla-k80-*.yml)
- Cleaned up temporary development artifacts

## Rationale

This fork maintains Tesla K80 GPU support (compute 3.7) which was dropped in
official Ollama due to legacy driver/CUDA requirements. The toolchain constraint
creates a deadlock:
- K80 → Driver 470 → CUDA 11.4 → GCC 10 → No AVX_VNNI

We accept the loss of cutting-edge CPU optimizations to enable running modern
LLMs on legacy but still capable Tesla K80 hardware (12GB VRAM per GPU).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Shang Chieh Tseng
2025-11-05 14:03:05 +08:00
parent fabe2c5cb7
commit ef14fb5b26
817 changed files with 241634 additions and 70888 deletions

View File

@@ -1,21 +1,20 @@
// openai package provides middleware for partial compatibility with the OpenAI REST API
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
package openai
import (
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"slices"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
@@ -76,8 +75,10 @@ type JsonSchema struct {
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
Input any `json:"input"`
Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
}
type StreamOptions struct {
@@ -85,7 +86,7 @@ type StreamOptions struct {
}
type Reasoning struct {
Effort *string `json:"effort,omitempty"`
Effort string `json:"effort,omitempty"`
}
type ChatCompletionRequest struct {
@@ -103,16 +104,19 @@ type ChatCompletionRequest struct {
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
DebugRenderOnly bool `json:"_debug_render_only"`
}
type ChatCompletion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
}
type ChatCompletionChunk struct {
@@ -139,6 +143,7 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
DebugRenderOnly bool `json:"_debug_render_only"`
}
type Completion struct {
@@ -179,9 +184,9 @@ type Model struct {
}
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
Index int `json:"index"`
}
type ListCompletion struct {
@@ -215,11 +220,12 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toUsage(r api.ChatResponse) Usage {
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
@@ -232,7 +238,8 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}
func toToolCalls(tc []api.ToolCall) []ToolCall {
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
func ToToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(tc))
for i, tc := range tc {
toolCalls[i].ID = toolCallId()
@@ -251,8 +258,9 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := ToToolCalls(r.Message.ToolCalls)
return ChatCompletion{
Id: id,
Object: "chat.completion",
@@ -271,13 +279,14 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
return nil
}(r.DoneReason),
}},
Usage: toUsage(r),
}}, Usage: ToUsage(r),
DebugInfo: r.DebugInfo,
}
}
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := ToToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{
Id: id,
Object: "chat.completion.chunk",
@@ -300,15 +309,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
}
}
func toUsageGenerate(r api.GenerateResponse) Usage {
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
// ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
@@ -325,11 +336,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil
}(r.DoneReason),
}},
Usage: toUsageGenerate(r),
Usage: ToUsageGenerate(r),
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
@@ -349,7 +361,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
@@ -366,13 +379,22 @@ func toListCompletion(r api.ListResponse) ListCompletion {
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
// encodingFormat can be "float", "base64", or empty (defaults to "float")
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
var embedding any
if strings.EqualFold(encodingFormat, "base64") {
embedding = floatsToBase64(e)
} else {
embedding = e
}
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Embedding: embedding,
Index: i,
})
}
@@ -391,7 +413,15 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
// floatsToBase64 encodes a []float32 to a base64 string
func floatsToBase64(floats []float32) string {
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, floats)
return base64.StdEncoding.EncodeToString(buf.Bytes())
}
// ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
@@ -400,7 +430,8 @@ func toModel(r api.ShowResponse, m string) Model {
}
}
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message
for _, msg := range r.Messages {
toolName := ""
@@ -412,7 +443,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
switch content := msg.Content.(type) {
case string:
toolCalls, err := fromCompletionToolCall(msg.ToolCalls)
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil {
return nil, err
}
@@ -444,6 +475,11 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
types := []string{"jpeg", "jpg", "png", "webp"}
valid := false
// support blank mime type to match api/chat taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
valid = true
}
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
@@ -470,7 +506,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
// since we might have added multiple messages above, if we have tools
// calls we'll add them to the last message
if len(messages) > 0 && len(msg.ToolCalls) > 0 {
toolCalls, err := fromCompletionToolCall(msg.ToolCalls)
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
if err != nil {
return nil, err
}
@@ -541,10 +577,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["top_p"] = 1.0
}
if r.Reasoning != nil {
options["reasoning"] = *r.Reasoning.Effort
}
var format json.RawMessage
if r.ResponseFormat != nil {
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
@@ -559,20 +591,35 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
var think *api.ThinkValue
var effort string
if r.Reasoning != nil {
think = &api.ThinkValue{
Value: *r.Reasoning.Effort,
effort = r.Reasoning.Effort
} else if r.ReasoningEffort != nil {
effort = *r.ReasoningEffort
}
if effort != "" {
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
}
if effort == "none" {
think = &api.ThinkValue{Value: false}
} else {
think = &api.ThinkValue{Value: effort}
}
}
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
Think: think,
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
Think: think,
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
@@ -590,7 +637,8 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
return ""
}
func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
apiToolCalls := make([]api.ToolCall, len(toolCalls))
for i, tc := range toolCalls {
apiToolCalls[i].Function.Name = tc.Function.Name
@@ -603,7 +651,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
return apiToolCalls, nil
}
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
@@ -646,420 +695,11 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := toChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
c.Usage = &u
c.Choices = []ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}