From 5794a8eba3c2bd9131ff708e243bb3d2fcc5e340 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 12 Jun 2025 19:30:43 -0700 Subject: [PATCH 01/16] [feat] sm100 cutlass blockscaled group gemm Signed-off-by: Duncan Moss --- CMakeLists.txt | 21 + csrc/ops.h | 9 + .../moe/blockwise_scaled_group_mm_sm100.cu | 410 ++++++++++++++++++ csrc/torch_bindings.cpp | 7 + .../kernels/moe/test_cutlass_grouped_gemm.py | 160 +++++++ vllm/_custom_ops.py | 18 + 6 files changed, 625 insertions(+) create mode 100644 csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu create mode 100644 tests/kernels/moe/test_cutlass_grouped_gemm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bd389823fbb2..1910ca3c7a0c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -565,6 +565,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "in CUDA target architectures") endif() endif() + + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu" + "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() # # Machete kernels diff --git a/csrc/ops.h b/csrc/ops.h index f02f5083ac19..601739a7456d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -239,6 +239,15 @@ void cutlass_moe_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch); +void cutlass_blockwise_scaled_grouped_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets); + void cutlass_fp4_group_mm( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu new file mode 100644 index 000000000000..a244dfd01acd --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu @@ -0,0 +1,410 @@ +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include + +using namespace cute; + +template < + typename ElementAB, + typename ElementC, + typename ElementAccumulator, + typename LayoutSFA, + typename LayoutSFB, + typename ScaleConfig> +__global__ void get_ggemm_starts( + int32_t* expert_offsets, + ElementAB** a_offsets, + ElementAB** b_offsets, + ElementC** out_offsets, + ElementAccumulator** a_scale_offsets, + ElementAccumulator** b_scale_offsets, + ElementAB* a_base_as_int, + ElementAB* b_base_as_int, + ElementC* out_base_as_int, + ElementAccumulator* a_scale_base_as_int, + ElementAccumulator* b_scale_base_as_int, + LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, + int* problem_sizes) { + + int expert_id = threadIdx.x; + + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + + int m = problem_sizes[expert_id * 3]; + int n = problem_sizes[expert_id * 3 + 1]; + int k = problem_sizes[expert_id * 3 + 2]; + + int32_t expert_offset = expert_offsets[expert_id]; + int a_stride = expert_offset * k; + int b_stride = expert_id * k * n; + int a_scale_stride = expert_offset * k / 128; + int b_scale_stride = expert_id * k * n / 128 / 128; + + a_offsets[expert_id] = a_base_as_int + a_stride; + b_offsets[expert_id] = b_base_as_int + b_stride; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride; + b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_ggemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(problem_sizes.data_ptr())); \ + } + + +template +void run_get_ggemm_starts( + torch::Tensor const& expert_offsets, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, + torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor out_tensors, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& layout_sfa, + torch::Tensor const& layout_sfb, + torch::Tensor const& problem_sizes) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0); + TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0); + + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) {} + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Unsupported output tensor type"); + } +} + +template +void run_blockwise_scaled_group_mm( + torch::Tensor& out_ptrs, + const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, + const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, + const torch::Tensor& stride_a, + const torch::Tensor& stride_b, + const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, + const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets) { + + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + // Types + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = LayoutD; + + // Alignments + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + void, + LayoutC*, + AlignmentC, + ElementD, + LayoutC*, + AlignmentC, + typename ScheduleConfig::EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = (int)expert_offsets.size(0); + + Gemm gemm_op; + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(stride_a.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(stride_b.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr()) + }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a_ptrs.get_device(); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(stride_c.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(stride_c.data_ptr()) + }; + + UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info}; + + at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void blockwise_scaled_group_mm_dispatch_shape( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets) { + + struct MmaConfig { + using ElementA = cutlass::float_e4m3_t; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + using LayoutC = cutlass::layout::RowMajor; + using MmaTileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + }; + + int num_experts = (int)expert_offsets.size(0); + + auto a_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto b_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto out_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto a_scales_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto b_scales_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + + auto layout_sfa = torch::empty({num_experts, 5}, torch::TensorOptions().dtype(torch::kInt32).device(a.device())); + auto layout_sfb = torch::empty({num_experts, 5}, torch::TensorOptions().dtype(torch::kInt32).device(a.device())); + + auto stride_a = torch::full({num_experts}, a.size(1), torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto stride_b = torch::full({num_experts}, a.size(1), torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto stride_c = torch::full({num_experts}, output.size(1), torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + + torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + run_get_ggemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes + ); + + run_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets + ); +} + + +void cutlass_blockwise_scaled_grouped_mm( + torch::Tensor& output, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets) { + + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); + TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn"); + TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn"); + TORCH_CHECK( + output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf, + "output must be bfloat16 or half"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be float32"); + TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32"); + + TORCH_CHECK(output.dim() == 2, "output must be 2D tensor"); + TORCH_CHECK(a.dim() == 2, "a must be 2D tensor"); + TORCH_CHECK(b.dim() == 3, "b must be 3D tensor"); + TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor"); + TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor"); + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); + TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); + + #if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100 + if (output.scalar_type() == torch::kBFloat16) { + blockwise_scaled_group_mm_dispatch_shape( + output, + a, + b, + scales_a, + scales_b, + problem_sizes, + expert_offsets); + } else if (output.scalar_type() == torch::kFloat16) { + blockwise_scaled_group_mm_dispatch_shape( + output, + a, + b, + scales_a, + scales_b, + problem_sizes, + expert_offsets); + } else { + TORCH_CHECK(false, "Unsupported output tensor type"); + } + #endif +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1a1896b4c1ee..43ab61d31e88 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -393,6 +393,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // cutlass blockwise scaledgroup GEMM + ops.def( + "cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, " + "Tensor problem_sizes, Tensor expert_offsets) -> ()", + {stride_tag}); + ops.impl("cutlass_blockwise_scaled_grouped_mm", torch::kCUDA, &cutlass_blockwise_scaled_grouped_mm); + // cutlass nvfp4 block scaled group GEMM ops.def( "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py new file mode 100644 index 000000000000..1aec0f53d59a --- /dev/null +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +import pytest +import torch +from typing import Tuple + +from vllm import _custom_ops as ops + +def cdiv(a, b): + return (a + b - 1) // b + +def scale_shape(shape, group_shape): + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def baseline_scaled_mm(a, b, a_scales, b_scales, out_dtype): + + def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s% t.shape[i] == 0 + t = ( + t.unsqueeze(i+1) + .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:]) + .flatten(i, i+1) + ) + return t + + scale_a = group_broadcast(a_scales, a.shape) + scale_b = group_broadcast(b_scales, b.shape) + + return torch.mm( + (scale_a * a.to(dtype=torch.float32)), + (scale_b * b.to(dtype=torch.float32)) + ).to(dtype=out_dtype) + +@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ + (4, 8192, 7168, 4096), + # (16, 128, 128, 128), + # (64, 128, 128, 128), +]) +@pytest.mark.parametrize("out_dtype", [torch.half]) +def test_cutlass_grouped_gemm( + num_groups: int, + expected_m_per_group: int, + k: int, + n: int, + out_dtype: torch.dtype, +): + device = "cuda" + alignment = 8 + group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + m = sum([cdiv(m, alignment) * alignment for m in group_ms]) + + x = torch.randn((m, k), device=device, dtype=out_dtype) + y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype) + m_indicies = torch.empty(m, device=device, dtype=torch.int32) + out = torch.empty((m, n), device=device, dtype=out_dtype) + ref_out = torch.randn((m, n), device=device, dtype=out_dtype) + + start = 0 + for i, group_m in enumerate(group_ms): + actual_end = start + group_m + aligned_end = start + cdiv(group_m, alignment) * alignment + m_indicies[start:aligned_end] = i + m_indicies[aligned_end:actual_end] = -1 + ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() + start = aligned_end + ref_out = torch.where((m_indicies == -1)) + + + print(num_groups, expected_m_per_group, k, n) + + + + + +@pytest.mark.parametrize("num_experts", [8, 16, 64]) +@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) +def test_cutlass_grouped_moe( + num_experts: int, + out_dtype: torch.dtype, +): + device = "cuda" + alignment = 8 + n_g = alignment * random.randint(1, 5) * 128 + k_g = alignment * random.randint(1, 5) * 128 + + scale_a_group_shape = (1, 128) + scale_b_group_shape = (128, 128) + + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) + + a_tensors = [] + b_tensors = [] + a_scales_tensors = [] + b_scales_tensors = [] + baseline_tensors = [] + + for g in range(num_experts): + m_g = alignment * random.randint(1, 64) + expert_offsets[g+1] = expert_offsets[g] + m_g + problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) + + a_g = to_fp8(torch.randn((m_g, k_g), device=device)) + b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) + a_tensors.append(a_g) + b_tensors.append(b_g) + + scale_a_shape = scale_shape(a_g.shape, scale_a_group_shape) + scale_b_shape = scale_shape(b_g.shape, scale_b_group_shape) + + a_scales_tensors.append(torch.randn(scale_a_shape, device=device) * 0.001) + b_scales_tensors.append(torch.randn(scale_b_shape, device=device) * 0.001) + + baseline = baseline_scaled_mm( + a_g, b_g, a_scales_tensors[-1], b_scales_tensors[-1], out_dtype + ) + baseline_tensors.append(baseline) + + a_stack = torch.empty((expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn) + b_stack = torch.empty((num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn) + + for g in range(num_experts): + a_stack[expert_offsets[g]:expert_offsets[g+1]] = a_tensors[g] + b_stack[g] = b_tensors[g].t() + b_stack = b_stack.transpose(1, 2) + + a_scale_stack = torch.empty((expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32) + b_scale_stack = torch.empty((num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32) + + for g in range(num_experts): + a_scale_stack[expert_offsets[g]:expert_offsets[g+1]] = a_scales_tensors[g] + b_scale_stack[g] = b_scales_tensors[g].t() + b_scale_stack = b_scale_stack.transpose(1, 2) + + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + + ops.cutlass_blockwise_scaled_grouped_mm( + c_out, + a_stack, + b_stack, + a_scale_stack, + b_scale_stack, + problem_sizes, + expert_offsets[:-1], + ) + + for g in range(num_experts): + baseline = baseline_tensors[g] + actual = c_out[expert_offsets[g]:expert_offsets[g+1]] + + torch.testing.assert_close(baseline, actual, atol=1e-2, rtol=5e-4) + \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e26c90bf70cb..48e13d97a589 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -646,6 +646,24 @@ def _ggml_moe_a8_vec_fake( def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) +def cutlass_blockwise_scaled_grouped_mm( + output: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, +): + torch.ops._C.cutlass_blockwise_scaled_grouped_mm( + output, + a, + b, + scales_a, + scales_b, + problem_sizes, + expert_offsets) + def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, From eea2a2d0b6a9a625367fa3c519b8fdb368854802 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 16 Jun 2025 14:11:23 -0700 Subject: [PATCH 02/16] added vllm moe hooks Signed-off-by: Duncan Moss --- .../kernels/moe/test_cutlass_grouped_gemm.py | 161 +++++++----------- vllm/envs.py | 5 + .../layers/fused_moe/cutlass_moe.py | 107 ++++++++++++ .../layers/fused_moe/fused_moe.py | 16 +- .../model_executor/layers/quantization/fp8.py | 16 +- 5 files changed, 203 insertions(+), 102 deletions(-) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 1aec0f53d59a..c2e21f0227c1 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +# DeepGEMM Style Cutlass Grouped GEMM Test +# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py + import random import pytest import torch @@ -7,16 +10,34 @@ from vllm import _custom_ops as ops +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + def cdiv(a, b): return (a + b - 1) // b -def scale_shape(shape, group_shape): - return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - +def per_token_cast_to_fp8(x : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (128 - (n % 128)) % 128 + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + +def per_block_cast_to_fp8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((cdiv(m, 128)* 128, cdiv(n, 128)* 128), device=x.device, dtype=x.dtype) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) def baseline_scaled_mm(a, b, a_scales, b_scales, out_dtype): @@ -41,10 +62,13 @@ def group_broadcast(t, shape): @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), - # (16, 128, 128, 128), - # (64, 128, 128, 128), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + (32, 1024, 7168, 4096), + (32, 1024, 2048, 7168), ]) -@pytest.mark.parametrize("out_dtype", [torch.half]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) def test_cutlass_grouped_gemm( num_groups: int, expected_m_per_group: int, @@ -53,7 +77,7 @@ def test_cutlass_grouped_gemm( out_dtype: torch.dtype, ): device = "cuda" - alignment = 8 + alignment = 128 group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] m = sum([cdiv(m, alignment) * alignment for m in group_ms]) @@ -63,98 +87,35 @@ def test_cutlass_grouped_gemm( out = torch.empty((m, n), device=device, dtype=out_dtype) ref_out = torch.randn((m, n), device=device, dtype=out_dtype) - start = 0 - for i, group_m in enumerate(group_ms): - actual_end = start + group_m - aligned_end = start + cdiv(group_m, alignment) * alignment - m_indicies[start:aligned_end] = i - m_indicies[aligned_end:actual_end] = -1 - ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() - start = aligned_end - ref_out = torch.where((m_indicies == -1)) - - - print(num_groups, expected_m_per_group, k, n) - - - - - -@pytest.mark.parametrize("num_experts", [8, 16, 64]) -@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) -def test_cutlass_grouped_moe( - num_experts: int, - out_dtype: torch.dtype, -): - device = "cuda" - alignment = 8 - n_g = alignment * random.randint(1, 5) * 128 - k_g = alignment * random.randint(1, 5) * 128 - - scale_a_group_shape = (1, 128) - scale_b_group_shape = (128, 128) - - expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) - problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) - - a_tensors = [] - b_tensors = [] - a_scales_tensors = [] - b_scales_tensors = [] - baseline_tensors = [] - - for g in range(num_experts): - m_g = alignment * random.randint(1, 64) - expert_offsets[g+1] = expert_offsets[g] + m_g - problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) - - a_g = to_fp8(torch.randn((m_g, k_g), device=device)) - b_g = to_fp8(torch.randn((n_g, k_g), device=device).t()) - a_tensors.append(a_g) - b_tensors.append(b_g) - - scale_a_shape = scale_shape(a_g.shape, scale_a_group_shape) - scale_b_shape = scale_shape(b_g.shape, scale_b_group_shape) - - a_scales_tensors.append(torch.randn(scale_a_shape, device=device) * 0.001) - b_scales_tensors.append(torch.randn(scale_b_shape, device=device) * 0.001) - - baseline = baseline_scaled_mm( - a_g, b_g, a_scales_tensors[-1], b_scales_tensors[-1], out_dtype - ) - baseline_tensors.append(baseline) - - a_stack = torch.empty((expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn) - b_stack = torch.empty((num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn) - - for g in range(num_experts): - a_stack[expert_offsets[g]:expert_offsets[g+1]] = a_tensors[g] - b_stack[g] = b_tensors[g].t() - b_stack = b_stack.transpose(1, 2) - - a_scale_stack = torch.empty((expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32) - b_scale_stack = torch.empty((num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32) - - for g in range(num_experts): - a_scale_stack[expert_offsets[g]:expert_offsets[g+1]] = a_scales_tensors[g] - b_scale_stack[g] = b_scales_tensors[g].t() - b_scale_stack = b_scale_stack.transpose(1, 2) - - c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) + ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m] + pb_size = [] + for i in range(num_groups): + pb_size.append([ep_offset[i+1] - ep_offset[i], n, k]) + problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32) + expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) + + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + for i in range(num_groups): + a = x_fp8[0][ep_offset[i]:ep_offset[i+1]] + a_scale = x_fp8[1][ep_offset[i]:ep_offset[i+1]] + b = y_fp8[0][i].t() + b_scale = y_fp8[1][i].t() + baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) + ref_out[ep_offset[i]:ep_offset[i+1]] = baseline + ref_out = torch.where((m_indicies == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) ops.cutlass_blockwise_scaled_grouped_mm( - c_out, - a_stack, - b_stack, - a_scale_stack, - b_scale_stack, + out, + x_fp8[0], + y_fp8[0], + x_fp8[1], + y_fp8[1], problem_sizes, expert_offsets[:-1], ) - for g in range(num_experts): - baseline = baseline_tensors[g] - actual = c_out[expert_offsets[g]:expert_offsets[g+1]] - - torch.testing.assert_close(baseline, actual, atol=1e-2, rtol=5e-4) - \ No newline at end of file + assert calc_diff(ref_out, out) < 1e-3, f"Cutlass grouped gemm is not accurate" \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index f24ae64396f3..097db3510758 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -117,6 +117,7 @@ VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -810,6 +811,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Allow use of Cutlass Blockwise Scaled Grouped GEMM kernels for fused moe ops. + "VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index f380cb77c7e8..ab2824274c36 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -480,3 +480,110 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, out = (c2.view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()).sum(dim=1) return out.to(dtype=out_dtype) + + +def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> bool: + + def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): + return 128 <= M and N % 128 == 0 and K % 128 == 0 + + m = hidden_states.size(0) + _, K, N = w2.size() + if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + logger.debug("CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") + return False + + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug("CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") + return False + + return True + + +def run_cutlass_block_scaled_fused_experts( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + w1_scale = w1_scale.transpose(1, 2) + w2_scale = w2_scale.transpose(1, 2) + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert a.shape[0] == topk_ids.shape[0], "a and topk_ids must have the same batch size" + assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" + assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + out_dtype = a.dtype + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device="cuda") + problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device="cuda") + problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device="cuda") + + topk = topk_ids.size(1) + + a_q, a1_scale = _fp8_quantize(a, A_scale=None, per_act_token=False, block_shape=[128, 128]) + device = a_q.device + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] + + c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device) + c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device) + + ops.cutlass_blockwise_scaled_grouped_mm( + c1, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + problem_sizes1, + expert_offsets[:-1], + ) + + intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) + torch.ops._C.silu_and_mul(intermediate, c1) + + intermediate_q, a2_scale = _fp8_quantize(intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128]) + + ops.cutlass_blockwise_scaled_grouped_mm( + c2, + intermediate_q, + w2_q, + a2_scale, + w2_scale, + problem_sizes2, + expert_offsets[:-1], + ) + + return (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d9b1ba132671..609381413d85 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1159,7 +1159,9 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, + ) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. N = w1.shape[1] @@ -1182,6 +1184,18 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and N > 512 + and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + assert apply_router_weight_on_input is False + return run_cutlass_block_scaled_fused_experts( + a=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b3042bfaed3d..e7222965b292 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -464,12 +464,26 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") + # Check for CutlassBlockScaledGroupedGemm support. + self.allow_cutlass_block_scaled_grouped_gemm = False + if envs.VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM: + if not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + "CutlassBlockScaledGroupedGemm kernels") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(100)): + logger.info_once("Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod.") + self.allow_cutlass_block_scaled_grouped_gemm = True + else: + logger.warning_once("CutlassBlockScaledGroupedGemm not supported on the current platform.") + self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore fused_experts, use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm) + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=self.allow_cutlass_block_scaled_grouped_gemm) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, From df3c61c273d20cb6f3a507e0751011c402d99cf5 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 16 Jun 2025 16:29:18 -0700 Subject: [PATCH 03/16] fix Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 609381413d85..40b2b4f4b0cc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -20,6 +20,9 @@ MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + _valid_cutlass_block_scaled_grouped_gemm, + run_cutlass_block_scaled_fused_experts) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op From 9b568f06aa31dd8d3a8842d9f0c78c2897b11b53 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 16 Jun 2025 17:05:37 -0700 Subject: [PATCH 04/16] syntax errors Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index ab2824274c36..f5f5d3255eee 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -502,15 +502,15 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): def run_cutlass_block_scaled_fused_experts( a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor: - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) + w1_q = w1.transpose(1, 2) + w2_q = w2.transpose(1, 2) w1_scale = w1_scale.transpose(1, 2) w2_scale = w2_scale.transpose(1, 2) From d42d7212ac9b1184ee82c29ba32281a4389dea95 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Mon, 16 Jun 2025 17:17:30 -0700 Subject: [PATCH 05/16] missing import Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index f5f5d3255eee..e1e02c639bef 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache, _fp8_quantize from vllm.scalar_type import scalar_types From 80b1b120201ff0d5c8d778d4868fb3a88ae7d3e6 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 17 Jun 2025 14:13:40 -0700 Subject: [PATCH 06/16] PR updates Signed-off-by: Duncan Moss --- CMakeLists.txt | 9 ++++----- tests/kernels/moe/test_cutlass_grouped_gemm.py | 2 -- vllm/envs.py | 3 ++- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 5 ++++- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 7 +++++-- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1910ca3c7a0c..e2e3e21e2723 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -296,6 +296,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" + "csrc/quantization/cutlass_w8a8/moe/moe_data.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/attention/mla/cutlass_mla_entry.cu") @@ -547,8 +548,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -568,8 +568,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu" - "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -580,7 +579,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " - "if you intend on running FP8 quantized MoE models on Hopper.") + "if you intend on running FP8 quantized MoE models on Blackwell.") else() message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " "in CUDA target architectures") diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index c2e21f0227c1..c1255f4c4642 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -83,7 +83,6 @@ def test_cutlass_grouped_gemm( x = torch.randn((m, k), device=device, dtype=out_dtype) y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype) - m_indicies = torch.empty(m, device=device, dtype=torch.int32) out = torch.empty((m, n), device=device, dtype=out_dtype) ref_out = torch.randn((m, n), device=device, dtype=out_dtype) @@ -106,7 +105,6 @@ def test_cutlass_grouped_gemm( b_scale = y_fp8[1][i].t() baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) ref_out[ep_offset[i]:ep_offset[i+1]] = baseline - ref_out = torch.where((m_indicies == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) ops.cutlass_blockwise_scaled_grouped_mm( out, diff --git a/vllm/envs.py b/vllm/envs.py index 097db3510758..c07ac3717cf6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -813,7 +813,8 @@ def get_vllm_port() -> Optional[int]: # Allow use of Cutlass Blockwise Scaled Grouped GEMM kernels for fused moe ops. "VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM": - lambda: bool(int(os.getenv("VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM", "0"))), + lambda: bool(int(os.getenv( + "VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM", "0"))), # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e1e02c639bef..8c3b7cbd53aa 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,11 +7,14 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache, _fp8_quantize +from vllm.model_executor.layers.fused_moe.utils import ( + _fp8_perm, _resize_cache, _fp8_quantize, _fp8_dequantize) from vllm.scalar_type import scalar_types +logger = init_logger(__name__) def run_cutlass_moe_fp8( output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 40b2b4f4b0cc..132903756f3d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1187,7 +1187,7 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and N > 512 + elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e7222965b292..cb994ab5082a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -475,7 +475,8 @@ def __init__(self, quant_config: Fp8Config): logger.info_once("Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod.") self.allow_cutlass_block_scaled_grouped_gemm = True else: - logger.warning_once("CutlassBlockScaledGroupedGemm not supported on the current platform.") + logger.warning_once( + "CutlassBlockScaledGroupedGemm not supported on the current platform.") self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore @@ -483,7 +484,9 @@ def __init__(self, quant_config: Fp8Config): use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=self.allow_cutlass_block_scaled_grouped_gemm) + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm) + ) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, From 8a68070ee09e33c02add5f8055a0699e7b572a1a Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 17 Jun 2025 15:04:43 -0700 Subject: [PATCH 07/16] import fix Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 8c3b7cbd53aa..40e463988ebd 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import ( - _fp8_perm, _resize_cache, _fp8_quantize, _fp8_dequantize) + _fp8_perm, _resize_cache, _fp8_quantize) from vllm.scalar_type import scalar_types logger = init_logger(__name__) From fb1bff742425d292d942fa518b3e70f1e6c9db2b Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 24 Jun 2025 20:38:43 -0700 Subject: [PATCH 08/16] various pre-commit fixes and minor updates Signed-off-by: Duncan Moss --- CMakeLists.txt | 4 +- csrc/ops.h | 10 +- .../moe/blockwise_scaled_group_mm_sm100.cu | 623 ++++++++---------- csrc/torch_bindings.cpp | 8 +- .../kernels/moe/test_cutlass_grouped_gemm.py | 68 +- vllm/_custom_ops.py | 26 +- vllm/envs.py | 6 - .../layers/fused_moe/cutlass_moe.py | 58 +- .../layers/fused_moe/fused_moe.py | 62 +- .../model_executor/layers/quantization/fp8.py | 27 +- 10 files changed, 434 insertions(+), 458 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e2e3e21e2723..fc7f23aacbf6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,7 +259,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "f115c3f85467d5d9619119d1dbeb9c03c3d73864" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -283,7 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) endif() FetchContent_MakeAvailable(cutlass) diff --git a/csrc/ops.h b/csrc/ops.h index 601739a7456d..6a1db1040f41 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -240,13 +240,9 @@ void cutlass_moe_mm( bool per_act_token, bool per_out_ch); void cutlass_blockwise_scaled_grouped_mm( - torch::Tensor& output, - const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& scales_a, - const torch::Tensor& scales_b, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets); + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets); void cutlass_fp4_group_mm( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu index a244dfd01acd..ef57e503b21a 100644 --- a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu @@ -31,380 +31,337 @@ using namespace cute; -template < - typename ElementAB, - typename ElementC, - typename ElementAccumulator, - typename LayoutSFA, - typename LayoutSFB, - typename ScaleConfig> +template __global__ void get_ggemm_starts( - int32_t* expert_offsets, - ElementAB** a_offsets, - ElementAB** b_offsets, - ElementC** out_offsets, - ElementAccumulator** a_scale_offsets, - ElementAccumulator** b_scale_offsets, - ElementAB* a_base_as_int, - ElementAB* b_base_as_int, - ElementC* out_base_as_int, + int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scale_offsets, + ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, ElementAccumulator* a_scale_base_as_int, - ElementAccumulator* b_scale_base_as_int, - LayoutSFA* layout_sfa_base_as_int, - LayoutSFB* layout_sfb_base_as_int, - int* problem_sizes) { - - int expert_id = threadIdx.x; - - if (expert_id >= gridDim.x * blockDim.x) { - return; - } - - int m = problem_sizes[expert_id * 3]; - int n = problem_sizes[expert_id * 3 + 1]; - int k = problem_sizes[expert_id * 3 + 2]; - - int32_t expert_offset = expert_offsets[expert_id]; - int a_stride = expert_offset * k; - int b_stride = expert_id * k * n; - int a_scale_stride = expert_offset * k / 128; - int b_scale_stride = expert_id * k * n / 128 / 128; - - a_offsets[expert_id] = a_base_as_int + a_stride; - b_offsets[expert_id] = b_base_as_int + b_stride; - out_offsets[expert_id] = out_base_as_int + expert_offset * n; - a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride; - b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride; - - LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; - LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; - - *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); - *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); -} + ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) { + int expert_id = threadIdx.x; -#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ - get_ggemm_starts \ - <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ - static_cast(a_ptrs.data_ptr()), \ - static_cast(b_ptrs.data_ptr()), \ - static_cast(out_ptrs.data_ptr()), \ - static_cast(a_scales_ptrs.data_ptr()), \ - static_cast(b_scales_ptrs.data_ptr()), \ - static_cast(a_tensors.data_ptr()), \ - static_cast(b_tensors.data_ptr()), \ - static_cast(out_tensors.data_ptr()), \ - static_cast(a_scales.data_ptr()), \ - static_cast(b_scales.data_ptr()), \ - reinterpret_cast(layout_sfa.data_ptr()), \ - reinterpret_cast(layout_sfb.data_ptr()), \ - static_cast(problem_sizes.data_ptr())); \ + if (expert_id >= gridDim.x * blockDim.x) { + return; } + int m = problem_sizes[expert_id * 3]; + int n = problem_sizes[expert_id * 3 + 1]; + int k = problem_sizes[expert_id * 3 + 2]; + + int32_t expert_offset = expert_offsets[expert_id]; + int a_stride = expert_offset * k; + int b_stride = expert_id * k * n; + int a_scale_stride = expert_offset * k / 128; + int b_scale_stride = expert_id * k * n / 128 / 128; + + a_offsets[expert_id] = a_base_as_int + a_stride; + b_offsets[expert_id] = b_base_as_int + b_stride; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride; + b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = + ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + *layout_sfb_ptr = + ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \ + ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_ggemm_starts<<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(problem_sizes.data_ptr())); \ + } template void run_get_ggemm_starts( - torch::Tensor const& expert_offsets, - torch::Tensor& a_ptrs, - torch::Tensor& b_ptrs, - torch::Tensor& out_ptrs, - torch::Tensor& a_scales_ptrs, - torch::Tensor& b_scales_ptrs, - torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - torch::Tensor out_tensors, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& layout_sfa, - torch::Tensor const& layout_sfb, - torch::Tensor const& problem_sizes) { - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0); - TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0); - - int num_experts = (int)expert_offsets.size(0); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - - if (false) {} - __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) - __CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA, LayoutSFB, ScaleConfig) - else { - TORCH_CHECK(false, "Unsupported output tensor type"); - } + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& layout_sfa, + torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0); + TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0); + + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA, + LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA, + LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Unsupported output tensor type"); + } } template void run_blockwise_scaled_group_mm( - torch::Tensor& out_ptrs, - const torch::Tensor& a_ptrs, - const torch::Tensor& b_ptrs, - const torch::Tensor& a_scales_ptrs, - const torch::Tensor& b_scales_ptrs, - const torch::Tensor& stride_a, - const torch::Tensor& stride_b, - const torch::Tensor& stride_c, - const torch::Tensor& layout_sfa, - const torch::Tensor& layout_sfb, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets) { - - using ProblemShape = cutlass::gemm::GroupProblemShape>; - - // Types - using ElementA = cutlass::float_e4m3_t; - using ElementB = cutlass::float_e4m3_t; - using ElementC = OutType; - using ElementD = ElementC; - using ElementAccumulator = float; - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = LayoutD; - - // Alignments - static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using ArchTag = cutlass::arch::Sm100; - using OperatorClass = cutlass::arch::OpClassTensorOp; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - typename ScheduleConfig::MmaTileShape, - typename ScheduleConfig::ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementAccumulator, - void, - LayoutC*, - AlignmentC, - ElementD, - LayoutC*, - AlignmentC, - typename ScheduleConfig::EpilogueSchedule - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - cute::tuple, - AlignmentA, - ElementB, - cute::tuple, - AlignmentB, - ElementAccumulator, - typename ScheduleConfig::MmaTileShape, - typename ScheduleConfig::ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - typename ScheduleConfig::KernelSchedule - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloop, - CollectiveEpilogue, - void - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = typename Gemm::GemmKernel::InternalStrideA; - using StrideB = typename Gemm::GemmKernel::InternalStrideB; - using StrideC = typename Gemm::GemmKernel::InternalStrideC; - using StrideD = typename Gemm::GemmKernel::InternalStrideD; - - using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; - int num_experts = (int)expert_offsets.size(0); - - Gemm gemm_op; - - // Mainloop Arguments - typename GemmKernel::MainloopArguments mainloop_args{ + torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs, + const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs, + const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a, + const torch::Tensor& stride_b, const torch::Tensor& stride_c, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + // Types + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = LayoutD; + + // Alignments + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*, + AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, + cute::tuple, + AlignmentA, ElementB, + cute::tuple, + AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = (int)expert_offsets.size(0); + + Gemm gemm_op; + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ static_cast(a_ptrs.data_ptr()), static_cast(stride_a.data_ptr()), static_cast(b_ptrs.data_ptr()), static_cast(stride_b.data_ptr()), static_cast(a_scales_ptrs.data_ptr()), - reinterpret_cast(layout_sfa.data_ptr()), + reinterpret_cast( + layout_sfa.data_ptr()), static_cast(b_scales_ptrs.data_ptr()), - reinterpret_cast(layout_sfb.data_ptr()) - }; + reinterpret_cast( + layout_sfb.data_ptr())}; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a_ptrs.get_device(); - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = a_ptrs.get_device(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); - // Epilogue Arguments - typename GemmKernel::EpilogueArguments epilogue_args{ + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ {}, // epilogue.thread nullptr, static_cast(stride_c.data_ptr()), static_cast(out_ptrs.data_ptr()), - static_cast(stride_c.data_ptr()) - }; + static_cast(stride_c.data_ptr())}; - UnderlyingProblemShape* problem_sizes_as_shapes = static_cast(problem_sizes.data_ptr()); + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); - // Gemm Arguments - typename GemmKernel::Arguments args{ + // Gemm Arguments + typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, {num_experts, problem_sizes_as_shapes, nullptr}, mainloop_args, epilogue_args, hw_info}; - at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()}; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); + at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); - auto can_implement_status = gemm_op.can_implement(args); - TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device()); - auto workspace = torch::empty(workspace_size, workspace_options); + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device()); + auto workspace = torch::empty(workspace_size, workspace_options); - auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); - status = gemm_op.run(stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } template void blockwise_scaled_group_mm_dispatch_shape( - torch::Tensor& output, - const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& scales_a, - const torch::Tensor& scales_b, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets) { - - struct MmaConfig { - using ElementA = cutlass::float_e4m3_t; - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; - using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); - using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); - using LayoutC = cutlass::layout::RowMajor; - using MmaTileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _1, _1>; - }; - - int num_experts = (int)expert_offsets.size(0); - - auto a_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto b_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto out_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto a_scales_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto b_scales_ptrs = torch::empty({num_experts}, torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - - auto layout_sfa = torch::empty({num_experts, 5}, torch::TensorOptions().dtype(torch::kInt32).device(a.device())); - auto layout_sfb = torch::empty({num_experts, 5}, torch::TensorOptions().dtype(torch::kInt32).device(a.device())); - - auto stride_a = torch::full({num_experts}, a.size(1), torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto stride_b = torch::full({num_experts}, a.size(1), torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - auto stride_c = torch::full({num_experts}, output.size(1), torch::TensorOptions().dtype(torch::kInt64).device(a.device())); - - torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - - run_get_ggemm_starts( - expert_offsets, - a_ptrs, - b_ptrs, - out_ptrs, - a_scales_ptrs, - b_scales_ptrs, - a, - b, - output, - scales_a, - scales_b, - layout_sfa, - layout_sfb, - problem_sizes - ); - - run_blockwise_scaled_group_mm( - out_ptrs, - a_ptrs, - b_ptrs, - a_scales_ptrs, - b_scales_ptrs, - stride_a, - stride_b, - stride_c, - layout_sfa, - layout_sfb, - problem_sizes, - expert_offsets - ); + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { + struct MmaConfig { + using ElementA = cutlass::float_e4m3_t; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + 1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + using LayoutC = cutlass::layout::RowMajor; + using MmaTileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + }; + + int num_experts = (int)expert_offsets.size(0); + + auto a_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto b_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto out_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto a_scales_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto b_scales_ptrs = torch::empty( + {num_experts}, + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + + auto layout_sfa = torch::empty( + {num_experts, 5}, + torch::TensorOptions().dtype(torch::kInt32).device(a.device())); + auto layout_sfb = torch::empty( + {num_experts, 5}, + torch::TensorOptions().dtype(torch::kInt32).device(a.device())); + + auto stride_a = torch::full( + {num_experts}, a.size(1), + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto stride_b = torch::full( + {num_experts}, a.size(1), + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + auto stride_c = torch::full( + {num_experts}, output.size(1), + torch::TensorOptions().dtype(torch::kInt64).device(a.device())); + + torch::TensorOptions options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + run_get_ggemm_starts( + expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a, + b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes); + + run_blockwise_scaled_group_mm( + out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a, + stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes, + expert_offsets); } - void cutlass_blockwise_scaled_grouped_mm( - torch::Tensor& output, - const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& scales_a, - const torch::Tensor& scales_b, - const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets) { - - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK( - problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); - TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, "a must be kFloat8_e4m3fn"); - TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, "b must be kFloat8_e4m3fn"); - TORCH_CHECK( - output.scalar_type() == torch::kBFloat16 || output.scalar_type() == torch::kHalf, - "output must be bfloat16 or half"); - TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be float32"); - TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be float32"); - TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32"); - - TORCH_CHECK(output.dim() == 2, "output must be 2D tensor"); - TORCH_CHECK(a.dim() == 2, "a must be 2D tensor"); - TORCH_CHECK(b.dim() == 3, "b must be 3D tensor"); - TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor"); - TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor"); - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); - TORCH_CHECK( - problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32"); - TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); - - #if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100 - if (output.scalar_type() == torch::kBFloat16) { - blockwise_scaled_group_mm_dispatch_shape( - output, - a, - b, - scales_a, - scales_b, - problem_sizes, - expert_offsets); - } else if (output.scalar_type() == torch::kFloat16) { - blockwise_scaled_group_mm_dispatch_shape( - output, - a, - b, - scales_a, - scales_b, - problem_sizes, - expert_offsets); - } else { - TORCH_CHECK(false, "Unsupported output tensor type"); - } - #endif + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) { + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn, + "a must be kFloat8_e4m3fn"); + TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn, + "b must be kFloat8_e4m3fn"); + TORCH_CHECK(output.scalar_type() == torch::kBFloat16 || + output.scalar_type() == torch::kHalf, + "output must be bfloat16 or half"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, + "scales_a must be float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, + "scales_b must be float32"); + TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, + "expert_offsets must be int32"); + + TORCH_CHECK(output.dim() == 2, "output must be 2D tensor"); + TORCH_CHECK(a.dim() == 2, "a must be 2D tensor"); + TORCH_CHECK(b.dim() == 3, "b must be 3D tensor"); + TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor"); + TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor"); + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor"); + +#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100 + if (output.scalar_type() == torch::kBFloat16) { + blockwise_scaled_group_mm_dispatch_shape( + output, a, b, scales_a, scales_b, problem_sizes, expert_offsets); + } else if (output.scalar_type() == torch::kFloat16) { + blockwise_scaled_group_mm_dispatch_shape( + output, a, b, scales_a, scales_b, problem_sizes, expert_offsets); + } else { + TORCH_CHECK(false, "Unsupported output tensor type"); + } +#endif } \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 43ab61d31e88..0d29835b83c5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -393,12 +393,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); - // cutlass blockwise scaledgroup GEMM + // cutlass blockwise scaledgroup GEMM ops.def( - "cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, " + "cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, " + "Tensor scales_a, Tensor scales_b, " "Tensor problem_sizes, Tensor expert_offsets) -> ()", {stride_tag}); - ops.impl("cutlass_blockwise_scaled_grouped_mm", torch::kCUDA, &cutlass_blockwise_scaled_grouped_mm); + ops.impl("cutlass_blockwise_scaled_grouped_mm", torch::kCUDA, + &cutlass_blockwise_scaled_grouped_mm); // cutlass nvfp4 block scaled group GEMM ops.def( diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index c1255f4c4642..8347eaa2b709 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -4,61 +4,63 @@ # See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py import random + import pytest import torch -from typing import Tuple from vllm import _custom_ops as ops -def calc_diff(x, y): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim def cdiv(a, b): return (a + b - 1) // b -def per_token_cast_to_fp8(x : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + +def per_token_cast_to_fp8( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape pad_size = (128 - (n % 128)) % 128 - x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) + fp8_data = (x_view * + (448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + +def per_block_cast_to_fp8( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((cdiv(m, 128)* 128, cdiv(n, 128)* 128), device=x.device, dtype=x.dtype) + x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128), + device=x.device, + dtype=x.dtype) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + def baseline_scaled_mm(a, b, a_scales, b_scales, out_dtype): def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: - assert s% t.shape[i] == 0 - t = ( - t.unsqueeze(i+1) - .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:]) - .flatten(i, i+1) - ) + assert s % t.shape[i] == 0 + t = (t.unsqueeze(i + + 1).expand(*t.shape[:i + 1], s // t.shape[i], + *t.shape[i + 1:]).flatten(i, i + 1)) return t scale_a = group_broadcast(a_scales, a.shape) scale_b = group_broadcast(b_scales, b.shape) - return torch.mm( - (scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32)) - ).to(dtype=out_dtype) + return torch.mm((scale_a * a.to(dtype=torch.float32)), + (scale_b * b.to(dtype=torch.float32))).to(dtype=out_dtype) + @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), @@ -68,7 +70,7 @@ def group_broadcast(t, shape): (32, 1024, 7168, 4096), (32, 1024, 2048, 7168), ]) -@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", [torch.float16]) def test_cutlass_grouped_gemm( num_groups: int, expected_m_per_group: int, @@ -78,7 +80,10 @@ def test_cutlass_grouped_gemm( ): device = "cuda" alignment = 128 - group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + group_ms = [ + int(expected_m_per_group * random.uniform(0.7, 1.3)) + for _ in range(num_groups) + ] m = sum([cdiv(m, alignment) * alignment for m in group_ms]) x = torch.randn((m, k), device=device, dtype=out_dtype) @@ -89,22 +94,25 @@ def test_cutlass_grouped_gemm( ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m] pb_size = [] for i in range(num_groups): - pb_size.append([ep_offset[i+1] - ep_offset[i], n, k]) + pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k]) problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32) expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, cdiv(n, 128), k // 128), + device=device, + dtype=torch.float)) for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) for i in range(num_groups): - a = x_fp8[0][ep_offset[i]:ep_offset[i+1]] - a_scale = x_fp8[1][ep_offset[i]:ep_offset[i+1]] + a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] + a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]] b = y_fp8[0][i].t() b_scale = y_fp8[1][i].t() baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) - ref_out[ep_offset[i]:ep_offset[i+1]] = baseline + ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline ops.cutlass_blockwise_scaled_grouped_mm( out, @@ -116,4 +124,4 @@ def test_cutlass_grouped_gemm( expert_offsets[:-1], ) - assert calc_diff(ref_out, out) < 1e-3, f"Cutlass grouped gemm is not accurate" \ No newline at end of file + torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 48e13d97a589..9dbf879f9cbc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -646,23 +646,19 @@ def _ggml_moe_a8_vec_fake( def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + def cutlass_blockwise_scaled_grouped_mm( - output: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - scales_a: torch.Tensor, - scales_b: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, + output: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, ): - torch.ops._C.cutlass_blockwise_scaled_grouped_mm( - output, - a, - b, - scales_a, - scales_b, - problem_sizes, - expert_offsets) + torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a, + scales_b, problem_sizes, + expert_offsets) def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index c07ac3717cf6..f24ae64396f3 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -117,7 +117,6 @@ VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False - VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -811,11 +810,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), - # Allow use of Cutlass Blockwise Scaled Grouped GEMM kernels for fused moe ops. - "VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM": - lambda: bool(int(os.getenv( - "VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM", "0"))), - # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 40e463988ebd..6db6a3e9b2fc 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -10,12 +10,14 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import ( - _fp8_perm, _resize_cache, _fp8_quantize) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, + _fp8_quantize, + _resize_cache) from vllm.scalar_type import scalar_types logger = init_logger(__name__) + def run_cutlass_moe_fp8( output: torch.Tensor, hidden_states: torch.Tensor, @@ -485,21 +487,25 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> bool: +def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): - return 128 <= M and N % 128 == 0 and K % 128 == 0 + return M >= 128 and N % 128 == 0 and K % 128 == 0 m = hidden_states.size(0) _, K, N = w2.size() if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): - logger.debug("CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") return False - + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): - logger.debug("CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") return False - + return True @@ -518,14 +524,17 @@ def run_cutlass_block_scaled_fused_experts( w2_scale = w2_scale.transpose(1, 2) assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert a.shape[0] == topk_ids.shape[0], "a and topk_ids must have the same batch size" + assert a.shape[0] == topk_ids.shape[ + 0], "a and topk_ids must have the same batch size" assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2_scale expert number mismatch" assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" out_dtype = a.dtype @@ -534,13 +543,22 @@ def run_cutlass_block_scaled_fused_experts( k = w1_q.size(1) n = w2_q.size(1) - expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device="cuda") - problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device="cuda") - problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device="cuda") + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device="cuda") + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") topk = topk_ids.size(1) - a_q, a1_scale = _fp8_quantize(a, A_scale=None, per_act_token=False, block_shape=[128, 128]) + a_q, a1_scale = _fp8_quantize(a, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) device = a_q.device a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) @@ -576,8 +594,11 @@ def run_cutlass_block_scaled_fused_experts( intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) torch.ops._C.silu_and_mul(intermediate, c1) - - intermediate_q, a2_scale = _fp8_quantize(intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128]) + + intermediate_q, a2_scale = _fp8_quantize(intermediate, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) ops.cutlass_blockwise_scaled_grouped_mm( c2, @@ -589,4 +610,5 @@ def run_cutlass_block_scaled_fused_experts( expert_offsets[:-1], ) - return (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 132903756f3d..cbb66fbd4557 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -12,6 +12,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + _valid_cutlass_block_scaled_grouped_gemm, + run_cutlass_block_scaled_fused_experts) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( @@ -20,9 +23,6 @@ MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - _valid_cutlass_block_scaled_grouped_gemm, - run_cutlass_block_scaled_fused_experts) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -1140,31 +1140,32 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: return torch_vllm_outplace_fused_experts -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False, - ) -> torch.Tensor: +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, +) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. N = w1.shape[1] @@ -1188,7 +1189,7 @@ def fused_experts(hidden_states: torch.Tensor, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( a=hidden_states, @@ -1197,8 +1198,7 @@ def fused_experts(hidden_states: torch.Tensor, w1_scale=w1_scale, w2_scale=w2_scale, topk_weights=topk_weights, - topk_ids=topk_ids - ) + topk_ids=topk_ids) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cb994ab5082a..add397869a30 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -466,17 +466,19 @@ def __init__(self, quant_config: Fp8Config): # Check for CutlassBlockScaledGroupedGemm support. self.allow_cutlass_block_scaled_grouped_gemm = False - if envs.VLLM_USE_CUTLASS_BLOCKSCALED_GROUPED_GEMM: - if not self.block_quant: - logger.warning_once("Model is not block quantized. Not using " - "CutlassBlockScaledGroupedGemm kernels") - elif (current_platform.is_cuda() - and current_platform.has_device_capability(100)): - logger.info_once("Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod.") - self.allow_cutlass_block_scaled_grouped_gemm = True - else: - logger.warning_once( - "CutlassBlockScaledGroupedGemm not supported on the current platform.") + if not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + "CutlassBlockScaledGroupedGemm kernels") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(100)): + logger.info_once( + "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." + ) + self.allow_cutlass_block_scaled_grouped_gemm = True + else: + logger.warning_once( + "CutlassBlockScaledGroupedGemm not supported on the current platform." + ) self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore @@ -485,8 +487,7 @@ def __init__(self, quant_config: Fp8Config): block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm, allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm) - ) + self.allow_cutlass_block_scaled_grouped_gemm)) def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, From 795a0d7aaa68758aba6f9d909b502d2c88da0c3f Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 25 Jun 2025 04:02:42 +0000 Subject: [PATCH 09/16] fixed linter errors Signed-off-by: Duncan Moss --- vllm/model_executor/layers/quantization/fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index add397869a30..c90bdd3feded 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -477,8 +477,8 @@ def __init__(self, quant_config: Fp8Config): self.allow_cutlass_block_scaled_grouped_gemm = True else: logger.warning_once( - "CutlassBlockScaledGroupedGemm not supported on the current platform." - ) + "CutlassBlockScaledGroupedGemm not supported on the current " + "platform.") self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore From 4adbbaa87f0068fc171974ceab5d3ba55d348aa7 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 25 Jun 2025 04:09:59 +0000 Subject: [PATCH 10/16] test cleanup Signed-off-by: Duncan Moss --- .../kernels/moe/test_cutlass_grouped_gemm.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 8347eaa2b709..0165aa7f892f 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -8,6 +8,7 @@ import pytest import torch +from tests.kernels.utils import baseline_scaled_mm from vllm import _custom_ops as ops @@ -44,24 +45,6 @@ def per_block_cast_to_fp8( x_amax / 448.0).view(x_view.size(0), x_view.size(2)) -def baseline_scaled_mm(a, b, a_scales, b_scales, out_dtype): - - def group_broadcast(t, shape): - for i, s in enumerate(shape): - if t.shape[i] != s and t.shape[i] != 1: - assert s % t.shape[i] == 0 - t = (t.unsqueeze(i + - 1).expand(*t.shape[:i + 1], s // t.shape[i], - *t.shape[i + 1:]).flatten(i, i + 1)) - return t - - scale_a = group_broadcast(a_scales, a.shape) - scale_b = group_broadcast(b_scales, b.shape) - - return torch.mm((scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32))).to(dtype=out_dtype) - - @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), (4, 8192, 2048, 7168), From c9e47eddbed7084cb1631bfe80fb8cb3dff67313 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 1 Jul 2025 09:29:43 -0700 Subject: [PATCH 11/16] updated to use cutlass v4 tag Signed-off-by: Duncan Moss --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fc7f23aacbf6..7186300e438b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,7 +259,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "f115c3f85467d5d9619119d1dbeb9c03c3d73864" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -283,7 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW FALSE + GIT_SHALLOW TRUE ) endif() FetchContent_MakeAvailable(cutlass) From 074bf3a7f0d9813726aa4199d3519c968c39d557 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 1 Jul 2025 21:08:58 +0000 Subject: [PATCH 12/16] update cd alignemnt for cutlass v4.0.0: Signed-off-by: Duncan Moss --- csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh | 5 +++-- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index 8f4df836bcc8..19c33e344a39 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -43,7 +43,7 @@ struct cutlass_3x_gemm { using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; - using ElementC = void; + using ElementC = ElementD_; using StrideC = StrideD; using EVTCompute = typename Epilogue::EVTCompute; @@ -51,7 +51,8 @@ struct cutlass_3x_gemm { // These are the minimum alignments needed for the kernels to compile static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentCD = 4; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index c22523da4e43..351b8299968c 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -69,7 +69,7 @@ struct cutlass_sparse_3x_gemm { using Epilogue = Epilogue_; - using ElementC = void; + using ElementC = ElementD_; using LayoutC = cutlass::layout::RowMajor; using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose::type; @@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm { // These are the minimum alignments needed for the kernels to compile static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentCD = 4; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< From 7436ec0aac5d8302aca0bbaf3739a5316bb0c0a2 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 2 Jul 2025 18:44:13 +0000 Subject: [PATCH 13/16] fixed error in elementC type Signed-off-by: Duncan Moss --- csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh | 2 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index 19c33e344a39..6054d7974ebb 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -43,7 +43,7 @@ struct cutlass_3x_gemm { using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; - using ElementC = ElementD_; + using ElementC = void; using StrideC = StrideD; using EVTCompute = typename Epilogue::EVTCompute; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 351b8299968c..637bba1384a4 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -69,7 +69,7 @@ struct cutlass_sparse_3x_gemm { using Epilogue = Epilogue_; - using ElementC = ElementD_; + using ElementC = void; using LayoutC = cutlass::layout::RowMajor; using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose::type; From ddf21eae12746088b8180da11fa6224f9878d78a Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 2 Jul 2025 21:49:17 +0000 Subject: [PATCH 14/16] remove tensor_predicate Signed-off-by: Duncan Moss --- ...m90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index d922a3349e1e..ce7f47cf7233 100644 --- a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -45,7 +45,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass_extensions/gemm/dispatch_policy.hpp" From 61c270d4be83ad3f9e804affb2b7d7738c988242 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 3 Jul 2025 15:42:42 +0000 Subject: [PATCH 15/16] missing tensor_predicate Signed-off-by: Duncan Moss --- csrc/quantization/machete/machete_mainloop.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index eca5d328b00c..2f52a6b7a024 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -38,7 +38,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/atom/copy_traits_sm90_tma.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" From 23d5c72463be09da71b3b691059ad755d274fa2d Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 3 Jul 2025 22:38:22 +0000 Subject: [PATCH 16/16] skip test if not sm100 Signed-off-by: Duncan Moss --- tests/kernels/moe/test_cutlass_grouped_gemm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 0165aa7f892f..bf228dcece3c 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -10,6 +10,7 @@ from tests.kernels.utils import baseline_scaled_mm from vllm import _custom_ops as ops +from vllm.platforms import current_platform def cdiv(a, b): @@ -54,6 +55,10 @@ def per_block_cast_to_fp8( (32, 1024, 2048, 7168), ]) @pytest.mark.parametrize("out_dtype", [torch.float16]) +@pytest.mark.skipif( + (lambda x: x is None or x.to_int() != 100)( + current_platform.get_device_capability()), + reason="Block Scaled Grouped GEMM is only supported on SM100.") def test_cutlass_grouped_gemm( num_groups: int, expected_m_per_group: int,