@@ -31,6 +31,7 @@ class GDNAttentionMetadata:
3131    num_decode_tokens : int 
3232    num_spec_decodes : int 
3333    num_spec_decode_tokens : int 
34+     num_actual_tokens : int 
3435
3536    has_initial_state : Optional [torch .Tensor ] =  None 
3637
@@ -205,25 +206,22 @@ def build(  # type: ignore[override]
205206                has_initial_state  =  has_initial_state [~ spec_sequence_masks ]
206207        else :
207208            has_initial_state  =  None 
209+         num_actual_tokens  =  num_prefill_tokens  +  num_decode_tokens  +  \
210+             num_spec_decode_tokens 
208211
209212        # prepare tensors for cudagraph 
210213        # 
211214        # With speculative decoding, the xgrammar backend may rollback tokens 
212215        # and causing some sequences has less draft tokens than self.num_spec. 
213216        # 
214-         # During cudagraph capture, the GDN backends requires an assumption 
215-         # that num_spec_decode_tokens == num_spec_decodes * (self.num_spec + 1). 
216-         # 
217-         # More than one such sequences may break the assumption (less tokens), 
218-         # causing incompatible inputs for cuda graph replay. 
217+         # In above cases, the max possible batch size for n tokens, can be 
218+         # min(n, cudagraph_max_bs). 
219219        if  (self .use_full_cuda_graph  and  num_prefills  ==  0  and  num_decodes  ==  0 
220220                and  num_spec_decodes  <=  self .decode_cudagraph_max_bs 
221-                 and  num_spec_decode_tokens  <=  self .decode_cudagraph_max_bs 
222-                 and  num_spec_decode_tokens  ==  num_spec_decodes  * 
223-             (self .num_spec  +  1 )):
224-             num_total_tokens  =  self .vllm_config .pad_for_cudagraph (
221+                 and  num_spec_decode_tokens  <=  self .decode_cudagraph_max_bs ):
222+             num_actual_tokens  =  self .vllm_config .pad_for_cudagraph (
225223                m .num_actual_tokens )
226-             batch_size  =  num_total_tokens   //   (self .num_spec   +   1 )
224+             batch_size  =  min (self .decode_cudagraph_max_bs ,  num_actual_tokens )
227225
228226            self .spec_state_indices_tensor [:num_spec_decodes ].copy_ (
229227                spec_state_indices_tensor , non_blocking = True )
@@ -239,7 +237,7 @@ def build(  # type: ignore[override]
239237            assert  spec_token_masks  is  not   None 
240238            self .spec_token_masks [:spec_token_masks .size (0 )].copy_ (
241239                spec_token_masks , non_blocking = True )
242-             spec_token_masks  =  self .spec_token_masks [:m . num_actual_tokens ]
240+             spec_token_masks  =  self .spec_token_masks [:num_actual_tokens ]
243241            spec_token_masks [spec_token_masks .size (0 ):].fill_ (False )
244242
245243            self .spec_query_start_loc [:num_spec_decodes  +  1 ].copy_ (
@@ -258,9 +256,9 @@ def build(  # type: ignore[override]
258256        if  (self .use_full_cuda_graph  and  num_prefills  ==  0 
259257                and  num_spec_decodes  ==  0 
260258                and  num_decodes  <=  self .decode_cudagraph_max_bs ):
261-             num_total_tokens  =  self .vllm_config .pad_for_cudagraph (
259+             num_actual_tokens  =  self .vllm_config .pad_for_cudagraph (
262260                m .num_actual_tokens )
263-             batch_size  =  num_total_tokens 
261+             batch_size  =  num_actual_tokens 
264262
265263            self .non_spec_state_indices_tensor [:num_decodes ].copy_ (
266264                non_spec_state_indices_tensor , non_blocking = True )
@@ -284,6 +282,7 @@ def build(  # type: ignore[override]
284282            num_decode_tokens = num_decode_tokens ,
285283            num_spec_decodes = num_spec_decodes ,
286284            num_spec_decode_tokens = num_spec_decode_tokens ,
285+             num_actual_tokens = num_actual_tokens ,
287286            has_initial_state = has_initial_state ,
288287            spec_query_start_loc = spec_query_start_loc ,
289288            non_spec_query_start_loc = non_spec_query_start_loc ,
0 commit comments