2929import numpy .typing as npt
3030import torch
3131import torch .nn as nn
32+ from torch .distributed import ReduceOp
3233from vllm .attention import AttentionType , get_attn_backend
3334from vllm .attention .layer import Attention
3435from vllm .config import CompilationLevel , VllmConfig
5960
6061from vllm_ascend .attention .attention import AttentionMaskBuilder
6162from vllm_ascend .attention .attention_v1 import AscendAttentionState
63+ from vllm_ascend .patch .platform .patch_common .patch_distributed import \
64+ get_dp_group
6265from vllm_ascend .platform import NPUPlatform
6366from vllm_ascend .sample .rejection_sampler import AscendRejectionSampler
6467from vllm_ascend .utils import vllm_version_is
@@ -328,6 +331,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
328331 False ) and self .vllm_config .model_config .use_mla
329332 self .use_cached_npu_graph = additional_config .get (
330333 "use_cached_npu_graph" , False )
334+ self .has_prefilled = False
335+ self .dp_group = get_dp_group ()
331336
332337 def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
333338 """Update the cached states and the persistent batch with the scheduler
@@ -635,6 +640,22 @@ def _process_reqs(
635640 device = input_ids .device )
636641 input_ids = torch .cat ([input_ids , padding ])
637642 positions = torch .cat ([positions , padding ])
643+ if self .has_prefilled and not attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
644+ self .has_prefilled = False
645+ if not self .has_prefilled and self .enable_torchair_graph_mode :
646+ self .has_prefilled = self .has_prefilled_all_rank (
647+ attn_metadata .attn_state == AscendAttentionState .DecodeOnly )
648+
649+ if self .dp_group :
650+ while not self .has_prefilled and self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
651+ self ._dummy_run (1 )
652+ tensor = torch .tensor ([1 ], dtype = torch .int32 , device = "cpu" )
653+ torch .distributed .all_reduce (tensor ,
654+ op = ReduceOp .MAX ,
655+ group = self .dp_group )
656+ self .has_prefilled = self .has_prefilled_all_rank (
657+ attn_metadata .attn_state ==
658+ AscendAttentionState .DecodeOnly )
638659
639660 # Run forward pass
640661 with set_forward_context (attn_metadata ,
@@ -644,7 +665,7 @@ def _process_reqs(
644665 if self .enable_torchair_graph_mode :
645666 model_kwargs ["kv_caches" ] = self .kv_caches
646667 model_kwargs ["attn_metadata" ] = attn_metadata
647- if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
668+ if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly and self . has_prefilled :
648669 torch ._dynamo .mark_static (input_ids )
649670 torch ._dynamo .mark_static (positions )
650671 torch ._dynamo .mark_static (attn_metadata .decode .block_table )
@@ -772,6 +793,15 @@ def _calc_spec_decode_metadata(
772793 )
773794 return metadata
774795
796+ def has_prefilled_all_rank (self , has_prefilled : bool ) -> bool :
797+ tensor = torch .tensor ([has_prefilled ], dtype = torch .int32 , device = "cpu" )
798+ if self .dp_group :
799+ torch .distributed .all_reduce (tensor ,
800+ op = ReduceOp .MIN ,
801+ group = self .dp_group )
802+ aggregated_has_prefilled = bool (tensor .item ())
803+ return aggregated_has_prefilled
804+
775805 def apply_grammar_bitmask (
776806 self ,
777807 scheduler_output : "SchedulerOutput" ,
@@ -1039,7 +1069,11 @@ def _profile_multimodal(self) -> None:
10391069 self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
10401070
10411071 @torch .inference_mode ()
1042- def _dummy_run (self , num_tokens : int ) -> torch .Tensor :
1072+ def _dummy_run (
1073+ self ,
1074+ num_tokens : int ,
1075+ attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
1076+ ) -> torch .Tensor :
10431077 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
10441078 # for dummy run with LoRA so that the num_reqs collectively
10451079 # has num_tokens in total.
@@ -1083,11 +1117,34 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
10831117 })
10841118
10851119 with set_forward_context (None , self .vllm_config ):
1086- hidden_states = model (
1087- input_ids = input_ids ,
1088- positions = positions ,
1089- intermediate_tensors = intermediate_tensors ,
1090- inputs_embeds = inputs_embeds )
1120+ if self .enable_torchair_graph_mode and attn_state == AscendAttentionState .DecodeOnly :
1121+ attn_metadata = self .attn_metadata_builder .dummy_build (
1122+ num_reqs = num_tokens , num_actual_tokens = 1 )
1123+ torch ._dynamo .mark_static (input_ids )
1124+ torch ._dynamo .mark_static (positions )
1125+ torch ._dynamo .mark_static (attn_metadata .decode .block_table )
1126+ torch ._dynamo .mark_static (
1127+ attn_metadata .decode .input_positions )
1128+ torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1129+ for kv in self .kv_caches :
1130+ assert isinstance (kv ,
1131+ tuple ), "kv_cache must be a tuple"
1132+ torch ._dynamo .mark_static (kv [0 ])
1133+ torch ._dynamo .mark_static (kv [1 ])
1134+ hidden_states = self .compile_model (
1135+ input_ids = input_ids ,
1136+ positions = positions ,
1137+ intermediate_tensors = intermediate_tensors ,
1138+ inputs_embeds = None ,
1139+ kv_caches = self .kv_caches ,
1140+ attn_metadata = attn_metadata ,
1141+ )
1142+ else :
1143+ hidden_states = model (
1144+ input_ids = input_ids ,
1145+ positions = positions ,
1146+ intermediate_tensors = intermediate_tensors ,
1147+ inputs_embeds = inputs_embeds )
10911148 return hidden_states
10921149
10931150 def profile_run (self ) -> None :
0 commit comments