Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
Expand All @@ -595,6 +595,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
"if you intend on running FP8 quantized MoE models on Blackwell.")
else()
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()

# moe_data.cu is used by all CUTLASS MoE kernels.
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
Expand Down
13 changes: 9 additions & 4 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ using ProblemShape =
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;

using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;

using LayoutA = cutlass::layout::RowMajor;
Expand All @@ -33,7 +32,7 @@ using LayoutD_Transpose =
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;

template <typename ElementAB_, typename ElementC_,
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, bool swap_ab_ = false>
Expand All @@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm {
using ElementC = void;
using ElementD = ElementC_;
using ElementAccumulator = float;
using ArchTag = ArchTag_;

using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;

Expand Down Expand Up @@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm {
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
Stages, KernelSchedule>::CollectiveOp>;

using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;

struct GemmKernel : public KernelType {};
Expand Down Expand Up @@ -156,9 +156,14 @@ void cutlass_group_gemm_caller(
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};

int device_id = a_tensors.device().index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};

typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
epilogue_args};
epilogue_args, hw_info};

using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
Expand Down
140 changes: 140 additions & 0 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#include <cudaTypedefs.h>

#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"

using namespace cute;

namespace {

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in [1,64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_N8192 {
// N in [8192, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");

TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");

using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;

uint32_t const m = a_tensors.size(0);
uint32_t const n = out_tensors.size(1);

if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
} // namespace

void dispatch_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}

void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ struct sm90_fp8_config_default {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
using ArchTag = cutlass::arch::Sm90;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
Expand All @@ -38,10 +39,12 @@ struct sm90_fp8_config_M4 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};

template <typename InType, typename OutType,
Expand All @@ -55,10 +58,12 @@ struct sm90_fp8_config_M64 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};

template <typename InType, typename OutType,
Expand All @@ -72,10 +77,11 @@ struct sm90_fp8_config_K8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
Expand All @@ -89,10 +95,11 @@ struct sm90_fp8_config_N8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType>
Expand All @@ -112,9 +119,6 @@ void run_cutlass_moe_mm_sm90(
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");

TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);

using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
}
45 changes: 34 additions & 11 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90(

#endif

#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif

#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
Expand Down Expand Up @@ -130,22 +140,25 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// and at least SM90 (Hopper)

#if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif

return false;
}

bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)

#if defined CUDA_VERSION
if (cuda_device_capability == 90) {
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
Expand Down Expand Up @@ -234,16 +247,26 @@ void cutlass_moe_mm(
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
if (version_num >= 90) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90");
". Required capability: 90 or 100");
}

void get_cutlass_moe_mm_data(
Expand Down
Loading