@@ -559,45 +559,56 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
559559 self .input_batch .refresh_sampling_metadata ()
560560
561561 def _get_forward_metadata_across_dp (
562- self ,
563- maybe_padded_num_tokens : int ,
564- num_tokens : int ,
565- with_prefill : bool ,
566- enable_dbo : bool = False ,
562+ self ,
563+ num_tokens : int ,
564+ with_prefill : bool , enable_dbo : bool
565+ ) -> tuple [Optional [torch .Tensor ], bool , bool ]:
566+
567+ # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
568+ num_tokens_across_dp = torch .zeros (self .dp_size + 2 , dtype = torch .int32 , device = "cpu" )
569+ num_tokens_across_dp [self .dp_rank ] = num_tokens
570+ num_tokens_across_dp [- 2 ] = int (with_prefill )
571+ num_tokens_across_dp [- 1 ] = int (not enable_dbo )
572+ dist .all_reduce (num_tokens_across_dp , group = get_dp_group ().cpu_group )
573+ with_prefill = bool (num_tokens_across_dp [- 2 ])
574+ enable_dbo = not bool (num_tokens_across_dp [- 1 ])
575+ num_tokens_across_dp = num_tokens_across_dp [:- 2 ]
576+ return num_tokens_across_dp , with_prefill , enable_dbo
577+
578+ def _get_forward_metadata_across_dp_and_pad (
579+ self , num_tokens : int ,
580+ with_prefill : bool , enable_dbo : bool
567581 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
568582 if self .dp_size == 1 :
569- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
570-
571- num_tokens_across_dp = [0 ] * self .dp_size * 2
572- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
573- num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
574- forward_metadata = torch .tensor (num_tokens_across_dp +
575- [with_prefill , not enable_dbo ],
576- device = "cpu" ,
577- dtype = torch .int32 )
578- dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
579- with_prefill = bool (forward_metadata [- 2 ])
580-
581- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
582- if with_prefill :
583- num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
584- 2 ]
585- maybe_padded_num_tokens = num_tokens
586- else :
587- num_tokens_across_dp = forward_metadata [:self .dp_size ]
588-
589- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
590- # `max_tokens_across_dp`, in other situation it is not necessary.
591- if self .torchair_graph_enabled and not with_prefill :
592- maybe_padded_num_tokens = torch .max (num_tokens_across_dp ).item ()
593- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
583+ return num_tokens , None , with_prefill , enable_dbo
584+
585+ if self .is_kv_producer and not envs_ascend .VLLM_ASCEND_ENABLE_CHUNK_MC2 :
586+ num_tokens_across_dp = torch .tensor ([num_tokens ] * self .dp_size ,
587+ device = "cpu" ,
588+ dtype = torch .int32 )
589+ return num_tokens , num_tokens_across_dp , True , enable_dbo
590+
591+ if self .is_kv_consumer and self .torchair_graph_enabled and len (
592+ self .torchair_graph_batch_sizes
593+ ) == 1 and not self .in_profile_run :
594+ max_num_decode_tokens = self .torchair_graph_batch_sizes [0 ]
595+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
594596 self .dp_size ,
595597 device = "cpu" ,
596598 dtype = torch .int32 )
599+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
600+
601+ maybe_padded_num_tokens = num_tokens
602+ num_tokens_across_dp , with_prefill , enable_dbo = self ._get_forward_metadata_across_dp (num_tokens , with_prefill ,enable_dbo )
597603
598- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
599- forward_metadata [- 1 ])
600-
604+ if self .torchair_graph_enabled and not with_prefill :
605+ max_num_token = num_tokens_across_dp .max ().item ()
606+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
607+ max_num_token )
608+ num_tokens_across_dp = torch .full ((self .dp_size ,), maybe_padded_num_tokens , dtype = torch .int32 , device = "cpu" )
609+
610+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , enable_dbo
611+
601612 def get_eagle_atten_dict (
602613 self ,
603614 scheduler_output : "SchedulerOutput" ,
@@ -1073,13 +1084,12 @@ def _process_reqs(
10731084 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
10741085 ]
10751086
1076- maybe_padded_num_tokens = total_num_scheduled_tokens
1077- if self .torchair_graph_enabled and not with_prefill :
1078- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1079- total_num_scheduled_tokens )
1087+ enable_dbo = self ._check_dbo_is_valid (self .query_lens .tolist (),
1088+ attn_state ,
1089+ total_num_scheduled_tokens )
10801090 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
1081- enable_dbo ) = self ._get_forward_metadata_across_dp (
1082- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill )
1091+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1092+ total_num_scheduled_tokens , with_prefill , enable_dbo )
10831093
10841094 if self .torchair_graph_enabled and not with_prefill :
10851095 graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1712,15 +1722,10 @@ def _dummy_run(
17121722 with_prefill : bool = False ,
17131723 is_torchair_compile : bool = False ,
17141724 ) -> torch .Tensor :
1715- maybe_padded_num_tokens = num_tokens
1716- if self .torchair_graph_enabled and not with_prefill :
1717- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1718- num_tokens )
1719-
17201725 # Padding for DP
17211726 (num_tokens , num_tokens_across_dp , with_prefill ,
1722- enable_dbo ) = self ._get_forward_metadata_across_dp (
1723- maybe_padded_num_tokens , num_tokens , with_prefill , False )
1727+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1728+ num_tokens , with_prefill , False )
17241729
17251730 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
17261731 # for dummy run with LoRA so that the num_reqs collectively
0 commit comments