Skip to content

Commit 0db6670

Browse files
yiz-liuYizhou Liu
andauthored
[Feature] Implement EP-compatible fused_moe (#121)
### What this PR does / why we need it? Enable Expert-Parallel for ascend devices. ### Does this PR introduce _any_ user-facing change? Enable EP add `enable_expert_parallel=True` in your offline inference scripts, like this: ```python llm = LLM( model="/path/to/model", trust_remote_code=True, tensor_parallel_size=4, max_model_len=4096, enforce_eager=True, distributed_executor_backend="mp", enable_expert_parallel=True, ) ``` ### How was this patch tested? Please use the `main` branch of vLLM. --------- Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com>
1 parent 4c9d78a commit 0db6670

File tree

2 files changed

+366
-129
lines changed

2 files changed

+366
-129
lines changed

tests/ops/test_fused_moe.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# This file is a part of the vllm-ascend project.
3+
# Adapted from vllm/tests/kernels/test_moe.py
4+
# Copyright 2023 The vLLM team.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# SPDX-License-Identifier: Apache-2.0
18+
"""Tests for the MOE layers.
19+
20+
Run `pytest tests/ops/test_fused_moe.py`.
21+
"""
22+
23+
import pytest
24+
import torch
25+
from vllm.model_executor.layers.activation import SiluAndMul
26+
27+
from vllm_ascend.ops.fused_moe import fused_experts
28+
29+
NUM_EXPERTS = [8, 64]
30+
EP_SIZE = [1, 4]
31+
TOP_KS = [2, 6]
32+
DEVICE = ["npu"]
33+
34+
35+
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
36+
B, D = a.shape
37+
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
38+
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
39+
topk_weights = topk_weights.view(-1)
40+
topk_ids = topk_ids.view(-1)
41+
if expert_map is not None:
42+
topk_ids = expert_map[topk_ids]
43+
for i in range(w1.shape[0]):
44+
mask = topk_ids == i
45+
if mask.sum():
46+
out[mask] = SiluAndMul()(
47+
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
48+
return (out.view(B, -1, w2.shape[1]) *
49+
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
50+
51+
52+
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
53+
@pytest.mark.parametrize("n", [128, 1024, 2048])
54+
@pytest.mark.parametrize("k", [128, 511, 1024])
55+
@pytest.mark.parametrize("e", NUM_EXPERTS)
56+
@pytest.mark.parametrize("topk", TOP_KS)
57+
@pytest.mark.parametrize("ep_size", EP_SIZE)
58+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
59+
@pytest.mark.parametrize("device", DEVICE)
60+
def test_fused_experts(
61+
m: int,
62+
n: int,
63+
k: int,
64+
e: int,
65+
topk: int,
66+
ep_size: int,
67+
dtype: torch.dtype,
68+
device: str,
69+
):
70+
a = torch.randn((m, k), device=device, dtype=dtype) / 10
71+
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
72+
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
73+
74+
score = torch.randn((m, e), device=device, dtype=dtype)
75+
76+
if ep_size > 1:
77+
local_e = e // ep_size
78+
e_ids = torch.randint(0,
79+
e, (local_e, ),
80+
device=device,
81+
dtype=torch.int32)
82+
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
83+
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
84+
w1 = w1[e_ids]
85+
w2 = w2[e_ids]
86+
else:
87+
e_map = None
88+
89+
score = torch.softmax(score, dim=-1, dtype=dtype)
90+
topk_weights, topk_ids = torch.topk(score, topk)
91+
topk_ids = topk_ids.to(torch.int32)
92+
93+
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
94+
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
95+
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
96+
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)

0 commit comments

Comments
 (0)