diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 12233af057b0..74eb9ae9d325 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -209,7 +209,8 @@ def build( # type: ignore[override] # prepare tensors for cudagraph if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 - and num_spec_decodes <= self.decode_cudagraph_max_bs): + and num_spec_decodes <= self.decode_cudagraph_max_bs + and m.num_actual_tokens <= self.decode_cudagraph_max_bs): num_total_tokens = self.vllm_config.pad_for_cudagraph( m.num_actual_tokens) batch_size = num_total_tokens // (self.num_spec + 1)