@@ -569,45 +569,55 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
569569 self .input_batch .refresh_sampling_metadata ()
570570
571571 def _get_forward_metadata_across_dp (
572- self ,
573- maybe_padded_num_tokens : int ,
574- num_tokens : int ,
575- with_prefill : bool ,
576- enable_dbo : bool = False ,
572+ self , num_tokens : int ,
573+ with_prefill : bool , enable_dbo : bool
577574 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
578- if self .dp_size == 1 :
579- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
580-
581- num_tokens_across_dp = [0 ] * self .dp_size * 2
582- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
583- num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
584- forward_metadata = torch .tensor (num_tokens_across_dp +
585- [with_prefill , not enable_dbo ],
586- device = "cpu" ,
587- dtype = torch .int32 )
588- dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
589- with_prefill = bool (forward_metadata [- 2 ])
590-
591- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
592- if with_prefill :
593- num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
594- 2 ]
595- maybe_padded_num_tokens = num_tokens
596- else :
597- num_tokens_across_dp = forward_metadata [:self .dp_size ]
598575
599- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
600- # `max_tokens_across_dp`, in other situation it is not necessary.
601- if self .torchair_graph_enabled and not with_prefill :
602- maybe_padded_num_tokens = torch .max (num_tokens_across_dp ).item ()
603- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
576+ # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
577+ num_tokens_across_dp = torch .zeros (self .dp_size + 2 , dtype = torch .int32 , device = "cpu" )
578+ num_tokens_across_dp [self .dp_rank ] = num_tokens
579+ num_tokens_across_dp [- 2 ] = int (with_prefill )
580+ num_tokens_across_dp [- 1 ] = int (not enable_dbo )
581+ dist .all_reduce (num_tokens_across_dp , group = get_dp_group ().cpu_group )
582+ with_prefill = bool (num_tokens_across_dp [- 2 ])
583+ enable_dbo = not bool (num_tokens_across_dp [- 1 ])
584+ num_tokens_across_dp = num_tokens_across_dp [:- 2 ]
585+ return num_tokens_across_dp , with_prefill , enable_dbo
586+
587+ def _get_forward_metadata_across_dp_and_pad (
588+ self , num_tokens : int ,
589+ with_prefill : bool , enable_dbo : bool
590+ ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
591+ if self .dp_size == 1 :
592+ return num_tokens , None , with_prefill , enable_dbo
593+
594+ if self .is_kv_producer and not envs_ascend .VLLM_ASCEND_ENABLE_CHUNK_MC2 :
595+ num_tokens_across_dp = torch .tensor ([num_tokens ] * self .dp_size ,
596+ device = "cpu" ,
597+ dtype = torch .int32 )
598+ return num_tokens , num_tokens_across_dp , True , enable_dbo
599+
600+ if self .is_kv_consumer and self .torchair_graph_enabled and len (
601+ self .torchair_graph_batch_sizes
602+ ) == 1 and not self .in_profile_run :
603+ max_num_decode_tokens = self .torchair_graph_batch_sizes [0 ]
604+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
604605 self .dp_size ,
605606 device = "cpu" ,
606607 dtype = torch .int32 )
608+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
609+
610+ maybe_padded_num_tokens = num_tokens
611+ num_tokens_across_dp , with_prefill , enable_dbo = self ._get_forward_metadata_across_dp (num_tokens , with_prefill ,enable_dbo )
607612
608- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
609- forward_metadata [- 1 ])
610-
613+ if self .torchair_graph_enabled and not with_prefill :
614+ max_num_token = num_tokens_across_dp .max ().item ()
615+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
616+ max_num_token )
617+ num_tokens_across_dp = torch .full ((self .dp_size ,), maybe_padded_num_tokens , dtype = torch .int32 , device = "cpu" )
618+
619+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , enable_dbo
620+
611621 def get_eagle_atten_dict (
612622 self ,
613623 scheduler_output : "SchedulerOutput" ,
@@ -1083,13 +1093,12 @@ def _process_reqs(
10831093 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
10841094 ]
10851095
1086- maybe_padded_num_tokens = total_num_scheduled_tokens
1087- if self .torchair_graph_enabled and not with_prefill :
1088- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1089- total_num_scheduled_tokens )
1096+ enable_dbo = self ._check_dbo_is_valid (self .query_lens .tolist (),
1097+ attn_state ,
1098+ total_num_scheduled_tokens )
10901099 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
1091- enable_dbo ) = self ._get_forward_metadata_across_dp (
1092- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill )
1100+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1101+ total_num_scheduled_tokens , with_prefill , enable_dbo )
10931102
10941103 if self .torchair_graph_enabled and not with_prefill :
10951104 graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1722,15 +1731,10 @@ def _dummy_run(
17221731 with_prefill : bool = False ,
17231732 is_torchair_compile : bool = False ,
17241733 ) -> torch .Tensor :
1725- maybe_padded_num_tokens = num_tokens
1726- if self .torchair_graph_enabled and not with_prefill :
1727- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1728- num_tokens )
1729-
17301734 # Padding for DP
17311735 (num_tokens , num_tokens_across_dp , with_prefill ,
1732- enable_dbo ) = self ._get_forward_metadata_across_dp (
1733- maybe_padded_num_tokens , num_tokens , with_prefill , False )
1736+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1737+ num_tokens , with_prefill , False )
17341738
17351739 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
17361740 # for dummy run with LoRA so that the num_reqs collectively
@@ -2590,7 +2594,7 @@ def select_torchair_padded_batch_size(self, batch_size: int):
25902594 return selected_batch_size
25912595
25922596 def get_supported_pooling_tasks (self ):
2593- model = self .get_model ()
2597+ model = self .get_model ()
25942598 if not is_pooling_model (model ):
25952599 return []
25962600
0 commit comments