Skip to content

Commit bd7599d

Browse files
authored
[V1][TPU] Do not compile sampling more than needed (#15883)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent 01b6113 commit bd7599d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,9 @@ def capture_model(self) -> None:
862862
out = self.model.sample_from_hidden(dummy_hidden,
863863
sampling_meta)
864864
out = out.cpu()
865-
if num_reqs_to_sample >= self.max_num_reqs:
865+
# Requests can't be more than tokens. But do compile for the
866+
# next bigger value in case num_tokens uses bucketed padding.
867+
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
866868
break
867869
# Make sure to compile the `max_num_reqs` upper-limit case
868870
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(

0 commit comments

Comments
 (0)