Skip to content

Commit 521dce4

Browse files
cakengAkshat-Tripathi
authored andcommitted
Expert Parallelism (EP) Support for DeepSeek V2 (vllm-project#12583)
1 parent 449c61f commit 521dce4

File tree

19 files changed

+527
-59
lines changed

19 files changed

+527
-59
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,8 @@ def main(args: argparse.Namespace):
468468
topk = config.num_experts_per_tok
469469
intermediate_size = config.intermediate_size
470470
shard_intermediate_size = 2 * intermediate_size // args.tp_size
471-
elif config.architectures[0] == "DeepseekV3ForCausalLM":
471+
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
472+
or config.architectures[0] == "DeepseekV2ForCausalLM"):
472473
E = config.n_routed_experts
473474
topk = config.num_experts_per_tok
474475
intermediate_size = config.moe_intermediate_size
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
from typing import List, Literal, NamedTuple, Optional
5+
6+
import pytest
7+
8+
from vllm.config import TaskOption
9+
from vllm.logger import init_logger
10+
11+
from ..utils import compare_two_settings, fork_new_process_for_each_test
12+
13+
logger = init_logger("test_expert_parallel")
14+
15+
16+
class ParallelSetup(NamedTuple):
17+
tp_size: int
18+
eager_mode: bool
19+
chunked_prefill: bool
20+
21+
22+
class EPTestOptions(NamedTuple):
23+
trust_remote_code: bool
24+
tokenizer_mode: Optional[str]
25+
load_format: Optional[str] = None
26+
hf_overrides: Optional[str] = None
27+
28+
29+
@dataclass
30+
class EPTestSettings:
31+
parallel_setups: List[ParallelSetup]
32+
distributed_backends: List[str]
33+
task: TaskOption
34+
test_options: EPTestOptions
35+
36+
@staticmethod
37+
def detailed(
38+
*,
39+
tp_base: int = 2,
40+
task: TaskOption = "auto",
41+
trust_remote_code: bool = False,
42+
tokenizer_mode: Optional[str] = None,
43+
load_format: Optional[str] = None,
44+
hf_overrides: Optional[str] = None,
45+
):
46+
return EPTestSettings(
47+
parallel_setups=[
48+
ParallelSetup(tp_size=tp_base,
49+
eager_mode=False,
50+
chunked_prefill=False),
51+
ParallelSetup(tp_size=tp_base,
52+
eager_mode=False,
53+
chunked_prefill=True),
54+
ParallelSetup(tp_size=tp_base,
55+
eager_mode=True,
56+
chunked_prefill=False),
57+
ParallelSetup(tp_size=2 * tp_base,
58+
eager_mode=False,
59+
chunked_prefill=True),
60+
ParallelSetup(tp_size=2 * tp_base,
61+
eager_mode=True,
62+
chunked_prefill=False),
63+
],
64+
distributed_backends=["mp", "ray"],
65+
task=task,
66+
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
67+
tokenizer_mode=tokenizer_mode,
68+
load_format=load_format,
69+
hf_overrides=hf_overrides),
70+
)
71+
72+
@staticmethod
73+
def fast(
74+
*,
75+
tp_base: int = 2,
76+
task: TaskOption = "auto",
77+
trust_remote_code: bool = False,
78+
tokenizer_mode: Optional[str] = None,
79+
load_format: Optional[str] = None,
80+
hf_overrides: Optional[str] = None,
81+
):
82+
return EPTestSettings(
83+
parallel_setups=[
84+
ParallelSetup(tp_size=tp_base,
85+
eager_mode=True,
86+
chunked_prefill=False),
87+
],
88+
distributed_backends=["mp"],
89+
task=task,
90+
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
91+
tokenizer_mode=tokenizer_mode,
92+
load_format=load_format,
93+
hf_overrides=hf_overrides),
94+
)
95+
96+
def iter_params(self, model_name: str):
97+
opts = self.test_options
98+
99+
for parallel_setup in self.parallel_setups:
100+
for distributed_backend in self.distributed_backends:
101+
yield (model_name, parallel_setup, distributed_backend,
102+
self.task, opts)
103+
104+
105+
# NOTE: You can adjust tp_base locally to fit the model in GPU
106+
# The values displayed here are only a rough indicator of the size of the model
107+
108+
# yapf: disable
109+
TEST_MODELS = {
110+
"deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(
111+
trust_remote_code=True),
112+
"mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4),
113+
}
114+
115+
116+
def _compare_tp(
117+
model_name: str,
118+
parallel_setup: ParallelSetup,
119+
distributed_backend: str,
120+
task: TaskOption,
121+
test_options: EPTestOptions,
122+
num_gpus_available: int,
123+
*,
124+
method: Literal["generate"],
125+
):
126+
(
127+
tp_size,
128+
eager_mode,
129+
chunked_prefill,
130+
) = parallel_setup
131+
(
132+
trust_remote_code,
133+
tokenizer_mode,
134+
load_format,
135+
hf_overrides,
136+
) = test_options
137+
138+
if num_gpus_available < tp_size:
139+
pytest.skip(f"Need at least {tp_size} GPUs")
140+
141+
common_args = [
142+
# use half precision for speed and memory savings in CI environment
143+
"--dtype",
144+
"float16",
145+
"--max-model-len",
146+
"2048",
147+
"--max-num-seqs",
148+
"8",
149+
"--load-format",
150+
"auto",
151+
]
152+
if chunked_prefill:
153+
common_args.append("--enable-chunked-prefill")
154+
if eager_mode:
155+
common_args.append("--enforce-eager")
156+
if task != "auto":
157+
common_args.extend(["--task", task])
158+
if trust_remote_code:
159+
common_args.append("--trust-remote-code")
160+
if tokenizer_mode:
161+
common_args.extend(["--tokenizer-mode", tokenizer_mode])
162+
if load_format:
163+
common_args.extend(["--load-format", load_format])
164+
if hf_overrides:
165+
common_args.extend(["--hf-overrides", hf_overrides])
166+
167+
ep_env = {
168+
"VLLM_TEST_ENABLE_EP": "1",
169+
}
170+
171+
ep_args = [
172+
*common_args,
173+
"--tensor-parallel-size",
174+
str(tp_size),
175+
"--distributed-executor-backend",
176+
distributed_backend,
177+
]
178+
179+
# compare without expert parallelism
180+
tp_env = {
181+
"VLLM_TEST_ENABLE_EP": "0",
182+
}
183+
184+
tp_args = [
185+
*common_args,
186+
"--tensor-parallel-size",
187+
str(tp_size),
188+
"--distributed-executor-backend",
189+
"mp",
190+
]
191+
192+
try:
193+
compare_two_settings(model_name,
194+
ep_args,
195+
tp_args,
196+
ep_env,
197+
tp_env,
198+
method=method,
199+
max_wait_seconds=360)
200+
except Exception:
201+
raise
202+
203+
204+
@pytest.mark.parametrize(
205+
("model_name", "parallel_setup", "distributed_backend", "task",
206+
"test_options"),
207+
[
208+
params for model_name, settings in TEST_MODELS.items()
209+
for params in settings.iter_params(model_name)
210+
],
211+
)
212+
@fork_new_process_for_each_test
213+
def test_ep(
214+
model_name: str,
215+
parallel_setup: ParallelSetup,
216+
distributed_backend: str,
217+
task: TaskOption,
218+
test_options: EPTestOptions,
219+
num_gpus_available,
220+
):
221+
_compare_tp(model_name,
222+
parallel_setup,
223+
distributed_backend,
224+
task,
225+
test_options,
226+
num_gpus_available,
227+
method="generate")

