@@ -571,16 +571,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
571571
572572 def _get_forward_metadata_across_dp (
573573 self ,
574- maybe_padded_num_tokens : int ,
575574 num_tokens : int ,
576575 with_prefill : bool ,
577576 enable_dbo : bool = False ,
578577 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
579578 if self .dp_size == 1 :
580- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
579+ return num_tokens , None , with_prefill , enable_dbo
581580
582581 num_tokens_across_dp = [0 ] * self .dp_size * 2
583- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
582+ num_tokens_across_dp [self .dp_rank ] = num_tokens
584583 num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
585584 forward_metadata = torch .tensor (num_tokens_across_dp +
586585 [with_prefill , not enable_dbo ],
@@ -589,24 +588,13 @@ def _get_forward_metadata_across_dp(
589588 dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
590589 with_prefill = bool (forward_metadata [- 2 ])
591590
592- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
593591 if with_prefill :
594592 num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
595593 2 ]
596- maybe_padded_num_tokens = num_tokens
597594 else :
598595 num_tokens_across_dp = forward_metadata [:self .dp_size ]
599596
600- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
601- # `max_tokens_across_dp`, in other situation it is not necessary.
602- if self .torchair_graph_enabled and not with_prefill :
603- maybe_padded_num_tokens = torch .max (num_tokens_across_dp ).item ()
604- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
605- self .dp_size ,
606- device = "cpu" ,
607- dtype = torch .int32 )
608-
609- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
597+ return num_tokens , num_tokens_across_dp , with_prefill , not bool (
610598 forward_metadata [- 1 ])
611599
612600 def _check_dbo_is_valid (self , query_lens : torch .Tensor ,
@@ -1108,14 +1096,9 @@ def _process_reqs(
11081096 attn_state ,
11091097 total_num_scheduled_tokens )
11101098
1111- maybe_padded_num_tokens = total_num_scheduled_tokens
1112- if self .torchair_graph_enabled and not with_prefill :
1113- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1114- total_num_scheduled_tokens )
11151099 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
11161100 enable_dbo ) = self ._get_forward_metadata_across_dp (
1117- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill ,
1118- enable_dbo )
1101+ total_num_scheduled_tokens , with_prefill , enable_dbo )
11191102 extra_builder_kwargs ['enable_dbo_across_dp' ] = enable_dbo
11201103
11211104 if self .torchair_graph_enabled and not with_prefill :
@@ -1791,15 +1774,9 @@ def _dummy_run(
17911774 with_prefill : bool = False ,
17921775 is_torchair_compile : bool = False ,
17931776 ) -> torch .Tensor :
1794- maybe_padded_num_tokens = num_tokens
1795- if self .torchair_graph_enabled and not with_prefill :
1796- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1797- num_tokens )
1798-
17991777 # Padding for DP
18001778 (num_tokens , num_tokens_across_dp , with_prefill ,
1801- _ ) = self ._get_forward_metadata_across_dp (maybe_padded_num_tokens ,
1802- num_tokens , with_prefill ,
1779+ _ ) = self ._get_forward_metadata_across_dp (num_tokens , with_prefill ,
18031780 False )
18041781
18051782 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
0 commit comments