From aa2e772b2f97430cde544b990a5da88c19b36a1e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Apr 2025 06:13:42 +0000 Subject: [PATCH 01/22] Enable BF16 weights in CUTLASS MoE Signed-off-by: ElizaWszola --- .../kernels/benchmark_grouped_gemm_cutlass.py | 92 +++++----- csrc/ops.h | 6 + .../cutlass_w8a8/moe/get_group_starts.cuh | 51 ++++-- .../cutlass_w8a8/moe/grouped_mm_c3x.cu | 133 ++++++++++++++ .../cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh | 138 ++++++++++++++ .../quantization/cutlass_w8a8/moe/moe_data.cu | 4 +- .../cutlass_w8a8/scaled_mm_entry.cu | 23 +++ csrc/torch_bindings.cpp | 9 + tests/kernels/test_cutlass.py | 86 +++++++++ tests/kernels/test_cutlass_moe.py | 170 ++++++++++++++++-- vllm/_custom_ops.py | 20 +++ .../layers/fused_moe/__init__.py | 5 +- .../layers/fused_moe/fused_moe.py | 57 ++++++ vllm/model_executor/layers/fused_moe/layer.py | 51 +++++- 14 files changed, 767 insertions(+), 78 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index bcdbf6c7551a..d516d1e6246c 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -18,8 +18,8 @@ DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -PER_ACT_TOKEN_OPTS = [False] -PER_OUT_CH_OPTS = [False] +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] def to_fp8(tensor: torch.Tensor): @@ -48,7 +48,8 @@ def bench_run(results: list[benchmark.Measurement], model: str, w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 - _, a_scale = ops.scaled_fp8_quant(a) + _, a_scale = ops.scaled_fp8_quant(a, + use_per_token_if_dynamic=per_act_token) w1_q = torch.empty((num_experts, 2 * n, k), device="cuda", @@ -56,10 +57,12 @@ def bench_run(results: list[benchmark.Measurement], model: str, w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((num_experts, 1, 1), + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + w1_scale = torch.empty((num_experts, n_b_scales, 1), device="cuda", dtype=torch.float32) - w2_scale = torch.empty((num_experts, 1, 1), + w2_scale = torch.empty((num_experts, k_b_scales, 1), device="cuda", dtype=torch.float32) @@ -81,8 +84,10 @@ def bench_run(results: list[benchmark.Measurement], model: str, dtype=torch.int64) for expert in range(num_experts): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) w1_q_notransp = w1_q.clone() w2_q_notransp = w2_q.clone() w1_q = w1_q.transpose(1, 2) @@ -105,7 +110,8 @@ def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, - a1_scale=a_scale) + a1_scale=a_scale, + a2_scale=a_scale) def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -165,7 +171,8 @@ def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, - a1_scale=a_scale) + a1_scale=a_scale, + a2_scale=a_scale) def replay_graph(graph, num_repeats): for _ in range(num_repeats): @@ -180,12 +187,16 @@ def replay_graph(graph, num_repeats): ab_strides2, c_strides2) torch.cuda.synchronize() - triton_stream = torch.cuda.Stream() - triton_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(triton_graph, stream=triton_stream): - run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, - topk_ids, w1_scale, w2_scale, a_scale) - torch.cuda.synchronize() + if not per_act_token and not per_out_ch: + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, + topk_weights, topk_ids, w1_scale, w2_scale, + a_scale) + torch.cuda.synchronize() + else: + triton_graph = [] min_run_time = 5 num_warmup = 5 @@ -223,31 +234,32 @@ def replay_graph(graph, num_repeats): "replay_graph": replay_graph, } - # Warmup - run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, - w1_scale, w2_scale, a_scale, num_warmup) - - results.append( - benchmark.Timer( - stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="triton_moe", - ).blocked_autorange(min_run_time=min_run_time)) - - # Warmup - replay_graph(triton_graph, num_warmup) - - results.append( - benchmark.Timer( - stmt="replay_graph(triton_graph, num_runs)", - globals=globals, - label=label, - sub_label=sub_label, - description="triton_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + if not per_act_token and not per_out_ch: + # Warmup + run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, + w1_scale, w2_scale, a_scale, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) # Warmup run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, diff --git a/csrc/ops.h b/csrc/ops.h index 1ea9f465cf21..dd0415a09fc8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -183,6 +183,12 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_moe_mm_fp16( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 6c6e89790847..46c6dcf844d0 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -12,22 +12,23 @@ template __global__ void get_group_gemm_starts( int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, - ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, - ElementAB* b_base_as_int, ElementC* out_base_as_int, - ElementAccumulator* a_scales_base_as_int, - ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, - bool per_act_token, bool per_out_ch) { + ElementAccumulator** b_scales_offsets, ElementAB* a_base, ElementAB* b_base, + ElementC* out_base, ElementAccumulator* a_scales_base, + ElementAccumulator* b_scales_base, int64_t n, int64_t k, bool per_act_token, + bool per_out_ch) { int expert_id = threadIdx.x; int64_t expert_offset = expert_offsets[expert_id]; - a_offsets[expert_id] = a_base_as_int + expert_offset * k; - b_offsets[expert_id] = b_base_as_int + expert_id * k * n; - out_offsets[expert_id] = out_base_as_int + expert_offset * n; - a_scales_offsets[expert_id] = - a_scales_base_as_int + (per_act_token ? expert_offset : 0); - b_scales_offsets[expert_id] = - b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); + a_offsets[expert_id] = a_base + expert_offset * k; + b_offsets[expert_id] = b_base + expert_id * k * n; + out_offsets[expert_id] = out_base + expert_offset * n; + if (a_scales_offsets != nullptr && a_scales_base != nullptr) + a_scales_offsets[expert_id] = + a_scales_base + (per_act_token ? expert_offset : 0); + if (b_scales_offsets != nullptr && b_scales_base != nullptr) + b_scales_offsets[expert_id] = + b_scales_base + (per_out_ch ? n * expert_id : expert_id); } #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ @@ -77,4 +78,30 @@ void run_get_group_gemm_starts( } } +void run_get_group_gemm_starts_fp16(torch::Tensor const& expert_offsets, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor& out_tensors) { + TORCH_CHECK(a_tensors.dtype() == torch::kBFloat16); + TORCH_CHECK(b_tensors.dtype() == torch::kBFloat16); + + int num_experts = (int)expert_offsets.size(0); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + get_group_gemm_starts + <<<1, num_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(out_ptrs.data_ptr()), nullptr, + nullptr, static_cast(a_tensors.data_ptr()), + static_cast(b_tensors.data_ptr()), + static_cast(out_tensors.data_ptr()), nullptr, + nullptr, out_tensors.size(1), a_tensors.size(1), false, false); +} + } // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu index 2b8bc3fb0b26..8dc555f7fc69 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -5,6 +5,7 @@ #include "cutlass/cutlass.h" #include "grouped_mm_c3x.cuh" +#include "grouped_mm_fp16_c3x.cuh" using namespace cute; @@ -78,6 +79,74 @@ struct sm90_fp8_config_N8192 { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm90_fp16_config_default { + // M in (16, inf) + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm_fp16; +}; + +template typename Epilogue> +struct sm90_fp16_config_M16 { + // M in [1, 16] + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm_fp16; +}; + +template typename Epilogue> +struct sm90_fp16_config_K8192 { + // K in [8192, inf) + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm_fp16; +}; + +template typename Epilogue> +struct sm90_fp16_config_N8192 { + // N in [8192, inf) + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm_fp16; +}; + template void run_cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, @@ -129,6 +198,49 @@ void run_cutlass_moe_mm_sm90( } } +template +void run_cutlass_moe_mm_fp16_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + using Cutlass3xGemmN8192 = typename sm90_fp16_config_N8192< + InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; + using Cutlass3xGemmK8192 = typename sm90_fp16_config_K8192< + InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; + using Cutlass3xGemmM16 = + typename sm90_fp16_config_M16::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_fp16_config_default< + InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + uint32_t const k = a_tensors.size(1); + + if (n >= 8192) { + cutlass_group_gemm_fp16_caller( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } else if (k >= 8192) { + cutlass_group_gemm_fp16_caller( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } else if (m <= 16) { + cutlass_group_gemm_fp16_caller( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } else { + cutlass_group_gemm_fp16_caller( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } +} + void dispatch_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, @@ -146,6 +258,18 @@ void dispatch_moe_mm_sm90( } } +void dispatch_moe_mm_fp16_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + TORCH_CHECK(out_tensors.dtype() == torch::kBFloat16, + "output must be bfloat16"); + run_cutlass_moe_mm_fp16_sm90( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); +} + } // namespace void cutlass_moe_mm_sm90( @@ -158,3 +282,12 @@ void cutlass_moe_mm_sm90( expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } + +void cutlass_moe_mm_fp16_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + dispatch_moe_mm_fp16_sm90(out_tensors, a_tensors, b_tensors, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); +} diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh new file mode 100644 index 000000000000..8973fa877df1 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh @@ -0,0 +1,138 @@ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/bfloat16.h" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/common.hpp" +#include "get_group_starts.cuh" + +using namespace cute; + +namespace { + +using ProblemShape = + cutlass::gemm::GroupProblemShape>; + +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_group_gemm_fp16 { + using ElementAB = ElementAB_; + using ElementC = void; + using ElementD = ElementC_; + using ElementAccumulator = float; + + using Epilogue = Epilogue_; + + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, + LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, + Stages, KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_only>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_group_gemm_fp16_caller( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int num_experts = static_cast(expert_offsets.size(0)); + int k_size = a_tensors.size(1); + int n_size = out_tensors.size(1); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + + run_get_group_gemm_starts_fp16(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_tensors, b_tensors, out_tensors); + + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast( + problem_sizes.data_ptr()); + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; + + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr())}; + + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args(), nullptr, + static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args}; + + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +} // namespace diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 2fb0417ce6c4..a2bdafa62e84 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -7,7 +7,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; -__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, +__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, @@ -45,7 +45,7 @@ __global__ void compute_expert_offsets( } } -__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, +__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, int32_t* input_permutation, int32_t* output_permutation, int32_t* atomic_buffer, const int topk_length, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cb..d950dca89627 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -37,6 +37,12 @@ void cutlass_moe_mm_sm90( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_moe_mm_fp16_sm90( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -214,6 +220,23 @@ void cutlass_moe_mm( ". Required capability: 90"); } +void cutlass_moe_mm_fp16( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 + cutlass_moe_mm_fp16_sm90(out_tensors, a_tensors, b_tensors, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, + ". Required capability: 90"); +} + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 60ad6430336a..1c84becc9110 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -379,6 +379,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); + // CUTLASS w16a16 grouped GEMM + ops.def( + "cutlass_moe_mm_fp16(Tensor! out_tensors, Tensor a_tensors, " + " Tensor b_tensors, Tensor expert_offsets, " + " Tensor problem_sizes, Tensor a_strides, " + " Tensor b_strides, Tensor c_strides) -> ()", + {stride_tag}); + ops.impl("cutlass_moe_mm_fp16", torch::kCUDA, &cutlass_moe_mm_fp16); + // A function that computes data required to run fused MoE with w8a8 grouped // GEMM. It takes topk_ids as an input, and computes expert_offsets // (token start indices of each expert). In addition to this, it computes diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index f11ce6f45a98..53abc78bb891 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -641,3 +641,89 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, print(c) print("*") torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4) + + +@pytest.mark.parametrize("num_experts", [8, 64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_fp16_group_gemm(num_experts: int, dtype: torch.dtype): + + # Device and dtype setup + device = "cuda" + + # Create separate A, B, C tensors for each group + a_tensors = [] + b_tensors = [] + baseline_tensors = [] + + expert_offsets = torch.zeros((num_experts + 1), + device=device, + dtype=torch.int32) + + problem_sizes = torch.zeros((num_experts, 3), + device=device, + dtype=torch.int32) + + alignment = 16 + # For variation, each group has dimensions + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + for g in range(num_experts): + m_g = alignment * random.randint(1, 64) + + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][0] = m_g + problem_sizes[g][1] = n_g + problem_sizes[g][2] = k_g + + # Create group-specific A and B (FP16) and output (FP16/FP32) + a_g = torch.randn((m_g, k_g), device=device, dtype=dtype) + b_g = torch.randn((n_g, k_g), device=device, dtype=dtype).t() + a_tensors.append(a_g) + b_tensors.append(b_g) + + # Compute baseline result for this group + baseline_g = a_g.matmul(b_g) + baseline_tensors.append(baseline_g) + + a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), + device=device, + dtype=dtype) + b_tensors_stacked = torch.empty((num_experts, n_g, k_g), + device=device, + dtype=dtype) + + for g in range(num_experts): + a_tensors_stacked[expert_offsets[g]:expert_offsets[g + + 1]] = a_tensors[g] + b_tensors_stacked[g] = b_tensors[g].t() + b_tensors_stacked = b_tensors_stacked.transpose(1, 2) + + out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), + device=device, + dtype=dtype) + + ab_strides = torch.full((num_experts, ), + a_tensors_stacked.stride(0), + device=device, + dtype=torch.int64) + c_strides = torch.full((num_experts, ), + out_tensors_stacked.stride(0), + device=device, + dtype=torch.int64) + + ops.cutlass_moe_mm_fp16(out_tensors_stacked, a_tensors_stacked, + b_tensors_stacked, expert_offsets[:-1], + problem_sizes, ab_strides, ab_strides, c_strides) + + # Validate each group's result against the baseline + for g in range(num_experts): + baseline = baseline_tensors[g] + c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + print(baseline) + print(c) + print("*") + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-3) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 1652c72d86fe..798b706460e6 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -5,6 +5,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, + cutlass_moe_fp16, fused_experts, fused_topk) from vllm.platforms import current_platform @@ -35,13 +36,31 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, a1_scale=a_scale) -@pytest.mark.parametrize("m", [2, 64, 224]) +def run_fp16(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp16(a, w1, w2, topk_weights, topk_ids, ab_strides1, + c_strides1, ab_strides2, c_strides2) + + +@pytest.mark.parametrize("m", [2, 64, 224, 512, 163840]) @pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("k", [1024, 1536, 5120]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_act_token", [False]) @pytest.mark.parametrize("per_out_ch", [True, False]) +# @pytest.mark.parametrize("m", [64, 224, 512, 163840]) +# @pytest.mark.parametrize("n", [1024]) +# @pytest.mark.parametrize("k", [5120]) +# @pytest.mark.parametrize("e", [16]) +# @pytest.mark.parametrize("topk", [1]) +# @pytest.mark.parametrize("per_act_token", [False]) +# @pytest.mark.parametrize("per_out_ch", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), @@ -69,11 +88,12 @@ def test_cutlass_moe_no_graph( # Get the right scale for tests. _, a_scale1 = ops.scaled_fp8_quant( a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) + a_scale2 = a_scale1.clone() + # a_q, _ = ops.scaled_fp8_quant(a, + # a_scale1, + # use_per_token_if_dynamic=per_act_token) - a_d = a_q.float().mul(a_scale1).to(dtype) + # a_d = a_q.float().mul(a_scale1).to(dtype) n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 @@ -116,7 +136,16 @@ def test_cutlass_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + triton_output = fused_experts(a, + w1_q.transpose(1, 2), + w2_q.transpose(1, 2), + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale1, + a2_scale=a_scale2) cutlass_output = cutlass_moe_fp8(a, w1_q, @@ -129,15 +158,16 @@ def test_cutlass_moe_no_graph( c_strides1, ab_strides2, c_strides2, - a1_scale=a_scale1) + a1_scale=a_scale1, + a2_scale=a_scale2) - print(triton_output) - print(cutlass_output) + print(triton_output.view(cutlass_output.shape).t()[0]) + print(cutlass_output.t()[0]) print("*") - torch.testing.assert_close(triton_output, + torch.testing.assert_close(triton_output.view(cutlass_output.shape), cutlass_output, - atol=5e-2, + atol=2e-2, rtol=1e-2) @@ -242,3 +272,117 @@ def test_cutlass_moe_cuda_graph( cutlass_output, atol=9e-2, rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) +@pytest.mark.parametrize("n", [1024, 2048, 3072]) +@pytest.mark.parametrize("k", [1024, 1536, 2048]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_fp16_moe_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.ones((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn( + (e, 2 * n, k), device="cuda", dtype=dtype).transpose(1, 2) / 10 + w2 = torch.randn( + (e, k, n), device="cuda", dtype=dtype).transpose(1, 2) / 10 + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a, w1.transpose(1, + 2), w2.transpose(1, 2), + topk_weights, topk_ids) + cutlass_output = cutlass_moe_fp16(a, w1, w2, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) + + print(triton_output) + print(cutlass_output) + print("*") + + torch.testing.assert_close(triton_output.view(cutlass_output.shape), + cutlass_output, + atol=2e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) +@pytest.mark.parametrize("n", [1024, 2048, 3072]) +@pytest.mark.parametrize("k", [1024, 1536, 2048]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) #, torch.half]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_fp16_moe_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.ones((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn( + (e, 2 * n, k), device="cuda", dtype=dtype).transpose(1, 2) / 10 + w2 = torch.randn( + (e, k, n), device="cuda", dtype=dtype).transpose(1, 2) / 10 + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a, w1.transpose(1, + 2), w2.transpose(1, 2), + topk_weights, topk_ids) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run_fp16(a, w1, w2, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + print(triton_output) + print(cutlass_output) + print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=2e-2, + rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2ffcef414cb2..9a307989e7a0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -730,6 +730,26 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, a_strides, b_strides, c_strides) +def cutlass_moe_mm_fp16(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, a_strides: torch.Tensor, + b_strides: torch.Tensor, c_strides: torch.Tensor): + """ + A single grouped matrix multiplication used in CUTLASS-based fused MoE. + The function executes fp8-quantized OUT = AB matrix multiplication. + + - expert_offsets: Indices that mark at which token index each expert begins + its computation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + - a/b/c_strides: The data strides passed to grouped matrix multiplication. + """ + torch.ops._C.cutlass_moe_mm_fp16(out_tensors, a_tensors, b_tensors, + expert_offsets, problem_sizes, a_strides, + b_strides, c_strides) + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e096d14fc6f9..ca7226c96d9b 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,8 +36,8 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( - cutlass_moe_fp8, fused_experts, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + cutlass_moe_fp8, cutlass_moe_fp16, fused_experts, fused_moe, + fused_topk, get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -46,4 +46,5 @@ def get_config() -> Optional[Dict[str, Any]]: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "cutlass_moe_fp16", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0929530ebec4..df060e50fe91 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1760,3 +1760,60 @@ def cutlass_moe_fp8( return (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) + + +def cutlass_moe_fp16( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + m = a.shape[0] + k = w1.shape[1] + n = w2.shape[1] + + topk = topk_ids.shape[1] + device = a.device + + out_dtype = a.dtype + + expert_offsets = torch.empty((num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, num_experts, n, + k) + + rep_a = a[a_map] + + c1 = torch.zeros((m * topk, n * 2), device=device, dtype=out_dtype) + c2 = torch.zeros((m * topk, k), device=device, dtype=out_dtype) + + ops.cutlass_moe_mm_fp16(c1, rep_a, w1, expert_offsets[:-1], problem_sizes1, + ab_strides1, ab_strides1, c_strides1) + + intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) + torch.ops._C.silu_and_mul(intermediate, c1) + + ops.cutlass_moe_mm_fp16(c2, intermediate, w2, expert_offsets[:-1], + problem_sizes2, ab_strides2, ab_strides2, + c_strides2) + + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ef33852e3162..a0fcca6840af 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -97,6 +97,24 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + device = layer.w13_weight.device + self.ab_strides1 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((num_experts, ), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((num_experts, ), + intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory @@ -202,15 +220,30 @@ def forward_cuda( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map) + # TODO will need a similar class split as fp8 because we don't have + # expert map and other activations here either + from vllm.model_executor.layers.fused_moe import cutlass_moe_fp16 + return cutlass_moe_fp16( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + ) + + # return fused_experts(hidden_states=x, + # w1=layer.w13_weight, + # w2=layer.w2_weight, + # topk_weights=topk_weights, + # topk_ids=topk_ids, + # inplace=True, + # activation=activation, + # global_num_experts=global_num_experts, + # expert_map=expert_map) def forward_cpu( self, From 482e9adbbb43c2c2b585087280d20ba0ce5f3b19 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Apr 2025 12:18:18 +0000 Subject: [PATCH 02/22] cleanup tests Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 56 +++++++++++-------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 798b706460e6..fd9d860d45b7 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -47,20 +47,13 @@ def run_fp16(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, c_strides1, ab_strides2, c_strides2) -@pytest.mark.parametrize("m", [2, 64, 224, 512, 163840]) +@pytest.mark.parametrize("m", [2, 64, 224]) @pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536, 5120]) +@pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [False]) +@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -# @pytest.mark.parametrize("m", [64, 224, 512, 163840]) -# @pytest.mark.parametrize("n", [1024]) -# @pytest.mark.parametrize("k", [5120]) -# @pytest.mark.parametrize("e", [16]) -# @pytest.mark.parametrize("topk", [1]) -# @pytest.mark.parametrize("per_act_token", [False]) -# @pytest.mark.parametrize("per_out_ch", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), @@ -88,12 +81,11 @@ def test_cutlass_moe_no_graph( # Get the right scale for tests. _, a_scale1 = ops.scaled_fp8_quant( a, use_per_token_if_dynamic=per_act_token) - a_scale2 = a_scale1.clone() - # a_q, _ = ops.scaled_fp8_quant(a, - # a_scale1, - # use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) - # a_d = a_q.float().mul(a_scale1).to(dtype) + a_d = a_q.float().mul(a_scale1).to(dtype) n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 @@ -136,16 +128,7 @@ def test_cutlass_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - triton_output = fused_experts(a, - w1_q.transpose(1, 2), - w2_q.transpose(1, 2), - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale1, - a2_scale=a_scale2) + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) cutlass_output = cutlass_moe_fp8(a, w1_q, @@ -158,16 +141,15 @@ def test_cutlass_moe_no_graph( c_strides1, ab_strides2, c_strides2, - a1_scale=a_scale1, - a2_scale=a_scale2) + a1_scale=a_scale1) - print(triton_output.view(cutlass_output.shape).t()[0]) - print(cutlass_output.t()[0]) + print(triton_output) + print(cutlass_output) print("*") torch.testing.assert_close(triton_output.view(cutlass_output.shape), cutlass_output, - atol=2e-2, + atol=5e-2, rtol=1e-2) @@ -274,9 +256,9 @@ def test_cutlass_moe_cuda_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) -@pytest.mark.parametrize("n", [1024, 2048, 3072]) -@pytest.mark.parametrize("k", [1024, 1536, 2048]) +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -328,12 +310,12 @@ def test_cutlass_fp16_moe_no_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 16, 32, 64, 224, 512, 163840]) -@pytest.mark.parametrize("n", [1024, 2048, 3072]) -@pytest.mark.parametrize("k", [1024, 1536, 2048]) +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) #, torch.half]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), From 4e93c7f26a5946e47fdb1cf13794b68e3c57afa4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Apr 2025 13:20:10 +0000 Subject: [PATCH 03/22] Pick the right model based on arch, activation and expert_map Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 199 ++++++++++++++---- 2 files changed, 154 insertions(+), 47 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index fd9d860d45b7..3c9e2a56f418 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -147,7 +147,7 @@ def test_cutlass_moe_no_graph( print(cutlass_output) print("*") - torch.testing.assert_close(triton_output.view(cutlass_output.shape), + torch.testing.assert_close(triton_output, cutlass_output, atol=5e-2, rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a0fcca6840af..37717901c5c2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -24,9 +24,10 @@ from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): - from .fused_moe import fused_experts + from .fused_moe import cutlass_moe_fp16, fused_experts else: fused_experts = None # type: ignore + cutlass_moe_fp16 = None # type: ignore if current_platform.is_tpu(): # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas @@ -70,8 +71,23 @@ def apply( raise NotImplementedError -@CustomOp.register("unquantized_fused_moe") -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): +class UnquantizedFusedMoEMethod(FusedMoEMethodBase): + """MoE method without quantization.""" + + @staticmethod + def get_moe_method( + activation: str, + expert_map: Optional[torch.Tensor], + ) -> "UnquantizedFusedMoEMethod": + if (UnquantizedFusedCutlassMoEMethod.check_supported( + activation, expert_map)): + return UnquantizedFusedCutlassMoEMethod() + else: + return UnquantizedFusedTritonMoEMethod() + + +@CustomOp.register("unquantized_fused_triton_moe") +class UnquantizedFusedTritonMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -97,24 +113,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - device = layer.w13_weight.device - self.ab_strides1 = torch.full((num_experts, ), - hidden_size, - device=device, - dtype=torch.int64) - self.c_strides1 = torch.full((num_experts, ), - 2 * intermediate_size_per_partition, - device=device, - dtype=torch.int64) - self.ab_strides2 = torch.full((num_experts, ), - intermediate_size_per_partition, - device=device, - dtype=torch.int64) - self.c_strides2 = torch.full((num_experts, ), - hidden_size, - device=device, - dtype=torch.int64) - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory @@ -220,30 +218,15 @@ def forward_cuda( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - # TODO will need a similar class split as fp8 because we don't have - # expert map and other activations here either - from vllm.model_executor.layers.fused_moe import cutlass_moe_fp16 - return cutlass_moe_fp16( - x, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - topk_weights, - topk_ids, - self.ab_strides1, - self.c_strides1, - self.ab_strides2, - self.c_strides2, - ) - - # return fused_experts(hidden_states=x, - # w1=layer.w13_weight, - # w2=layer.w2_weight, - # topk_weights=topk_weights, - # topk_ids=topk_ids, - # inplace=True, - # activation=activation, - # global_num_experts=global_num_experts, - # expert_map=expert_map) + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map) def forward_cpu( self, @@ -345,6 +328,129 @@ def forward_tpu( forward_native = forward_cuda +@CustomOp.register("unquantized_fused_cutlass_moe") +class UnquantizedFusedCutlassMoEMethod(FusedMoEMethodBase, CustomOp): + """CUTLASS MoE method without quantization.""" + + @staticmethod + def check_supported(activation: str, + expert_map: Optional[torch.Tensor], + error: bool = True) -> bool: + required_capability = 90 + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + arch_supported = (capability == required_capability + and not current_platform.is_cpu() + and not current_platform.is_rocm()) + functions_supported = activation == "silu" and expert_map is None + if error and not arch_supported: + raise RuntimeError( + "Method is not supported for the current device. Required ", + f"GPU with capability: {required_capability}. Current " + f"capability: {capability}.") + elif error and not functions_supported: + raise RuntimeError( + "Method is not supported for the required functionality. ", + "Required activation: silu, expert map not supported.", + ) + return arch_supported and functions_supported + else: + return False + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + device = layer.w13_weight.device + self.ab_strides1 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((num_experts, ), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((num_experts, ), + intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # TODO half() + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + + assert activation == "silu" + assert global_num_experts == layer.w13_weight.shape[0] + assert expert_map is None + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return cutlass_moe_fp16( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + ) + + def determine_expert_map( ep_size: int, ep_rank: int, global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: @@ -511,7 +617,8 @@ def __init__( # for heuristic purposes, so it must be initialized first. if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + UnquantizedFusedMoEMethod.get_moe_method( + self.activation, self.expert_map)) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None From 11d640f43271c7b8efa4ad97a016b19715991bb4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 2 Apr 2025 14:07:31 +0000 Subject: [PATCH 04/22] Float16 support Signed-off-by: ElizaWszola --- .../cutlass_w8a8/moe/get_group_starts.cuh | 37 +++++++++++++------ .../cutlass_w8a8/moe/grouped_mm_c3x.cu | 14 ++++--- tests/kernels/test_cutlass.py | 2 +- tests/kernels/test_cutlass_moe.py | 4 +- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh index 46c6dcf844d0..841a1fe533af 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh @@ -49,6 +49,20 @@ __global__ void get_group_gemm_starts( a_tensors.size(1), per_act_token, per_out_ch); \ } +#define __CALL_GET_STARTS_KERNEL_FP16(ABC_TENSOR_TYPE, ABC_TYPE) \ + else if (out_tensors.dtype() == ABC_TENSOR_TYPE) { \ + get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), nullptr, nullptr, \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), nullptr, nullptr, \ + out_tensors.size(1), a_tensors.size(1), false, false); \ + } + namespace { void run_get_group_gemm_starts( @@ -85,23 +99,22 @@ void run_get_group_gemm_starts_fp16(torch::Tensor const& expert_offsets, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor& out_tensors) { - TORCH_CHECK(a_tensors.dtype() == torch::kBFloat16); - TORCH_CHECK(b_tensors.dtype() == torch::kBFloat16); + TORCH_CHECK(a_tensors.dtype() == torch::kBFloat16 || + a_tensors.dtype() == torch::kFloat16); + TORCH_CHECK(a_tensors.dtype() == b_tensors.dtype()); + TORCH_CHECK(a_tensors.dtype() == out_tensors.dtype()); int num_experts = (int)expert_offsets.size(0); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - get_group_gemm_starts - <<<1, num_experts, 0, stream>>>( - static_cast(expert_offsets.data_ptr()), - static_cast(a_ptrs.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(out_ptrs.data_ptr()), nullptr, - nullptr, static_cast(a_tensors.data_ptr()), - static_cast(b_tensors.data_ptr()), - static_cast(out_tensors.data_ptr()), nullptr, - nullptr, out_tensors.size(1), a_tensors.size(1), false, false); + if (false) { + } + __CALL_GET_STARTS_KERNEL_FP16(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL_FP16(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid i/o type (must be float16 or bfloat16)"); + } } } // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu index 8dc555f7fc69..9543bb2da3eb 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -263,11 +263,15 @@ void dispatch_moe_mm_fp16_sm90( torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - TORCH_CHECK(out_tensors.dtype() == torch::kBFloat16, - "output must be bfloat16"); - run_cutlass_moe_mm_fp16_sm90( - out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, - a_strides, b_strides, c_strides); + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_fp16_sm90( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } else { + run_cutlass_moe_mm_fp16_sm90( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } } } // namespace diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 53abc78bb891..520cbace73d8 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -644,7 +644,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, @pytest.mark.parametrize("num_experts", [8, 64]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 3c9e2a56f418..d44a8b765d31 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -261,7 +261,7 @@ def test_cutlass_moe_cuda_graph( @pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), @@ -315,7 +315,7 @@ def test_cutlass_fp16_moe_no_graph( @pytest.mark.parametrize("k", [1024, 1536]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), From ab0143bb7e4cd62b6a21e561c20cdd4cfcccee04 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 3 Apr 2025 12:58:11 +0000 Subject: [PATCH 05/22] Refactor, merge some common 16- and 8-bit functionalities Signed-off-by: ElizaWszola --- CMakeLists.txt | 8 +- ...mm_cutlass.py => benchmark_cutlass_moe.py} | 50 ++--- csrc/ops.h | 21 +- .../moe/{grouped_mm_c3x.cu => moe_mm_c3x.cu} | 121 ++++++------ ..._mm_fp16_c3x.cuh => moe_mm_c3x_16_bit.cuh} | 62 +----- ...rouped_mm_c3x.cuh => moe_mm_c3x_8_bit.cuh} | 66 +------ ...group_starts.cuh => moe_mm_c3x_common.cuh} | 84 ++++++-- .../cutlass_w8a8/scaled_mm_entry.cu | 53 +++-- csrc/torch_bindings.cpp | 18 +- tests/kernels/test_cutlass.py | 12 +- tests/kernels/test_cutlass_moe.py | 105 +++++----- vllm/_custom_ops.py | 45 ++--- .../layers/fused_moe/__init__.py | 7 +- .../layers/fused_moe/fused_moe.py | 185 +++++++----------- vllm/model_executor/layers/fused_moe/layer.py | 6 +- .../compressed_tensors_moe.py | 9 +- 16 files changed, 372 insertions(+), 480 deletions(-) rename benchmarks/kernels/{benchmark_grouped_gemm_cutlass.py => benchmark_cutlass_moe.py} (92%) rename csrc/quantization/cutlass_w8a8/moe/{grouped_mm_c3x.cu => moe_mm_c3x.cu} (73%) rename csrc/quantization/cutlass_w8a8/moe/{grouped_mm_fp16_c3x.cuh => moe_mm_c3x_16_bit.cuh} (55%) rename csrc/quantization/cutlass_w8a8/moe/{grouped_mm_c3x.cuh => moe_mm_c3x_8_bit.cuh} (58%) rename csrc/quantization/cutlass_w8a8/moe/{get_group_starts.cuh => moe_mm_c3x_common.cuh} (61%) diff --git a/CMakeLists.txt b/CMakeLists.txt index e0f1fdf78d14..3058dbb65fed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -469,21 +469,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" + set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + message(STATUS "Building moe_mm_c3x for archs: ${SCALED_MM_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + message(STATUS "Not building moe_mm_c3x kernels as CUDA Compiler version is " "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " "if you intend on running FP8 quantized MoE models on Hopper.") else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + message(STATUS "Not building moe_mm_c3x as no compatible archs found " "in CUDA target architectures") endif() endif() diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_cutlass_moe.py similarity index 92% rename from benchmarks/kernels/benchmark_grouped_gemm_cutlass.py rename to benchmarks/kernels/benchmark_cutlass_moe.py index d516d1e6246c..04bf57ee442f 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_cutlass_moe.py @@ -6,7 +6,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, fused_experts, fused_topk) from vllm.utils import FlexibleArgumentParser @@ -121,18 +121,18 @@ def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, num_repeats: int): for _ in range(num_repeats): - cutlass_moe_fp8(a, - w1, - w2, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + cutlass_moe(a, + w1, + w2, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) def run_cutlass_from_graph( a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, @@ -143,18 +143,18 @@ def run_cutlass_from_graph( with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + return cutlass_moe(a, + w1_q, + w2_q, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, diff --git a/csrc/ops.h b/csrc/ops.h index dd0415a09fc8..fba5b6ee714b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -176,18 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_moe_mm( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); - -void cutlass_moe_mm_fp16( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_moe_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + std::optional const& a_scales, + std::optional const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& c_strides); void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu similarity index 73% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu rename to csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu index 9543bb2da3eb..221e5d58498d 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu @@ -4,8 +4,8 @@ #include #include "cutlass/cutlass.h" -#include "grouped_mm_c3x.cuh" -#include "grouped_mm_fp16_c3x.cuh" +#include "moe_mm_c3x_8_bit.cuh" +#include "moe_mm_c3x_16_bit.cuh" using namespace cute; @@ -13,7 +13,7 @@ namespace { template typename Epilogue> -struct sm90_fp8_config_default { +struct sm90_8_bit_config_default { // M in (16, inf) static_assert(std::is_same()); using KernelSchedule = @@ -24,13 +24,13 @@ struct sm90_fp8_config_default { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp8_config_M16 { +struct sm90_8_bit_config_M16 { // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = @@ -41,13 +41,13 @@ struct sm90_fp8_config_M16 { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp8_config_K8192 { +struct sm90_8_bit_config_K8192 { // K in [8192, inf) static_assert(std::is_same()); using KernelSchedule = @@ -58,13 +58,13 @@ struct sm90_fp8_config_K8192 { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp8_config_N8192 { +struct sm90_8_bit_config_N8192 { // N in [8192, inf) static_assert(std::is_same()); using KernelSchedule = @@ -75,13 +75,13 @@ struct sm90_fp8_config_N8192 { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp16_config_default { +struct sm90_16_bit_config_default { // M in (16, inf) using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; @@ -91,14 +91,13 @@ struct sm90_fp16_config_default { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm_fp16; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp16_config_M16 { +struct sm90_16_bit_config_M16 { // M in [1, 16] using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; @@ -108,14 +107,13 @@ struct sm90_fp16_config_M16 { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm_fp16; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp16_config_K8192 { +struct sm90_16_bit_config_K8192 { // K in [8192, inf) using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; @@ -125,14 +123,13 @@ struct sm90_fp16_config_K8192 { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm_fp16; + cutlass_3x_moe_gemm; }; template typename Epilogue> -struct sm90_fp16_config_N8192 { +struct sm90_16_bit_config_N8192 { // N in [8192, inf) using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; @@ -142,13 +139,12 @@ struct sm90_fp16_config_N8192 { using ClusterShape = cute::Shape; using Cutlass3xGemm = - cutlass_3x_group_gemm_fp16; + cutlass_3x_moe_gemm; }; template -void run_cutlass_moe_mm_sm90( +void run_cutlass_moe_mm_sm90_8_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, @@ -166,13 +162,13 @@ void run_cutlass_moe_mm_sm90( TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< + using Cutlass3xGemmN8192 = typename sm90_8_bit_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< + using Cutlass3xGemmK8192 = typename sm90_8_bit_config_K8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< + using Cutlass3xGemmM16 = typename sm90_8_bit_config_M16< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmDefault = typename sm90_fp8_config_default< + using Cutlass3xGemmDefault = typename sm90_8_bit_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; uint32_t const m = a_tensors.size(0); @@ -180,26 +176,26 @@ void run_cutlass_moe_mm_sm90( uint32_t const k = a_tensors.size(1); if (n >= 8192) { - cutlass_group_gemm_caller( + cutlass_moe_gemm_caller_8_bit( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else if (k >= 8192) { - cutlass_group_gemm_caller( + cutlass_moe_gemm_caller_8_bit( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else if (m <= 16) { - cutlass_group_gemm_caller( + cutlass_moe_gemm_caller_8_bit( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { - cutlass_group_gemm_caller( + cutlass_moe_gemm_caller_8_bit( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } } template -void run_cutlass_moe_mm_fp16_sm90( +void run_cutlass_moe_mm_sm90_16_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, @@ -208,14 +204,13 @@ void run_cutlass_moe_mm_fp16_sm90( TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); - using Cutlass3xGemmN8192 = typename sm90_fp16_config_N8192< + using Cutlass3xGemmN8192 = typename sm90_16_bit_config_N8192< InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; - using Cutlass3xGemmK8192 = typename sm90_fp16_config_K8192< + using Cutlass3xGemmK8192 = typename sm90_16_bit_config_K8192< InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; - using Cutlass3xGemmM16 = - typename sm90_fp16_config_M16::Cutlass3xGemm; - using Cutlass3xGemmDefault = typename sm90_fp16_config_default< + using Cutlass3xGemmM16 = typename sm90_16_bit_config_M16< + InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_16_bit_config_default< InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; uint32_t const m = a_tensors.size(0); @@ -223,52 +218,52 @@ void run_cutlass_moe_mm_fp16_sm90( uint32_t const k = a_tensors.size(1); if (n >= 8192) { - cutlass_group_gemm_fp16_caller( + cutlass_moe_gemm_caller_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else if (k >= 8192) { - cutlass_group_gemm_fp16_caller( + cutlass_moe_gemm_caller_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else if (m <= 16) { - cutlass_group_gemm_fp16_caller( + cutlass_moe_gemm_caller_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { - cutlass_group_gemm_fp16_caller( + cutlass_moe_gemm_caller_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } } -void dispatch_moe_mm_sm90( +void dispatch_moe_mm_sm90_8_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { if (out_tensors.dtype() == torch::kBFloat16) { - run_cutlass_moe_mm_sm90( + run_cutlass_moe_mm_sm90_8_bit( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { - run_cutlass_moe_mm_sm90( + run_cutlass_moe_mm_sm90_8_bit( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } } -void dispatch_moe_mm_fp16_sm90( +void dispatch_moe_mm_sm90_16_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { if (out_tensors.dtype() == torch::kBFloat16) { - run_cutlass_moe_mm_fp16_sm90( + run_cutlass_moe_mm_sm90_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { - run_cutlass_moe_mm_fp16_sm90( + run_cutlass_moe_mm_sm90_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } @@ -276,22 +271,22 @@ void dispatch_moe_mm_fp16_sm90( } // namespace -void cutlass_moe_mm_sm90( +void cutlass_moe_mm_sm90_8_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); + dispatch_moe_mm_sm90_8_bit(out_tensors, a_tensors, b_tensors, a_scales, + b_scales, expert_offsets, problem_sizes, a_strides, + b_strides, c_strides); } -void cutlass_moe_mm_fp16_sm90( +void cutlass_moe_mm_sm90_16_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - dispatch_moe_mm_fp16_sm90(out_tensors, a_tensors, b_tensors, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + dispatch_moe_mm_sm90_16_bit(out_tensors, a_tensors, b_tensors, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); } diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_16_bit.cuh similarity index 55% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh rename to csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_16_bit.cuh index 8973fa877df1..6b0cdd6ead67 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_fp16_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_16_bit.cuh @@ -9,70 +9,14 @@ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/common.hpp" -#include "get_group_starts.cuh" +#include "moe_mm_c3x_common.cuh" using namespace cute; namespace { -using ProblemShape = - cutlass::gemm::GroupProblemShape>; - -using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; -using OperatorClass = cutlass::arch::OpClassTensorOp; - -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::RowMajor; - -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_group_gemm_fp16 { - using ElementAB = ElementAB_; - using ElementC = void; - using ElementD = ElementC_; - using ElementAccumulator = float; - - using Epilogue = Epilogue_; - - using StrideC = - cute::remove_pointer_t, cute::Int<0>>>; - - static constexpr int AlignmentAB = - 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, - LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, - LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, - Stages, KernelSchedule>::CollectiveOp; - - using KernelType = enable_sm90_only>; - - struct GemmKernel : public KernelType {}; -}; - template -void cutlass_group_gemm_fp16_caller( +void cutlass_moe_gemm_caller_16_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, @@ -93,7 +37,7 @@ void cutlass_group_gemm_fp16_caller( torch::Tensor b_ptrs = torch::empty(num_experts, options_int); torch::Tensor out_ptrs = torch::empty(num_experts, options_int); - run_get_group_gemm_starts_fp16(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + run_get_moe_gemm_starts_16_bit(expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_tensors, b_tensors, out_tensors); using GemmKernel = typename Gemm::GemmKernel; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_8_bit.cuh similarity index 58% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh rename to csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_8_bit.cuh index db827b7c5e18..99466d443d40 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_8_bit.cuh @@ -8,70 +8,14 @@ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/common.hpp" -#include "get_group_starts.cuh" +#include "moe_mm_c3x_common.cuh" using namespace cute; namespace { -using ProblemShape = - cutlass::gemm::GroupProblemShape>; - -using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; -using OperatorClass = cutlass::arch::OpClassTensorOp; - -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::RowMajor; - -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_group_gemm { - using ElementAB = ElementAB_; - using ElementC = void; - using ElementD = ElementC_; - using ElementAccumulator = float; - - using Epilogue = Epilogue_; - - using StrideC = - cute::remove_pointer_t, cute::Int<0>>>; - - static constexpr int AlignmentAB = - 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, - LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, - LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, - Stages, KernelSchedule>::CollectiveOp; - - using KernelType = enable_sm90_only>; - - struct GemmKernel : public KernelType {}; -}; - template -void cutlass_group_gemm_caller( +void cutlass_moe_gemm_caller_8_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, @@ -98,9 +42,9 @@ void cutlass_group_gemm_caller( torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); - run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, - a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, - out_tensors, a_scales, b_scales); + run_get_moe_gemm_starts_8_bit(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, a_tensors, + b_tensors, out_tensors, a_scales, b_scales); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_common.cuh similarity index 61% rename from csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh rename to csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_common.cuh index 841a1fe533af..d5f68f2d71e9 100644 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_common.cuh @@ -8,8 +8,10 @@ #include "cutlass/bfloat16.h" #include "cutlass/float8.h" +// get tensors with pointers pointing to the start index of each group's data + template -__global__ void get_group_gemm_starts( +__global__ void get_moe_gemm_starts( int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementAccumulator** b_scales_offsets, ElementAB* a_base, ElementAB* b_base, @@ -31,9 +33,9 @@ __global__ void get_group_gemm_starts( b_scales_base + (per_out_ch ? n * expert_id : expert_id); } -#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ +#define __CALL_GET_STARTS_KERNEL_8_BIT(TENSOR_C_TYPE, C_TYPE) \ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ - get_group_gemm_starts \ + get_moe_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ @@ -49,9 +51,9 @@ __global__ void get_group_gemm_starts( a_tensors.size(1), per_act_token, per_out_ch); \ } -#define __CALL_GET_STARTS_KERNEL_FP16(ABC_TENSOR_TYPE, ABC_TYPE) \ - else if (out_tensors.dtype() == ABC_TENSOR_TYPE) { \ - get_group_gemm_starts \ +#define __CALL_GET_STARTS_KERNEL_16_BIT(TENSOR_ABC_TYPE, ABC_TYPE) \ + else if (out_tensors.dtype() == TENSOR_ABC_TYPE) { \ + get_moe_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ static_cast(expert_offsets.data_ptr()), \ static_cast(a_ptrs.data_ptr()), \ @@ -65,7 +67,7 @@ __global__ void get_group_gemm_starts( namespace { -void run_get_group_gemm_starts( +void run_get_moe_gemm_starts_8_bit( torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, @@ -85,14 +87,14 @@ void run_get_group_gemm_starts( if (false) { } - __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) - __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + __CALL_GET_STARTS_KERNEL_8_BIT(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL_8_BIT(torch::kFloat16, half) else { TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); } } -void run_get_group_gemm_starts_fp16(torch::Tensor const& expert_offsets, +void run_get_moe_gemm_starts_16_bit(torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, @@ -110,11 +112,69 @@ void run_get_group_gemm_starts_fp16(torch::Tensor const& expert_offsets, if (false) { } - __CALL_GET_STARTS_KERNEL_FP16(torch::kBFloat16, cutlass::bfloat16_t) - __CALL_GET_STARTS_KERNEL_FP16(torch::kFloat16, half) + __CALL_GET_STARTS_KERNEL_16_BIT(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL_16_BIT(torch::kFloat16, half) else { TORCH_CHECK(false, "Invalid i/o type (must be float16 or bfloat16)"); } } +// common structs and types used by moe gemm + +using ProblemShape = + cutlass::gemm::GroupProblemShape>; + +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_moe_gemm { + using ElementAB = ElementAB_; + using ElementC = void; + using ElementD = ElementC_; + using ElementAccumulator = float; + + using Epilogue = Epilogue_; + + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, + LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, + Stages, KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_only>; + + struct GemmKernel : public KernelType {}; +}; + } // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index d950dca89627..671e2ede0be7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -30,14 +30,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_moe_mm_sm90( +void cutlass_moe_mm_sm90_8_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void cutlass_moe_mm_fp16_sm90( +void cutlass_moe_mm_sm90_16_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, @@ -201,34 +201,31 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_moe_mm( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { +void cutlass_moe_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + std::optional const& a_scales, + std::optional const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& c_strides) { int32_t version_num = get_sm_version_num(); #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); - return; -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90"); -} - -void cutlass_moe_mm_fp16( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - int32_t version_num = get_sm_version_num(); -#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 - cutlass_moe_mm_fp16_sm90(out_tensors, a_tensors, b_tensors, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + if (a_tensors.dtype() == torch::kBFloat16 || + a_tensors.dtype() == torch::kFloat16) { + TORCH_CHECK(!a_scales.has_value()); + TORCH_CHECK(!b_scales.has_value()); + cutlass_moe_mm_sm90_16_bit(out_tensors, a_tensors, b_tensors, + expert_offsets, problem_sizes, a_strides, + b_strides, c_strides); + } else { + TORCH_CHECK(a_scales.has_value()); + TORCH_CHECK(b_scales.has_value()); + cutlass_moe_mm_sm90_8_bit( + out_tensors, a_tensors, b_tensors, a_scales.value(), b_scales.value(), + expert_offsets, problem_sizes, a_strides, b_strides, c_strides); + } return; #endif TORCH_CHECK_NOT_IMPLEMENTED( diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1c84becc9110..b1a67b24ff77 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -370,24 +370,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool"); ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported); - // CUTLASS w8a8 grouped GEMM + // CUTLASS MoE GEMM ops.def( "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " - " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " - " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor c_strides) -> ()", + " Tensor? a_scales, Tensor? b_scales, " + " Tensor expert_offsets, Tensor problem_sizes, " + " Tensor a_strides, Tensor b_strides, Tensor c_strides" + ") -> ()", {stride_tag}); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); - // CUTLASS w16a16 grouped GEMM - ops.def( - "cutlass_moe_mm_fp16(Tensor! out_tensors, Tensor a_tensors, " - " Tensor b_tensors, Tensor expert_offsets, " - " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor c_strides) -> ()", - {stride_tag}); - ops.impl("cutlass_moe_mm_fp16", torch::kCUDA, &cutlass_moe_mm_fp16); - // A function that computes data required to run fused MoE with w8a8 grouped // GEMM. It takes topk_ids as an input, and computes expert_offsets // (token start indices of each expert). In addition to this, it computes diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 520cbace73d8..4d56a3dc5bde 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -518,8 +518,8 @@ def test_cutlass_support_opcheck(): (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): +def test_cutlass_fp8_moe_gemm(num_experts: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): # Device and dtype setup device = "cuda" @@ -649,7 +649,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp16_group_gemm(num_experts: int, dtype: torch.dtype): +def test_cutlass_fp16_moe_gemm(num_experts: int, dtype: torch.dtype): # Device and dtype setup device = "cuda" @@ -715,9 +715,9 @@ def test_cutlass_fp16_group_gemm(num_experts: int, dtype: torch.dtype): device=device, dtype=torch.int64) - ops.cutlass_moe_mm_fp16(out_tensors_stacked, a_tensors_stacked, - b_tensors_stacked, expert_offsets[:-1], - problem_sizes, ab_strides, ab_strides, c_strides) + ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, + b_tensors_stacked, None, None, expert_offsets[:-1], + problem_sizes, ab_strides, ab_strides, c_strides) # Validate each group's result against the baseline for g in range(num_experts): diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index d44a8b765d31..de4e0e79c116 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -4,8 +4,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, - cutlass_moe_fp16, +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, fused_experts, fused_topk) from vllm.platforms import current_platform @@ -14,37 +13,38 @@ TOP_KS = [6, 8] -def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): +def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, ab_strides2: torch.Tensor, + c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - -def run_fp16(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + return cutlass_moe(a, + w1_q, + w2_q, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale) + + +def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp16(a, w1, w2, topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) + return cutlass_moe(a, w1, w2, topk_weights, topk_ids, ab_strides1, + c_strides1, ab_strides2, c_strides2) @pytest.mark.parametrize("m", [2, 64, 224]) @@ -58,7 +58,7 @@ def run_fp16(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_no_graph( +def test_cutlass_moe_8_bit_no_graph( m: int, n: int, k: int, @@ -130,18 +130,18 @@ def test_cutlass_moe_no_graph( triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - cutlass_output = cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1) + cutlass_output = cutlass_moe(a, + w1_q, + w2_q, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale1) print(triton_output) print(cutlass_output) @@ -164,7 +164,7 @@ def test_cutlass_moe_no_graph( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_cuda_graph( +def test_cutlass_moe_8_bit_cuda_graph( m: int, n: int, k: int, @@ -239,9 +239,10 @@ def test_cutlass_moe_cuda_graph( stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) + cutlass_output = run_8_bit(a, a_scale1, w1_q, w2_q, w1_scale, + w2_scale, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() @@ -266,7 +267,7 @@ def test_cutlass_moe_cuda_graph( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp16_moe_no_graph( +def test_cutlass_moe_16_bit_no_graph( m: int, n: int, k: int, @@ -296,9 +297,9 @@ def test_cutlass_fp16_moe_no_graph( triton_output = fused_experts(a, w1.transpose(1, 2), w2.transpose(1, 2), topk_weights, topk_ids) - cutlass_output = cutlass_moe_fp16(a, w1, w2, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + cutlass_output = cutlass_moe(a, w1, w2, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) print(triton_output) print(cutlass_output) @@ -320,7 +321,7 @@ def test_cutlass_fp16_moe_no_graph( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp16_moe_cuda_graph( +def test_cutlass_moe_16_bit_cuda_graph( m: int, n: int, k: int, @@ -353,9 +354,9 @@ def test_cutlass_fp16_moe_cuda_graph( stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_fp16(a, w1, w2, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + cutlass_output = run_16_bit(a, w1, w2, topk_weights, topk_ids, + ab_strides1, c_strides1, ab_strides2, + c_strides2) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9a307989e7a0..6ce543a5a548 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -710,10 +710,11 @@ def get_cutlass_moe_mm_data( def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor, expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, a_strides: torch.Tensor, - b_strides: torch.Tensor, c_strides: torch.Tensor): + b_tensors: torch.Tensor, a_scales: Optional[torch.Tensor], + b_scales: Optional[torch.Tensor], + expert_offsets: torch.Tensor, problem_sizes: torch.Tensor, + a_strides: torch.Tensor, b_strides: torch.Tensor, + c_strides: torch.Tensor): """ A single grouped matrix multiplication used in CUTLASS-based fused MoE. The function executes fp8-quantized OUT = AB matrix multiplication. @@ -730,24 +731,24 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, a_strides, b_strides, c_strides) -def cutlass_moe_mm_fp16(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, a_strides: torch.Tensor, - b_strides: torch.Tensor, c_strides: torch.Tensor): - """ - A single grouped matrix multiplication used in CUTLASS-based fused MoE. - The function executes fp8-quantized OUT = AB matrix multiplication. - - - expert_offsets: Indices that mark at which token index each expert begins - its computation. The number of tokens computed with - expert E is expert_offsets[E + 1] - expert_offsets[E] - - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped - MMs used in the fused MoE operation. - - a/b/c_strides: The data strides passed to grouped matrix multiplication. - """ - torch.ops._C.cutlass_moe_mm_fp16(out_tensors, a_tensors, b_tensors, - expert_offsets, problem_sizes, a_strides, - b_strides, c_strides) +# def cutlass_moe_mm_fp16(out_tensors: torch.Tensor, a_tensors: torch.Tensor, +# b_tensors: torch.Tensor, expert_offsets: torch.Tensor, +# problem_sizes: torch.Tensor, a_strides: torch.Tensor, +# b_strides: torch.Tensor, c_strides: torch.Tensor): +# """ +# A single grouped matrix multiplication used in CUTLASS-based fused MoE. +# The function executes fp8-quantized OUT = AB matrix multiplication. + +# - expert_offsets: Indices that mark at which token index each expert begins +# its computation. The number of tokens computed with +# expert E is expert_offsets[E + 1] - expert_offsets[E] +# - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped +# MMs used in the fused MoE operation. +# - a/b/c_strides: The data strides passed to grouped matrix multiplication. +# """ +# torch.ops._C.cutlass_moe_mm_fp16(out_tensors, a_tensors, b_tensors, +# expert_offsets, problem_sizes, a_strides, +# b_strides, c_strides) # aqlm diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index ca7226c96d9b..c29849ce707d 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,8 +36,8 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( - cutlass_moe_fp8, cutlass_moe_fp16, fused_experts, fused_moe, - fused_topk, get_config_file_name, grouped_topk) + cutlass_moe, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -45,6 +45,5 @@ def get_config() -> Optional[Dict[str, Any]]: "fused_experts", "get_config_file_name", "grouped_topk", - "cutlass_moe_fp8", - "cutlass_moe_fp16", + "cutlass_moe", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index df060e50fe91..0aa1c2accdc5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1626,21 +1626,20 @@ def fused_moe( #TODO make the grouped gemm kernel consistent with scaled gemm kernel -def cutlass_moe_fp8( +def cutlass_moe( a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.half, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -1651,75 +1650,89 @@ def cutlass_moe_fp8( Parameters: - a (torch.Tensor): The input tensor to the MoE layer. Shape: [M, K] - - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + - w1 (torch.Tensor): The first set of expert weights. Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + - w2 (torch.Tensor): The second set of expert weights. Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping. - ab_strides1 (torch.Tensor): The input and weights strides of the first grouped gemm. - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - ab_strides2 (torch.Tensor): The input and weights strides of the second grouped gemm. - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - w1_scale (Optional[torch.Tensor]): The optional fp32 scale + to dequantize w1. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (Optional[torch.Tensor]): The optional fp32 scale + to dequantize w2. + Shape: [num_experts] or [num_experts, K] - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. Returns: - - torch.Tensor: The fp16 output tensor after applying the MoE layer. + - torch.Tensor: The output tensor after applying the MoE layer. """ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ + assert a.shape[1] == w1.shape[1], "Hidden size mismatch w1" + assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert ab_strides1.shape[0] == w1.shape[ 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ + assert c_strides1.shape[0] == w1.shape[ 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ + assert ab_strides2.shape[0] == w2.shape[ 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ + assert c_strides2.shape[0] == w2.shape[ 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert w1.dtype in [torch.float8_e4m3fn, torch.half, + torch.bfloat16], "Invalid weight type" + assert w1.dtype == w2.dtype, "Weights type mismatch" + + if w1.dtype in [torch.half, torch.bfloat16]: + assert w1.dtype == a.dtype, "Unquantized input and weights type mismatch" # noqa: E501 + assert w1_scale is None and w2_scale is None and a1_scale is None and a2_scale is None, "Received scales for unquantized input type" # noqa: E501 + elif w1.dtype == torch.float8_e4m3fn: + assert w1_scale is not None and w2_scale is not None, "Missing scales for quantized input type" # noqa: E501 + + if w1_scale is not None: + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1.shape[2], "W1 scale shape mismatch" + assert w1.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + if w2_scale is not None: + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2.shape[2], "W2 scale shape mismatch" + assert w2.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + if a1_scale is not None: + assert a1_scale.dim() == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ + 0] == a.shape[0], "Input scale shape mismatch" + if a2_scale is not None: + assert a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + + is_quantized = w1.dtype == torch.float8_e4m3fn + device = a.device + num_experts = w1.size(0) + m = a.size(0) + k = w1.size(1) + n = w2.size(1) topk = topk_ids.size(1) - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + out_dtype = a.dtype - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device + if is_quantized: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + a, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) expert_offsets = torch.empty((num_experts + 1), dtype=torch.int32, @@ -1738,82 +1751,32 @@ def cutlass_moe_fp8( problem_sizes2, a_map, c_map, num_experts, n, k) - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + if is_quantized: + rep_a = a.view(dtype=torch.uint8)[a_map].view(dtype=a.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + else: + rep_a = a[a_map] + rep_a1_scales = None c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, + ops.cutlass_moe_mm(c1, rep_a, w1, rep_a1_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, ab_strides1, c_strides1) intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) torch.ops._C.silu_and_mul(intermediate, c1) - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) + if is_quantized: + rep_a = a.view(dtype=torch.uint8)[a_map].view(dtype=a.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + intermediate, a2_scale = ops.scaled_fp8_quant( + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, + ops.cutlass_moe_mm(c2, intermediate, w2, a2_scale, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) return (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) - - -def cutlass_moe_fp16( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, -) -> torch.Tensor: - num_experts = w1.shape[0] - m = a.shape[0] - k = w1.shape[1] - n = w2.shape[1] - - topk = topk_ids.shape[1] - device = a.device - - out_dtype = a.dtype - - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_experts, n, - k) - - rep_a = a[a_map] - - c1 = torch.zeros((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = torch.zeros((m * topk, k), device=device, dtype=out_dtype) - - ops.cutlass_moe_mm_fp16(c1, rep_a, w1, expert_offsets[:-1], problem_sizes1, - ab_strides1, ab_strides1, c_strides1) - - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) - torch.ops._C.silu_and_mul(intermediate, c1) - - ops.cutlass_moe_mm_fp16(c2, intermediate, w2, expert_offsets[:-1], - problem_sizes2, ab_strides2, ab_strides2, - c_strides2) - - return (c2[c_map].view(m, topk, k) * - topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 37717901c5c2..b3f0d05ecb3a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -24,10 +24,10 @@ from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): - from .fused_moe import cutlass_moe_fp16, fused_experts + from .fused_moe import cutlass_moe, fused_experts else: fused_experts = None # type: ignore - cutlass_moe_fp16 = None # type: ignore + cutlass_moe = None # type: ignore if current_platform.is_tpu(): # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas @@ -438,7 +438,7 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return cutlass_moe_fp16( + return cutlass_moe( x, layer.w13_weight.transpose(1, 2), layer.w2_weight.transpose(1, 2), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bf32bee89e89..61fdf8997e74 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -457,23 +457,22 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 + from vllm.model_executor.layers.fused_moe import cutlass_moe - return cutlass_moe_fp8( + return cutlass_moe( x, layer.w13_weight.transpose(1, 2), layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, topk_weights, topk_ids, self.ab_strides1, self.c_strides1, self.ab_strides2, self.c_strides2, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - out_dtype=x.dtype, ) From 1c7406796ff27ae01195c08dc901a86f02fc38cd Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 3 Apr 2025 13:16:38 +0000 Subject: [PATCH 06/22] mnk_factors in unit tests Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index de4e0e79c116..43535c70e9be 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -12,6 +12,21 @@ NUM_EXPERTS = [40, 64] TOP_KS = [6, 8] +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 3072, 1536), + (224, 1024, 1024), + (224, 1024, 1536), + (224, 3072, 1024), + (224, 3072, 1536), +] + def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, @@ -47,9 +62,7 @@ def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, c_strides1, ab_strides2, c_strides2) -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -153,9 +166,7 @@ def test_cutlass_moe_8_bit_no_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -257,9 +268,7 @@ def test_cutlass_moe_8_bit_cuda_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) @@ -311,9 +320,7 @@ def test_cutlass_moe_16_bit_no_graph( rtol=1e-2) -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) From 6fa2f6a26178f90eaf3976491ed9e33f1cb24b56 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 3 Apr 2025 14:09:59 +0000 Subject: [PATCH 07/22] Move cutlass moe source files outside quantized w8a8 directory Signed-off-by: ElizaWszola --- CMakeLists.txt | 4 ++-- .../cutlass_w8a8/moe => cutlass_moe}/moe_data.cu | 0 .../cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x.cu | 0 .../cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x_16_bit.cuh | 0 .../cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x_8_bit.cuh | 0 .../cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x_common.cuh | 0 6 files changed, 2 insertions(+), 2 deletions(-) rename csrc/{quantization/cutlass_w8a8/moe => cutlass_moe}/moe_data.cu (100%) rename csrc/{quantization/cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x.cu (100%) rename csrc/{quantization/cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x_16_bit.cuh (100%) rename csrc/{quantization/cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x_8_bit.cuh (100%) rename csrc/{quantization/cutlass_w8a8/moe => cutlass_moe}/moe_mm_c3x_common.cuh (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3058dbb65fed..203d2052214d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -469,8 +469,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/cutlass_moe/moe_mm_c3x.cu" + "csrc/cutlass_moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/cutlass_moe/moe_data.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_data.cu rename to csrc/cutlass_moe/moe_data.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu b/csrc/cutlass_moe/moe_mm_c3x.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x.cu rename to csrc/cutlass_moe/moe_mm_c3x.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_16_bit.cuh b/csrc/cutlass_moe/moe_mm_c3x_16_bit.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_16_bit.cuh rename to csrc/cutlass_moe/moe_mm_c3x_16_bit.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_8_bit.cuh b/csrc/cutlass_moe/moe_mm_c3x_8_bit.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_8_bit.cuh rename to csrc/cutlass_moe/moe_mm_c3x_8_bit.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_common.cuh b/csrc/cutlass_moe/moe_mm_c3x_common.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_mm_c3x_common.cuh rename to csrc/cutlass_moe/moe_mm_c3x_common.cuh From 8160305ad1d1cd2f87f2edc535d289eda3f8a471 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 3 Apr 2025 14:20:22 +0000 Subject: [PATCH 08/22] Add separate entry file to cutlass moe Signed-off-by: ElizaWszola --- CMakeLists.txt | 1 + csrc/cutlass_moe/moe_mm_entry.cu | 83 +++++++++++++++++++ .../cutlass_w8a8/scaled_mm_entry.cu | 73 ---------------- 3 files changed, 84 insertions(+), 73 deletions(-) create mode 100644 csrc/cutlass_moe/moe_mm_entry.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 203d2052214d..8ecafc13e529 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -285,6 +285,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/cutlass_moe/moe_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" diff --git a/csrc/cutlass_moe/moe_mm_entry.cu b/csrc/cutlass_moe/moe_mm_entry.cu new file mode 100644 index 000000000000..b2a324b8c401 --- /dev/null +++ b/csrc/cutlass_moe/moe_mm_entry.cu @@ -0,0 +1,83 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 + +void cutlass_moe_mm_sm90_8_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + +void cutlass_moe_mm_sm90_16_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); + +#endif + +void cutlass_moe_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + std::optional const& a_scales, + std::optional const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& c_strides) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + if (a_tensors.dtype() == torch::kBFloat16 || + a_tensors.dtype() == torch::kFloat16) { + TORCH_CHECK(!a_scales.has_value()); + TORCH_CHECK(!b_scales.has_value()); + cutlass_moe_mm_sm90_16_bit(out_tensors, a_tensors, b_tensors, + expert_offsets, problem_sizes, a_strides, + b_strides, c_strides); + } else { + TORCH_CHECK(a_scales.has_value()); + TORCH_CHECK(b_scales.has_value()); + cutlass_moe_mm_sm90_8_bit( + out_tensors, a_tensors, b_tensors, a_scales.value(), b_scales.value(), + expert_offsets, problem_sizes, a_strides, b_strides, c_strides); + } + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, + ". Required capability: 90"); +} + +void get_cutlass_moe_mm_data( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + // This function currently gets compiled only if we have a valid cutlass moe + // mm to run it for. + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, input_permutation, + output_permutation, num_experts, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " + "CUDA device capability: ", + version_num, ". Required capability: 90"); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 671e2ede0be7..925a6207466a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -30,25 +30,6 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_moe_mm_sm90_8_bit( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); - -void cutlass_moe_mm_sm90_16_bit( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); - -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); - #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -201,60 +182,6 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_moe_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, - std::optional const& a_scales, - std::optional const& b_scales, - torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, - torch::Tensor const& a_strides, - torch::Tensor const& b_strides, - torch::Tensor const& c_strides) { - int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - if (a_tensors.dtype() == torch::kBFloat16 || - a_tensors.dtype() == torch::kFloat16) { - TORCH_CHECK(!a_scales.has_value()); - TORCH_CHECK(!b_scales.has_value()); - cutlass_moe_mm_sm90_16_bit(out_tensors, a_tensors, b_tensors, - expert_offsets, problem_sizes, a_strides, - b_strides, c_strides); - } else { - TORCH_CHECK(a_scales.has_value()); - TORCH_CHECK(b_scales.has_value()); - cutlass_moe_mm_sm90_8_bit( - out_tensors, a_tensors, b_tensors, a_scales.value(), b_scales.value(), - expert_offsets, problem_sizes, a_strides, b_strides, c_strides); - } - return; -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90"); -} - -void get_cutlass_moe_mm_data( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k) { - // This function currently gets compiled only if we have a valid cutlass moe - // mm to run it for. - int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, input_permutation, - output_permutation, num_experts, n, k); - return; -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " - "CUDA device capability: ", - version_num, ". Required capability: 90"); -} - void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, From c25904d04114dc16def8a165eb883cda59d45cc7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 4 Apr 2025 19:47:11 +0000 Subject: [PATCH 09/22] some cleanup Signed-off-by: Tyler Michael Smith --- vllm/_custom_ops.py | 20 ------- .../layers/fused_moe/fused_moe.py | 57 ++++++++++--------- 2 files changed, 31 insertions(+), 46 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6ce543a5a548..138ed79914f4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -731,26 +731,6 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, a_strides, b_strides, c_strides) -# def cutlass_moe_mm_fp16(out_tensors: torch.Tensor, a_tensors: torch.Tensor, -# b_tensors: torch.Tensor, expert_offsets: torch.Tensor, -# problem_sizes: torch.Tensor, a_strides: torch.Tensor, -# b_strides: torch.Tensor, c_strides: torch.Tensor): -# """ -# A single grouped matrix multiplication used in CUTLASS-based fused MoE. -# The function executes fp8-quantized OUT = AB matrix multiplication. - -# - expert_offsets: Indices that mark at which token index each expert begins -# its computation. The number of tokens computed with -# expert E is expert_offsets[E + 1] - expert_offsets[E] -# - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped -# MMs used in the fused MoE operation. -# - a/b/c_strides: The data strides passed to grouped matrix multiplication. -# """ -# torch.ops._C.cutlass_moe_mm_fp16(out_tensors, a_tensors, b_tensors, -# expert_offsets, problem_sizes, a_strides, -# b_strides, c_strides) - - # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0aa1c2accdc5..eb9856ea7226 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1574,8 +1574,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. @@ -1682,41 +1682,46 @@ def cutlass_moe( assert a.shape[1] == w1.shape[1], "Hidden size mismatch w1" assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert ab_strides1.shape[0] == w1.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2.shape[ - 0], "C Strides 2 expert number mismatch" + assert ab_strides1.shape[0] == w1.shape[0], \ + "AB Strides 1 expert number mismatch" + assert c_strides1.shape[0] == w1.shape[0], \ + "C Strides 1 expert number mismatch" + assert ab_strides2.shape[0] == w2.shape[0], \ + "AB Strides 2 expert number mismatch" + assert c_strides2.shape[0] == w2.shape[0], \ + "C Strides 2 expert number mismatch" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert w1.dtype in [torch.float8_e4m3fn, torch.half, - torch.bfloat16], "Invalid weight type" + assert w1.dtype in [torch.float8_e4m3fn, torch.half,torch.bfloat16], \ + "Invalid weight type" assert w1.dtype == w2.dtype, "Weights type mismatch" if w1.dtype in [torch.half, torch.bfloat16]: - assert w1.dtype == a.dtype, "Unquantized input and weights type mismatch" # noqa: E501 - assert w1_scale is None and w2_scale is None and a1_scale is None and a2_scale is None, "Received scales for unquantized input type" # noqa: E501 + assert w1.dtype == a.dtype, \ + "Unquantized input and weights type mismatch" + assert w1_scale is None and w2_scale is None \ + and a1_scale is None and a2_scale is None, \ + "Received scales for unquantized input type" elif w1.dtype == torch.float8_e4m3fn: - assert w1_scale is not None and w2_scale is not None, "Missing scales for quantized input type" # noqa: E501 + assert w1_scale is not None and w2_scale is not None, \ + "Missing scales for quantized input type" if w1_scale is not None: - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1.shape[2], "W1 scale shape mismatch" - assert w1.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 \ + or w1_scale.shape[1] == w1.shape[2], "W1 scale shape mismatch" + assert w1.shape[0] == w1_scale.shape[0], \ + "w1 scales expert number mismatch" if w2_scale is not None: - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2.shape[2], "W2 scale shape mismatch" - assert w2.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 \ + or w2_scale.shape[1] == w2.shape[2], "W2 scale shape mismatch" + assert w2.shape[0] == w2_scale.shape[0], \ + "w2 scales expert number mismatch" if a1_scale is not None: - assert a1_scale.dim() == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ - 0] == a.shape[0], "Input scale shape mismatch" + assert a1_scale.dim() == 0 or a1_scale.shape[0] == 1 \ + or a1_scale.shape[0] == a.shape[0], "Input scale shape mismatch" if a2_scale is not None: - assert a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert a1_scale is None or a2_scale.shape == a1_scale.shape, \ + "Intermediate scale shape mismatch" is_quantized = w1.dtype == torch.float8_e4m3fn From 31c4c80b5b9db50ae971f4d183137ee23b5f1d25 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 4 Apr 2025 20:38:09 +0000 Subject: [PATCH 10/22] fixes Signed-off-by: Tyler Michael Smith --- vllm/model_executor/layers/fused_moe/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8fe3740f712e..5de09286605c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -24,7 +24,8 @@ from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): - from .fused_moe import cutlass_moe, fused_experts + from .cutlass_moe import cutlass_moe + from .fused_moe import fused_experts else: fused_experts = None # type: ignore cutlass_moe = None # type: ignore From 847150ab024aa8f3d8c23f56bac24b2ee6b00728 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sat, 12 Apr 2025 02:58:59 +0000 Subject: [PATCH 11/22] fix plumbing Signed-off-by: varun sundar rabindranath --- vllm/model_executor/layers/fused_moe/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fbc57b3a31ef..21dc078e4473 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -437,9 +437,9 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - assert activation == "silu" assert global_num_experts == layer.w13_weight.shape[0] assert expert_map is None @@ -466,6 +466,7 @@ def apply( self.c_strides1, self.ab_strides2, self.c_strides2, + apply_router_weight_on_input=apply_router_weight_on_input, ) From 1168828f52358493cfabae86c7934e4b16399802 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sat, 12 Apr 2025 02:11:04 +0000 Subject: [PATCH 12/22] fp16 configs and expert map support Signed-off-by: varun sundar rabindranath --- csrc/cutlass_moe/moe_data.cu | 19 +- csrc/cutlass_moe/moe_mm_c3x.cu | 64 +-- tests/kernels/test_cutlass_moe.py | 529 +++++++++++------- .../layers/fused_moe/cutlass_moe.py | 28 +- vllm/model_executor/layers/fused_moe/layer.py | 38 +- .../compressed_tensors_moe.py | 5 +- 6 files changed, 406 insertions(+), 277 deletions(-) diff --git a/csrc/cutlass_moe/moe_data.cu b/csrc/cutlass_moe/moe_data.cu index a2bdafa62e84..76597aa65643 100644 --- a/csrc/cutlass_moe/moe_data.cu +++ b/csrc/cutlass_moe/moe_data.cu @@ -46,15 +46,27 @@ __global__ void compute_expert_offsets( } __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, + const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, int32_t* atomic_buffer, const int topk_length, const int topk) { - int expert_id = blockIdx.x; + int blk_expert_id = blockIdx.x; + int const num_experts = gridDim.x; + int32_t const num_tokens = expert_offsets[num_experts]; for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { - if (topk_ids[i] == expert_id) { - int start = atomicAdd(&atomic_buffer[expert_id], 1); + int const expert_id = topk_ids[i]; + if (expert_id == -1 && blockIdx.x == 0) { + // output_permutation is used to re-order the moe outputs. It is + // used as c2 = c2[c_map], where c2 is a torch.tensor that is the + // output of the cutlass kernels and c_map is the output_permutation. + // c2 is initialized to zeros, therefore by setting the output_permutation + // to num_tokens, we are guaranteed to fill the moe outputs to zero + // for "invalid" topk_ids. + output_permutation[i] = num_tokens; + } else if (expert_id == blk_expert_id) { + int start = atomicAdd(&atomic_buffer[blk_expert_id], 1); input_permutation[start] = i / topk; output_permutation[i] = start; } @@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), num_experts); compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), + static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), diff --git a/csrc/cutlass_moe/moe_mm_c3x.cu b/csrc/cutlass_moe/moe_mm_c3x.cu index 221e5d58498d..5adf877fd232 100644 --- a/csrc/cutlass_moe/moe_mm_c3x.cu +++ b/csrc/cutlass_moe/moe_mm_c3x.cu @@ -81,13 +81,13 @@ struct sm90_8_bit_config_N8192 { template typename Epilogue> -struct sm90_16_bit_config_default { - // M in (16, inf) +struct sm90_16_bit_config_M512 { + // M in [1, 512] using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; + using TileShape = cute::Shape; using ClusterShape = cute::Shape; using Cutlass3xGemm = @@ -97,46 +97,14 @@ struct sm90_16_bit_config_default { template typename Epilogue> -struct sm90_16_bit_config_M16 { - // M in [1, 16] - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - - using Cutlass3xGemm = - cutlass_3x_moe_gemm; -}; - -template typename Epilogue> -struct sm90_16_bit_config_K8192 { - // K in [8192, inf) - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - - using Cutlass3xGemm = - cutlass_3x_moe_gemm; -}; - -template typename Epilogue> -struct sm90_16_bit_config_N8192 { - // N in [8192, inf) +struct sm90_16_bit_config_default { + // M in (1024, inf] using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_moe_gemm 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); - using Cutlass3xGemmN8192 = typename sm90_16_bit_config_N8192< - InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; - using Cutlass3xGemmK8192 = typename sm90_16_bit_config_K8192< - InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; - using Cutlass3xGemmM16 = typename sm90_16_bit_config_M16< + using Cutlass3xGemmM512 = typename sm90_16_bit_config_M512< InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; using Cutlass3xGemmDefault = typename sm90_16_bit_config_default< InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; @@ -217,16 +181,8 @@ void run_cutlass_moe_mm_sm90_16_bit( uint32_t const n = out_tensors.size(1); uint32_t const k = a_tensors.size(1); - if (n >= 8192) { - cutlass_moe_gemm_caller_16_bit( - out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, - a_strides, b_strides, c_strides); - } else if (k >= 8192) { - cutlass_moe_gemm_caller_16_bit( - out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, - a_strides, b_strides, c_strides); - } else if (m <= 16) { - cutlass_moe_gemm_caller_16_bit( + if (m <= 512) { + cutlass_moe_gemm_caller_16_bit( out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, a_strides, b_strides, c_strides); } else { diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index f2d8a55d5487..49a8bda1e6a1 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import dataclasses +from typing import Optional + import pytest import torch @@ -28,38 +31,200 @@ ] -def run_8_bit(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, ab_strides2: torch.Tensor, - c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe(a, - w1_q, - w2_q, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - -def run_16_bit(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe(a, w1, w2, topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) +@dataclasses.dataclass +class MOETensors: + a: torch.Tensor + w1: torch.Tensor + w2: torch.Tensor + ab_strides1: torch.Tensor + c_strides1: torch.Tensor + ab_strides2: torch.Tensor + c_strides2: torch.Tensor + + @staticmethod + def make_moe_tensors(m: int, k: int, n: int, e: int, + dtype: torch.dtype) -> "MOETensors": + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + return MOETensors(a=a, + w1=w1, + w2=w2, + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2) + + +@dataclasses.dataclass +class MOETensors8Bit(MOETensors): + # quantized + a_q: Optional[torch.Tensor] = None # a -> a_q + w1_q: Optional[torch.Tensor] = None # w1 -> w1_q + w2_q: Optional[torch.Tensor] = None # w2 -> w2_q + a_scale: Optional[torch.Tensor] = None + w1_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + # dequantized + a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d + w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d + w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d + + @staticmethod + def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, + per_act_token: bool, + per_out_channel: bool) -> "MOETensors8Bit": + dtype = torch.half + q_dtype = torch.float8_e4m3fn + + moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype) + + # a -> a_q, w1 -> w1_q, w2 -> w2_q + n_b_scales = 2 * n if per_out_channel else 1 + k_b_scales = k if per_out_channel else 1 + # Get the right scale for tests. + _, a_scale = ops.scaled_fp8_quant( + moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a, + a_scale, + use_per_token_if_dynamic=per_act_token) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + moe_tensors_fp16.w1[expert], + use_per_token_if_dynamic=per_out_channel) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + moe_tensors_fp16.w2[expert], + use_per_token_if_dynamic=per_out_channel) + + # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d + a_d = a_q.float().mul(a_scale).to(dtype) + w1_d = torch.empty_like(moe_tensors_fp16.w1) + w2_d = torch.empty_like(moe_tensors_fp16.w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() + + return MOETensors8Bit(a=moe_tensors_fp16.a, + w1=moe_tensors_fp16.w1, + w2=moe_tensors_fp16.w2, + ab_strides1=moe_tensors_fp16.ab_strides1, + c_strides1=moe_tensors_fp16.c_strides1, + ab_strides2=moe_tensors_fp16.ab_strides2, + c_strides2=moe_tensors_fp16.c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a_d=a_d, + w1_d=w1_d, + w2_d=w2_d) + + +def run_with_expert_maps(num_experts: int, num_local_experts: int, + **cutlass_moe_kwargs): + + def slice_experts(): + slice_params = [ + "w1", "w2", "ab_strides1", "ab_strides2", "c_strides1", + "c_strides2", "w1_scale", "w2_scale" + ] + full_tensors = { + k: v + for k, v in cutlass_moe_kwargs.items() + if k in slice_params and k in cutlass_moe_kwargs + } + + for i in range(0, num_experts, num_local_experts): + s, e = i, i + num_local_experts + + # make expert map + expert_map = [-1] * num_experts + expert_map[s:e] = list(range(num_local_experts)) + expert_map = torch.tensor(expert_map, + dtype=torch.int32, + device="cuda") + + # update cutlass moe arg with expert_map + cutlass_moe_kwargs["expert_map"] = expert_map + # update cutlass moe arg tensors + for k, t in full_tensors.items(): + cutlass_moe_kwargs[k] = t[s:e] + + yield cutlass_moe_kwargs + + out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) + for kwargs in slice_experts(): + out_tensor = out_tensor + cutlass_moe(**kwargs) + + return out_tensor + + +def run_8_bit(moe_tensors: MOETensors8Bit, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_local_experts: Optional[int] = None) -> torch.Tensor: + + kwargs = { + 'a': moe_tensors.a, + 'w1': moe_tensors.w1_q.transpose(1, 2), + 'w2': moe_tensors.w2_q.transpose(1, 2), + 'topk_weights': topk_weights, + 'topk_ids': topk_ids, + 'ab_strides1': moe_tensors.ab_strides1, + 'c_strides1': moe_tensors.c_strides1, + 'ab_strides2': moe_tensors.ab_strides2, + 'c_strides2': moe_tensors.c_strides2, + 'w1_scale': moe_tensors.w1_scale, + 'w2_scale': moe_tensors.w2_scale, + 'a1_scale': moe_tensors.a_scale + } + + num_experts = moe_tensors.w1.size(0) + with_ep = num_local_experts is not None or num_local_experts == num_experts + if not with_ep: + return cutlass_moe(**kwargs) + + return run_with_expert_maps(num_experts, num_local_experts, **kwargs) + + +def run_16_bit(moe_tensors: MOETensors, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_local_experts: Optional[int] = None) -> torch.Tensor: + + kwargs = { + "a": moe_tensors.a, + "w1": moe_tensors.w1.transpose(1, 2), + "w2": moe_tensors.w2.transpose(1, 2), + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "ab_strides1": moe_tensors.ab_strides1, + "c_strides1": moe_tensors.c_strides1, + "ab_strides2": moe_tensors.ab_strides2, + "c_strides2": moe_tensors.c_strides2 + } + + num_experts = moe_tensors.w1.size(0) + with_ep = num_local_experts is not None or num_local_experts == num_experts + if not with_ep: + return cutlass_moe(**kwargs) + + return run_with_expert_maps(num_experts, num_local_experts, **kwargs) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -85,76 +250,21 @@ def test_cutlass_moe_8_bit_no_graph( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_ch) - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) - a_d = a_q.float().mul(a_scale1).to(dtype) + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - cutlass_output = cutlass_moe(a, - w1_q, - w2_q, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale1) + cutlass_output = run_8_bit(mt, topk_weights, topk_ids) #print(triton_output) #print(cutlass_output) @@ -191,69 +301,25 @@ def test_cutlass_moe_8_bit_cuda_graph( dtype = torch.half - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_8_bit(a, a_scale1, w1_q, w2_q, w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() @@ -289,26 +355,17 @@ def test_cutlass_moe_16_bit_no_graph( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - a = torch.ones((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn( - (e, 2 * n, k), device="cuda", dtype=dtype).transpose(1, 2) / 10 - w2 = torch.randn( - (e, k, n), device="cuda", dtype=dtype).transpose(1, 2) / 10 - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) - triton_output = fused_experts(a, w1.transpose(1, - 2), w2.transpose(1, 2), - topk_weights, topk_ids) - cutlass_output = cutlass_moe(a, w1, w2, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + triton_output = fused_experts(mt.a, mt.w1, mt.w2, topk_weights, + topk_ids) + cutlass_output = run_16_bit(mt, topk_weights, topk_ids) print(triton_output) print(cutlass_output) @@ -341,29 +398,20 @@ def test_cutlass_moe_16_bit_cuda_graph( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - a = torch.ones((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn( - (e, 2 * n, k), device="cuda", dtype=dtype).transpose(1, 2) / 10 - w2 = torch.randn( - (e, k, n), device="cuda", dtype=dtype).transpose(1, 2) / 10 - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) - triton_output = fused_experts(a, w1.transpose(1, - 2), w2.transpose(1, 2), - topk_weights, topk_ids) + triton_output = fused_experts(mt.a, mt.w1, mt.w2, topk_weights, + topk_ids) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_16_bit(a, w1, w2, topk_weights, topk_ids, - ab_strides1, c_strides1, ab_strides2, - c_strides2) + cutlass_output = run_16_bit(mt, topk_weights, topk_ids) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() @@ -376,3 +424,106 @@ def test_cutlass_moe_16_bit_cuda_graph( cutlass_output, atol=2e-2, rtol=1e-2) + + +@pytest.mark.parametrize("m", [64]) +@pytest.mark.parametrize("n", [1024]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [16]) +@pytest.mark.parametrize("topk", [1, 8]) +@pytest.mark.parametrize("per_act_token", [True]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_EP( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, + per_out_channel) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, + topk_ids) + + assert e % ep_size == 0, "Cannot distribute experts evenly" + cutlass_output = run_8_bit(mt, + topk_weights, + topk_ids, + num_local_experts=e // ep_size) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [64]) +@pytest.mark.parametrize("n", [1024]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [16]) +@pytest.mark.parametrize("topk", [1, 8]) +@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_16_bit_EP( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: bool, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a, mt.w1, mt.w2, topk_weights, + topk_ids) + + assert e % ep_size == 0, "Cannot distribute experts evenly" + cutlass_output = run_16_bit(mt, + topk_weights, + topk_ids, + num_local_experts=e // ep_size) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 663d1ade8890..1bd67742e604 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -22,6 +22,7 @@ def cutlass_moe( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ @@ -56,8 +57,15 @@ def cutlass_moe( - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for some experts. expert_map is a mapping + from global expert-id to local expert-id. When expert_map[i] is -1, + it means that this Rank is not responsible for global expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. Returns: + - torch.Tensor: The output tensor after applying the MoE layer. """ @@ -116,8 +124,12 @@ def cutlass_moe( topk = topk_ids.size(1) out_dtype = a.dtype - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + local_topk_ids = topk_ids + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + if apply_router_weight_on_input: assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" @@ -140,10 +152,14 @@ def cutlass_moe( dtype=torch.int32, device=device) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + a_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + c_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, num_experts, n, k) @@ -155,7 +171,7 @@ def cutlass_moe( rep_a1_scales = None c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) + c2 = torch.zeros((m * topk, k), device=device, dtype=out_dtype) ops.cutlass_moe_mm(c1, rep_a, w1, rep_a1_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 21dc078e4473..be0b5aeca87f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -78,12 +78,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): """MoE method without quantization.""" @staticmethod - def get_moe_method( - activation: str, - expert_map: Optional[torch.Tensor], - ) -> "UnquantizedFusedMoEMethod": - if (UnquantizedFusedCutlassMoEMethod.check_supported( - activation, expert_map)): + def get_moe_method(activation: str) -> "UnquantizedFusedMoEMethod": + if (UnquantizedFusedCutlassMoEMethod.check_supported(activation)): return UnquantizedFusedCutlassMoEMethod() else: return UnquantizedFusedTritonMoEMethod() @@ -351,9 +347,7 @@ class UnquantizedFusedCutlassMoEMethod(FusedMoEMethodBase, CustomOp): """CUTLASS MoE method without quantization.""" @staticmethod - def check_supported(activation: str, - expert_map: Optional[torch.Tensor], - error: bool = True) -> bool: + def check_supported(activation: str, error: bool = True) -> bool: required_capability = 90 capability_tuple = current_platform.get_device_capability() @@ -362,17 +356,19 @@ def check_supported(activation: str, arch_supported = (capability == required_capability and not current_platform.is_cpu() and not current_platform.is_rocm()) - functions_supported = activation == "silu" and expert_map is None - if error and not arch_supported: - raise RuntimeError( - "Method is not supported for the current device. Required ", + functions_supported = activation == "silu" + if not arch_supported: + warn_msg = ( + "UnquantizedFusedCutlassMoEMethod is not supported" + "for the current device. Required " f"GPU with capability: {required_capability}. Current " f"capability: {capability}.") - elif error and not functions_supported: - raise RuntimeError( - "Method is not supported for the required functionality. ", - "Required activation: silu, expert map not supported.", - ) + logger.warning(warn_msg) + if not functions_supported: + logger.warning( + "UnquantizedFusedCutlassMoEMethod Method is not supported" + "for the required functionality. " + "Required activation: silu, expert map not supported.") return arch_supported and functions_supported else: return False @@ -441,8 +437,6 @@ def apply( activation: str = "silu", ) -> torch.Tensor: assert activation == "silu" - assert global_num_experts == layer.w13_weight.shape[0] - assert expert_map is None topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -466,6 +460,7 @@ def apply( self.c_strides1, self.ab_strides2, self.c_strides2, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -637,8 +632,7 @@ def __init__( # for heuristic purposes, so it must be initialized first. if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod.get_moe_method( - self.activation, self.expert_map)) + UnquantizedFusedMoEMethod.get_moe_method(self.activation)) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f282bcb84954..14a155c2253e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -67,7 +67,7 @@ def get_moe_method( else: return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - and layer.activation == "silu" and layer.expert_map is None): + and layer.activation == "silu"): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -460,8 +460,6 @@ def apply( ) -> torch.Tensor: assert activation == "silu" - assert global_num_experts == layer.w13_weight.shape[0] - assert expert_map is None topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -492,6 +490,7 @@ def apply( a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, ) From f930b1d694c4a6df6fc0b1ebec5f04a43d6303f0 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sat, 12 Apr 2025 03:13:31 +0000 Subject: [PATCH 13/22] fix lint Signed-off-by: varun sundar rabindranath --- tests/kernels/test_cutlass_moe.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 49a8bda1e6a1..c570e752014f 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -178,6 +178,12 @@ def run_8_bit(moe_tensors: MOETensors8Bit, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_local_experts: Optional[int] = None) -> torch.Tensor: + assert not any([ + t is None for t in [ + moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, + moe_tensors.w2_scale, moe_tensors.a_scale + ] + ]) kwargs = { 'a': moe_tensors.a, @@ -199,6 +205,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, if not with_ep: return cutlass_moe(**kwargs) + assert num_local_experts is not None return run_with_expert_maps(num_experts, num_local_experts, **kwargs) From eecfb15846d07f199bf4c972f94e667dac970199 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sat, 12 Apr 2025 03:19:55 +0000 Subject: [PATCH 14/22] fix lint Signed-off-by: varun sundar rabindranath --- tests/kernels/test_cutlass_moe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index c570e752014f..f7bbbd131165 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -187,8 +187,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, kwargs = { 'a': moe_tensors.a, - 'w1': moe_tensors.w1_q.transpose(1, 2), - 'w2': moe_tensors.w2_q.transpose(1, 2), + 'w1': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] + 'w2': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, @@ -206,7 +206,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, return cutlass_moe(**kwargs) assert num_local_experts is not None - return run_with_expert_maps(num_experts, num_local_experts, **kwargs) + return run_with_expert_maps(num_experts, num_local_experts, + **kwargs) # type: ignore[arg-type] def run_16_bit(moe_tensors: MOETensors, From fda5f44db75b97a58c6fbffdfb75d1cde39444fc Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sat, 12 Apr 2025 03:26:38 +0000 Subject: [PATCH 15/22] fix lint Signed-off-by: varun sundar rabindranath --- tests/kernels/test_cutlass_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index f7bbbd131165..d2841105a960 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -206,8 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, return cutlass_moe(**kwargs) assert num_local_experts is not None - return run_with_expert_maps(num_experts, num_local_experts, - **kwargs) # type: ignore[arg-type] + return run_with_expert_maps( + num_experts, + num_local_experts, # type: ignore[arg-type] + **kwargs) def run_16_bit(moe_tensors: MOETensors, From 2e8f3acea22f2ced450c04b8a9643d53a3e3c8b4 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sat, 12 Apr 2025 03:32:28 +0000 Subject: [PATCH 16/22] fix lint Signed-off-by: varun sundar rabindranath --- tests/kernels/test_cutlass_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index d2841105a960..f27680b7242d 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -234,7 +234,10 @@ def run_16_bit(moe_tensors: MOETensors, if not with_ep: return cutlass_moe(**kwargs) - return run_with_expert_maps(num_experts, num_local_experts, **kwargs) + return run_with_expert_maps( + num_experts, + num_local_experts, # type: ignore[arg-type] + **kwargs) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) From e0c3a513b9096289aa20cea8e1277696801e75e9 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sun, 13 Apr 2025 03:28:44 +0000 Subject: [PATCH 17/22] c_map zeros -> empty Signed-off-by: varun sundar rabindranath --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 1bd67742e604..80c25bb529ee 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -155,7 +155,7 @@ def cutlass_moe( a_map = torch.zeros((local_topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.zeros((local_topk_ids.numel()), + c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) From 5b9ab4f721a5b58add984e33737dea81e108dc9b Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Sun, 13 Apr 2025 22:01:54 +0000 Subject: [PATCH 18/22] add expert parallel to torch hash Signed-off-by: varun sundar rabindranath Signed-off-by: ElizaWszola --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index b466b765d774..4d10c3ab8037 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1637,6 +1637,7 @@ def compute_hash(self): factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) + factors.append(self.enable_expert_parallel) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: From 99edc71b88dfc325aea180296972aeb0fe2ffe85 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 16 Apr 2025 15:27:14 +0000 Subject: [PATCH 19/22] Comment out output prints in tests Signed-off-by: ElizaWszola --- tests/kernels/test_cutlass_moe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index f27680b7242d..17dfae15079d 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -380,9 +380,9 @@ def test_cutlass_moe_16_bit_no_graph( topk_ids) cutlass_output = run_16_bit(mt, topk_weights, topk_ids) - print(triton_output) - print(cutlass_output) - print("*") + # print(triton_output) + # print(cutlass_output) + # print("*") torch.testing.assert_close(triton_output.view(cutlass_output.shape), cutlass_output, @@ -429,9 +429,9 @@ def test_cutlass_moe_16_bit_cuda_graph( graph.replay() torch.cuda.synchronize() - print(triton_output) - print(cutlass_output) - print("*") + # print(triton_output) + # print(cutlass_output) + # print("*") torch.testing.assert_close(triton_output, cutlass_output, From e47180623f4c45dfaf286f1b368f2aeb16441575 Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Thu, 17 Apr 2025 15:46:58 +0000 Subject: [PATCH 20/22] Add more moe benchmark shapes Signed-off-by: varun sundar rabindranath --- benchmarks/kernels/benchmark_shapes.py | 70 ++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 70190ba24d9d..466a2d2f5554 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -76,12 +76,18 @@ ], } +# yapf: disable WEIGHT_SHAPES_MOE = { "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ - [8, 2, 4096, 28672], - [8, 2, 14336, 4096], + [8, 2, 4096, 14336], ], - "nm-testing/deepseekv2-lite": [ + "nm-testing/Mixtral-8x7B-Instruct-v0.1-TP2": [ + [8, 2, 4096, 14336 // 2], + ], + "nm-testing/Mixtral-8x7B-Instruct-v0.1-EP2": [ + [8 // 2, 2, 4096, 14336], + ], + "nm-testing/deepseekv2-lite-TP1": [ [64, 6, 2048, 1408], ], "ibm-granite/granite-3.0-1b-a400m": [ @@ -90,4 +96,62 @@ "ibm-granite/granite-3.0-3b-a800m": [ [40, 8, 1024, 1536], ], + "ai21labs/Jamba-v0.1" : [ + [16, 2, 4096, 14336] + ], + "ai21labs/Jamba-v0.1-TP2" : [ + [16, 2, 4096, 14336 // 2] + ], + "ai21labs/Jamba-v0.1-EP2" : [ + [16 // 2, 2, 4096, 14336] + ], + "deepseek-ai/DeepSeek-V2" : [ + [160, 6, 5120, 1536] + ], + "deepseek-ai/DeepSeek-V2-TP8" : [ + [160, 6, 5120, 1536 // 8] + ], + "deepseek-ai/DeepSeek-V2-EP8" : [ + [160 // 8, 6, 5120, 1536] + ], + "Qwen/Qwen1.5-MoE-A2.7B-Chat" : [ + [60, 4, 2048, 1408] + ], + "mistralai/Mixtral-8x22B-v0.1" : [ + [8, 2, 6144, 16384] + ], + "mistralai/Mixtral-8x22B-v0.1-TP8" : [ + [8, 2, 6144, 16384 // 8] + ], + "mistralai/Mixtral-8x22B-v0.1-EP8" : [ + [8 // 8, 2, 6144, 16384] + ], + "deepseek-ai/DeepSeek-R1" : [ + [256, 8, 7168, 18432] + ], + "deepseek-ai/DeepSeek-R1-TP8" : [ + [256, 8, 7168, 18432 // 8] + ], + "deepseek-ai/DeepSeek-R1-EP8" : [ + [256 // 8, 8, 7168, 18432] + ], + "meta-llama/Llama-4-Maverick-17B-128E-Instruct" : [ + [128, 1, 5120, 8192] + ], + "meta-llama/Llama-4-Maverick-17B-128E-Instruct-TP8" : [ + [128, 1, 5120, 8192 // 8] + ], + "meta-llama/Llama-4-Maverick-17B-128E-Instruct-EP8" : [ + [128 // 8, 1, 5120, 8192] + ], + "meta-llama/Llama-4-Scout-17B-16E" : [ + [16, 1, 5120, 8192] + ], + "meta-llama/Llama-4-Scout-17B-16E-TP4" : [ + [16, 1, 5120, 8192 // 4] + ], + "meta-llama/Llama-4-Scout-17B-16E-EP4" : [ + [16 // 4, 1, 5120, 8192] + ] } +# yapf: disable From 2256ab492e0180080359cfe6a0e6bdbde9556e2b Mon Sep 17 00:00:00 2001 From: varun sundar rabindranath Date: Fri, 18 Apr 2025 19:31:32 +0000 Subject: [PATCH 21/22] update benchmark_cutlass_moe Signed-off-by: varun sundar rabindranath --- benchmarks/kernels/benchmark_cutlass_moe.py | 412 ++++++++++++-------- 1 file changed, 253 insertions(+), 159 deletions(-) diff --git a/benchmarks/kernels/benchmark_cutlass_moe.py b/benchmarks/kernels/benchmark_cutlass_moe.py index 04bf57ee442f..23e5387abb08 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe.py +++ b/benchmarks/kernels/benchmark_cutlass_moe.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +import dataclasses +from itertools import product +from typing import Optional + import torch import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe, - fused_experts, +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.utils import FlexibleArgumentParser @@ -28,10 +32,147 @@ def to_fp8(tensor: torch.Tensor): min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) -def bench_run(results: list[benchmark.Measurement], model: str, - num_experts: int, topk: int, per_act_token: bool, +def is_16bit(dtype: torch.dtype) -> bool: + return dtype.itemsize == 2 + + +def is_8bit(dtype: torch.dtype) -> bool: + return dtype.itemsize == 1 + + +@dataclasses.dataclass +class MOETensors: + a: torch.Tensor + w1: torch.Tensor + w2: torch.Tensor + w1_t: torch.Tensor # Transposed w1 for cutlass_moe + w2_t: torch.Tensor # Transposed w2 for cutlass_moe + ab_strides1: torch.Tensor + c_strides1: torch.Tensor + ab_strides2: torch.Tensor + c_strides2: torch.Tensor + # quantized + a_q: Optional[torch.Tensor] = None # a -> a_q + w1_q: Optional[torch.Tensor] = None # w1 -> w1_q + w2_q: Optional[torch.Tensor] = None # w2 -> w2_q + a_scale: Optional[torch.Tensor] = None + w1_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + + @staticmethod + def make_moe_tensors(in_dtype: torch.dtype, m: int, k: int, n: int, e: int, + per_act_token: bool, + per_out_channel: bool) -> "MOETensors": + + # For fp8, use torch.half to create 16bit tensors that can be later + # quantized into fp8. + dtype = in_dtype if is_16bit(in_dtype) else torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + if is_16bit(in_dtype): + assert not (per_act_token or per_out_channel) + return MOETensors(a=a, + w1=w1, + w2=w2, + w1_t=w1.transpose(1, 2), + w2_t=w2.transpose(1, 2), + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2) + + assert in_dtype == torch.float8_e4m3fn + q_dtype = torch.float8_e4m3fn + # a -> a_q, w1 -> w1_q, w2 -> w2_q + n_b_scales = 2 * n if per_out_channel else 1 + k_b_scales = k if per_out_channel else 1 + # Get the right scale for tests. + _, a_scale = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale, + use_per_token_if_dynamic=per_act_token) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_channel) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_channel) + + return MOETensors(a=a, + w1=w1, + w2=w2, + w1_t=w1.transpose(1, 2), + w2_t=w2.transpose(1, 2), + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale) + + def as_8bit_tensors(self) -> "MOETensors": + assert all([ + x is not None for x in + [self.w1_q, self.w2_q, self.w1_scale, self.w2_scale, self.a_scale] + ]) + return MOETensors(a=self.a, + w1=self.w1_q, + w2=self.w2_q, + w1_t=self.w1_q.transpose(1, 2), + w2_t=self.w2_q.transpose(1, 2), + ab_strides1=self.ab_strides1, + c_strides1=self.c_strides1, + ab_strides2=self.ab_strides2, + c_strides2=self.c_strides2, + a_q=None, + w1_q=None, + w2_q=None, + a_scale=self.a_scale, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale) + + def as_16bit_tensors(self) -> "MOETensors": + return MOETensors(a=self.a, + w1=self.w1, + w2=self.w2, + w1_t=self.w1.transpose(1, 2), + w2_t=self.w2.transpose(1, 2), + ab_strides1=self.ab_strides1, + c_strides1=self.c_strides1, + ab_strides2=self.ab_strides2, + c_strides2=self.c_strides2, + a_q=None, + w1_q=None, + w2_q=None, + a_scale=None, + w1_scale=None, + w2_scale=None) + + +def bench_run(results: list[benchmark.Measurement], dtype: torch.dtype, + model: str, num_experts: int, topk: int, per_act_token: bool, per_out_ch: bool, mkn: tuple[int, int, int]): - label = "Quant Matmul" + label = "Quant Matmul" if dtype == torch.float8_e4m3fn else "Matmul" sub_label = ( "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " @@ -41,138 +182,90 @@ def bench_run(results: list[benchmark.Measurement], model: str, print(f"Testing: {sub_label}") (m, k, n) = mkn - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 - - _, a_scale = ops.scaled_fp8_quant(a, - use_per_token_if_dynamic=per_act_token) - - w1_q = torch.empty((num_experts, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((num_experts, k, n), - device="cuda", - dtype=torch.float8_e4m3fn) - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - w1_scale = torch.empty((num_experts, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((num_experts, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_experts, ), - 2 * n, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_experts, ), - n, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - - for expert in range(num_experts): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q_notransp = w1_q.clone() - w2_q_notransp = w2_q.clone() - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - score = torch.randn((m, num_experts), device="cuda", dtype=dtype) - - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a_scale: torch.Tensor, num_repeats: int): + tensors = MOETensors.make_moe_tensors(dtype, + m=m, + k=k, + n=n, + e=num_experts, + per_act_token=per_act_token, + per_out_channel=per_out_ch) + tensors = tensors.as_8bit_tensors() if is_8bit( + dtype) else tensors.as_16bit_tensors() + + score_dtype = torch.half if is_8bit(dtype) else dtype + score = torch.randn((m, num_experts), device="cuda", dtype=score_dtype) + topk_weights, topk_ids = fused_topk(tensors.a, + score, + topk, + renormalize=False) + + def run_triton_moe(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_repeats: int): + use_fp8_w8a8 = (tensors.a_scale is not None + and tensors.w1_scale is not None) for _ in range(num_repeats): - fused_experts(a, - w1, - w2, + fused_experts(tensors.a, + tensors.w1, + tensors.w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, - a2_scale=a_scale) - - def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - num_repeats: int): + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale, + a2_scale=tensors.a_scale) + + def run_cutlass_moe(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_repeats: int): for _ in range(num_repeats): - cutlass_moe(a, - w1, - w2, + cutlass_moe(tensors.a, + tensors.w1_t, + tensors.w2_t, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def run_cutlass_from_graph( - a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + tensors.ab_strides1, + tensors.c_strides1, + tensors.ab_strides2, + tensors.c_strides2, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale) + + def run_cutlass_from_graph(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe(a, - w1_q, - w2_q, + return cutlass_moe(tensors.a, + tensors.w1_t, + tensors.w2_t, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, a_scale: torch.Tensor): + tensors.ab_strides1, + tensors.c_strides1, + tensors.ab_strides2, + tensors.c_strides2, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale) + + def run_triton_from_graph(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor): + use_fp8_w8a8 = (tensors.a_scale is not None + and tensors.w1_scale is not None) with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return fused_experts(a, - w1, - w2, + return fused_experts(tensors.a, + tensors.w1, + tensors.w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, - a2_scale=a_scale) + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale, + a2_scale=tensors.a_scale) def replay_graph(graph, num_repeats): for _ in range(num_repeats): @@ -182,18 +275,14 @@ def replay_graph(graph, num_repeats): cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): - run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) + run_cutlass_from_graph(tensors, topk_weights, topk_ids) torch.cuda.synchronize() if not per_act_token and not per_out_ch: triton_stream = torch.cuda.Stream() triton_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(triton_graph, stream=triton_stream): - run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, - topk_weights, topk_ids, w1_scale, w2_scale, - a_scale) + run_triton_from_graph(tensors, topk_weights, topk_ids) torch.cuda.synchronize() else: triton_graph = [] @@ -204,27 +293,14 @@ def replay_graph(graph, num_repeats): globals = { # Baseline params - "w1": w1, - "w2": w2, "score": score, "topk": topk, - "w1_q_notransp": w1_q_notransp, - "w2_q_notransp": w2_q_notransp, - # Cutlass params - "a_scale": a_scale, - "w1_q": w1_q, - "w2_q": w2_q, - "w1_scale": w1_scale, - "w2_scale": w2_scale, - "ab_strides1": ab_strides1, - "c_strides1": c_strides1, - "ab_strides2": ab_strides2, - "c_strides2": c_strides2, + "tensors": tensors, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, # Gen params - "a": a, + "a": tensors.a, "topk_weights": topk_weights, "topk_ids": topk_ids, "num_runs": num_runs, @@ -236,13 +312,12 @@ def replay_graph(graph, num_repeats): if not per_act_token and not per_out_ch: # Warmup - run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, - w1_scale, w2_scale, a_scale, num_warmup) + run_triton_moe(tensors, topk_weights, topk_ids, num_warmup) results.append( benchmark.Timer( stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + "run_triton_moe(tensors, topk_weights, topk_ids, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -262,18 +337,16 @@ def replay_graph(graph, num_repeats): ).blocked_autorange(min_run_time=min_run_time)) # Warmup - run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, - num_warmup) + run_cutlass_moe(tensors, topk_weights, topk_ids, num_warmup) results.append( benchmark.Timer( stmt= - "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + "run_cutlass_moe(tensors, topk_weights, topk_ids, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="grouped_gemm_moe", + description="cutlass_moe", ).blocked_autorange(min_run_time=min_run_time)) # Warmup @@ -285,7 +358,7 @@ def replay_graph(graph, num_repeats): globals=globals, label=label, sub_label=sub_label, - description="grouped_gemm_moe_cuda_graphs", + description="cutlass_moe_cuda_graphs", ).blocked_autorange(min_run_time=min_run_time)) @@ -296,6 +369,9 @@ def main(args): results: list[benchmark.Measurement] = [] + quant_schemes = product(PER_ACT_TOKEN_OPTS, PER_OUT_CH_OPTS) if is_8bit( + args.dtype) else [(False, False)] + for model in args.models: for tp in args.tp_sizes: for layer in WEIGHT_SHAPES_MOE[model]: @@ -310,20 +386,34 @@ def main(args): if len(args.limit_n) > 0 and size_n not in args.limit_n: continue - for per_act_token in PER_ACT_TOKEN_OPTS: - for per_out_ch in PER_OUT_CH_OPTS: - for size_m in DEFAULT_BATCH_SIZES: - mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + for per_act_token, per_out_ch in quant_schemes: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run(results, args.dtype, model, num_experts, + topk, per_act_token, per_out_ch, mkn) compare = benchmark.Compare(results) compare.print() if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + + def str_to_dtype(dtype_str: str) -> torch.dtype: + if dtype_str == "fp8": + return torch.float8_e4m3fn + if dtype_str == "fp16": + return torch.float16 + if dtype_str == "bf16": + return torch.bfloat16 + raise ValueError(f"Unrecognized dtype str {dtype_str}") + + parser = FlexibleArgumentParser(description=""" + Benchmark Cutlass MOE layer against Triton MOE Layer. \n + Example : python3 benchmarks/kernels/benchmark_cutlass_moe.py + --dtype bf16 + --models nm-testing/Mixtral-8x7B-Instruct-v0.1 + --batch-sizes 1 16 32 + """) parser.add_argument( "--models", nargs="+", @@ -331,6 +421,10 @@ def main(args): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) + parser.add_argument("--dtype", + type=str_to_dtype, + required=True, + help="Please choose one from fp8, fp16 or bf16") parser.add_argument("--tp-sizes", nargs="+", type=int, From d5995b219df96cc8690a9a36fd8abcae7b84808b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 29 Apr 2025 14:20:53 +0000 Subject: [PATCH 22/22] Format, cleanup Signed-off-by: ElizaWszola --- tests/kernels/moe/test_cutlass_moe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 4b1f47f49e8b..03b323cd02a0 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -43,7 +43,7 @@ class MOETensors: @staticmethod def make_moe_tensors(m: int, k: int, n: int, e: int, - dtype: torch.dtype, ep_size: int = 1) -> "MOETensors": + dtype: torch.dtype) -> "MOETensors": a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -77,12 +77,11 @@ class MOETensors8Bit(MOETensors): @staticmethod def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, per_act_token: bool, - per_out_channel: bool, - ep_size: int = 1) -> "MOETensors8Bit": + per_out_channel: bool) -> "MOETensors8Bit": dtype = torch.half q_dtype = torch.float8_e4m3fn - moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype, ep_size) + moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype) # a -> a_q, w1 -> w1_q, w2 -> w2_q n_b_scales = 2 * n if per_out_channel else 1 @@ -347,6 +346,7 @@ def test_cutlass_moe_8_bit_cuda_graph( atol=9e-2, rtol=1e-2) + @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @@ -467,7 +467,7 @@ def test_cutlass_moe_8_bit_EP( pipeline_parallel_size=1))): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_channel, ep_size=ep_size) + per_out_channel) score = torch.randn((m, e), device="cuda", dtype=torch.half) topk_weights, topk_ids = fused_topk(mt.a, @@ -517,7 +517,7 @@ def test_cutlass_moe_16_bit_EP( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype, ep_size=ep_size) + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=torch.half) topk_weights, topk_ids = fused_topk(mt.a,