@@ -337,6 +337,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
337337 torch ._logging .set_logs (
338338 recompiles = envs_ascend .VLLM_ASCEND_TRACE_RECOMPILES )
339339
340+ # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
341+ self .check_batch_sizes_consistency ()
340342 # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
341343 self .in_profile_run = False
342344
@@ -345,6 +347,33 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
345347 if vllm_config .kv_transfer_config is not None :
346348 self .is_kv_producer = vllm_config .kv_transfer_config .is_kv_producer
347349
350+ def _check_dbo_is_valid (self , query_lens , other_args = None ):
351+ """Reserved DBO validation interface, no actual implemented yet"""
352+ # TODO: Add real DBO validation later
353+ return False
354+
355+ def check_batch_sizes_consistency (self ) -> None :
356+ if not dist .is_initialized ():
357+ return
358+
359+ local = torch .tensor (self .torchair_graph_batch_sizes ,
360+ device = "cpu" ,
361+ dtype = torch .int32 )
362+ gathered_graph_batch_size = local .clone ()
363+ dist .all_reduce (gathered_graph_batch_size ,
364+ group = get_dp_group ().cpu_group )
365+ expected = local * self .dp_size
366+
367+ if not torch .equal (gathered_graph_batch_size , expected ):
368+ diff_idxs = (gathered_graph_batch_size != expected ).nonzero (
369+ as_tuple = False ).flatten ().tolist ()
370+ raise AssertionError (
371+ f"[Graph BatchSize Mismatch] Found mismatches at indices { diff_idxs } .\n "
372+ f"Local (rank { self .dp_rank } ): { local .tolist ()} \n "
373+ f"Sum over ranks: { gathered_graph_batch_size .tolist ()} \n "
374+ f"Expected if all equal: { [v * self .dp_size for v in local .tolist ()]} "
375+ )
376+
348377 def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
349378 """Update the cached states and the persistent batch with the scheduler
350379 output.
@@ -559,44 +588,58 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
559588 self .input_batch .refresh_sampling_metadata ()
560589
561590 def _get_forward_metadata_across_dp (
562- self ,
563- maybe_padded_num_tokens : int ,
564- num_tokens : int ,
565- with_prefill : bool ,
566- enable_dbo : bool = False ,
591+ self , num_tokens : int , with_prefill : bool ,
592+ enable_dbo : bool ) -> tuple [torch .Tensor , bool , bool ]:
593+
594+ # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
595+ num_tokens_across_dp = torch .zeros (self .dp_size + 2 ,
596+ dtype = torch .int32 ,
597+ device = "cpu" )
598+ num_tokens_across_dp [self .dp_rank ] = num_tokens
599+ num_tokens_across_dp [- 2 ] = int (with_prefill )
600+ num_tokens_across_dp [- 1 ] = int (not enable_dbo )
601+ dist .all_reduce (num_tokens_across_dp , group = get_dp_group ().cpu_group )
602+ with_prefill = bool (num_tokens_across_dp [- 2 ])
603+ enable_dbo = not bool (num_tokens_across_dp [- 1 ])
604+ num_tokens_across_dp = num_tokens_across_dp [:- 2 ]
605+ return num_tokens_across_dp , with_prefill , enable_dbo
606+
607+ def _get_forward_metadata_across_dp_and_pad (
608+ self , num_tokens : int , with_prefill : bool , enable_dbo : bool
567609 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
568610 if self .dp_size == 1 :
569- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
611+ return num_tokens , None , with_prefill , enable_dbo
570612
571- num_tokens_across_dp = [0 ] * self .dp_size * 2
572- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
573- num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
574- forward_metadata = torch .tensor (num_tokens_across_dp +
575- [with_prefill , not enable_dbo ],
576- device = "cpu" ,
577- dtype = torch .int32 )
578- dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
579- with_prefill = bool (forward_metadata [- 2 ])
580-
581- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
582- if with_prefill :
583- num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
584- 2 ]
585- maybe_padded_num_tokens = num_tokens
586- else :
587- num_tokens_across_dp = forward_metadata [:self .dp_size ]
613+ if self .is_kv_producer and not envs_ascend .VLLM_ASCEND_ENABLE_CHUNK_MC2 :
614+ num_tokens_across_dp = torch .tensor ([num_tokens ] * self .dp_size ,
615+ device = "cpu" ,
616+ dtype = torch .int32 )
617+ return num_tokens , num_tokens_across_dp , True , enable_dbo
588618
589- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
590- # `max_tokens_across_dp`, in other situation it is not necessary.
591- if self . torchair_graph_enabled and not with_prefill :
592- maybe_padded_num_tokens = torch . max ( num_tokens_across_dp ). item ()
593- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
619+ if self . is_kv_consumer and self . torchair_graph_enabled and len (
620+ self . torchair_graph_batch_sizes
621+ ) == 1 and not self . in_profile_run :
622+ max_num_decode_tokens = self . torchair_graph_batch_sizes [ 0 ]
623+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
594624 self .dp_size ,
595625 device = "cpu" ,
596626 dtype = torch .int32 )
627+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
597628
598- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
599- forward_metadata [- 1 ])
629+ maybe_padded_num_tokens = num_tokens
630+ num_tokens_across_dp , with_prefill , enable_dbo = self ._get_forward_metadata_across_dp (
631+ num_tokens , with_prefill , enable_dbo )
632+
633+ if self .torchair_graph_enabled and not with_prefill :
634+ max_num_token = num_tokens_across_dp .max ().item ()
635+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
636+ max_num_token )
637+ num_tokens_across_dp = torch .full ((self .dp_size , ),
638+ maybe_padded_num_tokens ,
639+ dtype = torch .int32 ,
640+ device = "cpu" )
641+
642+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , enable_dbo
600643
601644 def get_eagle_atten_dict (
602645 self ,
@@ -1073,13 +1116,12 @@ def _process_reqs(
10731116 AscendAttentionState .DecodeOnly , AscendAttentionState .SpecDecoding
10741117 ]
10751118
1076- maybe_padded_num_tokens = total_num_scheduled_tokens
1077- if self .torchair_graph_enabled and not with_prefill :
1078- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1079- total_num_scheduled_tokens )
1119+ enable_dbo = self ._check_dbo_is_valid (self .query_lens .tolist (),
1120+ attn_state ,
1121+ total_num_scheduled_tokens )
10801122 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
1081- enable_dbo ) = self ._get_forward_metadata_across_dp (
1082- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill )
1123+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1124+ total_num_scheduled_tokens , with_prefill , enable_dbo )
10831125
10841126 if self .torchair_graph_enabled and not with_prefill :
10851127 graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1712,15 +1754,10 @@ def _dummy_run(
17121754 with_prefill : bool = False ,
17131755 is_torchair_compile : bool = False ,
17141756 ) -> torch .Tensor :
1715- maybe_padded_num_tokens = num_tokens
1716- if self .torchair_graph_enabled and not with_prefill :
1717- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1718- num_tokens )
1719-
17201757 # Padding for DP
17211758 (num_tokens , num_tokens_across_dp , with_prefill ,
1722- enable_dbo ) = self ._get_forward_metadata_across_dp (
1723- maybe_padded_num_tokens , num_tokens , with_prefill , False )
1759+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1760+ num_tokens , with_prefill , False )
17241761
17251762 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
17261763 # for dummy run with LoRA so that the num_reqs collectively
0 commit comments