Skip to content

Commit 20d1afc

Browse files
authored
Merge branch 'main' into dev_multistream_overlap
2 parents 9b50520 + c94afd7 commit 20d1afc

File tree

12 files changed

+459
-166
lines changed

12 files changed

+459
-166
lines changed

docs/source/developer_guide/evaluation/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ using_evalscope
1212
:caption: Performance
1313
:maxdepth: 1
1414
performance_benchmark
15+
profile_execute_duration
1516
:::
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Profile Execute Duration
2+
3+
The execution duration of each stage (including pre/post-processing, model forward, etc.) usually needs to be captured during a complete inference process. Typically, this is done by using `torch.npu.synchronize()` and obtaining CPU timestamps, which increases the performance overhead of host/device synchronization.
4+
5+
**To reduce the performance overhead, we add this feature, using the NPU event timestamp mechanism to observe the device execution time asynchronously.**
6+
7+
## Usage
8+
* Use the environment variable `VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature.
9+
* Use the non-blocking API `ProfileExecuteDuration().capture_async` to set observation points asynchronously when you need to observe the execution duration.
10+
* Use the blocking API `ProfileExecuteDuration().pop_captured_sync` at an appropriate time to get and print the execution durations of all observed stages.
11+
12+
## Example Output
13+
14+
```
15+
5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms
16+
5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms
17+
5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms
18+
5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms
19+
5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms
20+
5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms
21+
5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms
22+
5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms
23+
5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms
24+
5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms
25+
5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms
26+
5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms
27+
5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms
28+
5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms
29+
5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms
30+
5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms
31+
5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms
32+
5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms
33+
34+
```

tests/singlecard/test_ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,6 @@ def test_ascend_config_load_error():
114114
},
115115
}
116116
with VllmRunner("facebook/opt-125m",
117+
enforce_eager=False,
117118
additional_config=input_additional_config_fake_2):
118119
pass
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import os
20+
import time
21+
from unittest.mock import patch
22+
23+
import torch
24+
import vllm # noqa: F401
25+
26+
from vllm_ascend.utils import ProfileExecuteDuration
27+
28+
29+
@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"})
30+
def test_execue_duration_enabled_discrepancy():
31+
a = torch.randn(10000, 10000).npu()
32+
b = torch.randn(10000, 10000).npu()
33+
34+
# warmup
35+
torch.matmul(a, b)
36+
torch.npu.synchronize()
37+
38+
cpu_start = time.perf_counter()
39+
with ProfileExecuteDuration().capture_async("forward"):
40+
torch.matmul(a, b)
41+
torch.npu.synchronize()
42+
cpu_duration = (time.perf_counter() - cpu_start) * 1000
43+
npu_durations = ProfileExecuteDuration().pop_captured_sync()
44+
assert npu_durations and 'forward' in npu_durations
45+
assert not ProfileExecuteDuration._observations
46+
47+
# Assert discrepancy between CPU and NPU duration is within 50% roughly
48+
diff = abs(cpu_duration - npu_durations['forward']) / max(
49+
cpu_duration, npu_durations['forward'])
50+
assert diff <= 0.5, (
51+
f"CPU={cpu_duration:.2f}ms, NPU={npu_durations['forward']:.2f}ms")
52+
53+
54+
def test_execue_duration_disabled():
55+
a = torch.randn(100, 100).npu()
56+
b = torch.randn(100, 100).npu()
57+
58+
with ProfileExecuteDuration().capture_async("forward"):
59+
torch.matmul(a, b)
60+
torch.npu.synchronize()
61+
npu_durations = ProfileExecuteDuration().pop_captured_sync()
62+
assert not npu_durations

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config):
5353
"graph_batch_sizes", [])
5454
self.graph_batch_sizes_init = torchair_graph_config.get(
5555
"graph_batch_sizes_init", False)
56+
self.enable_multistream_shared_expert = torchair_graph_config.get(
57+
"enable_multistream_shared_expert", False)
5658

5759
if not isinstance(self.graph_batch_sizes, list):
5860
raise TypeError("graph_batch_sizes must be list[int]")
@@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager):
105107
ascend_config = get_ascend_config()
106108

107109
# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
108-
if ascend_config.torchair_graph_config.enabled and not enforce_eager:
110+
if ascend_config.torchair_graph_config.enabled and enforce_eager:
109111
raise RuntimeError(
110112
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
111113
)

vllm_ascend/envs.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,84 @@
2727
# begin-env-vars-definition
2828

2929
env_variables: Dict[str, Callable[[], Any]] = {
30-
# max compile thread num
30+
# max compile thread number for package building. Usually, it is set to
31+
# the number of CPU cores. If not set, the default value is None, which
32+
# means all number of CPU cores will be used.
3133
"MAX_JOBS":
3234
lambda: os.getenv("MAX_JOBS", None),
35+
# The build type of the package. It can be one of the following values:
36+
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
3337
"CMAKE_BUILD_TYPE":
3438
lambda: os.getenv("CMAKE_BUILD_TYPE"),
39+
# Whether to compile custom kernels. If not set, the default value is True.
40+
# If set to False, the custom kernels will not be compiled. Please note that
41+
# the sleep mode feature will be disabled as well if custom kernels are not
42+
# compiled.
3543
"COMPILE_CUSTOM_KERNELS":
3644
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
45+
# The CXX compiler used for compiling the package. If not set, the default
46+
# value is None, which means the system default CXX compiler will be used.
47+
"CXX_COMPILER":
48+
lambda: os.getenv("CXX_COMPILER", None),
49+
# The C compiler used for compiling the package. If not set, the default
50+
# value is None, which means the system default C compiler will be used.
51+
"C_COMPILER":
52+
lambda: os.getenv("C_COMPILER", None),
53+
# Whether to enable MC2 for DeepSeek. If not set, the default value is False.
54+
# MC2 is a fusion operator provided by Ascend to speed up computing and communication.
55+
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
3756
"VLLM_ENABLE_MC2":
3857
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
58+
# Whether to enable the topk optimization. It's disabled by default for experimental support
59+
# We'll make it enabled by default in the future.
3960
"VLLM_ASCEND_ENABLE_TOPK_OPTIMZE":
4061
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMZE", '0'))),
62+
# Whether to use LCCL communication. If not set, the default value is False.
4163
"USING_LCCL_COM":
4264
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
65+
# The version of the Ascend chip. If not set, the default value is
66+
# ASCEND910B1. It's used for package building. Please make sure that the
67+
# version is correct.
4368
"SOC_VERSION":
4469
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
4570
# If set, vllm-ascend will print verbose logs during compilation
4671
"VERBOSE":
4772
lambda: bool(int(os.getenv('VERBOSE', '0'))),
73+
# The home path for CANN toolkit. If not set, the default value is
74+
# /usr/local/Ascend/ascend-toolkit/latest
4875
"ASCEND_HOME_PATH":
4976
lambda: os.getenv("ASCEND_HOME_PATH", None),
50-
"LD_LIBRARY_PATH":
51-
lambda: os.getenv("LD_LIBRARY_PATH", None),
52-
# Used for disaggregated prefilling
77+
# The path for HCCN Tool, the tool will be called by disaggregated prefilling
78+
# case.
5379
"HCCN_PATH":
5480
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
81+
# The path for HCCL library, it's used by pyhccl communicator backend. If
82+
# not set, the default value is libhccl.so。
5583
"HCCL_SO_PATH":
84+
# The prefill device id for disaggregated prefilling case.
5685
lambda: os.environ.get("HCCL_SO_PATH", None),
5786
"PROMPT_DEVICE_ID":
5887
lambda: os.getenv("PROMPT_DEVICE_ID", None),
88+
# The decode device id for disaggregated prefilling case.
5989
"DECODE_DEVICE_ID":
6090
lambda: os.getenv("DECODE_DEVICE_ID", None),
91+
# The port number for llmdatadist communication. If not set, the default
92+
# value is 26000.
6193
"LLMDATADIST_COMM_PORT":
6294
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
95+
# The wait time for llmdatadist sync cache. If not set, the default value is
96+
# 5000ms.
6397
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
6498
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
65-
"CXX_COMPILER":
66-
lambda: os.getenv("CXX_COMPILER", None),
67-
"C_COMPILER":
68-
lambda: os.getenv("C_COMPILER", None),
99+
# The version of vllm is installed. This value is used for developers who
100+
# installed vllm from source locally. In this case, the version of vllm is
101+
# usually changed. For example, if the version of vllm is "0.9.0", but when
102+
# it's installed from source, the version of vllm is usually set to "0.9.1".
103+
# In this case, developers need to set this value to "0.9.0" to make sure
104+
# that the correct package is installed.
69105
"VLLM_VERSION":
70106
lambda: os.getenv("VLLM_VERSION", None),
107+
# Whether to enable the trace recompiles from pytorch.
71108
"VLLM_ASCEND_TRACE_RECOMPILES":
72109
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
73110
"VLLM_ASCEND_ENABLE_DBO":

vllm_ascend/models/deepseek_v2.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def __init__(
216216

217217
ascend_config = get_ascend_config()
218218
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
219+
self.enable_multistream_shared_expert = \
220+
ascend_config.torchair_graph_config.enable_multistream_shared_expert
219221

220222
def forward(
221223
self,
@@ -238,6 +240,8 @@ def forward(
238240

239241
num_tokens, hidden_size = hidden_states.shape
240242

243+
multistream = self.enable_multistream_shared_expert and not is_prefill
244+
241245
old_hidden_states = hidden_states.clone()
242246

243247
if self.tp_size > 1:
@@ -259,13 +263,25 @@ def forward(
259263
# router_logits: (num_tokens, n_experts)
260264
router_logits, _ = self.gate(hidden_states)
261265

266+
kwargs = {}
267+
if multistream:
268+
kwargs.update({
269+
"shared_experts": self.shared_experts,
270+
"shared_hidden_states": old_hidden_states
271+
})
272+
262273
hidden_states = self.experts(
263274
hidden_states=hidden_states,
264275
router_logits=router_logits,
265276
is_prefill=is_prefill,
266277
top_k=CustomDeepseekV2MoE.top_k,
267278
enable_force_load_balance=enable_force_load_balance,
268-
) * self.routed_scaling_factor
279+
**kwargs)
280+
281+
if multistream:
282+
hidden_states, shared_output = hidden_states
283+
284+
hidden_states = hidden_states * self.routed_scaling_factor
269285

270286
if self.tp_size > 1:
271287
if self.torchair_graph_enabled:
@@ -288,7 +304,8 @@ def forward(
288304
hidden_states = hidden_states[:-num_padding_tokens]
289305

290306
if self.n_shared_experts is not None:
291-
shared_output = self.shared_experts(old_hidden_states)
307+
if not multistream:
308+
shared_output = self.shared_experts(old_hidden_states)
292309

293310
if shared_output is not None:
294311
hidden_states = hidden_states + shared_output

vllm_ascend/ops/fused_moe.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,18 @@
3939
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
4040

4141

42-
def fused_experts_with_mc2(
43-
hidden_states: torch.Tensor,
44-
w1: torch.Tensor,
45-
w2: torch.Tensor,
46-
topk_weights: torch.Tensor,
47-
topk_ids: torch.Tensor,
48-
top_k: int,
49-
expert_map: torch.Tensor = None,
50-
moe_all_to_all_group_name: Optional[str] = None,
51-
) -> torch.Tensor:
42+
def fused_experts_with_mc2(hidden_states: torch.Tensor,
43+
w1: torch.Tensor,
44+
w2: torch.Tensor,
45+
topk_weights: torch.Tensor,
46+
topk_ids: torch.Tensor,
47+
top_k: int,
48+
expert_map: torch.Tensor = None,
49+
moe_all_to_all_group_name: Optional[str] = None,
50+
**kwargs) -> torch.Tensor:
5251
global_bs = 0
5352
moe_expert_num = len(expert_map)
54-
kwargs = {
53+
kwargs_mc2 = {
5554
"x": hidden_states,
5655
"expert_ids": topk_ids,
5756
"expert_shard_type": 0,
@@ -81,9 +80,9 @@ def fused_experts_with_mc2(
8180
"tp_world_size": tp_size,
8281
"tp_rank_id": tp_rank,
8382
}
84-
kwargs.update(stage1_kwargs)
83+
kwargs_mc2.update(stage1_kwargs)
8584

86-
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
85+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
8786
# comm_stream.wait_stream(torch.npu.current_stream())
8887
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
8988
0:5]
@@ -119,7 +118,7 @@ def fused_experts_with_mc2(
119118
down_out_list = torch.cat(down_out_list, dim=0)
120119

121120
# moeCombine
122-
kwargs = {
121+
kwargs_mc2 = {
123122
"expand_x": down_out_list,
124123
"expert_ids": topk_ids,
125124
"expand_idx": expand_idx,
@@ -141,9 +140,9 @@ def fused_experts_with_mc2(
141140
"tp_world_size": tp_size,
142141
"tp_rank_id": tp_rank,
143142
}
144-
kwargs.update(stage3_kwargs)
143+
kwargs_mc2.update(stage3_kwargs)
145144

146-
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
145+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
147146

148147
return hidden_states
149148

@@ -675,7 +674,8 @@ def apply(
675674
topk_ids=topk_ids,
676675
top_k=top_k,
677676
expert_map=expert_map,
678-
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
677+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
678+
**kwargs)
679679
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
680680
return fused_experts(hidden_states=x,
681681
w1=layer.w13_weight,
@@ -772,6 +772,8 @@ def __init__(
772772

773773
ascend_config = get_ascend_config()
774774
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
775+
self.enable_multistream_shared_expert = \
776+
ascend_config.torchair_graph_config.enable_multistream_shared_expert
775777

776778
if self.scoring_func != "softmax" and not self.use_grouped_topk:
777779
raise ValueError("Only softmax scoring function is supported for "
@@ -818,7 +820,8 @@ def forward(self,
818820
router_logits: torch.Tensor,
819821
is_prefill: bool,
820822
enable_force_load_balance: bool = False,
821-
top_k=None):
823+
top_k=None,
824+
**kwargs):
822825
assert self.quant_method is not None
823826

824827
if top_k:
@@ -862,7 +865,11 @@ def forward(self,
862865
scoring_func=self.scoring_func,
863866
e_score_correction_bias=self.e_score_correction_bias,
864867
is_prefill=is_prefill,
865-
enable_force_load_balance=enable_force_load_balance)
868+
enable_force_load_balance=enable_force_load_balance,
869+
**kwargs)
870+
871+
if self.enable_multistream_shared_expert and not is_prefill:
872+
hidden_states, shared_output = hidden_states
866873

867874
if self.dp_size > 1:
868875
if VLLM_ENABLE_MC2 and not is_prefill:
@@ -886,6 +893,8 @@ def forward(self,
886893
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
887894
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
888895

896+
if self.enable_multistream_shared_expert and not is_prefill:
897+
return hidden_states, shared_output
889898
return hidden_states
890899

891900
# ----------------------------------------- TBO-related --------------------------------------------

0 commit comments

Comments
 (0)