From 1eacd3d4a79a94964406c0e0faaaa6f5b9e9f14a Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 3 Jul 2025 17:29:04 +0000 Subject: [PATCH 01/16] [feat]: add SM100 support for cutlass groupGEMM Signed-off-by: Duncan Moss --- CMakeLists.txt | 22 ++++- .../cutlass_w8a8/moe/grouped_mm_c3x.cuh | 14 ++- .../cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu | 90 +++++++++++++++++++ ...ouped_mm_c3x.cu => grouped_mm_c3x_sm90.cu} | 15 ++-- .../cutlass_w8a8/scaled_mm_entry.cu | 27 +++++- .../compressed_tensors/compressed_tensors.py | 10 ++- .../compressed_tensors_moe.py | 3 +- 7 files changed, 161 insertions(+), 20 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu rename csrc/quantization/cutlass_w8a8/moe/{grouped_mm_c3x.cu => grouped_mm_c3x_sm90.cu} (93%) diff --git a/CMakeLists.txt b/CMakeLists.txt index edc64f87730a..10f8667db649 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -577,7 +577,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -595,6 +595,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + # moe_data.cu is used by all CUTLASS MoE kernels. cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index 3225378a6ca0..6657549df511 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -18,7 +18,6 @@ using ProblemShape = cutlass::gemm::GroupProblemShape>; using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; @@ -33,7 +32,7 @@ using LayoutD_Transpose = using LayoutC = LayoutD; using LayoutC_Transpose = LayoutD_Transpose; -template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule, bool swap_ab_ = false> @@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm { using ElementC = void; using ElementD = ElementC_; using ElementAccumulator = float; + using ArchTag = ArchTag_; using Epilogue = Epilogue_; @@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm { LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp>; - using KernelType = enable_sm90_only>; struct GemmKernel : public KernelType {}; @@ -156,9 +156,15 @@ void cutlass_group_gemm_caller( static_cast(out_ptrs.data_ptr()), static_cast(c_strides.data_ptr())}; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a_tensors.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, - epilogue_args}; + epilogue_args, hw_info}; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu new file mode 100644 index 000000000000..3f8902d2b8d7 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -0,0 +1,90 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "grouped_mm_c3x.cuh" + +using namespace cute; + +namespace { + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template +void run_cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = typename sm100_fp8_config_default< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); +} + +void dispatch_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm100( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + run_cutlass_moe_mm_sm100( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} + +} // namespace + +void cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); +} diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu similarity index 93% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu rename to csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu index b024482208d3..b182dd249e08 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu @@ -21,10 +21,11 @@ struct sm90_fp8_config_default { cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 31b60488dfb7..0a07caf26842 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90( #endif +#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 +void cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); +#endif + #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, @@ -141,13 +151,16 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS grouped FP8 kernels need at least CUDA 12.3 - // and SM90 (Hopper) + // CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper) + // or CUDA 12.8 and SM100 (Blackwell) #if defined CUDA_VERSION - if (cuda_device_capability == 90) { + if (cuda_device_capability >= 90 && cuda_device_capability < 100) { return CUDA_VERSION >= 12030; } + if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; + } #endif return false; @@ -234,6 +247,12 @@ void cutlass_moe_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 + cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; +#endif #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, @@ -243,7 +262,7 @@ void cutlass_moe_mm( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90"); + ". Required capability: 90 or 100"); } void get_cutlass_moe_mm_data( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e7f65d13181d..cac8104fd8e9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -327,10 +327,12 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: - return (self._check_scheme_supported(90, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90_or_sm100(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return ( + self._check_scheme_supported(90, error=False, match_exact=True) + or self._check_scheme_supported(100, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1a31410c3385..a41af93d40a1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -83,7 +83,8 @@ def get_moe_method( return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4MoeMethod() - elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): + elif quant_config._is_fp8_w8a8_sm90_or_sm100(weight_quant, + input_quant): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) From 32a7d2163caeb71084d9cba2f3eae32cd7c9c263 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 3 Jul 2025 18:39:41 +0000 Subject: [PATCH 02/16] remove redundant checks Signed-off-by: Duncan Moss --- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu | 3 --- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu | 3 --- 2 files changed, 6 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu index 3f8902d2b8d7..8fcf35874664 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -43,9 +43,6 @@ void run_cutlass_moe_mm_sm100( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmDefault = typename sm100_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu index b182dd249e08..a20c5876661a 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu @@ -115,9 +115,6 @@ void run_cutlass_moe_mm_sm90( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< From f0d819b6d027bbf25fba44efc039c5cde1d7834a Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 3 Jul 2025 19:52:30 +0000 Subject: [PATCH 03/16] separate sm90 and sm100 _is_fp8_w8a8 Signed-off-by: Duncan Moss --- .../compressed_tensors/compressed_tensors.py | 16 ++++++++++------ .../compressed_tensors/compressed_tensors_moe.py | 4 ++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index cac8104fd8e9..90b45e32a688 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -327,12 +327,16 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w8a8_sm90_or_sm100(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: - return ( - self._check_scheme_supported(90, error=False, match_exact=True) - or self._check_scheme_supported(100, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) + + def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported( + 100, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a41af93d40a1..a0cf09a5d159 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -83,8 +83,8 @@ def get_moe_method( return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4MoeMethod() - elif quant_config._is_fp8_w8a8_sm90_or_sm100(weight_quant, - input_quant): + elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) From 3ae294b9a098872f362298178eef88a32cfbcc8a Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 3 Jul 2025 21:00:38 +0000 Subject: [PATCH 04/16] further updates Signed-off-by: Duncan Moss --- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index 6657549df511..659941de182e 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -156,11 +156,10 @@ void cutlass_group_gemm_caller( static_cast(out_ptrs.data_ptr()), static_cast(c_strides.data_ptr())}; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a_tensors.device().index(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + int device_id = a_tensors.device().index(); + static const cutlass::KernelHardwareInfo hw_info{ + device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, From a7ac8820e684c48707a9cd9e9614ce7df4106239 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 7 Jul 2025 20:40:41 +0000 Subject: [PATCH 05/16] nit buster Signed-off-by: Duncan Moss --- csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 0a07caf26842..f9cf014ffaed 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -140,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { // and at least SM90 (Hopper) #if defined CUDA_VERSION - if (cuda_device_capability >= 90 && cuda_device_capability < 100) { - return CUDA_VERSION >= 12000; - } else if (cuda_device_capability >= 100) { + if (cuda_device_capability >= 100) { return CUDA_VERSION >= 12080; + } else if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; } #endif @@ -155,12 +155,12 @@ bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { // or CUDA 12.8 and SM100 (Blackwell) #if defined CUDA_VERSION - if (cuda_device_capability >= 90 && cuda_device_capability < 100) { - return CUDA_VERSION >= 12030; - } if (cuda_device_capability >= 100) { return CUDA_VERSION >= 12080; } + if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12030; + } #endif return false; From 271fbad914d507e692b7a7034866f13ffb840d12 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 8 Jul 2025 19:58:57 +0000 Subject: [PATCH 06/16] add version_num check to cutlass_moe_mm function Signed-off-by: Duncan Moss --- .../cutlass_w8a8/scaled_mm_entry.cu | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index f9cf014ffaed..106bacb4883c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -248,16 +248,20 @@ void cutlass_moe_mm( bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); #if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 - cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides, per_act_token, per_out_ch); - return; + if (version_num >= 100) { + cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } #endif #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides, per_act_token, per_out_ch); - return; + if (version_num >= 90) { + cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } #endif TORCH_CHECK_NOT_IMPLEMENTED( false, From af06e98335d06110099cb7ebc38f0be06c3ab1aa Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:21:55 -0700 Subject: [PATCH 07/16] add cutlass tuning Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../cutlass_w8a8/moe/grouped_mm_c3x.cuh | 47 +++++++++- .../cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu | 91 ++++++++++++++++++- 2 files changed, 128 insertions(+), 10 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index 659941de182e..d49c60914494 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -32,6 +32,13 @@ using LayoutD_Transpose = using LayoutC = LayoutD; using LayoutC_Transpose = LayoutD_Transpose; +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -92,7 +99,6 @@ void cutlass_group_gemm_caller( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { static constexpr bool swap_ab = Gemm::swap_ab; - using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -114,13 +120,44 @@ void cutlass_group_gemm_caller( out_tensors, a_scales, b_scales); using GemmKernel = typename Gemm::GemmKernel; - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; - using StrideC = typename GemmKernel::InternalStrideC; + + // Define stride types based on swap_ab + using StrideA = cute::conditional_t, Int<0>>, // B->A: ColumnMajor transposed + Stride, Int<0>> // A: RowMajor + >; + using StrideB = cute::conditional_t, Int<0>>, // A->B: RowMajor transposed + Stride, Int<0>> // B: ColumnMajor + >; + using StrideC = cute::conditional_t; + + // Handle problem shape for swapped case + torch::Tensor effective_problem_sizes = problem_sizes; + if constexpr (swap_ab) { + // When swapping A and B, problem dimensions need to be adjusted + // Original: (M, N, K) -> Swapped: (N, M, K) + effective_problem_sizes = torch::empty_like(problem_sizes); + auto* orig_sizes = static_cast(problem_sizes.data_ptr()); + auto* new_sizes = static_cast(effective_problem_sizes.data_ptr()); + + for (int i = 0; i < num_experts; ++i) { + int32_t m = orig_sizes[i * 3 + 0]; + int32_t n = orig_sizes[i * 3 + 1]; + int32_t k = orig_sizes[i * 3 + 2]; + + // Swap M and N for the kernel + new_sizes[i * 3 + 0] = n; // new M = original N + new_sizes[i * 3 + 1] = m; // new N = original M + new_sizes[i * 3 + 2] = k; // K unchanged + } + } ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = static_cast( - problem_sizes.data_ptr()); + effective_problem_sizes.data_ptr()); ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; typename GemmKernel::MainloopArguments mainloop_args; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu index 8fcf35874664..02fd39856099 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -13,6 +13,24 @@ namespace { template typename Epilogue> struct sm100_fp8_config_default { + // M in [1, 16) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm100_fp8_config_M32 { + // M in [32,inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; @@ -26,6 +44,40 @@ struct sm100_fp8_config_default { ClusterShape, KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm100_fp8_config_N4096 { + // N in [4096, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm100_fp8_config_SwapAB { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + template void run_cutlass_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, @@ -45,11 +97,40 @@ void run_cutlass_moe_mm_sm100( using Cutlass3xGemmDefault = typename sm100_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); + using Cutlass3xGemmN4096 = typename sm100_fp8_config_N4096< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM32 = typename sm100_fp8_config_M32< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmSwapAB = typename sm100_fp8_config_SwapAB< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + uint32_t const k = a_tensors.size(1); + bool swap_ab = m < 32; + if (swap_ab) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } + else{ + if (n >= 4096) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (m >= 32) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } + } } void dispatch_moe_mm_sm100( From d221ffd552bb9c18ed0d4cffc010151a94373c5b Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:59:36 -0700 Subject: [PATCH 08/16] update swap AB Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../cutlass_w8a8/moe/get_group_starts.cuh | 32 ++++++--- .../cutlass_w8a8/moe/grouped_mm_c3x.cuh | 72 +++++++++++-------- 2 files changed, 65 insertions(+), 39 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e89790847..251ca5c4923c 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -16,18 +16,32 @@ __global__ void get_group_gemm_starts( ElementAB* b_base_as_int, ElementC* out_base_as_int, ElementAccumulator* a_scales_base_as_int, ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, - bool per_act_token, bool per_out_ch) { + bool per_act_token, bool per_out_ch, bool swap_ab = false) { int expert_id = threadIdx.x; int64_t expert_offset = expert_offsets[expert_id]; - a_offsets[expert_id] = a_base_as_int + expert_offset * k; - b_offsets[expert_id] = b_base_as_int + expert_id * k * n; + if (swap_ab) { + + a_offsets[expert_id] = b_base_as_int + expert_id * k * n; // First operand gets B data + b_offsets[expert_id] = a_base_as_int + expert_offset * k; // Second operand gets A data + + // Swap scale pointers accordingly + a_scales_offsets[expert_id] = + b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); + b_scales_offsets[expert_id] = + a_scales_base_as_int + (per_act_token ? expert_offset : 0); + } else { + // Normal case + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + b_offsets[expert_id] = b_base_as_int + expert_id * k * n; + + a_scales_offsets[expert_id] = + a_scales_base_as_int + (per_act_token ? expert_offset : 0); + b_scales_offsets[expert_id] = + b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); + } out_offsets[expert_id] = out_base_as_int + expert_offset * n; - a_scales_offsets[expert_id] = - a_scales_base_as_int + (per_act_token ? expert_offset : 0); - b_scales_offsets[expert_id] = - b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); } #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ @@ -45,7 +59,7 @@ __global__ void get_group_gemm_starts( static_cast(out_tensors.data_ptr()), \ static_cast(a_scales.data_ptr()), \ static_cast(b_scales.data_ptr()), out_tensors.size(1), \ - a_tensors.size(1), per_act_token, per_out_ch); \ + a_tensors.size(1), per_act_token, per_out_ch, swap_ab); \ } namespace { @@ -56,7 +70,7 @@ void run_get_group_gemm_starts( torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor& out_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { + torch::Tensor const& b_scales, bool swap_ab = false) { TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index d49c60914494..fa02527de5ec 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -12,6 +12,26 @@ using namespace cute; +// Kernel to swap M and N dimensions in problem_sizes for swap_ab +__global__ void swap_problem_sizes_kernel( + const int32_t* input_problem_sizes, + int32_t* output_problem_sizes, + int num_experts) { + int expert_id = blockIdx.x * blockDim.x + threadIdx.x; + if (expert_id < num_experts) { + // Each expert has 3 values: M, N, K + int base_idx = expert_id * 3; + int32_t M = input_problem_sizes[base_idx + 0]; + int32_t N = input_problem_sizes[base_idx + 1]; + int32_t K = input_problem_sizes[base_idx + 2]; + + // Swap M and N for swap_ab + output_problem_sizes[base_idx + 0] = N; // New M = old N + output_problem_sizes[base_idx + 1] = M; // New N = old M + output_problem_sizes[base_idx + 2] = K; // K remains the same + } +} + namespace { using ProblemShape = @@ -117,42 +137,34 @@ void cutlass_group_gemm_caller( run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, - out_tensors, a_scales, b_scales); + out_tensors, a_scales, b_scales, swap_ab); using GemmKernel = typename Gemm::GemmKernel; - // Define stride types based on swap_ab - using StrideA = cute::conditional_t, Int<0>>, // B->A: ColumnMajor transposed - Stride, Int<0>> // A: RowMajor - >; - using StrideB = cute::conditional_t, Int<0>>, // A->B: RowMajor transposed - Stride, Int<0>> // B: ColumnMajor - >; - using StrideC = cute::conditional_t; - - // Handle problem shape for swapped case - torch::Tensor effective_problem_sizes = problem_sizes; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; + + // Create effective problem sizes - may need to swap M and N for swap_ab + torch::Tensor effective_problem_sizes; if constexpr (swap_ab) { - // When swapping A and B, problem dimensions need to be adjusted - // Original: (M, N, K) -> Swapped: (N, M, K) + // For swap_ab, we need to swap M and N dimensions in problem sizes effective_problem_sizes = torch::empty_like(problem_sizes); - auto* orig_sizes = static_cast(problem_sizes.data_ptr()); - auto* new_sizes = static_cast(effective_problem_sizes.data_ptr()); + // Launch kernel to swap M and N dimensions + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + int num_experts = static_cast(expert_offsets.size(0)); + int block_size = 256; + int grid_size = (num_experts + block_size - 1) / block_size; + + swap_problem_sizes_kernel<<>>( + static_cast(problem_sizes.data_ptr()), + static_cast(effective_problem_sizes.data_ptr()), + num_experts); - for (int i = 0; i < num_experts; ++i) { - int32_t m = orig_sizes[i * 3 + 0]; - int32_t n = orig_sizes[i * 3 + 1]; - int32_t k = orig_sizes[i * 3 + 2]; - - // Swap M and N for the kernel - new_sizes[i * 3 + 0] = n; // new M = original N - new_sizes[i * 3 + 1] = m; // new N = original M - new_sizes[i * 3 + 2] = k; // K unchanged - } + // Synchronize to ensure the swapped problem sizes are ready + cudaStreamSynchronize(stream); + } else { + effective_problem_sizes = problem_sizes; } ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = From 7bdae95ff3443da1f2edf40e2dd67c04f93bfd8f Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 14 Jul 2025 21:59:40 -0700 Subject: [PATCH 09/16] update shape Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu | 62 ++++++++++++------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu index 02fd39856099..667a7dab47ab 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -35,7 +35,7 @@ struct sm100_fp8_config_M32 { using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - using TileShape = cute::Shape; + using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; @@ -106,31 +106,47 @@ void run_cutlass_moe_mm_sm100( uint32_t const m = a_tensors.size(0); uint32_t const n = out_tensors.size(1); uint32_t const k = a_tensors.size(1); - bool swap_ab = m < 32; - if (swap_ab) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); - } - else{ +// bool swap_ab = false; +// if (swap_ab) { +// cutlass_group_gemm_caller( +// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, +// problem_sizes, a_strides, b_strides, c_strides, per_act_token, +// per_out_ch); +// } +// else{ +// if (n >= 4096) { +// cutlass_group_gemm_caller( +// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, +// problem_sizes, a_strides, b_strides, c_strides, per_act_token, +// per_out_ch); +// } else if (m >= 32) { +// cutlass_group_gemm_caller( +// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, +// problem_sizes, a_strides, b_strides, c_strides, per_act_token, +// per_out_ch); +// } else { +// cutlass_group_gemm_caller( +// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, +// problem_sizes, a_strides, b_strides, c_strides, per_act_token, +// per_out_ch); +// } +// } if (n >= 4096) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); - } else if (m >= 32) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); - } else { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); - } - } + } else if (m >= 32) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } } void dispatch_moe_mm_sm100( From b9a576389254545ae1aa6187c886c0b181485c63 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Tue, 15 Jul 2025 20:18:41 -0700 Subject: [PATCH 10/16] fix swapAB Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../cutlass_w8a8/moe/get_group_starts.cuh | 32 ++---- .../cutlass_w8a8/moe/grouped_mm_c3x.cuh | 49 +-------- .../cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu | 104 +++++------------- .../quantization/cutlass_w8a8/moe/moe_data.cu | 25 +---- 4 files changed, 48 insertions(+), 162 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 251ca5c4923c..6c6e89790847 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -16,32 +16,18 @@ __global__ void get_group_gemm_starts( ElementAB* b_base_as_int, ElementC* out_base_as_int, ElementAccumulator* a_scales_base_as_int, ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, - bool per_act_token, bool per_out_ch, bool swap_ab = false) { + bool per_act_token, bool per_out_ch) { int expert_id = threadIdx.x; int64_t expert_offset = expert_offsets[expert_id]; - if (swap_ab) { - - a_offsets[expert_id] = b_base_as_int + expert_id * k * n; // First operand gets B data - b_offsets[expert_id] = a_base_as_int + expert_offset * k; // Second operand gets A data - - // Swap scale pointers accordingly - a_scales_offsets[expert_id] = - b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); - b_scales_offsets[expert_id] = - a_scales_base_as_int + (per_act_token ? expert_offset : 0); - } else { - // Normal case - a_offsets[expert_id] = a_base_as_int + expert_offset * k; - b_offsets[expert_id] = b_base_as_int + expert_id * k * n; - - a_scales_offsets[expert_id] = - a_scales_base_as_int + (per_act_token ? expert_offset : 0); - b_scales_offsets[expert_id] = - b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); - } + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + b_offsets[expert_id] = b_base_as_int + expert_id * k * n; out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = + a_scales_base_as_int + (per_act_token ? expert_offset : 0); + b_scales_offsets[expert_id] = + b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); } #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ @@ -59,7 +45,7 @@ __global__ void get_group_gemm_starts( static_cast(out_tensors.data_ptr()), \ static_cast(a_scales.data_ptr()), \ static_cast(b_scales.data_ptr()), out_tensors.size(1), \ - a_tensors.size(1), per_act_token, per_out_ch, swap_ab); \ + a_tensors.size(1), per_act_token, per_out_ch); \ } namespace { @@ -70,7 +56,7 @@ void run_get_group_gemm_starts( torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor& out_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, bool swap_ab = false) { + torch::Tensor const& b_scales) { TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index fa02527de5ec..75eef618746c 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -11,27 +11,6 @@ #include "get_group_starts.cuh" using namespace cute; - -// Kernel to swap M and N dimensions in problem_sizes for swap_ab -__global__ void swap_problem_sizes_kernel( - const int32_t* input_problem_sizes, - int32_t* output_problem_sizes, - int num_experts) { - int expert_id = blockIdx.x * blockDim.x + threadIdx.x; - if (expert_id < num_experts) { - // Each expert has 3 values: M, N, K - int base_idx = expert_id * 3; - int32_t M = input_problem_sizes[base_idx + 0]; - int32_t N = input_problem_sizes[base_idx + 1]; - int32_t K = input_problem_sizes[base_idx + 2]; - - // Swap M and N for swap_ab - output_problem_sizes[base_idx + 0] = N; // New M = old N - output_problem_sizes[base_idx + 1] = M; // New N = old M - output_problem_sizes[base_idx + 2] = K; // K remains the same - } -} - namespace { using ProblemShape = @@ -137,39 +116,17 @@ void cutlass_group_gemm_caller( run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, - out_tensors, a_scales, b_scales, swap_ab); + out_tensors, a_scales, b_scales); using GemmKernel = typename Gemm::GemmKernel; - + using StrideA = Stride, Int<0>>; using StrideB = Stride, Int<0>>; using StrideC = typename GemmKernel::InternalStrideC; - // Create effective problem sizes - may need to swap M and N for swap_ab - torch::Tensor effective_problem_sizes; - if constexpr (swap_ab) { - // For swap_ab, we need to swap M and N dimensions in problem sizes - effective_problem_sizes = torch::empty_like(problem_sizes); - // Launch kernel to swap M and N dimensions - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - int num_experts = static_cast(expert_offsets.size(0)); - int block_size = 256; - int grid_size = (num_experts + block_size - 1) / block_size; - - swap_problem_sizes_kernel<<>>( - static_cast(problem_sizes.data_ptr()), - static_cast(effective_problem_sizes.data_ptr()), - num_experts); - - // Synchronize to ensure the swapped problem sizes are ready - cudaStreamSynchronize(stream); - } else { - effective_problem_sizes = problem_sizes; - } - ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = static_cast( - effective_problem_sizes.data_ptr()); + problem_sizes.data_ptr()); ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; typename GemmKernel::MainloopArguments mainloop_args; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu index 667a7dab47ab..641e5997f0fd 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -13,12 +13,11 @@ namespace { template typename Epilogue> struct sm100_fp8_config_default { - // M in [1, 16) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - using TileShape = cute::Shape; + using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; @@ -29,30 +28,30 @@ struct sm100_fp8_config_default { template typename Epilogue> -struct sm100_fp8_config_M32 { - // M in [32,inf) +struct sm100_fp8_config_M64 { + // M in [1,64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - using TileShape = cute::Shape; + using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; using Cutlass3xGemm = cutlass_3x_group_gemm; + ClusterShape, KernelSchedule, EpilogueSchedule, + true>; }; template typename Epilogue> -struct sm100_fp8_config_N4096 { - // N in [4096, inf) +struct sm100_fp8_config_N8192 { + // N in [8192, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; @@ -62,22 +61,6 @@ struct sm100_fp8_config_N4096 { ClusterShape, KernelSchedule, EpilogueSchedule>; }; -template typename Epilogue> -struct sm100_fp8_config_SwapAB { - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - using ArchTag = cutlass::arch::Sm100; - - using Cutlass3xGemm = - cutlass_3x_group_gemm; -}; - template void run_cutlass_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, @@ -97,57 +80,32 @@ void run_cutlass_moe_mm_sm100( using Cutlass3xGemmDefault = typename sm100_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmN4096 = typename sm100_fp8_config_N4096< + using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmM32 = typename sm100_fp8_config_M32< - InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmSwapAB = typename sm100_fp8_config_SwapAB< + using Cutlass3xGemmM64 = typename sm100_fp8_config_M64< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + uint32_t const m = a_tensors.size(0); uint32_t const n = out_tensors.size(1); - uint32_t const k = a_tensors.size(1); -// bool swap_ab = false; -// if (swap_ab) { -// cutlass_group_gemm_caller( -// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, -// problem_sizes, a_strides, b_strides, c_strides, per_act_token, -// per_out_ch); -// } -// else{ -// if (n >= 4096) { -// cutlass_group_gemm_caller( -// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, -// problem_sizes, a_strides, b_strides, c_strides, per_act_token, -// per_out_ch); -// } else if (m >= 32) { -// cutlass_group_gemm_caller( -// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, -// problem_sizes, a_strides, b_strides, c_strides, per_act_token, -// per_out_ch); -// } else { -// cutlass_group_gemm_caller( -// out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, -// problem_sizes, a_strides, b_strides, c_strides, per_act_token, -// per_out_ch); -// } -// } - if (n >= 4096) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); - } else if (m >= 32) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); - } else { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); - } + + if (m <= 64) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } } +} // namespace void dispatch_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, @@ -169,8 +127,6 @@ void dispatch_moe_mm_sm100( } } -} // namespace - void cutlass_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 623c9a2f096b..cec6db299a95 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -9,7 +9,6 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; // threshold must match the dispatch logic in run_cutlass_moe_mm_sm90() constexpr int SWAP_AB_THRESHOLD = 64; -template __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, @@ -118,23 +117,11 @@ void get_cutlass_moe_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - - if (topk_ids.numel() > SWAP_AB_THRESHOLD) { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } else { - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, - k); - } - + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); if (blockscale_offsets.has_value()) { compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), @@ -190,4 +177,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, static_cast(problem_sizes2.data_ptr()), static_cast(expert_num_tokens.data_ptr()), padded_m, n, k); -} +} \ No newline at end of file From c403a57fd1cf717c061621361468d29d03d658ec Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 16 Jul 2025 21:53:14 -0700 Subject: [PATCH 11/16] fall back logic Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../compressed_tensors_moe.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a0cf09a5d159..6d0fa3b9f723 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -952,7 +952,27 @@ def apply( per_act_token = ( self.input_quant.strategy == QuantizationStrategy.TOKEN) - + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + if topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) if self.fused_experts is None: # If no modular kernel is provided, use cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import ( From 2de546ce9575bf45442503cdecf6d6d669e6e0a6 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:08:58 -0700 Subject: [PATCH 12/16] rebase and fix Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../quantization/cutlass_w8a8/moe/moe_data.cu | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index cec6db299a95..d4f8f472583d 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -9,6 +9,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; // threshold must match the dispatch logic in run_cutlass_moe_mm_sm90() constexpr int SWAP_AB_THRESHOLD = 64; +template __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, @@ -117,11 +118,21 @@ void get_cutlass_moe_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); + if (topk_ids.numel() > SWAP_AB_THRESHOLD) { + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); + } else { + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); + } if (blockscale_offsets.has_value()) { compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), From de8ac9efca223841a1cb1931cf8fb2753ebd9e99 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:13:31 -0700 Subject: [PATCH 13/16] fix Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh | 7 ------- 1 file changed, 7 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index 75eef618746c..e4a0b7ab032a 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -31,13 +31,6 @@ using LayoutD_Transpose = using LayoutC = LayoutD; using LayoutC_Transpose = LayoutD_Transpose; -using LayoutA_Transpose = - typename cutlass::layout::LayoutTranspose::type; -using LayoutB_Transpose = - typename cutlass::layout::LayoutTranspose::type; -using LayoutC_Transpose = - typename cutlass::layout::LayoutTranspose::type; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, From d789068083010fa433de4307408980ed56f79de3 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:15:07 -0700 Subject: [PATCH 14/16] lint Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh | 3 ++- csrc/quantization/cutlass_w8a8/moe/moe_data.cu | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index e4a0b7ab032a..659941de182e 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -11,6 +11,7 @@ #include "get_group_starts.cuh" using namespace cute; + namespace { using ProblemShape = @@ -91,6 +92,7 @@ void cutlass_group_gemm_caller( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -112,7 +114,6 @@ void cutlass_group_gemm_caller( out_tensors, a_scales, b_scales); using GemmKernel = typename Gemm::GemmKernel; - using StrideA = Stride, Int<0>>; using StrideB = Stride, Int<0>>; using StrideC = typename GemmKernel::InternalStrideC; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index d4f8f472583d..993c30c48c84 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -118,6 +118,7 @@ void get_cutlass_moe_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); + if (topk_ids.numel() > SWAP_AB_THRESHOLD) { compute_problem_sizes<<>>( static_cast(topk_ids.data_ptr()), @@ -133,6 +134,7 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); } + if (blockscale_offsets.has_value()) { compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), From 457ddf9128384c8e3db35ae82bb8d7a94ca4a30c Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:06:52 -0700 Subject: [PATCH 15/16] add comemnt and fix logic Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors_moe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6d0fa3b9f723..d71cae0288c7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -741,6 +741,8 @@ def __init__( self.topk_indices_dtype = None self.fused_experts = None # type: ignore self.disable_expert_map = False + self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100( + self.weight_quant, self.input_quant) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -954,7 +956,9 @@ def apply( self.input_quant.strategy == QuantizationStrategy.TOKEN) per_channel_quant = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL) - if topk_ids.shape[0] <= 8: + # Triton fused_experts is faster in small batch sizes on SM100. + # Fall back to fused_experts in small batch sizes. + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts( x, From 9c7e2863a444c7fb45f169a3052b1c8bc2d5a604 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 18 Jul 2025 17:19:35 -0700 Subject: [PATCH 16/16] fix Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu index a20c5876661a..8f21623b52fa 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu @@ -39,10 +39,12 @@ struct sm90_fp8_config_M4 { cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template