2929import numpy .typing as npt
3030import torch
3131import torch .nn as nn
32+ from torch .distributed import ReduceOp
3233from vllm .attention import AttentionType , get_attn_backend
3334from vllm .attention .layer import Attention
3435from vllm .config import CompilationLevel , VllmConfig
5960
6061from vllm_ascend .attention .attention import AttentionMaskBuilder
6162from vllm_ascend .attention .attention_v1 import AscendAttentionState
63+ from vllm_ascend .patch .platform .patch_common .patch_distributed import \
64+ get_dp_group
6265from vllm_ascend .platform import NPUPlatform
6366from vllm_ascend .sample .rejection_sampler import AscendRejectionSampler
6467from vllm_ascend .utils import vllm_version_is
@@ -355,6 +358,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
355358 False ) and self .vllm_config .model_config .use_mla
356359 self .use_cached_npu_graph = additional_config .get (
357360 "use_cached_npu_graph" , False )
361+ self .has_prefilled = False
362+ self .dp_group = get_dp_group ()
358363
359364 def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
360365 """Update the cached states and the persistent batch with the scheduler
@@ -659,6 +664,22 @@ def _process_reqs(
659664 device = input_ids .device )
660665 input_ids = torch .cat ([input_ids , padding ])
661666 positions = torch .cat ([positions , padding ])
667+ if self .has_prefilled and not attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
668+ self .has_prefilled = False
669+ if not self .has_prefilled and self .enable_torchair_graph_mode :
670+ self .has_prefilled = self .has_prefilled_all_rank (
671+ attn_metadata .attn_state == AscendAttentionState .DecodeOnly )
672+
673+ if self .dp_group :
674+ while not self .has_prefilled and self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
675+ self ._dummy_run (1 )
676+ tensor = torch .tensor ([1 ], dtype = torch .int32 , device = "cpu" )
677+ torch .distributed .all_reduce (tensor ,
678+ op = ReduceOp .MAX ,
679+ group = self .dp_group )
680+ self .has_prefilled = self .has_prefilled_all_rank (
681+ attn_metadata .attn_state ==
682+ AscendAttentionState .DecodeOnly )
662683
663684 # Run forward pass
664685 with set_forward_context (attn_metadata ,
@@ -668,7 +689,7 @@ def _process_reqs(
668689 if self .enable_torchair_graph_mode :
669690 model_kwargs ["kv_caches" ] = self .kv_caches
670691 model_kwargs ["attn_metadata" ] = attn_metadata
671- if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
692+ if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly and self . has_prefilled :
672693 torch ._dynamo .mark_static (input_ids )
673694 torch ._dynamo .mark_static (positions )
674695 torch ._dynamo .mark_static (attn_metadata .decode .block_table )
@@ -796,6 +817,15 @@ def _calc_spec_decode_metadata(
796817 )
797818 return metadata
798819
820+ def has_prefilled_all_rank (self , has_prefilled : bool ) -> bool :
821+ tensor = torch .tensor ([has_prefilled ], dtype = torch .int32 , device = "cpu" )
822+ if self .dp_group :
823+ torch .distributed .all_reduce (tensor ,
824+ op = ReduceOp .MIN ,
825+ group = self .dp_group )
826+ aggregated_has_prefilled = bool (tensor .item ())
827+ return aggregated_has_prefilled
828+
799829 def apply_grammar_bitmask (
800830 self ,
801831 scheduler_output : "SchedulerOutput" ,
@@ -1063,7 +1093,11 @@ def _profile_multimodal(self) -> None:
10631093 self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
10641094
10651095 @torch .inference_mode ()
1066- def _dummy_run (self , num_tokens : int ) -> torch .Tensor :
1096+ def _dummy_run (
1097+ self ,
1098+ num_tokens : int ,
1099+ attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
1100+ ) -> torch .Tensor :
10671101 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
10681102 # for dummy run with LoRA so that the num_reqs collectively
10691103 # has num_tokens in total.
@@ -1107,11 +1141,32 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
11071141 })
11081142
11091143 with set_forward_context (None , self .vllm_config ):
1110- hidden_states = model (
1111- input_ids = input_ids ,
1112- positions = positions ,
1113- intermediate_tensors = intermediate_tensors ,
1114- inputs_embeds = inputs_embeds )
1144+ if self .enable_torchair_graph_mode and attn_state == AscendAttentionState .DecodeOnly :
1145+ attn_metadata = self .attn_metadata_builder .dummy_build (
1146+ num_reqs = num_tokens , num_actual_tokens = 1 )
1147+ torch ._dynamo .mark_static (input_ids )
1148+ torch ._dynamo .mark_static (positions )
1149+ torch ._dynamo .mark_static (attn_metadata .decode .block_table )
1150+ torch ._dynamo .mark_static (attn_metadata .decode .input_positions )
1151+ torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1152+ for kv in self .kv_caches :
1153+ assert isinstance (kv , tuple ), "kv_cache must be a tuple"
1154+ torch ._dynamo .mark_static (kv [0 ])
1155+ torch ._dynamo .mark_static (kv [1 ])
1156+ hidden_states = self .compile_model (
1157+ input_ids = input_ids ,
1158+ positions = positions ,
1159+ intermediate_tensors = intermediate_tensors ,
1160+ inputs_embeds = None ,
1161+ kv_caches = self .kv_caches ,
1162+ attn_metadata = attn_metadata ,
1163+ )
1164+ else :
1165+ hidden_states = model (
1166+ input_ids = input_ids ,
1167+ positions = positions ,
1168+ intermediate_tensors = intermediate_tensors ,
1169+ inputs_embeds = inputs_embeds )
11151170 return hidden_states
11161171
11171172 def profile_run (self ) -> None :
0 commit comments