Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 76 additions & 46 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,38 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)

self.check_batch_sizes_consistency()
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False

# kv role
self.is_kv_producer = False
self.is_kv_consumer = False
if vllm_config.kv_transfer_config is not None:
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer

def check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return

local = torch.tensor(self.torchair_graph_batch_sizes,
device="cpu",
dtype=torch.int32)
gathered_graph_batch_size = local.clone()
dist.all_reduce(gathered_graph_batch_size,
group=get_dp_group().cpu_group)
expected = local * self.dp_size

if not torch.equal(gathered_graph_batch_size, expected):
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
as_tuple=False).flatten().tolist()
raise AssertionError(
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
)

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
Expand Down Expand Up @@ -570,44 +595,58 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.refresh_sampling_metadata()

def _get_forward_metadata_across_dp(
self,
maybe_padded_num_tokens: int,
num_tokens: int,
with_prefill: bool,
enable_dbo: bool = False,
self, num_tokens: int, with_prefill: bool,
enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:

# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
dtype=torch.int32,
device="cpu")
num_tokens_across_dp[self.dp_rank] = num_tokens
num_tokens_across_dp[-2] = int(with_prefill)
num_tokens_across_dp[-1] = int(not enable_dbo)
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
with_prefill = bool(num_tokens_across_dp[-2])
enable_dbo = not bool(num_tokens_across_dp[-1])
num_tokens_across_dp = num_tokens_across_dp[:-2]
return num_tokens_across_dp, with_prefill, enable_dbo

def _get_forward_metadata_across_dp_and_pad(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo

num_tokens_across_dp = [0] * self.dp_size * 2
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
forward_metadata = torch.tensor(num_tokens_across_dp +
[with_prefill, not enable_dbo],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
with_prefill = bool(forward_metadata[-2])

# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
if with_prefill:
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
2]
maybe_padded_num_tokens = num_tokens
else:
num_tokens_across_dp = forward_metadata[:self.dp_size]
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_across_dp, True, enable_dbo

# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
# `max_tokens_across_dp`, in other situation it is not necessary.
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
if self.is_kv_consumer and self.torchair_graph_enabled and len(
self.torchair_graph_batch_sizes
) == 1 and not self.in_profile_run:
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo

maybe_padded_num_tokens = num_tokens
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
num_tokens, with_prefill, enable_dbo)

return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
forward_metadata[-1])
if self.torchair_graph_enabled and not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
max_num_token)
num_tokens_across_dp = torch.full((self.dp_size, ),
maybe_padded_num_tokens,
dtype=torch.int32,
device="cpu")

return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo

def _check_dbo_is_valid(self, query_lens: torch.Tensor,
attn_state: AscendAttentionState,
Expand Down Expand Up @@ -1108,16 +1147,13 @@ def _process_reqs(
attn_state,
total_num_scheduled_tokens)

maybe_padded_num_tokens = total_num_scheduled_tokens
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
attn_state,
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
enable_dbo)
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
total_num_scheduled_tokens, with_prefill, enable_dbo)
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo

if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens

Expand Down Expand Up @@ -1791,16 +1827,10 @@ def _dummy_run(
with_prefill: bool = False,
is_torchair_compile: bool = False,
) -> torch.Tensor:
maybe_padded_num_tokens = num_tokens
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)

# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
num_tokens, with_prefill,
False)
_) = self._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)

# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
Expand Down
Loading