@@ -346,13 +346,35 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
346346 torch ._logging .set_logs (
347347 recompiles = envs_ascend .VLLM_ASCEND_TRACE_RECOMPILES )
348348
349+ self .check_batch_sizes_consistency ()
349350 # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
350351 self .in_profile_run = False
351352
352353 # kv role
353354 self .is_kv_producer = False
355+ self .is_kv_consumer = False
354356 if vllm_config .kv_transfer_config is not None :
355357 self .is_kv_producer = vllm_config .kv_transfer_config .is_kv_producer
358+ self .is_kv_consumer = vllm_config .kv_transfer_config .is_kv_consumer
359+
360+ def check_batch_sizes_consistency (self ) -> None :
361+ if not dist .is_initialized ():
362+ return
363+
364+ local = torch .tensor (self .torchair_graph_batch_sizes , device = "cpu" , dtype = torch .int32 )
365+ gathered_graph_batch_size = local .clone ()
366+ dist .all_reduce (gathered_graph_batch_size , group = get_dp_group ().cpu_group )
367+ expected = local * self .dp_size
368+
369+ if not torch .equal (gathered_graph_batch_size , expected ):
370+ diff_idxs = (gathered_graph_batch_size != expected ).nonzero (as_tuple = False ).flatten ().tolist ()
371+ raise AssertionError (
372+ f"[Graph BatchSize Mismatch] Found mismatches at indices { diff_idxs } .\n "
373+ f"Local (rank { self .dp_rank } ): { local .tolist ()} \n "
374+ f"Sum over ranks: { sum_ts .tolist ()} \n "
375+ f"Expected if all equal: { [v * self .dp_size for v in local .tolist ()]} "
376+ )
377+
356378
357379 def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
358380 """Update the cached states and the persistent batch with the scheduler
@@ -568,44 +590,55 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
568590 self .input_batch .refresh_sampling_metadata ()
569591
570592 def _get_forward_metadata_across_dp (
571- self ,
572- maybe_padded_num_tokens : int ,
573- num_tokens : int ,
574- with_prefill : bool ,
575- enable_dbo : bool = False ,
593+ self ,
594+ num_tokens : int ,
595+ with_prefill : bool , enable_dbo : bool
596+ ) -> tuple [torch .Tensor , bool , bool ]:
597+
598+ # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
599+ num_tokens_across_dp = torch .zeros (self .dp_size + 2 , dtype = torch .int32 , device = "cpu" )
600+ num_tokens_across_dp [self .dp_rank ] = num_tokens
601+ num_tokens_across_dp [- 2 ] = int (with_prefill )
602+ num_tokens_across_dp [- 1 ] = int (not enable_dbo )
603+ dist .all_reduce (num_tokens_across_dp , group = get_dp_group ().cpu_group )
604+ with_prefill = bool (num_tokens_across_dp [- 2 ])
605+ enable_dbo = not bool (num_tokens_across_dp [- 1 ])
606+ num_tokens_across_dp = num_tokens_across_dp [:- 2 ]
607+ return num_tokens_across_dp , with_prefill , enable_dbo
608+
609+ def _get_forward_metadata_across_dp_and_pad (
610+ self , num_tokens : int ,
611+ with_prefill : bool , enable_dbo : bool
576612 ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
577613 if self .dp_size == 1 :
578- return maybe_padded_num_tokens , None , with_prefill , enable_dbo
614+ return num_tokens , None , with_prefill , enable_dbo
579615
580- num_tokens_across_dp = [0 ] * self .dp_size * 2
581- num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
582- num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
583- forward_metadata = torch .tensor (num_tokens_across_dp +
584- [with_prefill , not enable_dbo ],
585- device = "cpu" ,
586- dtype = torch .int32 )
587- dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
588- with_prefill = bool (forward_metadata [- 2 ])
589-
590- # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
591- if with_prefill :
592- num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
593- 2 ]
594- maybe_padded_num_tokens = num_tokens
595- else :
596- num_tokens_across_dp = forward_metadata [:self .dp_size ]
616+ if self .is_kv_producer and not envs_ascend .VLLM_ASCEND_ENABLE_CHUNK_MC2 :
617+ num_tokens_across_dp = torch .tensor ([num_tokens ] * self .dp_size ,
618+ device = "cpu" ,
619+ dtype = torch .int32 )
620+ return num_tokens , num_tokens_across_dp , True , enable_dbo
597621
598- # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
599- # `max_tokens_across_dp`, in other situation it is not necessary.
600- if self . torchair_graph_enabled and not with_prefill :
601- maybe_padded_num_tokens = torch . max ( num_tokens_across_dp ). item ()
602- num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
622+ if self . is_kv_consumer and self . torchair_graph_enabled and len (
623+ self . torchair_graph_batch_sizes
624+ ) == 1 and not self . in_profile_run :
625+ max_num_decode_tokens = self . torchair_graph_batch_sizes [ 0 ]
626+ num_tokens_across_dp = torch .tensor ([max_num_decode_tokens ] *
603627 self .dp_size ,
604628 device = "cpu" ,
605629 dtype = torch .int32 )
630+ return max_num_decode_tokens , num_tokens_across_dp , False , enable_dbo
606631
607- return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
608- forward_metadata [- 1 ])
632+ maybe_padded_num_tokens = num_tokens
633+ num_tokens_across_dp , with_prefill , enable_dbo = self ._get_forward_metadata_across_dp (num_tokens , with_prefill ,enable_dbo )
634+
635+ if self .torchair_graph_enabled and not with_prefill :
636+ max_num_token = num_tokens_across_dp .max ().item ()
637+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
638+ max_num_token )
639+ num_tokens_across_dp = torch .full ((self .dp_size ,), maybe_padded_num_tokens , dtype = torch .int32 , device = "cpu" )
640+
641+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , enable_dbo
609642
610643 def _check_dbo_is_valid (self , query_lens : torch .Tensor ,
611644 attn_state : AscendAttentionState ,
@@ -1106,16 +1139,13 @@ def _process_reqs(
11061139 attn_state ,
11071140 total_num_scheduled_tokens )
11081141
1109- maybe_padded_num_tokens = total_num_scheduled_tokens
1110- if self .torchair_graph_enabled and not with_prefill :
1111- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1112- total_num_scheduled_tokens )
1142+ enable_dbo = self ._check_dbo_is_valid (self .query_lens .tolist (),
1143+ attn_state ,
1144+ total_num_scheduled_tokens )
11131145 (padded_num_tokens_across_dp , num_tokens_across_dp , with_prefill ,
1114- enable_dbo ) = self ._get_forward_metadata_across_dp (
1115- maybe_padded_num_tokens , total_num_scheduled_tokens , with_prefill ,
1116- enable_dbo )
1146+ enable_dbo ) = self ._get_forward_metadata_across_dp_and_pad (
1147+ total_num_scheduled_tokens , with_prefill , enable_dbo )
11171148 extra_builder_kwargs ['enable_dbo_across_dp' ] = enable_dbo
1118-
11191149 if self .torchair_graph_enabled and not with_prefill :
11201150 graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
11211151
@@ -1759,16 +1789,10 @@ def _dummy_run(
17591789 with_prefill : bool = False ,
17601790 is_torchair_compile : bool = False ,
17611791 ) -> torch .Tensor :
1762- maybe_padded_num_tokens = num_tokens
1763- if self .torchair_graph_enabled and not with_prefill :
1764- maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
1765- num_tokens )
1766-
17671792 # Padding for DP
17681793 (num_tokens , num_tokens_across_dp , with_prefill ,
1769- _ ) = self ._get_forward_metadata_across_dp (maybe_padded_num_tokens ,
1770- num_tokens , with_prefill ,
1771- False )
1794+ _ ) = self ._get_forward_metadata_across_dp_and_pad (
1795+ num_tokens , with_prefill , False )
17721796
17731797 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
17741798 # for dummy run with LoRA so that the num_reqs collectively
0 commit comments