Skip to content

Commit 6eacbb5

Browse files
committed
v1/spec_decode/eagle: make dummy_run CUDA graph gating robust with cudagraphs_enabled
Signed-off-by: xiaohajiayou <923390377@qq.com>
1 parent 5fc3033 commit 6eacbb5

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,21 +1050,19 @@ def dummy_run(
10501050
num_tokens: int,
10511051
use_cudagraphs=True,
10521052
) -> None:
1053-
if (
1054-
use_cudagraphs
1055-
and self.use_cuda_graph
1056-
and num_tokens <= self.cudagraph_batch_sizes[-1]
1057-
):
1053+
# Determine if CUDA graphs should be used for this run.
1054+
cudagraphs_enabled = (
1055+
use_cudagraphs and self.use_cuda_graph and bool(self.cudagraph_batch_sizes)
1056+
)
1057+
if cudagraphs_enabled and num_tokens <= self.cudagraph_batch_sizes[-1]:
10581058
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
10591059

10601060
with set_forward_context(
10611061
None,
10621062
self.vllm_config,
10631063
num_tokens=num_tokens,
10641064
cudagraph_runtime_mode=(
1065-
CUDAGraphMode.PIECEWISE
1066-
if (use_cudagraphs and self.use_cuda_graph)
1067-
else CUDAGraphMode.NONE
1065+
CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
10681066
),
10691067
):
10701068
if self.supports_mm_inputs:

0 commit comments

Comments
 (0)