sample: do all sorting in topK

This commit is contained in:
ParthSareen
2025-03-12 13:40:25 -04:00
committed by Parth Sareen
parent 3ba91634c1
commit 4aeb67ef4c
3 changed files with 35 additions and 25 deletions

View File

@@ -84,11 +84,8 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return greedy(tokens), nil
}
if s.topK > 0 {
tokens = topK(tokens, s.topK)
} else {
sortLogits(tokens)
}
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
// token logit values are updated to probabilities
tokens = temperature(tokens, s.temperature)