Skip to content

Commit 15ba07e

Browse files
authored
[Minor] Fused experts refactor (#15914)
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent d2b58ca commit 15ba07e

File tree

8 files changed

+790
-737
lines changed

8 files changed

+790
-737
lines changed

tests/kernels/test_block_fp8.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from vllm.config import VllmConfig, set_current_vllm_config
1010
from vllm.model_executor.layers.activation import SiluAndMul
1111
from vllm.model_executor.layers.fused_moe import fused_moe
12-
from vllm.model_executor.layers.fused_moe.fused_moe import (
13-
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
12+
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
13+
deep_gemm_moe_fp8)
14+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
15+
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
16+
moe_align_block_size)
1417
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1518
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
1619
from vllm.platforms import current_platform
@@ -437,7 +440,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
437440
pytest.skip(
438441
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
439442

440-
if (N <= 512):
443+
if N <= 512:
441444
pytest.skip("Skipping N <= 512 until performance issues solved.")
442445

443446
vllm_config = VllmConfig()

tests/kernels/test_cutlass_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from vllm import _custom_ops as ops
66
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
7-
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
8-
fused_experts,
7+
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
8+
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
99
fused_topk)
1010
from vllm.platforms import current_platform
1111

@@ -131,9 +131,9 @@ def test_cutlass_moe_no_graph(
131131
c_strides2,
132132
a1_scale=a_scale1)
133133

134-
print(triton_output)
135-
print(cutlass_output)
136-
print("*")
134+
#print(triton_output)
135+
#print(cutlass_output)
136+
#print("*")
137137

138138
torch.testing.assert_close(triton_output,
139139
cutlass_output,
@@ -234,9 +234,9 @@ def test_cutlass_moe_cuda_graph(
234234
graph.replay()
235235
torch.cuda.synchronize()
236236

237-
print(triton_output)
238-
print(cutlass_output)
239-
print("*")
237+
#print(triton_output)
238+
#print(cutlass_output)
239+
#print("*")
240240

241241
torch.testing.assert_close(triton_output,
242242
cutlass_output,

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def get_config() -> Optional[Dict[str, Any]]:
3535
# import to register the custom ops
3636
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
3737
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
38+
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
39+
cutlass_moe_fp8)
3840
from vllm.model_executor.layers.fused_moe.fused_moe import (
39-
cutlass_moe_fp8, fused_experts, fused_moe, fused_topk,
40-
get_config_file_name, grouped_topk)
41+
fused_experts, fused_moe, fused_topk, get_config_file_name,
42+
grouped_topk)
4143

4244
__all__ += [
4345
"fused_moe",
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Fused MoE kernel."""
3+
from typing import Optional
4+
5+
import torch
6+
7+
from vllm import _custom_ops as ops
8+
9+
10+
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
11+
def cutlass_moe_fp8(
12+
a: torch.Tensor,
13+
w1_q: torch.Tensor,
14+
w2_q: torch.Tensor,
15+
w1_scale: torch.Tensor,
16+
w2_scale: torch.Tensor,
17+
topk_weights: torch.Tensor,
18+
topk_ids: torch.Tensor,
19+
ab_strides1: torch.Tensor,
20+
c_strides1: torch.Tensor,
21+
ab_strides2: torch.Tensor,
22+
c_strides2: torch.Tensor,
23+
a1_scale: Optional[torch.Tensor] = None,
24+
a2_scale: Optional[torch.Tensor] = None,
25+
out_dtype: torch.dtype = torch.half,
26+
) -> torch.Tensor:
27+
"""
28+
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
29+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
30+
mechanism. The matrix multiplications are implemented with CUTLASS
31+
grouped gemm.
32+
33+
Parameters:
34+
- a (torch.Tensor): The input tensor to the MoE layer.
35+
Shape: [M, K]
36+
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
37+
Shape: [num_experts, K, 2N] (the weights are passed transposed)
38+
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
39+
Shape: [num_experts, N, K] (the weights are passed transposed)
40+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
41+
Shape: [num_experts] or [num_experts, 2N]
42+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
43+
Shape: [num_experts] or [num_experts, K]
44+
- gating_output (torch.Tensor): The output of the gating operation
45+
(before softmax).
46+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
47+
- ab_strides1 (torch.Tensor): The input and weights strides of the first
48+
grouped gemm.
49+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
50+
- ab_strides2 (torch.Tensor): The input and weights strides of the second
51+
grouped gemm.
52+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
53+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
54+
Shape: scalar or [M]
55+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
56+
quantize the intermediate result between the gemms.
57+
Shape: scalar or [M]
58+
- out_dtype (torch.Tensor): The output tensor type.
59+
60+
Returns:
61+
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
62+
"""
63+
64+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
65+
assert w1_q.dtype == torch.float8_e4m3fn
66+
assert w2_q.dtype == torch.float8_e4m3fn
67+
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
68+
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
69+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
70+
assert a1_scale is None or a1_scale.dim(
71+
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
72+
0], "Input scale shape mismatch"
73+
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
74+
1] == w1_q.shape[2], "W1 scale shape mismatch"
75+
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
76+
1] == w2_q.shape[2], "W2 scale shape mismatch"
77+
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
78+
assert w1_q.shape[0] == w1_scale.shape[
79+
0], "w1 scales expert number mismatch"
80+
assert w1_q.shape[0] == w2_scale.shape[
81+
0], "w2 scales expert number mismatch"
82+
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
83+
assert ab_strides1.shape[0] == w1_q.shape[
84+
0], "AB Strides 1 expert number mismatch"
85+
assert c_strides1.shape[0] == w1_q.shape[
86+
0], "C Strides 1 expert number mismatch"
87+
assert ab_strides2.shape[0] == w2_q.shape[
88+
0], "AB Strides 2 expert number mismatch"
89+
assert c_strides2.shape[0] == w2_q.shape[
90+
0], "C Strides 2 expert number mismatch"
91+
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
92+
93+
num_experts = w1_q.size(0)
94+
m = a.size(0)
95+
k = w1_q.size(1)
96+
n = w2_q.size(1)
97+
98+
topk = topk_ids.size(1)
99+
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
100+
a2_scale.numel() != 1 if a2_scale is not None else False)
101+
102+
a_q, a1_scale = ops.scaled_fp8_quant(
103+
a, a1_scale, use_per_token_if_dynamic=per_act_token)
104+
device = a_q.device
105+
106+
expert_offsets = torch.empty((num_experts + 1),
107+
dtype=torch.int32,
108+
device=device)
109+
problem_sizes1 = torch.empty((num_experts, 3),
110+
dtype=torch.int32,
111+
device=device)
112+
problem_sizes2 = torch.empty((num_experts, 3),
113+
dtype=torch.int32,
114+
device=device)
115+
116+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
117+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
118+
119+
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
120+
problem_sizes2, a_map, c_map, num_experts, n,
121+
k)
122+
123+
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
124+
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
125+
126+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
127+
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
128+
129+
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
130+
expert_offsets[:-1], problem_sizes1, ab_strides1,
131+
ab_strides1, c_strides1)
132+
133+
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
134+
torch.ops._C.silu_and_mul(intermediate, c1)
135+
136+
intemediate_q, a2_scale = ops.scaled_fp8_quant(
137+
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
138+
139+
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
140+
expert_offsets[:-1], problem_sizes2, ab_strides2,
141+
ab_strides2, c_strides2)
142+
143+
return (c2[c_map].view(m, topk, k) *
144+
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)

0 commit comments

Comments
 (0)