diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index f1af2f3d1573..4e0fc3e8541a 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -116,7 +116,7 @@ def get_object_file_path(src: Path) -> Path: # Determine compute version compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) - if compute_version in ["90"]: + if compute_version in ["90", "100"]: compute_version += "a" cuda_cflags += [ "-gencode", @@ -488,3 +488,97 @@ def gen_sampling_module(target: Target, num_threads: int = 8): object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) modules = _load_flashinfer_modules(object_files) return modules + + +def gen_grouped_gemm_module( + dtype_a: str, + dtype_b: str, + dtype_out: str, + scale_granularity_m: int, + scale_granularity_n: int, + scale_granularity_k: int, + scale_major_mode: str, + mma_sm: int, + target: Target, + num_threads: int = 8, +) -> List[tvm.runtime.Module]: + """Generate a FlashInfer module for FP8 grouped GEMM. + + Parameters + ---------- + dtype_a : str + The data type of matrix A (e.g., "float8_e4m3fn"). + dtype_b : str + The data type of matrix B (e.g., "float8_e4m3fn"). + dtype_out : str + The data type of the output matrix (e.g., "bfloat16"). + scale_granularity_m : int + The scaling granularity in the M dimension. + scale_granularity_n : int + The scaling granularity in the N dimension. + scale_granularity_k : int + The scaling granularity in the K dimension. + scale_major_mode : str + The scale storage mode ("K" or "MN"). + mma_sm : int + The MMA scheduling mode (1 or 2). + target : Target + The target device to compile for. + num_threads : int + The number of threads to use for compilation. + + Returns + ------- + List[tvm.runtime.Module] + A list of compiled static library modules for FlashInfer FP8 grouped GEMM kernels. + + Note + _____ + when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), m_indptr: (batch_size, ) + requires all m in m_indptr to be multiple of 4 + """ + try: + from flashinfer.jit import ( # pylint: disable=import-outside-toplevel + gen_grouped_gemm_fp8_tvm_binding, + get_grouped_gemm_fp8_uri, + ) + except ImportError: + raise ImportError( + "FlashInfer is not installed. Please follow instructions " + "in https://docs.flashinfer.ai to install FlashInfer." + ) + try: + import torch # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") + + torch_dtype_a = getattr(torch, dtype_a) + torch_dtype_b = getattr(torch, dtype_b) + torch_dtype_out = getattr(torch, dtype_out) + + uri = get_grouped_gemm_fp8_uri( + dtype_a=torch_dtype_a, + dtype_b=torch_dtype_b, + dtype_out=torch_dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + mma_sm=mma_sm, + ) + + uri, source_paths = gen_grouped_gemm_fp8_tvm_binding( + uri=uri, + dtype_a=torch_dtype_a, + dtype_b=torch_dtype_b, + dtype_out=torch_dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + mma_sm=mma_sm, + ) + + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) + modules = _load_flashinfer_modules(object_files) + return modules diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py new file mode 100644 index 000000000000..8333e4b2d66b --- /dev/null +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -0,0 +1,496 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test for FlashInfer GroupedGemm TVM integration""" + +import math +import numpy as np +import pytest +import torch +import tvm +import tvm.testing +from tvm import relax +from tvm.contrib import utils +from tvm.relax.backend.cuda import flashinfer + +DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 +fp8_dtype = "float8_e4m3fn" + + +########################################### +################# Helpers ################# +########################################### +def has_flashinfer(): + """Check if FlashInfer is available""" + try: + from tvm.relax.backend.cuda import ( # pylint: disable=import-outside-toplevel + flashinfer, + ) + + return True + except ImportError: + return False + + +def has_cutlass(): + """Check if CUTLASS is available for SM90+ operations""" + if not tvm.get_global_func("device_api.cuda", True): + return False + try: + import pynvml # pylint: disable=import-outside-toplevel + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return major >= 9 # SM90+ + except: + return False + + +def calc_diff(x: np.ndarray, y: np.ndarray): + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): + from einops import rearrange, reduce, repeat + + """ + Quantizes a 2D or 3D tensor to FP8. + + Args: + x (torch.Tensor): The 2D or 3D input tensor. + scale_shape (tuple): The shape of the scale tensor. + tile_shape (tuple): The shape of the tiles. + scale_major_mode (str): The tiling order, "K" for row-major like, + or another value for column-major like. + + Returns: + tuple: A tuple containing the quantized FP8 tensor and the + calculated float32 scales. + """ + # 1. Assertions and Initial Setup + ndim = x.ndim + assert ndim == len(scale_shape) == len(tile_shape) + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32) + + # 2. Tiling and Scale Calculation + if ndim == 2: + s0, s1 = scale_shape + t0, t1 = tile_shape + if scale_major_mode == "K": + # Tile x and find the max absolute value in each tile + x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Broadcast scales back to the original tensor shape + scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1) + else: + # Handle column-major tiling + x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Permute scale axes before repeating to match layout + scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0") + scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1) + + elif ndim == 3: + s0, s1, s2 = scale_shape + t0, t1, t2 = tile_shape + if scale_major_mode == "K": + # Tile x and find the max absolute value in each tile + x_tiled = rearrange( + x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 + ) + abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Broadcast scales back to the original tensor shape + scales_repeated = repeat( + x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2 + ) + else: + # Handle layout where the last two axes are swapped + x_tiled = rearrange( + x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 + ) + abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + # Permute scale axes before repeating to match layout + scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") + scales_repeated = repeat( + scales_permuted, + "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)", + t0=t0, + t1=t1, + t2=t2, + ) + # 3. Final Quantization + # Divide the original tensor by the broadcasted scales + x_fp32 = x / (scales_repeated + 1e-8) + + # Convert the result to the target FP8 format + x_fp8 = x_fp32.to(torch.float8_e4m3fn) + + return x_fp8, x_scale + + +def dequantize_fp8(x, x_scale, scale_major_mode): + from einops import rearrange + + """ + Quantizes a 2D or 3D tensor to FP8. + + Args: + x (torch.Tensor): The 2D or 3D input tensor. + scale_shape (tuple): The shape of the scale tensor. + tile_shape (tuple): The shape of the tiles. + scale_major_mode (str): The tiling order, "K" for row-major like, + or another value for column-major like. + + Returns: + tuple: A tuple containing the quantized FP8 tensor and the + calculated float32 scales. + """ + # 1. Assertions and Initial Setup + ndim = x.ndim + assert ndim == len(x_scale.shape) + + # 2. Tiling and Scale Calculation + if ndim == 2: + if scale_major_mode == "K": + s0, s1 = x_scale.shape + else: + s1, s0 = x_scale.shape + x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + if scale_major_mode == "K": + x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1") + else: + x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1") + out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") + elif ndim == 3: + if scale_major_mode == "K": + s0, s1, s2 = x_scale.shape + else: + s0, s2, s1 = x_scale.shape + x = rearrange( + x.to(torch.float32), + "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2", + s0=s0, + s1=s1, + s2=s2, + ) + if scale_major_mode == "K": + x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") + else: + x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") + out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") + + return out + + +########################################### +########### Refernce generation ########### +########################################### +def compute_reference_grouped_gemm( + a_fp32: torch.Tensor, # (total_m, k) + b_fp32: torch.Tensor, # (batch_size, n, k) + m_indptr: torch.Tensor, + dtype_out: str, # (total_m, n) +): + """Compute reference result using PyTorch operations""" + """Compute reference result using original FP32 tensors""" + + total_m, k = a_fp32.shape + batch_size, n, k2 = b_fp32.shape + assert k == k2 + + # Perform grouped GEMM computation directly on original FP32 data + results = [] + + for i in range(batch_size): + start_m = m_indptr[i].item() + end_m = m_indptr[i + 1].item() + + # Extract group's portion of A + a_group = a_fp32[start_m:end_m, :] # [m_sizes[i], k] + b_group = b_fp32[i] + + # Multiply with shared B matrix + result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n] + results.append(result_group) + + result_fp32 = torch.cat(results, dim=0) + + # Convert to output dtype + if dtype_out == "bfloat16": + result = result_fp32.to(torch.bfloat16) + elif dtype_out == "float16": + result = result_fp32.to(torch.float16) + else: + result = result_fp32 + + return result + + +########################################### +########### Test data generation ########## +########################################### +def generate_test_data( + m_sizes: list, + batch_size: int, + n: int, + k: int, + dtype_a: str, + dtype_b: str, + dtype_out: str, + scale_granularity_m: int, + scale_granularity_n: int, + scale_granularity_k: int, + scale_major_mode: str, + device: tvm.runtime.Device, +): + """Generate test data for grouped GEMM operations""" + assert batch_size == len( + m_sizes + ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})" + + # print(f"Device object: {device}") + torch_device = torch.device(f"cuda:{device.index}") + + cum_m = [0] + list(np.cumsum(m_sizes)) + total_m = cum_m[-1] + + # Generate input matrices A and B (where we assert of form fp8) random data in fp32 first, then convert + assert dtype_a == "float8_e4m3fn" + a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32) + + assert dtype_b == "float8_e4m3fn" + b_fp32 = torch.randn(batch_size, n, k, device=torch_device, dtype=torch.float32) / math.sqrt(k) + + if scale_major_mode == "K": # K mode: + scale_a_shape = (total_m // scale_granularity_m, k // scale_granularity_k) + scale_b_shape = (batch_size, n // scale_granularity_n, k // scale_granularity_k) + + else: # MN mode + scale_a_shape = (k // scale_granularity_k, total_m // scale_granularity_m) + scale_b_shape = (batch_size, k // scale_granularity_k, n // scale_granularity_n) + + tile_a_shape = (scale_granularity_m, scale_granularity_k) + tile_b_shape = (1, scale_granularity_n, scale_granularity_k) + + # quantize A, B + a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape, scale_major_mode) + b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, scale_major_mode) + + if dtype_a == "float8_e4m3fn": + a_tvm = tvm.runtime.tensor( + a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device + ) + else: + a_tvm = tvm.runtime.from_dlpack(a_quantized) + + if dtype_b == "float8_e4m3fn": + b_tvm = tvm.runtime.tensor( + b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device + ) + else: + b_tvm = tvm.runtime.from_dlpack(b_quantized) + + scale_a_tvm = tvm.runtime.from_dlpack(scale_a) + scale_b_tvm = tvm.runtime.from_dlpack(scale_b) + + # Create m_indptr for grouped operation + m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32) + m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device) + + return { + "a": a_tvm, + "b": b_tvm, + "torch_a": a_fp32, + "torch_b": b_fp32, + "scale_a": scale_a_tvm, + "scale_b": scale_b_tvm, + "m_indptr": m_indptr_tvm, + "m_sizes": m_sizes, + "n": n, + "k": k, + "total_m": total_m, + "torch_scale_a": scale_a, + "torch_scale_b": scale_b, + "torch_m_indptr": m_indptr, + } + + +########################################### +############### Test driver ############### +########################################### +@pytest.mark.skipif(not has_flashinfer(), reason="FlashInfer not available") +@pytest.mark.skipif(not has_cutlass(), reason="CUTLASS SM90+ not available") +@pytest.mark.parametrize( + "dtype_a,dtype_b,dtype_out", + [ + ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"), + ("float8_e4m3fn", "float8_e4m3fn", "float16"), + ], +) +@pytest.mark.parametrize( + "scale_granularity_m,scale_granularity_n,scale_granularity_k", + [ + (1, 128, 128), # Row-wise A, block-wise B + ], +) +@pytest.mark.parametrize("scale_major_mode", ["K", "MN"]) +@pytest.mark.parametrize("mma_sm", [1, 2]) +@pytest.mark.parametrize( + "test_case", + [ + {"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k": 1024}, + {"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512}, + {"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768}, + {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768}, + ], +) +def test_grouped_gemm_correctness( + dtype_a, + dtype_b, + dtype_out, + scale_granularity_m, + scale_granularity_n, + scale_granularity_k, + scale_major_mode, + mma_sm, + test_case, +): + """Test correctness of GroupedGemm operations""" + device = tvm.cuda(0) + target = tvm.target.Target.from_device(device) + + def _load_module(name: str, static_modules): + """Helper function to load compiled modules.""" + assert len(static_modules) > 0 + if len(static_modules) == 1: + return static_modules[0] + static_mod = static_modules[0] + for mod in static_modules[1:]: + static_mod.import_module(mod) + temp = tvm.contrib.utils.tempdir() + mod_path = temp.relpath(f"{name}.so") + static_mod.export_library(mod_path) + return tvm.runtime.load_module(mod_path) + + # Generate the module + modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module( + dtype_a=dtype_a, + dtype_b=dtype_b, + dtype_out=dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + mma_sm=mma_sm, + target=target, + num_threads=4, + ) + + # Load the module + mod = _load_module("flashinfer_grouped_gemm", modules) + grouped_gemm_fn = mod["grouped_gemm_fp8_run"] + + # Generate test data + test_data = generate_test_data( + batch_size=test_case["batch_size"], + m_sizes=test_case["m_sizes"], + n=test_case["n"], + k=test_case["k"], + dtype_a=dtype_a, + dtype_b=dtype_b, + dtype_out=dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + device=device, + ) + + # Prepare output buffer + output_shape = (test_data["total_m"], test_data["n"]) + if dtype_out == "bfloat16": + output = tvm.runtime.empty(output_shape, dtype="bfloat16", device=device) + elif dtype_out == "float16": + output = tvm.runtime.empty(output_shape, dtype="float16", device=device) + else: + output = tvm.runtime.empty(output_shape, dtype="float32", device=device) + + # Create workspace buffers (required by the interface) + int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device) + float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device) + + grouped_gemm_fn( + int_workspace, # int_workspace_buffer + float_workspace, # float_workspace_buffer + test_data["a"], # A + test_data["b"], # B + test_data["scale_a"], # SFA + test_data["scale_b"], # SFB + output, # D + test_data["m_indptr"], # m_indptr + test_data["n"], # n (scalar) + test_data["k"], # k (scalar) + None, # cuda_stream (use default stream) + ) + + # Compute reference result + reference = compute_reference_grouped_gemm( + test_data["torch_a"], + test_data["torch_b"], + test_data["torch_m_indptr"], + dtype_out, + ) + + # Convert TVM output to PyTorch for comparison + output_torch = torch.as_tensor(output, device=test_data["torch_a"].device) + output_torch + + # Compare results with appropriate tolerance + if dtype_out == "bfloat16": + rtol, atol = 1e-2, 1e-2 + elif dtype_out == "float16": + rtol, atol = 1e-3, 1e-3 + else: + rtol, atol = 1e-4, 1e-4 + + # Check shapes match + assert ( + output_torch.shape == reference.shape + ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}" + + diff = calc_diff(output_torch.cpu().double().numpy(), reference.cpu().double().numpy()) + assert diff < 1e-3, f"diff too large {diff}" + + +if __name__ == "__main__": + tvm.testing.main()