ml: Allow models to constrain inputs to a single batch

Models may require that a set of inputs all be processed as part
of the same batch. For example, if an image has multiple patches
with fully connected attention between them, we should not split
the batch in the middle of an image.

Fixes #9697
This commit is contained in:
Jesse Gross
2025-03-12 16:56:11 -07:00
committed by Jesse Gross
parent 3892c3a703
commit 9679f40146
5 changed files with 64 additions and 66 deletions

View File

@@ -352,6 +352,8 @@ func (s *Server) processBatch() error {
seq.cache.Inputs = []input.Input{}
}
batchSize := s.batchSize
for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
@@ -364,7 +366,15 @@ func (s *Server) processBatch() error {
}
}
if j >= s.batchSize {
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
}
if len(seq.pendingInputs)+minBatch > batchSize {
break
}