Skip to content

Commit 7419186

Browse files
authored
[Perf] Delete redundant operations in model_runner and forward_context (#3677)
### What this PR does / why we need it? Remove redundant operations from `model_runner` and `forward_context`. This optimization can significantly reduce the idle time (bubble) before decoding when running models with small parameter counts (e.g., Qwen/Qwen2.5-0.5B). Testing on 800I A2, bubble is reduced from 3.8ms to 2.8ms : Before <img width="1655" height="696" alt="image" src="https://github.com/user-attachments/assets/d7608e52-2438-46dd-8fc9-391fd6274495" /> After <img width="1607" height="774" alt="image" src="https://github.com/user-attachments/assets/56daf081-2dba-4d2e-99d4-e055187d9806" /> ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@releases/v0.11.1 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
1 parent 0d1859a commit 7419186

File tree

5 files changed

+34
-25
lines changed

5 files changed

+34
-25
lines changed

tests/ut/worker/test_model_runner_v1.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
6868
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
6969
return_value=soc_version), \
7070
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
71+
return_value=True), \
72+
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
7173
return_value=True):
7274

7375
# Bind the real method to the mock object
@@ -102,6 +104,8 @@ def test_select_moe_comm_method_unsupported_soc():
102104
return_value=unsupported_soc), \
103105
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
104106
return_value=True), \
107+
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
108+
return_value=True), \
105109
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
106110

107111
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)

vllm_ascend/ascend_forward_context.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
set_forward_context)
1212

1313
import vllm_ascend.envs as envs_ascend
14-
from vllm_ascend.utils import enable_sp, is_moe_model, version_check
14+
from vllm_ascend.utils import (enable_sp, has_layer_idx, is_moe_model,
15+
version_check)
1516

1617
if TYPE_CHECKING:
1718
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@@ -137,9 +138,7 @@ def set_ascend_forward_context(
137138
# set layer_idx to enable optimization features that depend on this information.
138139
# This is only applicable to models that contain these necessary attributes.
139140
forward_context.layer_idx = None
140-
if model_instance is not None and \
141-
hasattr(model_instance, "model") and \
142-
hasattr(model_instance.model, "start_layer"):
141+
if has_layer_idx(model_instance):
143142
forward_context.layer_idx = model_instance.model.start_layer
144143

145144
# TODO(rjg-lyh): refactor mlp weight prefetch method

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
def get_moe_comm_method(
3939
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
40-
return _MoECommMethods.get(moe_comm_type)
40+
return _MoECommMethods.get(moe_comm_type, None)
4141

4242

4343
def setup_moe_comm_method(moe_config):

vllm_ascend/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
_MIN_DP_BUFFER_SIZE = 50
5959
_IS_MOE_MODEL = None
6060
_ENABLE_SP = None
61+
_HAS_LAYER_IDX = None
6162

6263

6364
def is_310p():
@@ -807,3 +808,14 @@ def version_check():
807808
if full_date >= "20250919":
808809
return True
809810
return False
811+
812+
813+
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
814+
if model_instance is None:
815+
return False
816+
817+
global _HAS_LAYER_IDX
818+
if _HAS_LAYER_IDX is None:
819+
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
820+
hasattr(model_instance.model, "start_layer")
821+
return _HAS_LAYER_IDX

vllm_ascend/worker/model_runner_v1.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
137137
AscendSocVersion, ProfileExecuteDuration,
138138
enable_sp, get_ascend_soc_version, is_310p,
139-
is_enable_nz, lmhead_tp_enable,
139+
is_enable_nz, is_moe_model, lmhead_tp_enable,
140140
prefill_context_parallel_enable,
141141
vllm_version_is)
142142
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -515,11 +515,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
515515
self.in_profile_run = False
516516

517517
self._init_mc2_tokens_capacity()
518-
self.reserved_mc2_mask = torch.zeros(
519-
self.mc2_tokens_capacity,
520-
dtype=torch.bool,
521-
device=self.device,
522-
)
518+
if is_moe_model(vllm_config):
519+
self.reserved_mc2_mask = torch.zeros(
520+
self.mc2_tokens_capacity,
521+
dtype=torch.bool,
522+
device=self.device,
523+
)
524+
else:
525+
self.reserved_mc2_mask = None
523526
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
524527
if self.dynamic_eplb:
525528
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
@@ -1497,9 +1500,7 @@ def _prepare_inputs(
14971500
self.query_lens = torch.from_numpy(num_scheduled_tokens)
14981501

14991502
# Copy the tensors to the NPU.
1500-
self.input_ids[:total_num_scheduled_tokens].copy_(
1501-
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
1502-
1503+
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
15031504
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
15041505
self.positions[:num_input_tokens].copy_(
15051506
self.positions_cpu[:num_input_tokens], non_blocking=True)
@@ -1521,16 +1522,6 @@ def _prepare_inputs(
15211522
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
15221523
attn_metadata: dict[str, Any] = {}
15231524

1524-
# Prepare input_ids
1525-
token_indices = (positions_np +
1526-
req_indices * self.input_batch.token_ids_cpu.shape[1])
1527-
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
1528-
0,
1529-
torch.from_numpy(token_indices),
1530-
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1531-
# Copy the tensors to the NPU.
1532-
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
1533-
15341525
# _prepare_inputs may reorder the batch, so we must gather
15351526
# multi-modal outputs after that to ensure the correct order
15361527
if self.is_multimodal_model:
@@ -2075,7 +2066,7 @@ def _pool(
20752066
)
20762067

20772068
def _select_moe_comm_method(self, num_tokens: int,
2078-
with_prefill: bool) -> MoECommType:
2069+
with_prefill: bool) -> Optional[MoECommType]:
20792070
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
20802071
are designed for expert parallelism.
20812072
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -2098,6 +2089,9 @@ def _select_moe_comm_method(self, num_tokens: int,
20982089
Returns:
20992090
MoECommType: The selected MoE communication method.
21002091
"""
2092+
if not is_moe_model(self.vllm_config):
2093+
return None
2094+
21012095
soc_version = get_ascend_soc_version()
21022096
quant_type = getattr(self.vllm_config.model_config.hf_config,
21032097
'moe_quantize', None)

0 commit comments

Comments
 (0)