1717                                                   MLACommonImpl ,
1818                                                   MLACommonMetadata ,
1919                                                   MLACommonMetadataBuilder )
20+ from  vllm .v1 .attention .backends .utils  import  AttentionCGSupport 
2021from  vllm .v1 .kv_cache_interface  import  AttentionSpec 
2122from  vllm .vllm_flash_attn  import  flash_attn_varlen_func , get_scheduler_metadata 
2223
2324logger  =  init_logger (__name__ )
2425
26+ # NOTE(matt): This is an arbitrary number, copied from 
27+ # woosuk's implementation in standard FlashAttention backend 
28+ _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH  =  16 
29+ 
2530
2631class  FlashAttnMLABackend (MLACommonBackend ):
2732
@@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
4853    max_query_len : int 
4954    max_seq_len : int 
5055    scheduler_metadata : Optional [torch .Tensor ] =  None 
56+     max_num_splits : int  =  0 
5157
5258
5359@dataclass  
@@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
5763
5864class  FlashAttnMLAMetadataBuilder (
5965        MLACommonMetadataBuilder [FlashAttnMLAMetadata ]):
66+     cudagraph_support : ClassVar [AttentionCGSupport ] =  \
67+         AttentionCGSupport .UNIFORM_BATCH 
68+ 
6069    reorder_batch_threshold : ClassVar [int ] =  512 
6170
6271    def  __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
6372                 vllm_config : VllmConfig , device : torch .device ):
6473        super ().__init__ (kv_cache_spec , layer_names , vllm_config , device ,
6574                         FlashAttnMLAMetadata )
75+         self .max_num_splits  =  0   # No upper bound on the number of splits. 
6676        self .fa_aot_schedule  =  (get_flash_attn_version () ==  3 )
6777
78+         self .use_full_cuda_graph  =  \
79+             self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
80+ 
81+         if  self .use_full_cuda_graph  and  self .fa_aot_schedule :
82+             self .max_cudagraph_size  =  self .compilation_config .max_capture_size 
83+ 
84+             if  self .max_cudagraph_size  >  992 :
85+                 # This condition derives from FA3's internal heuristic. 
86+                 # TODO(woosuk): Support larger cudagraph sizes. 
87+                 raise  ValueError (
88+                     "Capture size larger than 992 is not supported for " 
89+                     "full cuda graph." )
90+ 
91+             self .scheduler_metadata  =  torch .zeros (
92+                 vllm_config .scheduler_config .max_num_seqs  +  1 ,
93+                 dtype = torch .int32 ,
94+                 device = self .device ,
95+             )
96+             # When using cuda graph, we need to set the upper bound of the 
97+             # number of splits so that large enough intermediate buffers are 
98+             # pre-allocated during capture. 
99+             self .max_num_splits  =  _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH 
100+ 
68101    def  _schedule_decode (self , num_reqs , cu_query_lens , max_query_len , seqlens ,
69102                         max_seq_len , causal ):
70103        if  self .fa_aot_schedule :
@@ -81,14 +114,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
81114                page_size = self .page_size ,
82115                cu_seqlens_q = cu_query_lens ,
83116                causal = causal ,
117+                 num_splits = self .max_num_splits ,
84118            )
85119        return  None 
86120
87-     def  _build_decode (
88-             self , block_table_tensor : torch .Tensor , seq_lens_cpu : torch .Tensor ,
89-             seq_lens_device : torch .Tensor , query_start_loc_cpu : torch .Tensor ,
90-             query_start_loc_device : torch .Tensor 
91-     ) ->  FlashAttnMLADecodeMetadata :
121+     def  _build_decode (self , block_table_tensor : torch .Tensor ,
122+                       seq_lens_cpu : torch .Tensor ,
123+                       seq_lens_device : torch .Tensor ,
124+                       query_start_loc_cpu : torch .Tensor ,
125+                       query_start_loc_device : torch .Tensor ,
126+                       num_decode_tokens : int ) ->  FlashAttnMLADecodeMetadata :
92127        query_lens_cpu  =  (query_start_loc_cpu [1 :] -  query_start_loc_cpu [:- 1 ])
93128        max_query_len  =  query_lens_cpu .max ().item ()
94129        max_seq_len  =  seq_lens_cpu .max ().item ()
@@ -102,13 +137,37 @@ def _build_decode(
102137            causal = True ,
103138        )
104139
140+         # For FA3 + full cudagraph 
141+         max_num_splits  =  0 
142+         if  self .use_full_cuda_graph  and  scheduler_metadata  is  not None :
143+             n  =  scheduler_metadata .shape [0 ]
144+             # Ensure the persistent buffer is large enough 
145+             assert  n  <=  self .scheduler_metadata .shape [0 ], \
146+                 f"Scheduler metadata size { n }   +  \
147+                 f"{ self .scheduler_metadata .shape [0 ]}  
148+             self .scheduler_metadata [:n ] =  scheduler_metadata 
149+             # NOTE(woosuk): We should zero out the rest of the scheduler 
150+             # metadata to guarantee the correctness. Otherwise, some thread 
151+             # blocks may use the invalid scheduler metadata and overwrite the 
152+             # output buffer. 
153+             self .scheduler_metadata [n :] =  0 
154+             scheduler_metadata  =  self .scheduler_metadata [:n ]
155+ 
156+             if  num_decode_tokens  <=  self .max_cudagraph_size :
157+                 # NOTE(woosuk): Setting num_splits > 1 may increase the memory 
158+                 # usage, because the intermediate buffers of size [num_splits, 
159+                 # num_heads, num_tokens, head_size] are allocated. Therefore, 
160+                 # we only set num_splits when using cuda graphs. 
161+                 max_num_splits  =  self .max_num_splits 
162+ 
105163        return  FlashAttnMLADecodeMetadata (
106164            block_table = block_table_tensor ,
107165            seq_lens = seq_lens_device ,
108166            query_start_loc = query_start_loc_device ,
109167            max_query_len = max_query_len ,
110168            max_seq_len = max_seq_len ,
111169            scheduler_metadata = scheduler_metadata ,
170+             max_num_splits = max_num_splits ,
112171        )
113172
114173
@@ -175,12 +234,17 @@ def _forward_decode(
175234        kv_c_cache  =  kv_c_and_k_pe_cache [..., :self .kv_lora_rank ]
176235        k_pe_cache  =  kv_c_and_k_pe_cache [..., self .kv_lora_rank :]
177236
237+         # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the 
238+         # kernel uses this to calculate grid dimensions. Ensure it's at least 1 
239+         # to prevent invalid grid configuration during graph capture. 
240+         max_seqlen_q  =  max (attn_metadata .decode .max_query_len , 1 )
241+ 
178242        o  =  flash_attn_varlen_func (
179243            q = q_pe ,
180244            k = k_pe_cache .unsqueeze (- 2 ),  # Add head dim of 1 
181245            v = kv_c_cache .unsqueeze (- 2 ),  # Add head dim of 1 
182246            q_v = q_nope ,
183-             max_seqlen_q = attn_metadata . decode . max_query_len ,
247+             max_seqlen_q = max_seqlen_q ,
184248            cu_seqlens_q = attn_metadata .decode .query_start_loc ,
185249            max_seqlen_k = attn_metadata .decode .max_seq_len ,
186250            seqused_k = attn_metadata .decode .seq_lens ,
@@ -189,6 +253,7 @@ def _forward_decode(
189253            causal = True ,
190254            fa_version = 3 ,  # only version 3 is supported 
191255            scheduler_metadata = attn_metadata .decode .scheduler_metadata ,
256+             num_splits = attn_metadata .decode .max_num_splits ,
192257        )
193258
194259        return  self ._v_up_proj (o )
0 commit comments