pass model and predict options

This commit is contained in:
Patrick Devine
2023-07-06 17:09:48 -07:00
parent 7974ad58ff
commit 3f1b7177f2
3 changed files with 151 additions and 16 deletions

View File

@@ -26,12 +26,9 @@ var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func generate(c *gin.Context) {
// TODO: these should be request parameters
gpulayers := 1
tokens := 512
threads := runtime.NumCPU()
var req api.GenerateRequest
req.ModelOptions = api.DefaultModelOptions
req.PredictOptions = api.DefaultPredictOptions
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
@@ -41,7 +38,10 @@ func generate(c *gin.Context) {
req.Model = remoteModel.FullName()
}
model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
modelOpts := getModelOpts(req)
modelOpts.NGPULayers = 1 // hard-code this for now
model, err := llama.New(req.Model, modelOpts)
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
return
@@ -65,13 +65,16 @@ func generate(c *gin.Context) {
}
ch := make(chan string)
model.SetTokenCallback(func(token string) bool {
ch <- token
return true
})
predictOpts := getPredictOpts(req)
go func() {
defer close(ch)
_, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
ch <- token
return true
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
_, err := model.Predict(req.Prompt, predictOpts)
if err != nil {
panic(err)
}
@@ -161,3 +164,53 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
return
}
func getModelOpts(req api.GenerateRequest) llama.ModelOptions {
var opts llama.ModelOptions
opts.ContextSize = req.ModelOptions.ContextSize
opts.Seed = req.ModelOptions.Seed
opts.F16Memory = req.ModelOptions.F16Memory
opts.MLock = req.ModelOptions.MLock
opts.Embeddings = req.ModelOptions.Embeddings
opts.MMap = req.ModelOptions.MMap
opts.LowVRAM = req.ModelOptions.LowVRAM
opts.NBatch = req.ModelOptions.NBatch
opts.VocabOnly = req.ModelOptions.VocabOnly
opts.NUMA = req.ModelOptions.NUMA
opts.NGPULayers = req.ModelOptions.NGPULayers
opts.MainGPU = req.ModelOptions.MainGPU
opts.TensorSplit = req.ModelOptions.TensorSplit
return opts
}
func getPredictOpts(req api.GenerateRequest) llama.PredictOptions {
var opts llama.PredictOptions
if req.PredictOptions.Threads == -1 {
opts.Threads = runtime.NumCPU()
} else {
opts.Threads = req.PredictOptions.Threads
}
opts.Seed = req.PredictOptions.Seed
opts.Tokens = req.PredictOptions.Tokens
opts.Penalty = req.PredictOptions.Penalty
opts.Repeat = req.PredictOptions.Repeat
opts.Batch = req.PredictOptions.Batch
opts.NKeep = req.PredictOptions.NKeep
opts.TopK = req.PredictOptions.TopK
opts.TopP = req.PredictOptions.TopP
opts.TailFreeSamplingZ = req.PredictOptions.TailFreeSamplingZ
opts.TypicalP = req.PredictOptions.TypicalP
opts.Temperature = req.PredictOptions.Temperature
opts.FrequencyPenalty = req.PredictOptions.FrequencyPenalty
opts.PresencePenalty = req.PredictOptions.PresencePenalty
opts.Mirostat = req.PredictOptions.Mirostat
opts.MirostatTAU = req.PredictOptions.MirostatTAU
opts.MirostatETA = req.PredictOptions.MirostatETA
opts.MMap = req.PredictOptions.MMap
return opts
}