mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
sample: add sampling package for new engine (#8410)
This commit is contained in:
@@ -65,8 +65,8 @@ type Sequence struct {
|
||||
// number of tokens to predict
|
||||
numPredict int
|
||||
|
||||
// set of samplers to run on generated logits
|
||||
samplers []sample.Sampler
|
||||
// sampler with transforms to run on generated logits
|
||||
sampler sample.Sampler
|
||||
|
||||
// channel to send back the embedding if embedding only
|
||||
embedding chan []float32
|
||||
@@ -93,7 +93,7 @@ type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
samplers []sample.Sampler
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplers: params.samplers,
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
@@ -393,13 +393,7 @@ func (s *Server) processBatch() error {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
f32s := modelOutput.Floats()
|
||||
|
||||
// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
|
||||
logits := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
logits[i] = float64(f32)
|
||||
}
|
||||
logits := modelOutput.Floats()
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
@@ -433,14 +427,12 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(f32s) / len(options.Outputs)
|
||||
tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
|
||||
// TODO(jessegross): Sampler will output a single int32 in the future
|
||||
token := int32(tokens[0])
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
@@ -565,27 +557,6 @@ type CompletionResponse struct {
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func getSamplers(_ CompletionRequest) []sample.Sampler {
|
||||
// TODO(jessegross): Waiting for sampling code
|
||||
|
||||
/*samplingParams.TopK = req.TopK
|
||||
samplingParams.TopP = req.TopP
|
||||
samplingParams.MinP = req.MinP
|
||||
samplingParams.TypicalP = req.TypicalP
|
||||
samplingParams.Temp = req.Temperature
|
||||
samplingParams.RepeatLastN = req.RepeatLastN
|
||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||
samplingParams.Mirostat = req.Mirostat
|
||||
samplingParams.MirostatTau = req.MirostatTau
|
||||
samplingParams.MirostatEta = req.MirostatEta
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar*/
|
||||
|
||||
return []sample.Sampler{sample.Greedy()}
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
@@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sampler, err := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
samplers: getSamplers(req),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user