Skip to content

Commit 92e6aa9

Browse files
authored
[bugfix] add with_prefill cpu allreduce to handle D-node recomputatio… (#2129)
Add with-prefill CPU AllReduce to handle D-node recomputation situations. ### What this PR does / why we need it? Add with-prefill CPU AllReduce to handle D-node recomputation situations. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? gsm8k http://image.huawei.com/tiny-lts/v1/images/mdstorm/dcbc43b858db666f185d73868f7933fb_1242x502.png livecodebench http://image.huawei.com/tiny-lts/v1/images/mdstorm/78a2e9695c3d841870d02c840f032154_1242x502.png vllmbeachmark http://image.huawei.com/tiny-lts/v1/images/mdstorm/a4d32f4f2d702cf89854b83ae4d58337_1242x502.png performance http://image.huawei.com/tiny-lts/v1/images/mdstorm/38e194a09c3c9ae902a3772f1dca6862_1609x1095.png http://image.huawei.com/tiny-lts/v1/images/mdstorm/38e194a09c3c9ae902a3772f1dca6862_1609x1095.png Signed-off-by: liziyu <liziyu16@huawei.com>
1 parent 5c9e7a0 commit 92e6aa9

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
from torch.nn.parameter import Parameter
3535
from transformers import PretrainedConfig
3636
from vllm.attention import Attention, AttentionMetadata
37-
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
38-
get_current_vllm_config)
37+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3938
from vllm.distributed import (get_dp_group, get_pp_group,
4039
get_tensor_model_parallel_rank,
4140
get_tensor_model_parallel_world_size,
@@ -335,10 +334,6 @@ def __init__(
335334

336335
self.tp_group = get_tp_group().device_group
337336
self.tp_rank = get_tp_group().rank_in_group
338-
self.kv_consumer = None
339-
transfer_config = get_current_vllm_config().kv_transfer_config
340-
if transfer_config is not None:
341-
self.kv_consumer = transfer_config.kv_role == "kv_consumer"
342337

343338
def forward(
344339
self,
@@ -353,10 +348,6 @@ def forward(
353348
enable_force_load_balance = forward_context.in_profile_run
354349

355350
is_prefill = forward_context.with_prefill
356-
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
357-
# the behaviour aligned between dummy_run and normal model_execute.
358-
if self.kv_consumer:
359-
is_prefill = False
360351

361352
# router_logits: (num_tokens, n_experts)
362353
if self.enable_multistream_moe:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -636,12 +636,19 @@ def _get_forward_metadata_across_dp(
636636
if self.is_kv_consumer and self.torchair_graph_enabled and len(
637637
self.torchair_graph_batch_sizes
638638
) == 1 and not self.in_profile_run:
639-
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
640-
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
641-
self.dp_size,
642-
device="cpu",
643-
dtype=torch.int32)
644-
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
639+
with_prefill_tensor = torch.tensor([with_prefill],
640+
device="cpu",
641+
dtype=torch.bool)
642+
dist.all_reduce(with_prefill_tensor,
643+
group=get_dp_group().cpu_group,
644+
op=dist.ReduceOp.MAX)
645+
if not with_prefill_tensor.item():
646+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
647+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
648+
self.dp_size,
649+
device="cpu",
650+
dtype=torch.int32)
651+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
645652

646653
num_tokens_across_dp = [0] * self.dp_size * 2
647654
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
@@ -1644,9 +1651,6 @@ def _dummy_run(
16441651
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
16451652
num_tokens)
16461653

1647-
# For kv producer, with prefill always true
1648-
if self.is_kv_producer:
1649-
with_prefill = True
16501654
# Padding for DP
16511655
(num_tokens, num_tokens_across_dp, with_prefill,
16521656
enable_dbo) = self._get_forward_metadata_across_dp(

0 commit comments

Comments
 (0)