@@ -348,13 +348,38 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
348348 torch ._logging .set_logs (
349349 recompiles = envs_ascend .VLLM_ASCEND_TRACE_RECOMPILES )
350350
351+ self .check_batch_sizes_consistency ()
351352 # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
352353 self .in_profile_run = False
353354
354355 # kv role
355356 self .is_kv_producer = False
357+ self .is_kv_consumer = False
356358 if vllm_config .kv_transfer_config is not None :
357359 self .is_kv_producer = vllm_config .kv_transfer_config .is_kv_producer
360+ self .is_kv_consumer = vllm_config .kv_transfer_config .is_kv_consumer
361+
362+ def check_batch_sizes_consistency (self ) -> None :
363+ if not dist .is_initialized ():
364+ return
365+
366+ local = torch .tensor (self .torchair_graph_batch_sizes ,
367+ device = "cpu" ,
368+ dtype = torch .int32 )
369+ gathered_graph_batch_size = local .clone ()
370+ dist .all_reduce (gathered_graph_batch_size ,
371+ group = get_dp_group ().cpu_group )
372+ expected = local * self .dp_size
373+
374+ if not torch .equal (gathered_graph_batch_size , expected ):
375+ diff_idxs = (gathered_graph_batch_size != expected ).nonzero (
376+ as_tuple = False ).flatten ().tolist ()
377+ raise AssertionError (
378+ f"[Graph BatchSize Mismatch] Found mismatches at indices { diff_idxs } .\n "
379+ f"Local (rank { self .dp_rank } ): { local .tolist ()} \n "
380+ f"Sum over ranks: { gathered_graph_batch_size .tolist ()} \n "
381+ f"Expected if all equal: { [v * self .dp_size for v in local .tolist ()]} "
382+ )
358383
359384 def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
360385 """Update the cached states and the persistent batch with the scheduler
@@ -570,44 +595,58 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
570595 self .input_batch .refresh_sampling_metadata ()
571596
572597 def _get_forward_metadata_across_dp (
573- self ,
574- maybe_padded_num_tokens : int ,
575- num_tokens : int ,
576- with_prefill : bool ,
577- enable_dbo : bool = False ,
598+ self , num_tokens : int , with_prefill : bool ,
599+ enable_dbo : bool ) -> tuple [torch .Tensor , bool , bool ]:
600+
601+ # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
602+ num_tokens_across_dp = torch .zeros (self .dp_size + 2 ,
603+ dtype = torch .int32 ,
604+ device = "cpu" )
605+ num_tokens_across_dp [self .dp_rank ] = num_tokens
606+ num_tokens_across_dp [- 2 ] = int (with_prefill )
607+ num_tokens_across_dp [- 1 ] = int (not enable_dbo )
608+ dist .all_reduce (num_tokens_across_dp , group = get_dp_group ().cpu_group )
609+ with_prefill = bool (num_tokens_across_dp [- 2 ])
610+ enable_dbo = not bool (num_tokens_across_dp [- 1 ])
611+ num_tokens_across_dp = num_tokens_across_dp [:- 2 ]
612+ return num_tokens_across_dp , with_prefill , enable_dbo
613+
614+ def _get_forward_metadata_across_dp_and_pad (
615+ self , num_tokens : int , with_prefill : bool , enable_dbo : bool
578616 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
579617 if self .dp_size == 1 :
580- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
618+ return num_tokens , None , with_prefill , enable_dbo
581619
582- num_tokens_across_dp = [0 ] * self .dp_size * 2
583- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
584- num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
585- forward_metadata = torch .tensor (num_tokens_across_dp +
586- [with_prefill , not enable_dbo ],
587- device = "cpu" ,
588- dtype = torch .int32 )
589- dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
590- with_prefill = bool (forward_metadata [- 2 ])
591-
592- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
593- if with_prefill :
594- num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
595- 2 ]
596- maybe_padded_num_tokens = num_tokens
597- else :
598- num_tokens_across_dp = forward_metadata [:self .dp_size ]
620+ if self .is_kv_producer and not envs_ascend .VLLM_ASCEND_ENABLE_CHUNK_MC2 :
621+ num_tokens_across_dp = torch .tensor ([num_tokens ] * self .dp_size ,
622+ device = "cpu" ,
623+ dtype = torch .int32 )
624+ return num_tokens , num_tokens_across_dp , True , enable_dbo
599625
600- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
601- # `max_tokens_across_dp`, in other situation it is not necessary.
602- if self . torchair_graph_enabled and not with_prefill :
603- maybe_padded_num_tokens = torch . max ( num_tokens_across_dp ). item ()
604- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
626+ if self . is_kv_consumer and self . torchair_graph_enabled and len (
627+ self . torchair_graph_batch_sizes
628+ ) == 1 and not self . in_profile_run :
629+ max_num_decode_tokens = self . torchair_graph_batch_sizes [ 0 ]
630+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
605631 self .dp_size ,
606632 device = "cpu" ,
607633 dtype = torch .int32 )
634+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
635+
636+ maybe_padded_num_tokens = num_tokens
637+ num_tokens_across_dp , with_prefill , enable_dbo = self ._get_forward_metadata_across_dp (
638+ num_tokens , with_prefill , enable_dbo )
608639
609- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
610- forward_metadata [- 1 ])
640+ if self .torchair_graph_enabled and not with_prefill :
641+ max_num_token = num_tokens_across_dp .max ().item ()
642+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
643+ max_num_token )
644+ num_tokens_across_dp = torch .full ((self .dp_size , ),
645+ maybe_padded_num_tokens ,
646+ dtype = torch .int32 ,
647+ device = "cpu" )
648+
649+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , enable_dbo
611650
612651 def _check_dbo_is_valid (self , query_lens : torch .Tensor ,
613652 attn_state : AscendAttentionState ,
@@ -1108,16 +1147,13 @@ def _process_reqs(
11081147 attn_state ,
11091148 total_num_scheduled_tokens )
11101149
1111- maybe_padded_num_tokens = total_num_scheduled_tokens
1112- if self .torchair_graph_enabled and not with_prefill :
1113- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1114- total_num_scheduled_tokens )
1150+ enable_dbo = self ._check_dbo_is_valid (self .query_lens .tolist (),
1151+ attn_state ,
1152+ total_num_scheduled_tokens )
11151153 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
1116- enable_dbo ) = self ._get_forward_metadata_across_dp (
1117- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill ,
1118- enable_dbo )
1154+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1155+ total_num_scheduled_tokens , with_prefill , enable_dbo )
11191156 extra_builder_kwargs ['enable_dbo_across_dp' ] = enable_dbo
1120-
11211157 if self .torchair_graph_enabled and not with_prefill :
11221158 graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
11231159
@@ -1791,16 +1827,10 @@ def _dummy_run(
17911827 with_prefill : bool = False ,
17921828 is_torchair_compile : bool = False ,
17931829 ) -> torch .Tensor :
1794- maybe_padded_num_tokens = num_tokens
1795- if self .torchair_graph_enabled and not with_prefill :
1796- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1797- num_tokens )
1798-
17991830 # Padding for DP
18001831 (num_tokens , num_tokens_across_dp , with_prefill ,
1801- _ ) = self ._get_forward_metadata_across_dp (maybe_padded_num_tokens ,
1802- num_tokens , with_prefill ,
1803- False )
1832+ _ ) = self ._get_forward_metadata_across_dp_and_pad (
1833+ num_tokens , with_prefill , False )
18041834
18051835 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
18061836 # for dummy run with LoRA so that the num_reqs collectively
0 commit comments