Skip to content

Commit 84166fe

Browse files
[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent 6e0cd10 commit 84166fe

26 files changed

+925
-416
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
543543
# CUTLASS MoE kernels
544544

545545
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
546-
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
547-
# to compile MoE kernels that use its output.
546+
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
547+
# if it's possible to compile MoE kernels that use its output.
548548
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
549549
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
550550
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from vllm import _custom_ops as ops
99
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10+
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
1011
from vllm.model_executor.layers.fused_moe.fused_moe import (
11-
cutlass_moe_fp8,
1212
fused_experts,
1313
fused_topk,
1414
)
@@ -70,18 +70,9 @@ def bench_run(
7070
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
7171
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
7272

73-
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
74-
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
75-
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
76-
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
77-
7873
for expert in range(num_experts):
7974
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
8075
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
81-
w1_q_notransp = w1_q.clone()
82-
w2_q_notransp = w2_q.clone()
83-
w1_q = w1_q.transpose(1, 2)
84-
w2_q = w2_q.transpose(1, 2)
8576

8677
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
8778

@@ -122,25 +113,17 @@ def run_cutlass_moe(
122113
w2_scale: torch.Tensor,
123114
topk_weights: torch.Tensor,
124115
topk_ids: torch.Tensor,
125-
ab_strides1: torch.Tensor,
126-
c_strides1: torch.Tensor,
127-
ab_strides2: torch.Tensor,
128-
c_strides2: torch.Tensor,
129116
num_repeats: int,
130117
):
131118
for _ in range(num_repeats):
132119
cutlass_moe_fp8(
133120
a,
134121
w1,
135122
w2,
136-
w1_scale,
137-
w2_scale,
138123
topk_weights,
139124
topk_ids,
140-
ab_strides1,
141-
c_strides1,
142-
ab_strides2,
143-
c_strides2,
125+
w1_scale,
126+
w2_scale,
144127
a1_scale=a_scale,
145128
)
146129

@@ -153,10 +136,6 @@ def run_cutlass_from_graph(
153136
w2_scale: torch.Tensor,
154137
topk_weights: torch.Tensor,
155138
topk_ids: torch.Tensor,
156-
ab_strides1: torch.Tensor,
157-
c_strides1: torch.Tensor,
158-
ab_strides2: torch.Tensor,
159-
c_strides2: torch.Tensor,
160139
):
161140
with set_current_vllm_config(
162141
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
@@ -165,14 +144,10 @@ def run_cutlass_from_graph(
165144
a,
166145
w1_q,
167146
w2_q,
168-
w1_scale,
169-
w2_scale,
170147
topk_weights,
171148
topk_ids,
172-
ab_strides1,
173-
c_strides1,
174-
ab_strides2,
175-
c_strides2,
149+
w1_scale,
150+
w2_scale,
176151
a1_scale=a_scale,
177152
)
178153

@@ -218,10 +193,6 @@ def replay_graph(graph, num_repeats):
218193
w2_scale,
219194
topk_weights,
220195
topk_ids,
221-
ab_strides1,
222-
c_strides1,
223-
ab_strides2,
224-
c_strides2,
225196
)
226197
torch.cuda.synchronize()
227198

@@ -230,8 +201,8 @@ def replay_graph(graph, num_repeats):
230201
with torch.cuda.graph(triton_graph, stream=triton_stream):
231202
run_triton_from_graph(
232203
a,
233-
w1_q_notransp,
234-
w2_q_notransp,
204+
w1_q,
205+
w2_q,
235206
topk_weights,
236207
topk_ids,
237208
w1_scale,
@@ -250,18 +221,12 @@ def replay_graph(graph, num_repeats):
250221
"w2": w2,
251222
"score": score,
252223
"topk": topk,
253-
"w1_q_notransp": w1_q_notransp,
254-
"w2_q_notransp": w2_q_notransp,
255224
# Cutlass params
256225
"a_scale": a_scale,
257226
"w1_q": w1_q,
258227
"w2_q": w2_q,
259228
"w1_scale": w1_scale,
260229
"w2_scale": w2_scale,
261-
"ab_strides1": ab_strides1,
262-
"c_strides1": c_strides1,
263-
"ab_strides2": ab_strides2,
264-
"c_strides2": c_strides2,
265230
# cuda graph params
266231
"cutlass_graph": cutlass_graph,
267232
"triton_graph": triton_graph,
@@ -279,8 +244,8 @@ def replay_graph(graph, num_repeats):
279244
# Warmup
280245
run_triton_moe(
281246
a,
282-
w1_q_notransp,
283-
w2_q_notransp,
247+
w1_q,
248+
w2_q,
284249
topk_weights,
285250
topk_ids,
286251
w1_scale,
@@ -291,7 +256,7 @@ def replay_graph(graph, num_repeats):
291256

292257
results.append(
293258
benchmark.Timer(
294-
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
259+
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
295260
globals=globals,
296261
label=label,
297262
sub_label=sub_label,
@@ -322,16 +287,12 @@ def replay_graph(graph, num_repeats):
322287
w2_scale,
323288
topk_weights,
324289
topk_ids,
325-
ab_strides1,
326-
c_strides1,
327-
ab_strides2,
328-
c_strides2,
329290
num_warmup,
330291
)
331292

332293
results.append(
333294
benchmark.Timer(
334-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
295+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
335296
globals=globals,
336297
label=label,
337298
sub_label=sub_label,

csrc/ops.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ void cutlass_moe_mm(
236236
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
237237
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
238238
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
239-
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
239+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
240+
bool per_act_token, bool per_out_ch);
240241

241242
void cutlass_fp4_group_mm(
242243
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
@@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data(
251252
const int64_t num_experts, const int64_t n, const int64_t k,
252253
const std::optional<torch::Tensor>& blockscale_offsets);
253254

255+
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
256+
torch::Tensor& problem_sizes1,
257+
torch::Tensor& problem_sizes2,
258+
const torch::Tensor& expert_num_tokens,
259+
const int64_t num_local_experts,
260+
const int64_t padded_m, const int64_t n,
261+
const int64_t k);
262+
254263
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
255264
torch::Tensor const& b,
256265
torch::Tensor const& a_scales,

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
8484
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
8585
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
8686
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
87-
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
87+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
88+
bool per_act_token, bool per_out_ch) {
8889
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
8990
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
9091
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
113114
if (n >= 8192) {
114115
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
115116
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
116-
problem_sizes, a_strides, b_strides, c_strides);
117+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
118+
per_out_ch);
117119
} else if (k >= 8192) {
118120
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
119121
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
120-
problem_sizes, a_strides, b_strides, c_strides);
122+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
123+
per_out_ch);
121124
} else if (m <= 16) {
122125
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
123126
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
124-
problem_sizes, a_strides, b_strides, c_strides);
127+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
128+
per_out_ch);
125129
} else {
126130
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
127131
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
128-
problem_sizes, a_strides, b_strides, c_strides);
132+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
133+
per_out_ch);
129134
}
130135
}
131136