tests/kernels/test_awq_marlin.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
9999
num_bits=num_bits,
100100
)
101101

102-
torch_output = torch_moe(
103-
a,
104-
w_ref1.transpose(1, 2),
105-
w_ref2.transpose(1, 2),
106-
score,
107-
topk,
108-
)
102+
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
103+
score, topk, None)
109104

110105
assert compute_max_diff(marlin_output, torch_output) < 4e-2
111106

tests/kernels/test_moe.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from vllm.scalar_type import scalar_types
2727

2828
NUM_EXPERTS = [8, 64]
29+
EP_SIZE = [1, 4]
2930
TOP_KS = [2, 6]
3031

3132

@@ -34,24 +35,54 @@
3435
@pytest.mark.parametrize("k", [128, 511, 1024])
3536
@pytest.mark.parametrize("e", NUM_EXPERTS)
3637
@pytest.mark.parametrize("topk", TOP_KS)
38+
@pytest.mark.parametrize("ep_size", EP_SIZE)
3739
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
3840
def test_fused_moe(
3941
m: int,
4042
n: int,
4143
k: int,
4244
e: int,
4345
topk: int,
46+
ep_size: int,
4447
dtype: torch.dtype,
4548
):
4649
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
4750
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
4851
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
4952

5053
score = torch.randn((m, e), device="cuda", dtype=dtype)
51-
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
52-
torch_output = torch_moe(a, w1, w2, score, topk)
54+
55+
if ep_size > 1:
56+
local_e = e // ep_size
57+
e_ids = torch.randint(0,
58+
e, (local_e, ),
59+
device="cuda",
60+
dtype=torch.int32)
61+
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
62+
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
63+
w1 = w1[e_ids]
64+
w2 = w2[e_ids]
65+
else:
66+
e_map = None
67+
68+
triton_output = fused_moe(a,
69+
w1,
70+
w2,
71+
score,
72+
topk,
73+
global_num_experts=e,
74+
expert_map=e_map,
75+
renormalize=False)
76+
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
5377
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
54-
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
78+
iterative_output = iterative_moe(a,
79+
w1,
80+
w2,
81+
score,
82+
topk,
83+
global_num_experts=e,
84+
expert_map=e_map,
85+
renormalize=False)
5586
torch.testing.assert_close(iterative_output,
5687
torch_output,
5788
atol=2e-2,
@@ -63,13 +94,14 @@ def test_fused_moe(
6394
@pytest.mark.parametrize("k", [128, 1024])
6495
@pytest.mark.parametrize("e", NUM_EXPERTS)
6596
@pytest.mark.parametrize("topk", TOP_KS)
97+
@pytest.mark.parametrize("ep_size", EP_SIZE)
6698
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
6799
@pytest.mark.parametrize("group_size", [64, 128])
68100
@pytest.mark.parametrize("has_zp", [True, False])
69101
@pytest.mark.parametrize("weight_bits", [4, 8])
70102
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
71-
dtype: torch.dtype, group_size: int, has_zp: bool,
72-
weight_bits: int):
103+
ep_size: int, dtype: torch.dtype, group_size: int,
104+
has_zp: bool, weight_bits: int):
73105
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
74106
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
75107
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
130162
if has_zp:
131163
w_qzeros[expert_id] = qzeros
132164

