Skip to content

Commit 0a9bbaa

Browse files
xsank唯勤
andauthored
[Misc] support model prefix & add deepseek vl2 tiny fused moe config (#17763)
Signed-off-by: 唯勤 <xsank.mz@alibaba-inc.com> Co-authored-by: 唯勤 <xsank.mz@alibaba-inc.com>
1 parent 39956ef commit 0a9bbaa

File tree

2 files changed

+161
-9
lines changed

2 files changed

+161
-9
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
from contextlib import nullcontext
77
from datetime import datetime
88
from itertools import product
9+
from types import SimpleNamespace
910
from typing import Any, TypedDict
1011

1112
import ray
1213
import torch
1314
from ray.experimental.tqdm_ray import tqdm
14-
from transformers import AutoConfig
1515

1616
from vllm.model_executor.layers.fused_moe.fused_moe import *
1717
from vllm.platforms import current_platform
18+
from vllm.transformers_utils.config import get_config
1819
from vllm.triton_utils import triton
1920
from vllm.utils import FlexibleArgumentParser
2021

@@ -534,8 +535,12 @@ def get_weight_block_size_safety(config, default_value=None):
534535
def main(args: argparse.Namespace):
535536
print(args)
536537

537-
config = AutoConfig.from_pretrained(
538-
args.model, trust_remote_code=args.trust_remote_code)
538+
config = get_config(model=args.model,
539+
trust_remote_code=args.trust_remote_code)
540+
if args.model_prefix:
541+
config = getattr(config, args.model_prefix)
542+
config = SimpleNamespace(**config)
543+
539544
if config.architectures[0] == "DbrxForCausalLM":
540545
E = config.ffn_config.moe_num_experts
541546
topk = config.ffn_config.moe_top_k
@@ -546,15 +551,14 @@ def main(args: argparse.Namespace):
546551
topk = config.num_experts_per_tok
547552
intermediate_size = config.intermediate_size
548553
shard_intermediate_size = 2 * intermediate_size // args.tp_size
549-
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
550-
or config.architectures[0] == "DeepseekV2ForCausalLM"):
554+
elif (config.architectures[0]
555+
in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
551556
E = config.n_routed_experts
552557
topk = config.num_experts_per_tok
553558
intermediate_size = config.moe_intermediate_size
554559
shard_intermediate_size = 2 * intermediate_size // args.tp_size
555-
elif config.architectures[0] in [
556-
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
557-
]:
560+
elif config.architectures[0] in ("Qwen2MoeForCausalLM",
561+
"Qwen3MoeForCausalLM"):
558562
E = config.num_experts
559563
topk = config.num_experts_per_tok
560564
intermediate_size = config.moe_intermediate_size
@@ -569,7 +573,8 @@ def main(args: argparse.Namespace):
569573
shard_intermediate_size = 2 * intermediate_size // args.tp_size
570574

571575
hidden_size = config.hidden_size
572-
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
576+
dtype = torch.float16 if current_platform.is_rocm() else getattr(
577+
torch, config.torch_dtype)
573578
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
574579
use_int8_w8a16 = args.dtype == "int8_w8a16"
575580
block_quant_shape = get_weight_block_size_safety(config)
@@ -659,6 +664,7 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
659664
parser.add_argument("--batch-size", type=int, required=False)
660665
parser.add_argument("--tune", action="store_true")
661666
parser.add_argument("--trust-remote-code", action="store_true")
667+
parser.add_argument("--model-prefix", type=str, required=False)
662668
args = parser.parse_args()
663669

664670
main(args)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 64,
6+
"GROUP_SIZE_M": 64,
7+
"num_warps": 4,
8+
"num_stages": 5
9+
},
10+
"2": {
11+
"BLOCK_SIZE_M": 16,
12+
"BLOCK_SIZE_N": 64,
13+
"BLOCK_SIZE_K": 128,
14+
"GROUP_SIZE_M": 64,
15+
"num_warps": 4,
16+
"num_stages": 3
17+
},
18+
"4": {
19+
"BLOCK_SIZE_M": 16,
20+
"BLOCK_SIZE_N": 64,
21+
"BLOCK_SIZE_K": 64,
22+
"GROUP_SIZE_M": 1,
23+
"num_warps": 4,
24+
"num_stages": 3
25+
},
26+
"8": {
27+
"BLOCK_SIZE_M": 16,
28+
"BLOCK_SIZE_N": 64,
29+
"BLOCK_SIZE_K": 64,
30+
"GROUP_SIZE_M": 1,
31+
"num_warps": 4,
32+
"num_stages": 5
33+
},
34+
"16": {
35+
"BLOCK_SIZE_M": 16,
36+
"BLOCK_SIZE_N": 64,
37+
"BLOCK_SIZE_K": 64,
38+
"GROUP_SIZE_M": 1,
39+
"num_warps": 4,
40+
"num_stages": 4
41+
},
42+
"24": {
43+
"BLOCK_SIZE_M": 16,
44+
"BLOCK_SIZE_N": 64,
45+
"BLOCK_SIZE_K": 64,
46+
"GROUP_SIZE_M": 1,
47+
"num_warps": 4,
48+
"num_stages": 5
49+
},
50+
"32": {
51+
"BLOCK_SIZE_M": 16,
52+
"BLOCK_SIZE_N": 64,
53+
"BLOCK_SIZE_K": 128,
54+
"GROUP_SIZE_M": 1,
55+
"num_warps": 4,
56+
"num_stages": 3
57+
},
58+
"48": {
59+
"BLOCK_SIZE_M": 16,
60+
"BLOCK_SIZE_N": 128,
61+
"BLOCK_SIZE_K": 64,
62+
"GROUP_SIZE_M": 1,
63+
"num_warps": 4,
64+
"num_stages": 5
65+
},
66+
"64": {
67+
"BLOCK_SIZE_M": 16,
68+
"BLOCK_SIZE_N": 64,
69+
"BLOCK_SIZE_K": 64,
70+
"GROUP_SIZE_M": 1,
71+
"num_warps": 4,
72+
"num_stages": 5
73+
},
74+
"96": {
75+
"BLOCK_SIZE_M": 16,
76+
"BLOCK_SIZE_N": 64,
77+
"BLOCK_SIZE_K": 64,
78+
"GROUP_SIZE_M": 1,
79+
"num_warps": 4,
80+
"num_stages": 5
81+
},
82+
"128": {
83+
"BLOCK_SIZE_M": 16,
84+
"BLOCK_SIZE_N": 64,
85+
"BLOCK_SIZE_K": 64,
86+
"GROUP_SIZE_M": 1,
87+
"num_warps": 4,
88+
"num_stages": 4
89+
},
90+
"256": {
91+
"BLOCK_SIZE_M": 32,
92+
"BLOCK_SIZE_N": 128,
93+
"BLOCK_SIZE_K": 64,
94+
"GROUP_SIZE_M": 1,
95+
"num_warps": 4,
96+
"num_stages": 3
97+
},
98+
"512": {
99+
"BLOCK_SIZE_M": 64,
100+
"BLOCK_SIZE_N": 64,
101+
"BLOCK_SIZE_K": 64,
102+
"GROUP_SIZE_M": 1,
103+
"num_warps": 4,
104+
"num_stages": 4
105+
},
106+
"1024": {
107+
"BLOCK_SIZE_M": 64,
108+
"BLOCK_SIZE_N": 64,
109+
"BLOCK_SIZE_K": 64,
110+
"GROUP_SIZE_M": 1,
111+
"num_warps": 4,
112+
"num_stages": 4
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 64,
116+
"BLOCK_SIZE_N": 64,
117+
"BLOCK_SIZE_K": 64,
118+
"GROUP_SIZE_M": 1,
119+
"num_warps": 4,
120+
"num_stages": 4
121+
},
122+
"2048": {
123+
"BLOCK_SIZE_M": 64,
124+
"BLOCK_SIZE_N": 64,
125+
"BLOCK_SIZE_K": 64,
126+
"GROUP_SIZE_M": 16,
127+
"num_warps": 4,
128+
"num_stages": 4
129+
},
130+
"3072": {
131+
"BLOCK_SIZE_M": 64,
132+
"BLOCK_SIZE_N": 64,
133+
"BLOCK_SIZE_K": 64,
134+
"GROUP_SIZE_M": 1,
135+
"num_warps": 4,
136+
"num_stages": 4
137+
},
138+
"4096": {
139+
"BLOCK_SIZE_M": 64,
140+
"BLOCK_SIZE_N": 128,
141+
"BLOCK_SIZE_K": 64,
142+
"GROUP_SIZE_M": 32,
143+
"num_warps": 4,
144+
"num_stages": 4
145+
}
146+
}

0 commit comments

Comments
 (0)