@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
134139
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
135140
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
136141
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
137-
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
142+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
143+
bool per_act_token, bool per_out_ch) {
138144
if (out_tensors.dtype() == torch::kBFloat16) {
139145
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
140146
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
141-
problem_sizes, a_strides, b_strides, c_strides);
147+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
148+
per_out_ch);
142149
} else {
143150
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
144151
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
145-
problem_sizes, a_strides, b_strides, c_strides);
152+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
153+
per_out_ch);
146154
}
147155
}
148156

@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
153161
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
154162
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
155163
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
156-
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
164+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
165+
bool per_act_token, bool per_out_ch) {
157166
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
158167
expert_offsets, problem_sizes, a_strides, b_strides,
159-
c_strides);
168+
c_strides, per_act_token, per_out_ch);
160169
}

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,15 @@ void cutlass_group_gemm_caller(
7676
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
7777
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
7878
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
79-
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
79+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
80+
bool per_act_token, bool per_out_ch) {
8081
using ElementAB = typename Gemm::ElementAB;
8182
using ElementD = typename Gemm::ElementD;
8283

8384
int num_experts = static_cast<int>(expert_offsets.size(0));
8485
int k_size = a_tensors.size(1);
8586
int n_size = out_tensors.size(1);
8687

87-
bool per_act_token = a_scales.numel() != 1;
88-
bool per_out_ch = b_scales.numel() != num_experts;
89-
9088
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
9189

9290
auto options_int =

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
constexpr uint64_t THREADS_PER_EXPERT = 512;
99

10-
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
10+
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
1111
int32_t* problem_sizes1,
1212
int32_t* problem_sizes2,
1313
int32_t* atomic_buffer,
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
6262
}
6363
}
6464

65-
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
65+
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
6666
const int32_t* __restrict__ expert_offsets,
6767
int32_t* input_permutation,
6868
int32_t* output_permutation,
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
103103

104104
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
105105
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
106-
static_cast<const int32_t*>(topk_ids.data_ptr()),
106+
static_cast<const uint32_t*>(topk_ids.data_ptr()),
107107
static_cast<int32_t*>(problem_sizes1.data_ptr()),
108108
static_cast<int32_t*>(problem_sizes2.data_ptr()),
109109
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
120120
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
121121
}
122122
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
123-
static_cast<const int32_t*>(topk_ids.data_ptr()),
123+
static_cast<const uint32_t*>(topk_ids.data_ptr()),
124124
static_cast<const int32_t*>(expert_offsets.data_ptr()),
125125
static_cast<int32_t*>(input_permutation.data_ptr()),
126126
static_cast<int32_t*>(output_permutation.data_ptr()),
127127
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
128128
topk_ids.size(1));
129129
}
130+
131+
__global__ void compute_pplx_data(int32_t* expert_offsets,
132+
int32_t* problem_sizes1,
133+
int32_t* problem_sizes2,
134+
const int32_t* __restrict__ expert_num_tokens,
135+
const int padded_m, const int n,
136+
const int k) {
137+
int expert_idx = threadIdx.x;
138+
139+
expert_offsets[expert_idx] = expert_idx * padded_m;
140+
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
141+
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
142+
problem_sizes1[expert_idx * 3 + 2] = k;
143+
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
144+
problem_sizes2[expert_idx * 3 + 1] = k;
145+
problem_sizes2[expert_idx * 3 + 2] = n;
146+
}
147+
148+
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
149+
torch::Tensor& problem_sizes1,
150+
torch::Tensor& problem_sizes2,
151+
const torch::Tensor& expert_num_tokens,
152+
const int64_t num_local_experts,
153+
const int64_t padded_m,
154+
const int64_t n, const int64_t k) {
155+
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
156+
157+
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
158+
static_cast<int32_t*>(expert_offsets.data_ptr()),
159+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
160+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
161+
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
162+
k);
163+
}

0 commit comments

Comments
 (0)