Skip to content

Commit 1168828

Browse files
author
varun sundar rabindranath
committed
fp16 configs and expert map support
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
1 parent 847150a commit 1168828

File tree

6 files changed

+406
-277
lines changed

6 files changed

+406
-277
lines changed

csrc/cutlass_moe/moe_data.cu

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,27 @@ __global__ void compute_expert_offsets(
4646
}
4747

4848
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
49+
const int32_t* __restrict__ expert_offsets,
4950
int32_t* input_permutation,
5051
int32_t* output_permutation,
5152
int32_t* atomic_buffer, const int topk_length,
5253
const int topk) {
53-
int expert_id = blockIdx.x;
54+
int blk_expert_id = blockIdx.x;
55+
int const num_experts = gridDim.x;
56+
int32_t const num_tokens = expert_offsets[num_experts];
5457

5558
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
56-
if (topk_ids[i] == expert_id) {
57-
int start = atomicAdd(&atomic_buffer[expert_id], 1);
59+
int const expert_id = topk_ids[i];
60+
if (expert_id == -1 && blockIdx.x == 0) {
61+
// output_permutation is used to re-order the moe outputs. It is
62+
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
63+
// output of the cutlass kernels and c_map is the output_permutation.
64+
// c2 is initialized to zeros, therefore by setting the output_permutation
65+
// to num_tokens, we are guaranteed to fill the moe outputs to zero
66+
// for "invalid" topk_ids.
67+
output_permutation[i] = num_tokens;
68+
} else if (expert_id == blk_expert_id) {
69+
int start = atomicAdd(&atomic_buffer[blk_expert_id], 1);
5870
input_permutation[start] = i / topk;
5971
output_permutation[i] = start;
6072
}
@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
8395
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
8496
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
8597
static_cast<const int32_t*>(topk_ids.data_ptr()),
98+
static_cast<int32_t*>(expert_offsets.data_ptr()),
8699
static_cast<int32_t*>(input_permutation.data_ptr()),
87100
static_cast<int32_t*>(output_permutation.data_ptr()),
88101
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),

csrc/cutlass_moe/moe_mm_c3x.cu

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ struct sm90_8_bit_config_N8192 {
8181

8282
template <typename InType, typename OutType,
8383
template <typename, typename, typename> typename Epilogue>
84-
struct sm90_16_bit_config_default {
85-
// M in (16, inf)
84+
struct sm90_16_bit_config_M512 {
85+
// M in [1, 512]
8686
using KernelSchedule =
8787
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
8888
using EpilogueSchedule =
8989
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
90-
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
90+
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_64>;
9191
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
9292

9393
using Cutlass3xGemm =
@@ -97,46 +97,14 @@ struct sm90_16_bit_config_default {
9797

9898
template <typename InType, typename OutType,
9999
template <typename, typename, typename> typename Epilogue>
100-
struct sm90_16_bit_config_M16 {
101-
// M in [1, 16]
102-
using KernelSchedule =
103-
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
104-
using EpilogueSchedule =
105-
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
106-
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
107-
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
108-
109-
using Cutlass3xGemm =
110-
cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
111-
KernelSchedule, EpilogueSchedule>;
112-
};
113-
114-
template <typename InType, typename OutType,
115-
template <typename, typename, typename> typename Epilogue>
116-
struct sm90_16_bit_config_K8192 {
117-
// K in [8192, inf)
118-
using KernelSchedule =
119-
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
120-
using EpilogueSchedule =
121-
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
122-
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
123-
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
124-
125-
using Cutlass3xGemm =
126-
cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
127-
KernelSchedule, EpilogueSchedule>;
128-
};
129-
130-
template <typename InType, typename OutType,
131-
template <typename, typename, typename> typename Epilogue>
132-
struct sm90_16_bit_config_N8192 {
133-
// N in [8192, inf)
100+
struct sm90_16_bit_config_default {
101+
// M in (1024, inf]
134102
using KernelSchedule =
135103
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
136104
using EpilogueSchedule =
137105
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
138-
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
139-
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
106+
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_64>;
107+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
140108

141109
using Cutlass3xGemm =
142110
cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
@@ -204,11 +172,7 @@ void run_cutlass_moe_mm_sm90_16_bit(
204172
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
205173
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
206174

207-
using Cutlass3xGemmN8192 = typename sm90_16_bit_config_N8192<
208-
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
209-
using Cutlass3xGemmK8192 = typename sm90_16_bit_config_K8192<
210-
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
211-
using Cutlass3xGemmM16 = typename sm90_16_bit_config_M16<
175+
using Cutlass3xGemmM512 = typename sm90_16_bit_config_M512<
212176
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
213177
using Cutlass3xGemmDefault = typename sm90_16_bit_config_default<
214178
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
@@ -217,16 +181,8 @@ void run_cutlass_moe_mm_sm90_16_bit(
217181
uint32_t const n = out_tensors.size(1);
218182
uint32_t const k = a_tensors.size(1);
219183

220-
if (n >= 8192) {
221-
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmN8192>(
222-
out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
223-
a_strides, b_strides, c_strides);
224-
} else if (k >= 8192) {
225-
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmK8192>(
226-
out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
227-
a_strides, b_strides, c_strides);
228-
} else if (m <= 16) {
229-
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmM16>(
184+
if (m <= 512) {
185+
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmM512>(
230186
out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
231187
a_strides, b_strides, c_strides);
232188
} else {

0 commit comments

Comments
 (0)