Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions csrc/cutlass_moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,27 @@ __global__ void compute_expert_offsets(
}

__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
int32_t* atomic_buffer, const int topk_length,
const int topk) {
int expert_id = blockIdx.x;
int blk_expert_id = blockIdx.x;
int const num_experts = gridDim.x;
int32_t const num_tokens = expert_offsets[num_experts];

for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1);
int const expert_id = topk_ids[i];
if (expert_id == -1 && blockIdx.x == 0) {
// output_permutation is used to re-order the moe outputs. It is
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
// output of the cutlass kernels and c_map is the output_permutation.
// c2 is initialized to zeros, therefore by setting the output_permutation
// to num_tokens, we are guaranteed to fill the moe outputs to zero
// for "invalid" topk_ids.
output_permutation[i] = num_tokens;
} else if (expert_id == blk_expert_id) {
int start = atomicAdd(&atomic_buffer[blk_expert_id], 1);
input_permutation[start] = i / topk;
output_permutation[i] = start;
}
Expand Down Expand Up @@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
Expand Down
64 changes: 10 additions & 54 deletions csrc/cutlass_moe/moe_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ struct sm90_8_bit_config_N8192 {

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_16_bit_config_default {
// M in (16, inf)
struct sm90_16_bit_config_M512 {
// M in [1, 512]
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_64>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;

using Cutlass3xGemm =
Expand All @@ -97,46 +97,14 @@ struct sm90_16_bit_config_default {

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_16_bit_config_M16 {
// M in [1, 16]
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;

using Cutlass3xGemm =
cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_16_bit_config_K8192 {
// K in [8192, inf)
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;

using Cutlass3xGemm =
cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_16_bit_config_N8192 {
// N in [8192, inf)
struct sm90_16_bit_config_default {
// M in (1024, inf]
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_64>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;

using Cutlass3xGemm =
cutlass_3x_moe_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
Expand Down Expand Up @@ -204,11 +172,7 @@ void run_cutlass_moe_mm_sm90_16_bit(
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");

using Cutlass3xGemmN8192 = typename sm90_16_bit_config_N8192<
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
using Cutlass3xGemmK8192 = typename sm90_16_bit_config_K8192<
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
using Cutlass3xGemmM16 = typename sm90_16_bit_config_M16<
using Cutlass3xGemmM512 = typename sm90_16_bit_config_M512<
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
using Cutlass3xGemmDefault = typename sm90_16_bit_config_default<
InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm;
Expand All @@ -217,16 +181,8 @@ void run_cutlass_moe_mm_sm90_16_bit(
uint32_t const n = out_tensors.size(1);
uint32_t const k = a_tensors.size(1);
Comment on lines 181 to 182

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these vars look unused now, it's ok to remove them


if (n >= 8192) {
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides);
} else if (k >= 8192) {
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides);
} else if (m <= 16) {
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmM16>(
if (m <= 512) {
cutlass_moe_gemm_caller_16_bit<Cutlass3xGemmM512>(
out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes,
a_strides, b_strides, c_strides);
} else {
Expand Down
Loading