diff --git a/CMakeLists.txt b/CMakeLists.txt index edc64f87730a..10f8667db649 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -577,7 +577,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -595,6 +595,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + # moe_data.cu is used by all CUTLASS MoE kernels. cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index 3225378a6ca0..659941de182e 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -18,7 +18,6 @@ using ProblemShape = cutlass::gemm::GroupProblemShape>; using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; @@ -33,7 +32,7 @@ using LayoutD_Transpose = using LayoutC = LayoutD; using LayoutC_Transpose = LayoutD_Transpose; -template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule, bool swap_ab_ = false> @@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm { using ElementC = void; using ElementD = ElementC_; using ElementAccumulator = float; + using ArchTag = ArchTag_; using Epilogue = Epilogue_; @@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm { LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>; - using KernelType = enable_sm90_only>; struct GemmKernel : public KernelType {}; @@ -156,9 +156,14 @@ void cutlass_group_gemm_caller( static_cast(out_ptrs.data_ptr()), static_cast(c_strides.data_ptr())}; + int device_id = a_tensors.device().index(); + static const cutlass::KernelHardwareInfo hw_info{ + device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; + typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, - epilogue_args}; + epilogue_args, hw_info}; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu new file mode 100644 index 000000000000..641e5997f0fd --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -0,0 +1,140 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "grouped_mm_c3x.cuh" + +using namespace cute; + +namespace { + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm100_fp8_config_M64 { + // M in [1,64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm100_fp8_config_N8192 { + // N in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template +void run_cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); + + using Cutlass3xGemmDefault = typename sm100_fp8_config_default< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM64 = typename sm100_fp8_config_M64< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + + if (m <= 64) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} +} // namespace + +void dispatch_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm100( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + run_cutlass_moe_mm_sm100( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} + +void cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); +} diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu similarity index 88% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu rename to csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu index b024482208d3..8f21623b52fa 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu @@ -21,10 +21,11 @@ struct sm90_fp8_config_default { cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template @@ -112,9 +119,6 @@ void run_cutlass_moe_mm_sm90( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 623c9a2f096b..993c30c48c84 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -190,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, static_cast(problem_sizes2.data_ptr()), static_cast(expert_num_tokens.data_ptr()), padded_m, n, k); -} +} \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 31b60488dfb7..106bacb4883c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90( #endif +#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 +void cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); +#endif + #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, @@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { // and at least SM90 (Hopper) #if defined CUDA_VERSION - if (cuda_device_capability >= 90 && cuda_device_capability < 100) { - return CUDA_VERSION >= 12000; - } else if (cuda_device_capability >= 100) { + if (cuda_device_capability >= 100) { return CUDA_VERSION >= 12080; + } else if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; } #endif @@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS grouped FP8 kernels need at least CUDA 12.3 - // and SM90 (Hopper) + // CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper) + // or CUDA 12.8 and SM100 (Blackwell) #if defined CUDA_VERSION - if (cuda_device_capability == 90) { + if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; + } + if (cuda_device_capability >= 90) { return CUDA_VERSION >= 12030; } #endif @@ -234,16 +247,26 @@ void cutlass_moe_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 + if (version_num >= 100) { + cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } +#endif #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides, per_act_token, per_out_ch); - return; + if (version_num >= 90) { + cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } #endif TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90"); + ". Required capability: 90 or 100"); } void get_cutlass_moe_mm_data( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e7f65d13181d..90b45e32a688 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -332,6 +332,12 @@ def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported( + 100, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: # Confirm weights quantized. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1a31410c3385..d71cae0288c7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -83,7 +83,8 @@ def get_moe_method( return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4MoeMethod() - elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): + elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -740,6 +741,8 @@ def __init__( self.topk_indices_dtype = None self.fused_experts = None # type: ignore self.disable_expert_map = False + self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100( + self.weight_quant, self.input_quant) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -951,7 +954,29 @@ def apply( per_act_token = ( self.input_quant.strategy == QuantizationStrategy.TOKEN) - + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + # Triton fused_experts is faster in small batch sizes on SM100. + # Fall back to fused_experts in small batch sizes. + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) if self.fused_experts is None: # If no modular kernel is provided, use cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import (