diff --git a/CMakeLists.txt b/CMakeLists.txt index 3314f05fd2a0..fc52d332dca2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -287,6 +287,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/cutlass_moe/moe_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" @@ -491,21 +492,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/cutlass_moe/moe_mm_c3x.cu" + "csrc/cutlass_moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + message(STATUS "Building moe_mm_c3x for archs: ${SCALED_MM_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + message(STATUS "Not building moe_mm_c3x kernels as CUDA Compiler version is " "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " "if you intend on running FP8 quantized MoE models on Hopper.") else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + message(STATUS "Not building moe_mm_c3x as no compatible archs found " "in CUDA target architectures") endif() endif() diff --git a/benchmarks/kernels/benchmark_cutlass_moe.py b/benchmarks/kernels/benchmark_cutlass_moe.py new file mode 100644 index 000000000000..23e5387abb08 --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_moe.py @@ -0,0 +1,446 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from itertools import product +from typing import Optional + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES_MOE + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = [ + "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" +] +DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False, True] +PER_OUT_CH_OPTS = [False, True] + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def is_16bit(dtype: torch.dtype) -> bool: + return dtype.itemsize == 2 + + +def is_8bit(dtype: torch.dtype) -> bool: + return dtype.itemsize == 1 + + +@dataclasses.dataclass +class MOETensors: + a: torch.Tensor + w1: torch.Tensor + w2: torch.Tensor + w1_t: torch.Tensor # Transposed w1 for cutlass_moe + w2_t: torch.Tensor # Transposed w2 for cutlass_moe + ab_strides1: torch.Tensor + c_strides1: torch.Tensor + ab_strides2: torch.Tensor + c_strides2: torch.Tensor + # quantized + a_q: Optional[torch.Tensor] = None # a -> a_q + w1_q: Optional[torch.Tensor] = None # w1 -> w1_q + w2_q: Optional[torch.Tensor] = None # w2 -> w2_q + a_scale: Optional[torch.Tensor] = None + w1_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + + @staticmethod + def make_moe_tensors(in_dtype: torch.dtype, m: int, k: int, n: int, e: int, + per_act_token: bool, + per_out_channel: bool) -> "MOETensors": + + # For fp8, use torch.half to create 16bit tensors that can be later + # quantized into fp8. + dtype = in_dtype if is_16bit(in_dtype) else torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + if is_16bit(in_dtype): + assert not (per_act_token or per_out_channel) + return MOETensors(a=a, + w1=w1, + w2=w2, + w1_t=w1.transpose(1, 2), + w2_t=w2.transpose(1, 2), + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2) + + assert in_dtype == torch.float8_e4m3fn + q_dtype = torch.float8_e4m3fn + # a -> a_q, w1 -> w1_q, w2 -> w2_q + n_b_scales = 2 * n if per_out_channel else 1 + k_b_scales = k if per_out_channel else 1 + # Get the right scale for tests. + _, a_scale = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale, + use_per_token_if_dynamic=per_act_token) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_channel) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_channel) + + return MOETensors(a=a, + w1=w1, + w2=w2, + w1_t=w1.transpose(1, 2), + w2_t=w2.transpose(1, 2), + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale) + + def as_8bit_tensors(self) -> "MOETensors": + assert all([ + x is not None for x in + [self.w1_q, self.w2_q, self.w1_scale, self.w2_scale, self.a_scale] + ]) + return MOETensors(a=self.a, + w1=self.w1_q, + w2=self.w2_q, + w1_t=self.w1_q.transpose(1, 2), + w2_t=self.w2_q.transpose(1, 2), + ab_strides1=self.ab_strides1, + c_strides1=self.c_strides1, + ab_strides2=self.ab_strides2, + c_strides2=self.c_strides2, + a_q=None, + w1_q=None, + w2_q=None, + a_scale=self.a_scale, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale) + + def as_16bit_tensors(self) -> "MOETensors": + return MOETensors(a=self.a, + w1=self.w1, + w2=self.w2, + w1_t=self.w1.transpose(1, 2), + w2_t=self.w2.transpose(1, 2), + ab_strides1=self.ab_strides1, + c_strides1=self.c_strides1, + ab_strides2=self.ab_strides2, + c_strides2=self.c_strides2, + a_q=None, + w1_q=None, + w2_q=None, + a_scale=None, + w1_scale=None, + w2_scale=None) + + +def bench_run(results: list[benchmark.Measurement], dtype: torch.dtype, + model: str, num_experts: int, topk: int, per_act_token: bool, + per_out_ch: bool, mkn: tuple[int, int, int]): + label = "Quant Matmul" if dtype == torch.float8_e4m3fn else "Matmul" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, + mkn)) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + tensors = MOETensors.make_moe_tensors(dtype, + m=m, + k=k, + n=n, + e=num_experts, + per_act_token=per_act_token, + per_out_channel=per_out_ch) + tensors = tensors.as_8bit_tensors() if is_8bit( + dtype) else tensors.as_16bit_tensors() + + score_dtype = torch.half if is_8bit(dtype) else dtype + score = torch.randn((m, num_experts), device="cuda", dtype=score_dtype) + topk_weights, topk_ids = fused_topk(tensors.a, + score, + topk, + renormalize=False) + + def run_triton_moe(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_repeats: int): + use_fp8_w8a8 = (tensors.a_scale is not None + and tensors.w1_scale is not None) + for _ in range(num_repeats): + fused_experts(tensors.a, + tensors.w1, + tensors.w2, + topk_weights, + topk_ids, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale, + a2_scale=tensors.a_scale) + + def run_cutlass_moe(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_repeats: int): + for _ in range(num_repeats): + cutlass_moe(tensors.a, + tensors.w1_t, + tensors.w2_t, + topk_weights, + topk_ids, + tensors.ab_strides1, + tensors.c_strides1, + tensors.ab_strides2, + tensors.c_strides2, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale) + + def run_cutlass_from_graph(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe(tensors.a, + tensors.w1_t, + tensors.w2_t, + topk_weights, + topk_ids, + tensors.ab_strides1, + tensors.c_strides1, + tensors.ab_strides2, + tensors.c_strides2, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale) + + def run_triton_from_graph(tensors: MOETensors, topk_weights: torch.Tensor, + topk_ids: torch.Tensor): + use_fp8_w8a8 = (tensors.a_scale is not None + and tensors.w1_scale is not None) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return fused_experts(tensors.a, + tensors.w1, + tensors.w2, + topk_weights, + topk_ids, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=tensors.w1_scale, + w2_scale=tensors.w2_scale, + a1_scale=tensors.a_scale, + a2_scale=tensors.a_scale) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph(tensors, topk_weights, topk_ids) + torch.cuda.synchronize() + + if not per_act_token and not per_out_ch: + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph(tensors, topk_weights, topk_ids) + torch.cuda.synchronize() + else: + triton_graph = [] + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "score": score, + "topk": topk, + "tensors": tensors, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "a": tensors.a, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe": run_cutlass_moe, + "replay_graph": replay_graph, + } + + if not per_act_token and not per_out_ch: + # Warmup + run_triton_moe(tensors, topk_weights, topk_ids, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_triton_moe(tensors, topk_weights, topk_ids, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + run_cutlass_moe(tensors, topk_weights, topk_ids, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_cutlass_moe(tensors, topk_weights, topk_ids, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + quant_schemes = product(PER_ACT_TOKEN_OPTS, PER_OUT_CH_OPTS) if is_8bit( + args.dtype) else [(False, False)] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token, per_out_ch in quant_schemes: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run(results, args.dtype, model, num_experts, + topk, per_act_token, per_out_ch, mkn) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + + def str_to_dtype(dtype_str: str) -> torch.dtype: + if dtype_str == "fp8": + return torch.float8_e4m3fn + if dtype_str == "fp16": + return torch.float16 + if dtype_str == "bf16": + return torch.bfloat16 + raise ValueError(f"Unrecognized dtype str {dtype_str}") + + parser = FlexibleArgumentParser(description=""" + Benchmark Cutlass MOE layer against Triton MOE Layer. \n + Example : python3 benchmarks/kernels/benchmark_cutlass_moe.py + --dtype bf16 + --models nm-testing/Mixtral-8x7B-Instruct-v0.1 + --batch-sizes 1 16 32 + """) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--dtype", + type=str_to_dtype, + required=True, + help="Please choose one from fp8, fp16 or bf16") + parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", + nargs="+", + type=int, + default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py deleted file mode 100644 index bcdbf6c7551a..000000000000 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ /dev/null @@ -1,340 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.utils.benchmark as benchmark -from benchmark_shapes import WEIGHT_SHAPES_MOE - -from vllm import _custom_ops as ops -from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, - fused_experts, - fused_topk) -from vllm.utils import FlexibleArgumentParser - -DEFAULT_MODELS = [ - "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", - "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" -] -DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] -DEFAULT_TP_SIZES = [1] - -PER_ACT_TOKEN_OPTS = [False] -PER_OUT_CH_OPTS = [False] - - -def to_fp8(tensor: torch.Tensor): - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def bench_run(results: list[benchmark.Measurement], model: str, - num_experts: int, topk: int, per_act_token: bool, - per_out_ch: bool, mkn: tuple[int, int, int]): - label = "Quant Matmul" - - sub_label = ( - "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, - mkn)) - - print(f"Testing: {sub_label}") - - (m, k, n) = mkn - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 - - _, a_scale = ops.scaled_fp8_quant(a) - - w1_q = torch.empty((num_experts, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((num_experts, k, n), - device="cuda", - dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_experts, ), - 2 * n, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_experts, ), - n, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - - for expert in range(num_experts): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) - w1_q_notransp = w1_q.clone() - w2_q_notransp = w2_q.clone() - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - score = torch.randn((m, num_experts), device="cuda", dtype=dtype) - - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a_scale: torch.Tensor, num_repeats: int): - for _ in range(num_repeats): - fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - num_repeats: int): - for _ in range(num_repeats): - cutlass_moe_fp8(a, - w1, - w2, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - def run_cutlass_from_graph( - a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, a_scale: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def replay_graph(graph, num_repeats): - for _ in range(num_repeats): - graph.replay() - torch.cuda.synchronize() - - cutlass_stream = torch.cuda.Stream() - cutlass_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): - run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) - torch.cuda.synchronize() - - triton_stream = torch.cuda.Stream() - triton_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(triton_graph, stream=triton_stream): - run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, - topk_ids, w1_scale, w2_scale, a_scale) - torch.cuda.synchronize() - - min_run_time = 5 - num_warmup = 5 - num_runs = 25 - - globals = { - # Baseline params - "w1": w1, - "w2": w2, - "score": score, - "topk": topk, - "w1_q_notransp": w1_q_notransp, - "w2_q_notransp": w2_q_notransp, - # Cutlass params - "a_scale": a_scale, - "w1_q": w1_q, - "w2_q": w2_q, - "w1_scale": w1_scale, - "w2_scale": w2_scale, - "ab_strides1": ab_strides1, - "c_strides1": c_strides1, - "ab_strides2": ab_strides2, - "c_strides2": c_strides2, - # cuda graph params - "cutlass_graph": cutlass_graph, - "triton_graph": triton_graph, - # Gen params - "a": a, - "topk_weights": topk_weights, - "topk_ids": topk_ids, - "num_runs": num_runs, - # Kernels - "run_triton_moe": run_triton_moe, - "run_cutlass_moe": run_cutlass_moe, - "replay_graph": replay_graph, - } - - # Warmup - run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, - w1_scale, w2_scale, a_scale, num_warmup) - - results.append( - benchmark.Timer( - stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="triton_moe", - ).blocked_autorange(min_run_time=min_run_time)) - - # Warmup - replay_graph(triton_graph, num_warmup) - - results.append( - benchmark.Timer( - stmt="replay_graph(triton_graph, num_runs)", - globals=globals, - label=label, - sub_label=sub_label, - description="triton_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) - - # Warmup - run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, - num_warmup) - - results.append( - benchmark.Timer( - stmt= - "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="grouped_gemm_moe", - ).blocked_autorange(min_run_time=min_run_time)) - - # Warmup - replay_graph(cutlass_graph, num_warmup) - - results.append( - benchmark.Timer( - stmt="replay_graph(cutlass_graph, num_runs)", - globals=globals, - label=label, - sub_label=sub_label, - description="grouped_gemm_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) - - -def main(args): - print("Benchmarking models:") - for i, model in enumerate(args.models): - print(f"[{i}] {model}") - - results: list[benchmark.Measurement] = [] - - for model in args.models: - for tp in args.tp_sizes: - for layer in WEIGHT_SHAPES_MOE[model]: - num_experts = layer[0] - topk = layer[1] - size_k = layer[2] - size_n = layer[3] // tp - - if len(args.limit_k) > 0 and size_k not in args.limit_k: - continue - - if len(args.limit_n) > 0 and size_n not in args.limit_n: - continue - - for per_act_token in PER_ACT_TOKEN_OPTS: - for per_out_ch in PER_OUT_CH_OPTS: - for size_m in DEFAULT_BATCH_SIZES: - mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) - - compare = benchmark.Compare(results) - compare.print() - - -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") - parser.add_argument( - "--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES_MOE.keys(), - ) - parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - parser.add_argument("--limit-k", nargs="+", type=int, default=[]) - parser.add_argument("--limit-n", nargs="+", type=int, default=[]) - parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) - parser.add_argument("--limit-per-act-token", - nargs="+", - type=int, - default=[]) - parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) - - args = parser.parse_args() - main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 70190ba24d9d..466a2d2f5554 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -76,12 +76,18 @@ ], } +# yapf: disable WEIGHT_SHAPES_MOE = { "nm-testing/Mixtral-8x7B-Instruct-v0.1": [ - [8, 2, 4096, 28672], - [8, 2, 14336, 4096], + [8, 2, 4096, 14336], ], - "nm-testing/deepseekv2-lite": [ + "nm-testing/Mixtral-8x7B-Instruct-v0.1-TP2": [ + [8, 2, 4096, 14336 // 2], + ], + "nm-testing/Mixtral-8x7B-Instruct-v0.1-EP2": [ + [8 // 2, 2, 4096, 14336], + ], + "nm-testing/deepseekv2-lite-TP1": [ [64, 6, 2048, 1408], ], "ibm-granite/granite-3.0-1b-a400m": [ @@ -90,4 +96,62 @@ "ibm-granite/granite-3.0-3b-a800m": [ [40, 8, 1024, 1536], ], + "ai21labs/Jamba-v0.1" : [ + [16, 2, 4096, 14336] + ], + "ai21labs/Jamba-v0.1-TP2" : [ + [16, 2, 4096, 14336 // 2] + ], + "ai21labs/Jamba-v0.1-EP2" : [ + [16 // 2, 2, 4096, 14336] + ], + "deepseek-ai/DeepSeek-V2" : [ + [160, 6, 5120, 1536] + ], + "deepseek-ai/DeepSeek-V2-TP8" : [ + [160, 6, 5120, 1536 // 8] + ], + "deepseek-ai/DeepSeek-V2-EP8" : [ + [160 // 8, 6, 5120, 1536] + ], + "Qwen/Qwen1.5-MoE-A2.7B-Chat" : [ + [60, 4, 2048, 1408] + ], + "mistralai/Mixtral-8x22B-v0.1" : [ + [8, 2, 6144, 16384] + ], + "mistralai/Mixtral-8x22B-v0.1-TP8" : [ + [8, 2, 6144, 16384 // 8] + ], + "mistralai/Mixtral-8x22B-v0.1-EP8" : [ + [8 // 8, 2, 6144, 16384] + ], + "deepseek-ai/DeepSeek-R1" : [ + [256, 8, 7168, 18432] + ], + "deepseek-ai/DeepSeek-R1-TP8" : [ + [256, 8, 7168, 18432 // 8] + ], + "deepseek-ai/DeepSeek-R1-EP8" : [ + [256 // 8, 8, 7168, 18432] + ], + "meta-llama/Llama-4-Maverick-17B-128E-Instruct" : [ + [128, 1, 5120, 8192] + ], + "meta-llama/Llama-4-Maverick-17B-128E-Instruct-TP8" : [ + [128, 1, 5120, 8192 // 8] + ], + "meta-llama/Llama-4-Maverick-17B-128E-Instruct-EP8" : [ + [128 // 8, 1, 5120, 8192] + ], + "meta-llama/Llama-4-Scout-17B-16E" : [ + [16, 1, 5120, 8192] + ], + "meta-llama/Llama-4-Scout-17B-16E-TP4" : [ + [16, 1, 5120, 8192 // 4] + ], + "meta-llama/Llama-4-Scout-17B-16E-EP4" : [ + [16 // 4, 1, 5120, 8192] + ] } +# yapf: disable diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/cutlass_moe/moe_data.cu similarity index 95% rename from csrc/quantization/cutlass_w8a8/moe/moe_data.cu rename to csrc/cutlass_moe/moe_data.cu index 894727383a63..3b7d5cfef7e8 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/cutlass_moe/moe_data.cu @@ -7,7 +7,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; -__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, +__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, @@ -45,7 +45,7 @@ __global__ void compute_expert_offsets( } } -__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, +__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, @@ -66,7 +66,7 @@ __global__ void compute_arg_sorts(const int* __restrict__ topk_ids, // for "invalid" topk_ids. output_permutation[i] = num_tokens; } else if (expert_id == blk_expert_id) { - int start = atomicAdd(&atomic_buffer[expert_id], 1); + int start = atomicAdd(&atomic_buffer[blk_expert_id], 1); input_permutation[start] = i / topk; output_permutation[i] = start; } diff --git a/csrc/cutlass_moe/moe_mm_c3x.cu b/csrc/cutlass_moe/moe_mm_c3x.cu new file mode 100644 index 000000000000..5adf877fd232 --- /dev/null +++ b/csrc/cutlass_moe/moe_mm_c3x.cu @@ -0,0 +1,248 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "moe_mm_c3x_8_bit.cuh" +#include "moe_mm_c3x_16_bit.cuh" + +using namespace cute; + +namespace { + +template typename Epilogue> +struct sm90_8_bit_config_default { + // M in (16, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_moe_gemm; +}; + +template typename Epilogue> +struct sm90_8_bit_config_M16 { + // M in [1, 16] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_moe_gemm; +}; + +template typename Epilogue> +struct sm90_8_bit_config_K8192 { + // K in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_moe_gemm; +}; + +template typename Epilogue> +struct sm90_8_bit_config_N8192 { + // N in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_moe_gemm; +}; + +template typename Epilogue> +struct sm90_16_bit_config_M512 { + // M in [1, 512] + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_moe_gemm; +}; + +template typename Epilogue> +struct sm90_16_bit_config_default { + // M in (1024, inf] + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_moe_gemm; +}; + +template +void run_cutlass_moe_mm_sm90_8_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmN8192 = typename sm90_8_bit_config_N8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmK8192 = typename sm90_8_bit_config_K8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM16 = typename sm90_8_bit_config_M16< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_8_bit_config_default< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + uint32_t const k = a_tensors.size(1); + + if (n >= 8192) { + cutlass_moe_gemm_caller_8_bit( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else if (k >= 8192) { + cutlass_moe_gemm_caller_8_bit( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else if (m <= 16) { + cutlass_moe_gemm_caller_8_bit( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else { + cutlass_moe_gemm_caller_8_bit( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } +} + +template +void run_cutlass_moe_mm_sm90_16_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + using Cutlass3xGemmM512 = typename sm90_16_bit_config_M512< + InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; + using Cutlass3xGemmDefault = typename sm90_16_bit_config_default< + InType, OutType, vllm::c3x::TrivialEpilogue>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + uint32_t const k = a_tensors.size(1); + + if (m <= 512) { + cutlass_moe_gemm_caller_16_bit( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } else { + cutlass_moe_gemm_caller_16_bit( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } +} + +void dispatch_moe_mm_sm90_8_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm90_8_bit( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } else { + run_cutlass_moe_mm_sm90_8_bit( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); + } +} + +void dispatch_moe_mm_sm90_16_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm90_16_bit( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } else { + run_cutlass_moe_mm_sm90_16_bit( + out_tensors, a_tensors, b_tensors, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides); + } +} + +} // namespace + +void cutlass_moe_mm_sm90_8_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + dispatch_moe_mm_sm90_8_bit(out_tensors, a_tensors, b_tensors, a_scales, + b_scales, expert_offsets, problem_sizes, a_strides, + b_strides, c_strides); +} + +void cutlass_moe_mm_sm90_16_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + dispatch_moe_mm_sm90_16_bit(out_tensors, a_tensors, b_tensors, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides); +} diff --git a/csrc/cutlass_moe/moe_mm_c3x_16_bit.cuh b/csrc/cutlass_moe/moe_mm_c3x_16_bit.cuh new file mode 100644 index 000000000000..6b0cdd6ead67 --- /dev/null +++ b/csrc/cutlass_moe/moe_mm_c3x_16_bit.cuh @@ -0,0 +1,82 @@ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/bfloat16.h" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/common.hpp" +#include "moe_mm_c3x_common.cuh" + +using namespace cute; + +namespace { + +template +void cutlass_moe_gemm_caller_16_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int num_experts = static_cast(expert_offsets.size(0)); + int k_size = a_tensors.size(1); + int n_size = out_tensors.size(1); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + + run_get_moe_gemm_starts_16_bit(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_tensors, b_tensors, out_tensors); + + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename GemmKernel::InternalStrideC; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast( + problem_sizes.data_ptr()); + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; + + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr())}; + + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args(), nullptr, + static_cast(c_strides.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides.data_ptr())}; + + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, + epilogue_args}; + + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +} // namespace diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/cutlass_moe/moe_mm_c3x_8_bit.cuh similarity index 58% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh rename to csrc/cutlass_moe/moe_mm_c3x_8_bit.cuh index db827b7c5e18..99466d443d40 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/cutlass_moe/moe_mm_c3x_8_bit.cuh @@ -8,70 +8,14 @@ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/common.hpp" -#include "get_group_starts.cuh" +#include "moe_mm_c3x_common.cuh" using namespace cute; namespace { -using ProblemShape = - cutlass::gemm::GroupProblemShape>; - -using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; -using OperatorClass = cutlass::arch::OpClassTensorOp; - -using LayoutA = cutlass::layout::RowMajor; -using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::RowMajor; - -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_group_gemm { - using ElementAB = ElementAB_; - using ElementC = void; - using ElementD = ElementC_; - using ElementAccumulator = float; - - using Epilogue = Epilogue_; - - using StrideC = - cute::remove_pointer_t, cute::Int<0>>>; - - static constexpr int AlignmentAB = - 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, - LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, - LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, - Stages, KernelSchedule>::CollectiveOp; - - using KernelType = enable_sm90_only>; - - struct GemmKernel : public KernelType {}; -}; - template -void cutlass_group_gemm_caller( +void cutlass_moe_gemm_caller_8_bit( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, @@ -98,9 +42,9 @@ void cutlass_group_gemm_caller( torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); - run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, - a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors, - out_tensors, a_scales, b_scales); + run_get_moe_gemm_starts_8_bit(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, a_tensors, + b_tensors, out_tensors, a_scales, b_scales); using GemmKernel = typename Gemm::GemmKernel; using StrideA = Stride, Int<0>>; diff --git a/csrc/cutlass_moe/moe_mm_c3x_common.cuh b/csrc/cutlass_moe/moe_mm_c3x_common.cuh new file mode 100644 index 000000000000..d5f68f2d71e9 --- /dev/null +++ b/csrc/cutlass_moe/moe_mm_c3x_common.cuh @@ -0,0 +1,180 @@ +#pragma once + +#include +#include +#include + +#include "core/scalar_type.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +// get tensors with pointers pointing to the start index of each group's data + +template +__global__ void get_moe_gemm_starts( + int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, ElementAB* a_base, ElementAB* b_base, + ElementC* out_base, ElementAccumulator* a_scales_base, + ElementAccumulator* b_scales_base, int64_t n, int64_t k, bool per_act_token, + bool per_out_ch) { + int expert_id = threadIdx.x; + + int64_t expert_offset = expert_offsets[expert_id]; + + a_offsets[expert_id] = a_base + expert_offset * k; + b_offsets[expert_id] = b_base + expert_id * k * n; + out_offsets[expert_id] = out_base + expert_offset * n; + if (a_scales_offsets != nullptr && a_scales_base != nullptr) + a_scales_offsets[expert_id] = + a_scales_base + (per_act_token ? expert_offset : 0); + if (b_scales_offsets != nullptr && b_scales_base != nullptr) + b_scales_offsets[expert_id] = + b_scales_base + (per_out_ch ? n * expert_id : expert_id); +} + +#define __CALL_GET_STARTS_KERNEL_8_BIT(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_moe_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), out_tensors.size(1), \ + a_tensors.size(1), per_act_token, per_out_ch); \ + } + +#define __CALL_GET_STARTS_KERNEL_16_BIT(TENSOR_ABC_TYPE, ABC_TYPE) \ + else if (out_tensors.dtype() == TENSOR_ABC_TYPE) { \ + get_moe_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), nullptr, nullptr, \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), nullptr, nullptr, \ + out_tensors.size(1), a_tensors.size(1), false, false); \ + } + +namespace { + +void run_get_moe_gemm_starts_8_bit( + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int num_experts = static_cast(expert_offsets.size(0)); + bool per_act_token = a_scales.numel() != 1; + bool per_out_ch = b_scales.numel() != num_experts; + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL_8_BIT(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL_8_BIT(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +void run_get_moe_gemm_starts_16_bit(torch::Tensor const& expert_offsets, + torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, + torch::Tensor& out_ptrs, + torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + torch::Tensor& out_tensors) { + TORCH_CHECK(a_tensors.dtype() == torch::kBFloat16 || + a_tensors.dtype() == torch::kFloat16); + TORCH_CHECK(a_tensors.dtype() == b_tensors.dtype()); + TORCH_CHECK(a_tensors.dtype() == out_tensors.dtype()); + + int num_experts = (int)expert_offsets.size(0); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL_16_BIT(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL_16_BIT(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid i/o type (must be float16 or bfloat16)"); + } +} + +// common structs and types used by moe gemm + +using ProblemShape = + cutlass::gemm::GroupProblemShape>; + +using ElementAccumulator = float; +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_moe_gemm { + using ElementAB = ElementAB_; + using ElementC = void; + using ElementD = ElementC_; + using ElementAccumulator = float; + + using Epilogue = Epilogue_; + + using StrideC = + cute::remove_pointer_t, cute::Int<0>>>; + + static constexpr int AlignmentAB = + 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, + LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, + Stages, KernelSchedule>::CollectiveOp; + + using KernelType = enable_sm90_only>; + + struct GemmKernel : public KernelType {}; +}; + +} // namespace \ No newline at end of file diff --git a/csrc/cutlass_moe/moe_mm_entry.cu b/csrc/cutlass_moe/moe_mm_entry.cu new file mode 100644 index 000000000000..b2a324b8c401 --- /dev/null +++ b/csrc/cutlass_moe/moe_mm_entry.cu @@ -0,0 +1,83 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90 + +void cutlass_moe_mm_sm90_8_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + +void cutlass_moe_mm_sm90_16_bit( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides); + +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); + +#endif + +void cutlass_moe_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + std::optional const& a_scales, + std::optional const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& c_strides) { + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + if (a_tensors.dtype() == torch::kBFloat16 || + a_tensors.dtype() == torch::kFloat16) { + TORCH_CHECK(!a_scales.has_value()); + TORCH_CHECK(!b_scales.has_value()); + cutlass_moe_mm_sm90_16_bit(out_tensors, a_tensors, b_tensors, + expert_offsets, problem_sizes, a_strides, + b_strides, c_strides); + } else { + TORCH_CHECK(a_scales.has_value()); + TORCH_CHECK(b_scales.has_value()); + cutlass_moe_mm_sm90_8_bit( + out_tensors, a_tensors, b_tensors, a_scales.value(), b_scales.value(), + expert_offsets, problem_sizes, a_strides, b_strides, c_strides); + } + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, + ". Required capability: 90"); +} + +void get_cutlass_moe_mm_data( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k) { + // This function currently gets compiled only if we have a valid cutlass moe + // mm to run it for. + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, input_permutation, + output_permutation, num_experts, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " + "CUDA device capability: ", + version_num, ". Required capability: 90"); +} diff --git a/csrc/ops.h b/csrc/ops.h index fe120af5d568..9b86b99d9def 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -194,12 +194,15 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_moe_mm( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_moe_mm(torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, + std::optional const& a_scales, + std::optional const& b_scales, + torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, + torch::Tensor const& a_strides, + torch::Tensor const& b_strides, + torch::Tensor const& c_strides); void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh deleted file mode 100644 index 6c6e89790847..000000000000 --- a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "core/scalar_type.hpp" -#include "cutlass/bfloat16.h" -#include "cutlass/float8.h" - -template -__global__ void get_group_gemm_starts( - int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, - ElementC** out_offsets, ElementAccumulator** a_scales_offsets, - ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, - ElementAB* b_base_as_int, ElementC* out_base_as_int, - ElementAccumulator* a_scales_base_as_int, - ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k, - bool per_act_token, bool per_out_ch) { - int expert_id = threadIdx.x; - - int64_t expert_offset = expert_offsets[expert_id]; - - a_offsets[expert_id] = a_base_as_int + expert_offset * k; - b_offsets[expert_id] = b_base_as_int + expert_id * k * n; - out_offsets[expert_id] = out_base_as_int + expert_offset * n; - a_scales_offsets[expert_id] = - a_scales_base_as_int + (per_act_token ? expert_offset : 0); - b_scales_offsets[expert_id] = - b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id); -} - -#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ - get_group_gemm_starts \ - <<<1, num_experts, 0, stream>>>( \ - static_cast(expert_offsets.data_ptr()), \ - static_cast(a_ptrs.data_ptr()), \ - static_cast(b_ptrs.data_ptr()), \ - static_cast(out_ptrs.data_ptr()), \ - static_cast(a_scales_ptrs.data_ptr()), \ - static_cast(b_scales_ptrs.data_ptr()), \ - static_cast(a_tensors.data_ptr()), \ - static_cast(b_tensors.data_ptr()), \ - static_cast(out_tensors.data_ptr()), \ - static_cast(a_scales.data_ptr()), \ - static_cast(b_scales.data_ptr()), out_tensors.size(1), \ - a_tensors.size(1), per_act_token, per_out_ch); \ - } - -namespace { - -void run_get_group_gemm_starts( - torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, - torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, - torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, - torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, - torch::Tensor& out_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int num_experts = static_cast(expert_offsets.size(0)); - bool per_act_token = a_scales.numel() != 1; - bool per_out_ch = b_scales.numel() != num_experts; - - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - - if (false) { - } - __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) - __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) - else { - TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); - } -} - -} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu deleted file mode 100644 index 2b8bc3fb0b26..000000000000 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ /dev/null @@ -1,160 +0,0 @@ -#include - -#include -#include - -#include "cutlass/cutlass.h" -#include "grouped_mm_c3x.cuh" - -using namespace cute; - -namespace { - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (16, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - - using Cutlass3xGemm = - cutlass_3x_group_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M16 { - // M in [1, 16] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - - using Cutlass3xGemm = - cutlass_3x_group_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_K8192 { - // K in [8192, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - - using Cutlass3xGemm = - cutlass_3x_group_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_N8192 { - // N in [8192, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = - cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; - - using Cutlass3xGemm = - cutlass_3x_group_gemm; -}; - -template -void run_cutlass_moe_mm_sm90( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); - TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); - TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); - - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, - "A tensors must be of type float8_e4m3fn."); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, - "B tensors must be of type float8_e4m3fn."); - - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - - using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< - InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< - InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< - InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmDefault = typename sm90_fp8_config_default< - InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - - uint32_t const m = a_tensors.size(0); - uint32_t const n = out_tensors.size(1); - uint32_t const k = a_tensors.size(1); - - if (n >= 8192) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); - } else if (k >= 8192) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); - } else if (m <= 16) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); - } else { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); - } -} - -void dispatch_moe_mm_sm90( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - if (out_tensors.dtype() == torch::kBFloat16) { - run_cutlass_moe_mm_sm90( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); - } else { - run_cutlass_moe_mm_sm90( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); - } -} - -} // namespace - -void cutlass_moe_mm_sm90( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); -} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cb..925a6207466a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -30,19 +30,6 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); -void cutlass_moe_mm_sm90( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); - -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); - #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -195,46 +182,6 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, version_num); } -void cutlass_moe_mm( - torch::Tensor& out_tensors, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, - torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { - int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); - return; -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90"); -} - -void get_cutlass_moe_mm_data( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k) { - // This function currently gets compiled only if we have a valid cutlass moe - // mm to run it for. - int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, input_permutation, - output_permutation, num_experts, n, k); - return; -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " - "CUDA device capability: ", - version_num, ". Required capability: 90"); -} - void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c9a120976b1c..ad92b213b24b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -398,12 +398,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool"); ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported); - // CUTLASS w8a8 grouped GEMM + // CUTLASS MoE GEMM ops.def( "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " - " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " - " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor c_strides) -> ()", + " Tensor? a_scales, Tensor? b_scales, " + " Tensor expert_offsets, Tensor problem_sizes, " + " Tensor a_strides, Tensor b_strides, Tensor c_strides" + ") -> ()", {stride_tag}); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a171..03b323cd02a0 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform @@ -140,7 +140,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, def slice_experts(): slice_params = [ - "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", + "w1", "w2", "ab_strides1", "ab_strides2", "c_strides1", "c_strides2", "w1_scale", "w2_scale" ] full_tensors = { @@ -169,7 +169,7 @@ def slice_experts(): out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) for kwargs in slice_experts(): - out_tensor = out_tensor + cutlass_moe_fp8(**kwargs) + out_tensor = out_tensor + cutlass_moe(**kwargs) return out_tensor @@ -187,8 +187,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, kwargs = { 'a': moe_tensors.a, - 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] - 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] + 'w1': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] + 'w2': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, 'topk_ids_': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, @@ -203,7 +203,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, num_experts = moe_tensors.w1.size(0) with_ep = num_local_experts is not None or num_local_experts == num_experts if not with_ep: - return cutlass_moe_fp8(**kwargs) + return cutlass_moe(**kwargs) assert num_local_experts is not None return run_with_expert_maps( @@ -212,6 +212,34 @@ def run_8_bit(moe_tensors: MOETensors8Bit, **kwargs) +def run_16_bit(moe_tensors: MOETensors, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_local_experts: Optional[int] = None) -> torch.Tensor: + + kwargs = { + "a": moe_tensors.a, + "w1": moe_tensors.w1.transpose(1, 2), + "w2": moe_tensors.w2.transpose(1, 2), + "topk_weights": topk_weights, + "topk_ids_": topk_ids, + "ab_strides1": moe_tensors.ab_strides1, + "c_strides1": moe_tensors.c_strides1, + "ab_strides2": moe_tensors.ab_strides2, + "c_strides2": moe_tensors.c_strides2 + } + + num_experts = moe_tensors.w1.size(0) + with_ep = num_local_experts is not None or num_local_experts == num_experts + if not with_ep: + return cutlass_moe(**kwargs) + + return run_with_expert_maps( + num_experts, + num_local_experts, # type: ignore[arg-type] + **kwargs) + + @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @@ -251,6 +279,10 @@ def test_cutlass_moe_8_bit_no_graph( cutlass_output = run_8_bit(mt, topk_weights, topk_ids) + #print(triton_output) + #print(cutlass_output) + #print("*") + torch.testing.assert_close(triton_output, cutlass_output, atol=5e-2, @@ -305,12 +337,108 @@ def test_cutlass_moe_8_bit_cuda_graph( graph.replay() torch.cuda.synchronize() + #print(triton_output) + #print(cutlass_output) + #print("*") + torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_16_bit_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + triton_output = fused_experts(mt.a, mt.w1, mt.w2, topk_weights, + topk_ids) + cutlass_output = run_16_bit(mt, topk_weights, topk_ids) + + # print(triton_output) + # print(cutlass_output) + # print("*") + + torch.testing.assert_close(triton_output.view(cutlass_output.shape), + cutlass_output, + atol=2e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_16_bit_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + triton_output = fused_experts(mt.a, mt.w1, mt.w2, topk_weights, + topk_ids) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run_16_bit(mt, topk_weights, topk_ids) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + # print(triton_output) + # print(cutlass_output) + # print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=2e-2, + rtol=1e-2) + + @pytest.mark.parametrize("m", [64]) @pytest.mark.parametrize("n", [1024]) @pytest.mark.parametrize("k", [4096]) @@ -362,3 +490,53 @@ def test_cutlass_moe_8_bit_EP( cutlass_output, atol=5e-2, rtol=1e-2) + + +@pytest.mark.parametrize("m", [64]) +@pytest.mark.parametrize("n", [1024]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [16]) +@pytest.mark.parametrize("topk", [1, 8]) +@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_16_bit_EP( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + mt = MOETensors.make_moe_tensors(m, k, n, e, dtype=dtype) + + score = torch.randn((m, e), device="cuda", dtype=torch.half) + topk_weights, topk_ids = fused_topk(mt.a, + score, + topk, + renormalize=False) + + # Note that we are using the dequantized versions of the tensors. + # Using a, w1 and w2 directly results in minor output differences. + triton_output = fused_experts(mt.a, mt.w1, mt.w2, topk_weights, + topk_ids) + + assert e % ep_size == 0, "Cannot distribute experts evenly" + cutlass_output = run_16_bit(mt, + topk_weights, + topk_ids, + num_local_experts=e // ep_size) + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8084d9bf2c2d..3234e11ec1ff 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -516,8 +516,8 @@ def test_cutlass_support_opcheck(): (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): +def test_cutlass_fp8_moe_gemm(num_experts: int, per_act_token: bool, + per_out_ch: bool, use_bias: bool): # Device and dtype setup device = "cuda" @@ -639,3 +639,89 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, print(c) print("*") torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4) + + +@pytest.mark.parametrize("num_experts", [8, 64]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_fp16_moe_gemm(num_experts: int, dtype: torch.dtype): + + # Device and dtype setup + device = "cuda" + + # Create separate A, B, C tensors for each group + a_tensors = [] + b_tensors = [] + baseline_tensors = [] + + expert_offsets = torch.zeros((num_experts + 1), + device=device, + dtype=torch.int32) + + problem_sizes = torch.zeros((num_experts, 3), + device=device, + dtype=torch.int32) + + alignment = 16 + # For variation, each group has dimensions + n_g = alignment * random.randint(1, 64) + k_g = alignment * random.randint(1, 64) + for g in range(num_experts): + m_g = alignment * random.randint(1, 64) + + expert_offsets[g + 1] = expert_offsets[g] + m_g + problem_sizes[g][0] = m_g + problem_sizes[g][1] = n_g + problem_sizes[g][2] = k_g + + # Create group-specific A and B (FP16) and output (FP16/FP32) + a_g = torch.randn((m_g, k_g), device=device, dtype=dtype) + b_g = torch.randn((n_g, k_g), device=device, dtype=dtype).t() + a_tensors.append(a_g) + b_tensors.append(b_g) + + # Compute baseline result for this group + baseline_g = a_g.matmul(b_g) + baseline_tensors.append(baseline_g) + + a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), + device=device, + dtype=dtype) + b_tensors_stacked = torch.empty((num_experts, n_g, k_g), + device=device, + dtype=dtype) + + for g in range(num_experts): + a_tensors_stacked[expert_offsets[g]:expert_offsets[g + + 1]] = a_tensors[g] + b_tensors_stacked[g] = b_tensors[g].t() + b_tensors_stacked = b_tensors_stacked.transpose(1, 2) + + out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), + device=device, + dtype=dtype) + + ab_strides = torch.full((num_experts, ), + a_tensors_stacked.stride(0), + device=device, + dtype=torch.int64) + c_strides = torch.full((num_experts, ), + out_tensors_stacked.stride(0), + device=device, + dtype=torch.int64) + + ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, + b_tensors_stacked, None, None, expert_offsets[:-1], + problem_sizes, ab_strides, ab_strides, c_strides) + + # Validate each group's result against the baseline + for g in range(num_experts): + baseline = baseline_tensors[g] + c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + print(baseline) + print(c) + print("*") + torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-3) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4c577c1c47e7..16bab0a2de2c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -726,10 +726,11 @@ def get_cutlass_moe_mm_data( def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor, expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, a_strides: torch.Tensor, - b_strides: torch.Tensor, c_strides: torch.Tensor): + b_tensors: torch.Tensor, a_scales: Optional[torch.Tensor], + b_scales: Optional[torch.Tensor], + expert_offsets: torch.Tensor, problem_sizes: torch.Tensor, + a_strides: torch.Tensor, b_strides: torch.Tensor, + c_strides: torch.Tensor): """ A single grouped matrix multiplication used in CUTLASS-based fused MoE. The function executes fp8-quantized OUT = AB matrix multiplication. diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384f..1159a16c1b52 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -35,8 +35,7 @@ def get_config() -> Optional[Dict[str, Any]]: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) @@ -47,5 +46,5 @@ def get_config() -> Optional[Dict[str, Any]]: "fused_experts", "get_config_file_name", "grouped_topk", - "cutlass_moe_fp8", + "cutlass_moe", ] diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f834857..c9de881ffbc6 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -8,21 +8,20 @@ #TODO make the grouped gemm kernel consistent with scaled gemm kernel -def cutlass_moe_fp8( +def cutlass_moe( a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids_: torch.Tensor, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.half, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: @@ -35,74 +34,94 @@ def cutlass_moe_fp8( Parameters: - a (torch.Tensor): The input tensor to the MoE layer. Shape: [M, K] - - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + - w1 (torch.Tensor): The first set of expert weights. Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + - w2 (torch.Tensor): The second set of expert weights. Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping. - ab_strides1 (torch.Tensor): The input and weights strides of the first grouped gemm. - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - ab_strides2 (torch.Tensor): The input and weights strides of the second grouped gemm. - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - w1_scale (Optional[torch.Tensor]): The optional fp32 scale + to dequantize w1. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (Optional[torch.Tensor]): The optional fp32 scale + to dequantize w2. + Shape: [num_experts] or [num_experts, K] - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - every Rank is responsible for a subset of experts. expert_map is a - mapping from global expert-id to local expert-id. When expert_map[i] - is -1, it means that this Rank is not responsible for global - expert-id i. + every Rank is responsible for some experts. expert_map is a mapping + from global expert-id to local expert-id. When expert_map[i] is -1, + it means that this Rank is not responsible for global expert-id i. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. Returns: - - torch.Tensor: The fp16 output tensor after applying the MoE layer. + + - torch.Tensor: The output tensor after applying the MoE layer. """ assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ - 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - num_experts = w1_q.size(0) + assert a.shape[1] == w1.shape[1], "Hidden size mismatch w1" + assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert ab_strides1.shape[0] == w1.shape[0], \ + "AB Strides 1 expert number mismatch" + assert c_strides1.shape[0] == w1.shape[0], \ + "C Strides 1 expert number mismatch" + assert ab_strides2.shape[0] == w2.shape[0], \ + "AB Strides 2 expert number mismatch" + assert c_strides2.shape[0] == w2.shape[0], \ + "C Strides 2 expert number mismatch" + + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert w1.dtype in [torch.float8_e4m3fn, torch.half,torch.bfloat16], \ + "Invalid weight type" + assert w1.dtype == w2.dtype, "Weights type mismatch" + + if w1.dtype in [torch.half, torch.bfloat16]: + assert w1.dtype == a.dtype, \ + "Unquantized input and weights type mismatch" + assert w1_scale is None and w2_scale is None \ + and a1_scale is None and a2_scale is None, \ + "Received scales for unquantized input type" + elif w1.dtype == torch.float8_e4m3fn: + assert w1_scale is not None and w2_scale is not None, \ + "Missing scales for quantized input type" + + if w1_scale is not None: + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 \ + or w1_scale.shape[1] == w1.shape[2], "W1 scale shape mismatch" + assert w1.shape[0] == w1_scale.shape[0], \ + "w1 scales expert number mismatch" + if w2_scale is not None: + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 \ + or w2_scale.shape[1] == w2.shape[2], "W2 scale shape mismatch" + assert w2.shape[0] == w2_scale.shape[0], \ + "w2 scales expert number mismatch" + if a1_scale is not None: + assert a1_scale.dim() == 0 or a1_scale.shape[0] == 1 \ + or a1_scale.shape[0] == a.shape[0], "Input scale shape mismatch" + if a2_scale is not None: + assert a1_scale is None or a2_scale.shape == a1_scale.shape, \ + "Intermediate scale shape mismatch" + + is_quantized = w1.dtype == torch.float8_e4m3fn + + device = a.device + num_experts = w1.size(0) m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) + k = w1.size(1) + n = w2.size(1) + out_dtype = a.dtype local_topk_ids = topk_ids_ if expert_map is not None: @@ -112,17 +131,17 @@ def cutlass_moe_fp8( topk = local_topk_ids.size(1) - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) if apply_router_weight_on_input: assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" # TODO: this only works for topK=1, will need to update for topK>1 a = a * topk_weights.to(out_dtype) - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device + if is_quantized: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + a, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) expert_offsets = torch.empty((num_experts + 1), dtype=torch.int32, @@ -154,23 +173,30 @@ def cutlass_moe_fp8( problem_sizes2, a_map, c_map, num_experts, n, k) - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + if is_quantized: + rep_a = a.view(dtype=torch.uint8)[a_map].view(dtype=a.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + else: + rep_a = a[a_map] + rep_a1_scales = None c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, + ops.cutlass_moe_mm(c1, rep_a, w1, rep_a1_scales, w1_scale, expert_offsets[:-1], problem_sizes1, ab_strides1, ab_strides1, c_strides1) intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) torch.ops._C.silu_and_mul(intermediate, c1) - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) + if is_quantized: + rep_a = a.view(dtype=torch.uint8)[a_map].view(dtype=a.dtype) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale + intermediate, a2_scale = ops.scaled_fp8_quant( + intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, + ops.cutlass_moe_mm(c2, intermediate, w2, a2_scale, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) # Gather tokens diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede77..4936d6c527c0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1486,8 +1486,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 43fb311289fd..7d51260dd10f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -24,9 +24,11 @@ from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): + from .cutlass_moe import cutlass_moe from .fused_moe import fused_experts else: fused_experts = None # type: ignore + cutlass_moe = None # type: ignore if current_platform.is_tpu(): # the iterative moe implementation is used until the moe_pallas is fixed from .moe_torch_iterative import fused_moe as fused_moe_pallas @@ -72,8 +74,19 @@ def apply( raise NotImplementedError -@CustomOp.register("unquantized_fused_moe") -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): +class UnquantizedFusedMoEMethod(FusedMoEMethodBase): + """MoE method without quantization.""" + + @staticmethod + def get_moe_method(activation: str) -> "UnquantizedFusedMoEMethod": + if (UnquantizedFusedCutlassMoEMethod.check_supported(activation)): + return UnquantizedFusedCutlassMoEMethod() + else: + return UnquantizedFusedTritonMoEMethod() + + +@CustomOp.register("unquantized_fused_triton_moe") +class UnquantizedFusedTritonMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -329,6 +342,129 @@ def forward_tpu( forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda +@CustomOp.register("unquantized_fused_cutlass_moe") +class UnquantizedFusedCutlassMoEMethod(FusedMoEMethodBase, CustomOp): + """CUTLASS MoE method without quantization.""" + + @staticmethod + def check_supported(activation: str, error: bool = True) -> bool: + required_capability = 90 + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + arch_supported = (capability == required_capability + and not current_platform.is_cpu() + and not current_platform.is_rocm()) + functions_supported = activation == "silu" + if not arch_supported: + warn_msg = ( + "UnquantizedFusedCutlassMoEMethod is not supported" + "for the current device. Required " + f"GPU with capability: {required_capability}. Current " + f"capability: {capability}.") + logger.warning(warn_msg) + if not functions_supported: + logger.warning( + "UnquantizedFusedCutlassMoEMethod Method is not supported" + "for the required functionality. " + "Required activation: silu, expert map not supported.") + return arch_supported and functions_supported + else: + return False + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + device = layer.w13_weight.device + self.ab_strides1 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + self.c_strides1 = torch.full((num_experts, ), + 2 * intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.ab_strides2 = torch.full((num_experts, ), + intermediate_size_per_partition, + device=device, + dtype=torch.int64) + self.c_strides2 = torch.full((num_experts, ), + hidden_size, + device=device, + dtype=torch.int64) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # TODO half() + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu" + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return cutlass_moe( + x, + layer.w13_weight.transpose(1, 2), + layer.w2_weight.transpose(1, 2), + topk_weights, + topk_ids, + self.ab_strides1, + self.c_strides1, + self.ab_strides2, + self.c_strides2, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + def determine_expert_map( ep_size: int, ep_rank: int, global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: @@ -498,7 +634,7 @@ def __init__( # for heuristic purposes, so it must be initialized first. if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + UnquantizedFusedMoEMethod.get_moe_method(self.activation)) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 721e36af2b28..b79c79988d37 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -523,25 +523,24 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 + from vllm.model_executor.layers.fused_moe import cutlass_moe - return cutlass_moe_fp8( + return cutlass_moe( x, layer.w13_weight.transpose(1, 2), layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, topk_weights, topk_ids, self.ab_strides1, self.c_strides1, self.ab_strides2, self.c_strides2, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - out_dtype=x.dtype, - expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, )