1212
1313from vllm .attention import AttentionType , get_attn_backend
1414from vllm .attention .layer import Attention
15+ from vllm .attention .utils .fa_utils import get_flash_attn_version
1516from vllm .config import (CompilationLevel , VllmConfig ,
1617 get_layers_from_vllm_config )
1718from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
@@ -139,6 +140,16 @@ def __init__(
139140 raise NotImplementedError (
140141 "Non-Attention backend is not supported by V1 GPUModelRunner." )
141142
143+ if self .vllm_config .compilation_config .full_cuda_graph :
144+ attn_backend_name = self .attn_backend .__name__
145+ flash_attn_version = get_flash_attn_version ()
146+ if attn_backend_name != "FlashAttentionBackend" or \
147+ flash_attn_version != 3 :
148+ raise ValueError (
149+ f"full_cuda_graph is only supported with "
150+ f"FA3. Current attention backend is { attn_backend_name } , "
151+ f"FlashAttention version is { flash_attn_version } ." )
152+
142153 self .attn_metadata_builder = self .attn_backend .get_builder_cls ()(
143154 weakref .proxy (self ))
144155 self .cascade_attn_enabled = not self .model_config .disable_cascade_attn
@@ -219,6 +230,16 @@ def __init__(
219230 self .positions = torch .zeros (self .max_num_tokens ,
220231 dtype = torch .int64 ,
221232 device = self .device )
233+ self .query_start_loc = torch .zeros (self .max_num_reqs + 1 ,
234+ dtype = torch .int32 ,
235+ device = self .device )
236+ self .seq_lens = torch .zeros (self .max_num_reqs ,
237+ dtype = torch .int32 ,
238+ device = self .device )
239+ self .slot_mapping = torch .zeros (self .max_num_tokens ,
240+ dtype = torch .int64 ,
241+ device = self .device )
242+
222243 # None in the first PP rank. The rest are set after load_model.
223244 self .intermediate_tensors : Optional [IntermediateTensors ] = None
224245
@@ -271,7 +292,7 @@ def __init__(
271292 pin_memory = self .pin_memory )
272293 self .positions_np = self .positions_cpu .numpy ()
273294 self .slot_mapping_cpu = torch .zeros (self .max_num_tokens ,
274- dtype = torch .int32 ,
295+ dtype = torch .int64 ,
275296 device = "cpu" ,
276297 pin_memory = self .pin_memory )
277298 self .slot_mapping_np = self .slot_mapping_cpu .numpy ()
@@ -589,10 +610,22 @@ def _prepare_inputs(
589610 self .positions_cpu [:total_num_scheduled_tokens ],
590611 non_blocking = True )
591612
592- query_start_loc = self .query_start_loc_cpu [:num_reqs + 1 ].to (
593- self .device , non_blocking = True )
594- seq_lens = self .seq_lens_cpu [:num_reqs ].to (self .device ,
595- non_blocking = True )
613+ self .query_start_loc [:num_reqs + 1 ].copy_ (
614+ self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
615+ self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
616+ non_blocking = True )
617+ self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
618+ self .slot_mapping_cpu [:total_num_scheduled_tokens ],
619+ non_blocking = True )
620+
621+ # Fill unused with -1. Needed for reshape_and_cache
622+ self .slot_mapping [total_num_scheduled_tokens :].fill_ (- 1 )
623+ self .seq_lens [num_reqs :].fill_ (0 )
624+ self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
625+
626+ query_start_loc = self .query_start_loc [:num_reqs + 1 ]
627+ seq_lens = self .seq_lens [:num_reqs ]
628+
596629 common_attn_metadata = CommonAttentionMetadata (
597630 query_start_loc = query_start_loc , seq_lens = seq_lens )
598631
@@ -1478,6 +1511,7 @@ def _get_prompt_logprobs_dict(
14781511 def _dummy_run (
14791512 self ,
14801513 num_tokens : int ,
1514+ skip_attn : bool = True ,
14811515 ) -> torch .Tensor :
14821516
14831517 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
@@ -1494,6 +1528,23 @@ def _dummy_run(
14941528 num_scheduled_tokens = np .array (num_scheduled_tokens_list ,
14951529 dtype = np .int32 )
14961530
1531+ if skip_attn :
1532+ attn_metadata = None
1533+ else :
1534+ query_start_loc = self .query_start_loc [:num_reqs + 1 ]
1535+ seq_lens = self .seq_lens [:num_reqs ]
1536+
1537+ common_attn_metadata = CommonAttentionMetadata (
1538+ query_start_loc = query_start_loc , seq_lens = seq_lens )
1539+
1540+ attn_metadata = self .attn_metadata_builder .build (
1541+ num_reqs = num_tokens ,
1542+ num_actual_tokens = num_tokens ,
1543+ max_query_len = num_tokens ,
1544+ common_prefix_len = 0 ,
1545+ common_attn_metadata = common_attn_metadata ,
1546+ )
1547+
14971548 with self .maybe_dummy_run_with_lora (self .lora_config ,
14981549 num_scheduled_tokens ):
14991550 model = self .model
@@ -1522,7 +1573,7 @@ def _dummy_run(
15221573 for k , v in self .intermediate_tensors .items ()
15231574 })
15241575
1525- with set_forward_context (None ,
1576+ with set_forward_context (attn_metadata ,
15261577 self .vllm_config ,
15271578 num_tokens = num_tokens ):
15281579 outputs = model (
@@ -1708,11 +1759,12 @@ def capture_model(self) -> None:
17081759 # Capture the large shapes first so that the smaller shapes
17091760 # can reuse the memory pool allocated for the large shapes.
17101761 with graph_capture (device = self .device ):
1762+ skip_attn = not self .vllm_config .compilation_config .full_cuda_graph
17111763 for num_tokens in reversed (self .cudagraph_batch_sizes ):
17121764 for _ in range (self .vllm_config .compilation_config .
17131765 cudagraph_num_of_warmups ):
1714- self ._dummy_run (num_tokens )
1715- self ._dummy_run (num_tokens )
1766+ self ._dummy_run (num_tokens , skip_attn = skip_attn )
1767+ self ._dummy_run (num_tokens , skip_attn = skip_attn )
17161768
17171769 end_time = time .perf_counter ()
17181770 end_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
0 commit comments