165+
if ep_size > 1:
166+
local_e = e // ep_size
167+
e_ids = torch.randint(0,
168+
e, (local_e, ),
169+
device="cuda",
170+
dtype=torch.int32)
171+
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
172+
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
173+
w1_ref = w1_ref[e_ids]
174+
w2_ref = w2_ref[e_ids]
175+
w1_qweight = w1_qweight[e_ids]
176+
w2_qweight = w2_qweight[e_ids]
177+
w1_scales = w1_scales[e_ids]
178+
w2_scales = w2_scales[e_ids]
179+
w1_qzeros = w1_qzeros[e_ids]
180+
w2_qzeros = w2_qzeros[e_ids]
181+
else:
182+
e_map = None
183+
133184
triton_output = fused_moe(a,
134185
w1_qweight,
135186
w2_qweight,
@@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
138189
renormalize=False,
139190
use_int4_w4a16=weight_bits == 4,
140191
use_int8_w8a16=weight_bits == 8,
192+
global_num_experts=e,
193+
expert_map=e_map,
141194
w1_scale=w1_scales,
142195
w2_scale=w2_scales,
143196
w1_zp=w1_qzeros if has_zp else None,
144197
w2_zp=w2_qzeros if has_zp else None,
145198
block_shape=[0, group_size])
146-
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
199+
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
147200
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
148201

149202

tests/kernels/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1053,14 +1053,16 @@ def compute_max_diff(output, output_ref):
10531053
torch.abs(output_ref))
10541054

10551055

1056-
def torch_moe(a, w1, w2, score, topk):
1056+
def torch_moe(a, w1, w2, score, topk, expert_map):
10571057
B, D = a.shape
10581058
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
10591059
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
10601060
score = torch.softmax(score, dim=-1, dtype=torch.float32)
10611061
topk_weight, topk_ids = torch.topk(score, topk)
10621062
topk_weight = topk_weight.view(-1)
10631063
topk_ids = topk_ids.view(-1)
1064+
if expert_map is not None:
1065+
topk_ids = expert_map[topk_ids]
10641066
for i in range(w1.shape[0]):
10651067
mask = topk_ids == i
10661068
if mask.sum():

0 commit comments

Comments
 (0)