sample: add sampling package for new engine (#8410)

This commit is contained in:
Parth Sareen
2025-02-24 17:19:01 -08:00
committed by GitHub
parent 314573bfe8
commit 0b7e1676eb
7 changed files with 600 additions and 127 deletions

View File

@@ -65,8 +65,8 @@ type Sequence struct {
// number of tokens to predict
numPredict int
// set of samplers to run on generated logits
samplers []sample.Sampler
// sampler with transforms to run on generated logits
sampler sample.Sampler
// channel to send back the embedding if embedding only
embedding chan []float32
@@ -93,7 +93,7 @@ type NewSequenceParams struct {
numPredict int
stop []string
numKeep int32
samplers []sample.Sampler
sampler sample.Sampler
embedding bool
}
@@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplers: params.samplers,
sampler: params.sampler,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
@@ -393,13 +393,7 @@ func (s *Server) processBatch() error {
return fmt.Errorf("failed to decode batch: %w", err)
}
f32s := modelOutput.Floats()
// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
logits := make([]float64, len(f32s))
for i, f32 := range f32s {
logits[i] = float64(f32)
}
logits := modelOutput.Floats()
for i, seq := range s.seqs {
if seq == nil {
@@ -433,14 +427,12 @@ func (s *Server) processBatch() error {
}
// sample a token
vocabSize := len(f32s) / len(options.Outputs)
tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
if err != nil {
return err
}
vocabSize := len(logits) / len(options.Outputs)
// TODO(jessegross): Sampler will output a single int32 in the future
token := int32(tokens[0])
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
@@ -565,27 +557,6 @@ type CompletionResponse struct {
Timings Timings `json:"timings"`
}
func getSamplers(_ CompletionRequest) []sample.Sampler {
// TODO(jessegross): Waiting for sampling code
/*samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar*/
return []sample.Sampler{sample.Greedy()}
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
@@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
sampler, err := sample.NewSampler(
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
return
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
samplers: getSamplers(req),
sampler: sampler,
embedding: false,
})
if err != nil {