@@ -571,11 +571,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
571571
572572 def _get_forward_metadata_across_dp (
573573 self ,
574- maybe_padded_num_tokens : int ,
575574 num_tokens : int ,
576575 with_prefill : bool ,
577576 enable_dbo : bool = False ,
578577 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
578+ maybe_padded_num_tokens = num_tokens
579+ if self .torchair_graph_enabled and not with_prefill :
580+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
581+ num_tokens )
579582 if self .dp_size == 1 :
580583 return maybe_padded_num_tokens , None , with_prefill , enable_dbo
581584
@@ -1108,14 +1111,9 @@ def _process_reqs(
11081111 attn_state ,
11091112 total_num_scheduled_tokens )
11101113
1111- maybe_padded_num_tokens = total_num_scheduled_tokens
1112- if self .torchair_graph_enabled and not with_prefill :
1113- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1114- total_num_scheduled_tokens )
11151114 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
11161115 enable_dbo ) = self ._get_forward_metadata_across_dp (
1117- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill ,
1118- enable_dbo )
1116+ total_num_scheduled_tokens , with_prefill , enable_dbo )
11191117 extra_builder_kwargs ['enable_dbo_across_dp' ] = enable_dbo
11201118
11211119 if self .torchair_graph_enabled and not with_prefill :
@@ -1791,15 +1789,9 @@ def _dummy_run(
17911789 with_prefill : bool = False ,
17921790 is_torchair_compile : bool = False ,
17931791 ) -> torch .Tensor :
1794- maybe_padded_num_tokens = num_tokens
1795- if self .torchair_graph_enabled and not with_prefill :
1796- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1797- num_tokens )
1798-
17991792 # Padding for DP
18001793 (num_tokens , num_tokens_across_dp , with_prefill ,
1801- _ ) = self ._get_forward_metadata_across_dp (maybe_padded_num_tokens ,
1802- num_tokens , with_prefill ,
1794+ _ ) = self ._get_forward_metadata_across_dp (num_tokens , with_prefill ,
18031795 False )
18041796
18051797 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
0 commit comments