From a05984ddf3338036402a9bc86696f46950ba0044 Mon Sep 17 00:00:00 2001 From: Anrui Liu Date: Wed, 27 Aug 2025 00:10:59 -0400 Subject: [PATCH 1/5] [FlashInfer] Add gen_grouped_gemm_fp8 tvm binding --- 3rdparty/libbacktrace | 1 + python/tvm/relax/backend/cuda/flashinfer.py | 96 +++- .../relax/test_group_gemm_flashinfer.py | 500 ++++++++++++++++++ 3 files changed, 596 insertions(+), 1 deletion(-) create mode 160000 3rdparty/libbacktrace create mode 100644 tests/python/relax/test_group_gemm_flashinfer.py diff --git a/3rdparty/libbacktrace b/3rdparty/libbacktrace new file mode 160000 index 000000000000..08f7c7e69f8e --- /dev/null +++ b/3rdparty/libbacktrace @@ -0,0 +1 @@ +Subproject commit 08f7c7e69f8ea61a0c4151359bc8023be8e9217b diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index f1af2f3d1573..7bad25122101 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -20,6 +20,7 @@ import json import os import subprocess +from typing import Optional, Tuple from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import List @@ -116,7 +117,7 @@ def get_object_file_path(src: Path) -> Path: # Determine compute version compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split(".")) - if compute_version in ["90"]: + if compute_version in ["90", "100"]: compute_version += "a" cuda_cflags += [ "-gencode", @@ -488,3 +489,96 @@ def gen_sampling_module(target: Target, num_threads: int = 8): object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) modules = _load_flashinfer_modules(object_files) return modules + +def gen_grouped_gemm_module( + dtype_a: str, + dtype_b: str, + dtype_out: str, + scale_granularity_m: int, + scale_granularity_n: int, + scale_granularity_k: int, + scale_major_mode: str, + mma_sm: int, + target: Target, + num_threads: int = 8, +) -> List[tvm.runtime.Module]: + """Generate a FlashInfer module for FP8 grouped GEMM. + + Parameters + ---------- + dtype_a : str + The data type of matrix A (e.g., "float8_e4m3fn"). + dtype_b : str + The data type of matrix B (e.g., "float8_e4m3fn"). + dtype_out : str + The data type of the output matrix (e.g., "bfloat16"). + scale_granularity_m : int + The scaling granularity in the M dimension. + scale_granularity_n : int + The scaling granularity in the N dimension. + scale_granularity_k : int + The scaling granularity in the K dimension. + scale_major_mode : str + The scale storage mode ("K" or "MN"). + mma_sm : int + The MMA scheduling mode (1 or 2). + target : Target + The target device to compile for. + num_threads : int + The number of threads to use for compilation. + + Returns + ------- + List[tvm.runtime.Module] + A list of compiled static library modules for FlashInfer FP8 grouped GEMM kernels. + + Note + _____ + when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), m_indptr: (batch_size, ) + requires all m in m_indptr to be multiple of 4 + """ + try: + from flashinfer.jit import ( + gen_grouped_gemm_fp8_tvm_binding, + get_grouped_gemm_fp8_uri, + ) + except ImportError: + raise ImportError( + "FlashInfer is not installed. Please follow instructions " + "in https://docs.flashinfer.ai to install FlashInfer." + ) + try: + import torch + except ImportError: + raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") + + torch_dtype_a = getattr(torch, dtype_a) + torch_dtype_b = getattr(torch, dtype_b) + torch_dtype_out = getattr(torch, dtype_out) + + uri = get_grouped_gemm_fp8_uri( + dtype_a=torch_dtype_a, + dtype_b=torch_dtype_b, + dtype_out=torch_dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + mma_sm=mma_sm, + ) + + uri, source_paths = gen_grouped_gemm_fp8_tvm_binding( + uri=uri, + dtype_a=torch_dtype_a, + dtype_b=torch_dtype_b, + dtype_out=torch_dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + mma_sm=mma_sm, + ) + + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) + modules = _load_flashinfer_modules(object_files) + return modules \ No newline at end of file diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py new file mode 100644 index 000000000000..c669fbfad3af --- /dev/null +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -0,0 +1,500 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test for FlashInfer GroupedGemm TVM integration""" + +import math +import numpy as np +import pytest +import torch +from einops import rearrange, reduce, repeat + +import tvm +import tvm.testing +from tvm import relax +from tvm.contrib import utils +from tvm.relax.backend.cuda import flashinfer + +DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 +fp8_dtype = "float8_e4m3fn" + + +########################################### +################# Helpers ################# +########################################### +def has_flashinfer(): + """Check if FlashInfer is available""" + try: + from tvm.relax.backend.cuda import ( # pylint: disable=import-outside-toplevel + flashinfer, + ) + + return True + except ImportError: + return False + + +def has_cutlass(): + """Check if CUTLASS is available for SM90+ operations""" + if not tvm.get_global_func("device_api.cuda", True): + return False + try: + import pynvml # pylint: disable=import-outside-toplevel + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return major >= 9 # SM90+ + except: + return False + +def calc_diff(x: np.ndarray, y: np.ndarray): + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + +def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): + """ + Quantizes a 2D or 3D tensor to FP8. + + Args: + x (torch.Tensor): The 2D or 3D input tensor. + scale_shape (tuple): The shape of the scale tensor. + tile_shape (tuple): The shape of the tiles. + scale_major_mode (str): The tiling order, "K" for row-major like, + or another value for column-major like. + + Returns: + tuple: A tuple containing the quantized FP8 tensor and the + calculated float32 scales. + """ + # 1. Assertions and Initial Setup + ndim = x.ndim + assert ndim == len(scale_shape) == len(tile_shape) + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_amax = torch.tensor(fp8_info.max, device=x.device, dtype=torch.float32) + + # 2. Tiling and Scale Calculation + if ndim == 2: + s0, s1 = scale_shape + t0, t1 = tile_shape + if scale_major_mode == "K": + # Tile x and find the max absolute value in each tile + x_tiled = rearrange(x, "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Broadcast scales back to the original tensor shape + scales_repeated = repeat(x_scale, "s0 s1 -> (s0 t0) (s1 t1)", t0=t0, t1=t1) + else: + # Handle column-major tiling + x_tiled = rearrange(x, "(s1 t0) (s0 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + abs_max = reduce(x_tiled.abs(), "s0 s1 t0 t1 -> s0 s1", "max").clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Permute scale axes before repeating to match layout + scales_permuted = rearrange(x_scale, "s0 s1 -> s1 s0") + scales_repeated = repeat(scales_permuted, "s1 s0 -> (s1 t0) (s0 t1)", t0=t0, t1=t1) + + elif ndim == 3: + s0, s1, s2 = scale_shape + t0, t1, t2 = tile_shape + if scale_major_mode == "K": + # Tile x and find the max absolute value in each tile + x_tiled = rearrange( + x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 + ) + abs_max = reduce( + x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" + ).clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + + # Broadcast scales back to the original tensor shape + scales_repeated = repeat( + x_scale, "s0 s1 s2 -> (s0 t0) (s1 t1) (s2 t2)", t0=t0, t1=t1, t2=t2 + ) + else: + # Handle layout where the last two axes are swapped + x_tiled = rearrange( + x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 + ) + abs_max = reduce( + x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" + ).clamp(1e-4) + x_scale = abs_max / fp8_amax + x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) + # Permute scale axes before repeating to match layout + scales_permuted = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1") + scales_repeated = repeat( + scales_permuted, + "s0 s2 s1 -> (s0 t0) (s2 t1) (s1 t2)", + t0=t0, + t1=t1, + t2=t2, + ) + # 3. Final Quantization + # Divide the original tensor by the broadcasted scales + x_fp32 = x / (scales_repeated + 1e-8) + + # Convert the result to the target FP8 format + x_fp8 = x_fp32.to(torch.float8_e4m3fn) + + return x_fp8, x_scale + + +def dequantize_fp8(x, x_scale, scale_major_mode): + """ + Quantizes a 2D or 3D tensor to FP8. + + Args: + x (torch.Tensor): The 2D or 3D input tensor. + scale_shape (tuple): The shape of the scale tensor. + tile_shape (tuple): The shape of the tiles. + scale_major_mode (str): The tiling order, "K" for row-major like, + or another value for column-major like. + + Returns: + tuple: A tuple containing the quantized FP8 tensor and the + calculated float32 scales. + """ + # 1. Assertions and Initial Setup + ndim = x.ndim + assert ndim == len(x_scale.shape) + + # 2. Tiling and Scale Calculation + if ndim == 2: + if scale_major_mode == "K": + s0, s1 = x_scale.shape + else: + s1, s0 = x_scale.shape + x = rearrange(x.to(torch.float32), "(s0 t0) (s1 t1) -> s0 s1 t0 t1", s0=s0, s1=s1) + if scale_major_mode == "K": + x_scale = rearrange(x_scale, "s0 s1 -> s0 s1 1 1") + else: + x_scale = rearrange(x_scale, "s0 s1 -> s1 s0 1 1") + out = rearrange(x * x_scale, "s0 s1 t0 t1 -> (s0 t0) (s1 t1)") + elif ndim == 3: + if scale_major_mode == "K": + s0, s1, s2 = x_scale.shape + else: + s0, s2, s1 = x_scale.shape + x = rearrange( + x.to(torch.float32), + "(s0 t0) (s1 t1) (s2 t2)-> s0 s1 s2 t0 t1 t2", + s0=s0, + s1=s1, + s2=s2, + ) + if scale_major_mode == "K": + x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s1 s2 1 1 1") + else: + x_scale = rearrange(x_scale, "s0 s1 s2 -> s0 s2 s1 1 1 1") + out = rearrange(x * x_scale, "s0 s1 s2 t0 t1 t2 -> (s0 t0) (s1 t1) (s2 t2)") + + return out + + +########################################### +########### Refernce generation ########### +########################################### +def compute_reference_grouped_gemm( + a_fp32: torch.Tensor, # (total_m, k) + b_fp32: torch.Tensor, # (batch_size, n, k) + m_indptr: torch.Tensor, + dtype_out: str, # (total_m, n) +): + """Compute reference result using PyTorch operations""" + """Compute reference result using original FP32 tensors""" + + total_m, k = a_fp32.shape + batch_size, n, k2 = b_fp32.shape + assert k == k2 + + # Perform grouped GEMM computation directly on original FP32 data + results = [] + + for i in range(batch_size): + start_m = m_indptr[i].item() + end_m = m_indptr[i + 1].item() + + # Extract group's portion of A + a_group = a_fp32[start_m:end_m, :] # [m_sizes[i], k] + b_group = b_fp32[i] + + # Multiply with shared B matrix + result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n] + results.append(result_group) + + result_fp32 = torch.cat(results, dim=0) + + # Convert to output dtype + if dtype_out == "bfloat16": + result = result_fp32.to(torch.bfloat16) + elif dtype_out == "float16": + result = result_fp32.to(torch.float16) + else: + result = result_fp32 + + return result + + +########################################### +########### Test data generation ########## +########################################### +def generate_test_data( + m_sizes: list, + batch_size: int, + n: int, + k: int, + dtype_a: str, + dtype_b: str, + dtype_out: str, + scale_granularity_m: int, + scale_granularity_n: int, + scale_granularity_k: int, + scale_major_mode: str, + device: tvm.runtime.Device, +): + """Generate test data for grouped GEMM operations""" + assert batch_size == len( + m_sizes + ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})" + + torch_device = torch.device(f"cuda:{device.device_id}") + + cum_m = [0] + list(np.cumsum(m_sizes)) + total_m = cum_m[-1] + + # Generate input matrices A and B (where we assert of form fp8) random data in fp32 first, then convert + assert dtype_a == "float8_e4m3fn" + a_fp32 = torch.randn(total_m, k, device=torch_device, dtype=torch.float32) + + assert dtype_b == "float8_e4m3fn" + b_fp32 = torch.randn(batch_size, n, k, device=torch_device, dtype=torch.float32) / math.sqrt(k) + + if scale_major_mode == "K": # K mode: + scale_a_shape = (total_m // scale_granularity_m, k // scale_granularity_k) + scale_b_shape = (batch_size, n // scale_granularity_n, k // scale_granularity_k) + + else: # MN mode + scale_a_shape = (k // scale_granularity_k, total_m // scale_granularity_m) + scale_b_shape = (batch_size, k // scale_granularity_k, n // scale_granularity_n) + + tile_a_shape = (scale_granularity_m, scale_granularity_k) + tile_b_shape = (1, scale_granularity_n, scale_granularity_k) + + # quantize A, B + a_quantized, scale_a = quantize_fp8(a_fp32, scale_a_shape, tile_a_shape, scale_major_mode) + b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, scale_major_mode) + + if dtype_a == "float8_e4m3fn": + a_tvm = tvm.nd.array( + a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device + ) + else: + a_tvm = tvm.nd.from_dlpack(a_quantized) + + if dtype_b == "float8_e4m3fn": + b_tvm = tvm.nd.array( + b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device + ) + else: + b_tvm = tvm.nd.from_dlpack(b_quantized) + + scale_a_tvm = tvm.nd.from_dlpack(scale_a) + scale_b_tvm = tvm.nd.from_dlpack(scale_b) + + # Create m_indptr for grouped operation + m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32) + m_indptr_tvm = tvm.nd.array(m_indptr.cpu().numpy(), device) + + return { + "a": a_tvm, + "b": b_tvm, + "torch_a": a_fp32, + "torch_b": b_fp32, + "scale_a": scale_a_tvm, + "scale_b": scale_b_tvm, + "m_indptr": m_indptr_tvm, + "m_sizes": m_sizes, + "n": n, + "k": k, + "total_m": total_m, + "torch_scale_a": scale_a, + "torch_scale_b": scale_b, + "torch_m_indptr": m_indptr, + } + + +########################################### +############### Test driver ############### +########################################### +@pytest.mark.skipif(not has_flashinfer(), reason="FlashInfer not available") +@pytest.mark.skipif(not has_cutlass(), reason="CUTLASS SM90+ not available") +@pytest.mark.parametrize( + "dtype_a,dtype_b,dtype_out", + [ + ("float8_e4m3fn", "float8_e4m3fn", "bfloat16"), + ("float8_e4m3fn", "float8_e4m3fn", "float16"), + ], +) +@pytest.mark.parametrize( + "scale_granularity_m,scale_granularity_n,scale_granularity_k", + [ + (1, 128, 128), # Row-wise A, block-wise B + ], +) +@pytest.mark.parametrize("scale_major_mode", ["K", "MN"]) +@pytest.mark.parametrize("mma_sm", [1, 2]) +@pytest.mark.parametrize( + "test_case", + [ + {"batch_size": 4, "m_sizes": [128, 256, 192, 320], "n": 512, "k": 1024}, + {"batch_size": 2, "m_sizes": [64, 128], "n": 256, "k": 512}, + {"batch_size": 3, "m_sizes": [256, 256, 128], "n": 768, "k": 768}, + {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768}, + ], +) +def test_grouped_gemm_correctness( + dtype_a, + dtype_b, + dtype_out, + scale_granularity_m, + scale_granularity_n, + scale_granularity_k, + scale_major_mode, + mma_sm, + test_case, +): + """Test correctness of GroupedGemm operations""" + device = tvm.cuda(0) + target = tvm.target.Target.from_device(device) + + def _load_module(name: str, static_modules): + """Helper function to load compiled modules.""" + assert len(static_modules) > 0 + if len(static_modules) == 1: + return static_modules[0] + static_mod = static_modules[0] + for mod in static_modules[1:]: + static_mod.import_module(mod) + temp = tvm.contrib.utils.tempdir() + mod_path = temp.relpath(f"{name}.so") + static_mod.export_library(mod_path) + return tvm.runtime.load_module(mod_path) + + # Generate the module + modules = relax.backend.cuda.flashinfer.gen_grouped_gemm_module( + dtype_a=dtype_a, + dtype_b=dtype_b, + dtype_out=dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + mma_sm=mma_sm, + target=target, + num_threads=4, + ) + + # Load the module + mod = _load_module("flashinfer_grouped_gemm", modules) + grouped_gemm_fn = mod["grouped_gemm_fp8_run"] + + # Generate test data + test_data = generate_test_data( + batch_size=test_case["batch_size"], + m_sizes=test_case["m_sizes"], + n=test_case["n"], + k=test_case["k"], + dtype_a=dtype_a, + dtype_b=dtype_b, + dtype_out=dtype_out, + scale_granularity_m=scale_granularity_m, + scale_granularity_n=scale_granularity_n, + scale_granularity_k=scale_granularity_k, + scale_major_mode=scale_major_mode, + device=device, + ) + + # Prepare output buffer + output_shape = (test_data["total_m"], test_data["n"]) + if dtype_out == "bfloat16": + output = tvm.nd.empty(output_shape, dtype="bfloat16", device=device) + elif dtype_out == "float16": + output = tvm.nd.empty(output_shape, dtype="float16", device=device) + else: + output = tvm.nd.empty(output_shape, dtype="float32", device=device) + + # Create workspace buffers (required by the interface) + int_workspace = tvm.nd.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device) + float_workspace = tvm.nd.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device) + + grouped_gemm_fn( + int_workspace, # int_workspace_buffer + float_workspace, # float_workspace_buffer + test_data["a"], # A + test_data["b"], # B + test_data["scale_a"], # SFA + test_data["scale_b"], # SFB + output, # D + test_data["m_indptr"], # m_indptr + test_data["n"], # n (scalar) + test_data["k"], # k (scalar) + None, # cuda_stream (use default stream) + ) + + # Compute reference result + reference = compute_reference_grouped_gemm( + test_data['torch_a'], + test_data['torch_b'], + test_data["torch_m_indptr"], + dtype_out, + ) + + # Convert TVM output to PyTorch for comparison + output_torch = torch.as_tensor(output, device=test_data["torch_a"].device) + output_torch + + # Compare results with appropriate tolerance + if dtype_out == "bfloat16": + rtol, atol = 1e-2, 1e-2 + elif dtype_out == "float16": + rtol, atol = 1e-3, 1e-3 + else: + rtol, atol = 1e-4, 1e-4 + + # Check shapes match + assert ( + output_torch.shape == reference.shape + ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}" + + + diff = calc_diff( + output_torch.cpu().double().numpy(), + reference.cpu().double().numpy() + ) + assert diff < 1e-3, f"diff too large {diff}" + + +if __name__ == "__main__": + tvm.testing.main() + From 95d73727997d215944bc660d6f5a6fbd65cdff18 Mon Sep 17 00:00:00 2001 From: Anrui Liu Date: Fri, 19 Sep 2025 16:00:52 -0400 Subject: [PATCH 2/5] revert 3rdparty/libbacktrace --- 3rdparty/libbacktrace | 1 - 1 file changed, 1 deletion(-) delete mode 160000 3rdparty/libbacktrace diff --git a/3rdparty/libbacktrace b/3rdparty/libbacktrace deleted file mode 160000 index 08f7c7e69f8e..000000000000 --- a/3rdparty/libbacktrace +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 08f7c7e69f8ea61a0c4151359bc8023be8e9217b From 66a858db82ff8d36bc07e0eadf9859de0bc2700e Mon Sep 17 00:00:00 2001 From: Anrui Liu Date: Sun, 21 Sep 2025 18:45:57 -0400 Subject: [PATCH 3/5] move einops inside function call, and reformat --- python/tvm/relax/backend/cuda/flashinfer.py | 13 ++-- .../relax/test_group_gemm_flashinfer.py | 69 ++++++++++--------- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 7bad25122101..b5c39f6973cc 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -490,6 +490,7 @@ def gen_sampling_module(target: Target, num_threads: int = 8): modules = _load_flashinfer_modules(object_files) return modules + def gen_grouped_gemm_module( dtype_a: str, dtype_b: str, @@ -538,7 +539,7 @@ def gen_grouped_gemm_module( requires all m in m_indptr to be multiple of 4 """ try: - from flashinfer.jit import ( + from flashinfer.jit import ( gen_grouped_gemm_fp8_tvm_binding, get_grouped_gemm_fp8_uri, ) @@ -548,14 +549,14 @@ def gen_grouped_gemm_module( "in https://docs.flashinfer.ai to install FlashInfer." ) try: - import torch + import torch except ImportError: raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") torch_dtype_a = getattr(torch, dtype_a) torch_dtype_b = getattr(torch, dtype_b) torch_dtype_out = getattr(torch, dtype_out) - + uri = get_grouped_gemm_fp8_uri( dtype_a=torch_dtype_a, dtype_b=torch_dtype_b, @@ -566,7 +567,7 @@ def gen_grouped_gemm_module( scale_major_mode=scale_major_mode, mma_sm=mma_sm, ) - + uri, source_paths = gen_grouped_gemm_fp8_tvm_binding( uri=uri, dtype_a=torch_dtype_a, @@ -578,7 +579,7 @@ def gen_grouped_gemm_module( scale_major_mode=scale_major_mode, mma_sm=mma_sm, ) - + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) modules = _load_flashinfer_modules(object_files) - return modules \ No newline at end of file + return modules diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py index c669fbfad3af..1d20f682db81 100644 --- a/tests/python/relax/test_group_gemm_flashinfer.py +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -21,8 +21,6 @@ import numpy as np import pytest import torch -from einops import rearrange, reduce, repeat - import tvm import tvm.testing from tvm import relax @@ -62,12 +60,15 @@ def has_cutlass(): except: return False + def calc_diff(x: np.ndarray, y: np.ndarray): denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return 1 - sim + def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): + from einops import rearrange, reduce, repeat """ Quantizes a 2D or 3D tensor to FP8. @@ -121,9 +122,7 @@ def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): x_tiled = rearrange( x, "(s0 t0) (s1 t1) (s2 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 ) - abs_max = reduce( - x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" - ).clamp(1e-4) + abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) x_scale = abs_max / fp8_amax x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) @@ -136,9 +135,7 @@ def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): x_tiled = rearrange( x, "(s0 t0) (s2 t1) (s1 t2) -> s0 s1 s2 t0 t1 t2", s0=s0, s1=s1, s2=s2 ) - abs_max = reduce( - x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max" - ).clamp(1e-4) + abs_max = reduce(x_tiled.abs(), "s0 s1 s2 t0 t1 t2 -> s0 s1 s2", "max").clamp(1e-4) x_scale = abs_max / fp8_amax x_scale = torch.pow(2.0, torch.ceil(torch.log2(x_scale.abs()))) # Permute scale axes before repeating to match layout @@ -161,6 +158,7 @@ def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): def dequantize_fp8(x, x_scale, scale_major_mode): + from einops import rearrange, reduce, repeat """ Quantizes a 2D or 3D tensor to FP8. @@ -219,7 +217,7 @@ def compute_reference_grouped_gemm( a_fp32: torch.Tensor, # (total_m, k) b_fp32: torch.Tensor, # (batch_size, n, k) m_indptr: torch.Tensor, - dtype_out: str, # (total_m, n) + dtype_out: str, # (total_m, n) ): """Compute reference result using PyTorch operations""" """Compute reference result using original FP32 tensors""" @@ -243,7 +241,7 @@ def compute_reference_grouped_gemm( result_group = torch.mm(a_group, b_group.T) # [m_sizes[i], n] results.append(result_group) - result_fp32 = torch.cat(results, dim=0) + result_fp32 = torch.cat(results, dim=0) # Convert to output dtype if dtype_out == "bfloat16": @@ -278,7 +276,8 @@ def generate_test_data( m_sizes ), f"batch_size ({batch_size}) must equal len(m_sizes) ({len(m_sizes)})" - torch_device = torch.device(f"cuda:{device.device_id}") + # print(f"Device object: {device}") + torch_device = torch.device(f"cuda:{device.index}") cum_m = [0] + list(np.cumsum(m_sizes)) total_m = cum_m[-1] @@ -306,25 +305,25 @@ def generate_test_data( b_quantized, scale_b = quantize_fp8(b_fp32, scale_b_shape, tile_b_shape, scale_major_mode) if dtype_a == "float8_e4m3fn": - a_tvm = tvm.nd.array( + a_tvm = tvm.runtime.tensor( a_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device ) else: - a_tvm = tvm.nd.from_dlpack(a_quantized) + a_tvm = tvm.runtime.from_dlpack(a_quantized) if dtype_b == "float8_e4m3fn": - b_tvm = tvm.nd.array( + b_tvm = tvm.runtime.tensor( b_quantized.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device ) else: - b_tvm = tvm.nd.from_dlpack(b_quantized) + b_tvm = tvm.runtime.from_dlpack(b_quantized) - scale_a_tvm = tvm.nd.from_dlpack(scale_a) - scale_b_tvm = tvm.nd.from_dlpack(scale_b) + scale_a_tvm = tvm.runtime.from_dlpack(scale_a) + scale_b_tvm = tvm.runtime.from_dlpack(scale_b) # Create m_indptr for grouped operation m_indptr = torch.tensor(cum_m, device=torch_device, dtype=torch.int32) - m_indptr_tvm = tvm.nd.array(m_indptr.cpu().numpy(), device) + m_indptr_tvm = tvm.runtime.tensor(m_indptr.cpu().numpy(), device) return { "a": a_tvm, @@ -438,15 +437,15 @@ def _load_module(name: str, static_modules): # Prepare output buffer output_shape = (test_data["total_m"], test_data["n"]) if dtype_out == "bfloat16": - output = tvm.nd.empty(output_shape, dtype="bfloat16", device=device) + output = tvm.runtime.empty(output_shape, dtype="bfloat16", device=device) elif dtype_out == "float16": - output = tvm.nd.empty(output_shape, dtype="float16", device=device) + output = tvm.runtime.empty(output_shape, dtype="float16", device=device) else: - output = tvm.nd.empty(output_shape, dtype="float32", device=device) + output = tvm.runtime.empty(output_shape, dtype="float32", device=device) # Create workspace buffers (required by the interface) - int_workspace = tvm.nd.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device) - float_workspace = tvm.nd.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device) + int_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="int32", device=device) + float_workspace = tvm.runtime.empty((DEFAULT_WORKSPACE_SIZE,), dtype="float32", device=device) grouped_gemm_fn( int_workspace, # int_workspace_buffer @@ -464,8 +463,8 @@ def _load_module(name: str, static_modules): # Compute reference result reference = compute_reference_grouped_gemm( - test_data['torch_a'], - test_data['torch_b'], + test_data["torch_a"], + test_data["torch_b"], test_data["torch_m_indptr"], dtype_out, ) @@ -487,14 +486,20 @@ def _load_module(name: str, static_modules): output_torch.shape == reference.shape ), f"Shape mismatch: got {output_torch.shape}, expected {reference.shape}" - - diff = calc_diff( - output_torch.cpu().double().numpy(), - reference.cpu().double().numpy() - ) + diff = calc_diff(output_torch.cpu().double().numpy(), reference.cpu().double().numpy()) assert diff < 1e-3, f"diff too large {diff}" if __name__ == "__main__": - tvm.testing.main() - + test_grouped_gemm_correctness( + "float8_e4m3fn", + "float8_e4m3fn", + "bfloat16", + 1, + 128, + 128, + "K", + 1, + {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768}, + ) + # tvm.testing.main() From b2498ecc7fc39be502313997391e7aa524e70fb4 Mon Sep 17 00:00:00 2001 From: Anrui Liu Date: Sun, 21 Sep 2025 18:58:32 -0400 Subject: [PATCH 4/5] reformat --- tests/python/relax/test_group_gemm_flashinfer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py index 1d20f682db81..e124a6404993 100644 --- a/tests/python/relax/test_group_gemm_flashinfer.py +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -69,6 +69,7 @@ def calc_diff(x: np.ndarray, y: np.ndarray): def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): from einops import rearrange, reduce, repeat + """ Quantizes a 2D or 3D tensor to FP8. @@ -159,6 +160,7 @@ def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): def dequantize_fp8(x, x_scale, scale_major_mode): from einops import rearrange, reduce, repeat + """ Quantizes a 2D or 3D tensor to FP8. From b1d615624186120f9a6459444f3642e4fc4c9155 Mon Sep 17 00:00:00 2001 From: Anrui Liu Date: Sun, 21 Sep 2025 19:18:13 -0400 Subject: [PATCH 5/5] remove unused imports, to pass pylint --- python/tvm/relax/backend/cuda/flashinfer.py | 5 ++--- tests/python/relax/test_group_gemm_flashinfer.py | 15 ++------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index b5c39f6973cc..4e0fc3e8541a 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -20,7 +20,6 @@ import json import os import subprocess -from typing import Optional, Tuple from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import List @@ -539,7 +538,7 @@ def gen_grouped_gemm_module( requires all m in m_indptr to be multiple of 4 """ try: - from flashinfer.jit import ( + from flashinfer.jit import ( # pylint: disable=import-outside-toplevel gen_grouped_gemm_fp8_tvm_binding, get_grouped_gemm_fp8_uri, ) @@ -549,7 +548,7 @@ def gen_grouped_gemm_module( "in https://docs.flashinfer.ai to install FlashInfer." ) try: - import torch + import torch # pylint: disable=import-outside-toplevel except ImportError: raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.") diff --git a/tests/python/relax/test_group_gemm_flashinfer.py b/tests/python/relax/test_group_gemm_flashinfer.py index e124a6404993..8333e4b2d66b 100644 --- a/tests/python/relax/test_group_gemm_flashinfer.py +++ b/tests/python/relax/test_group_gemm_flashinfer.py @@ -159,7 +159,7 @@ def quantize_fp8(x, scale_shape, tile_shape, scale_major_mode): def dequantize_fp8(x, x_scale, scale_major_mode): - from einops import rearrange, reduce, repeat + from einops import rearrange """ Quantizes a 2D or 3D tensor to FP8. @@ -493,15 +493,4 @@ def _load_module(name: str, static_modules): if __name__ == "__main__": - test_grouped_gemm_correctness( - "float8_e4m3fn", - "float8_e4m3fn", - "bfloat16", - 1, - 128, - 128, - "K", - 1, - {"batch_size": 2, "m_sizes": [20, 36], "n": 768, "k": 768}, - ) - # tvm.testing.main() + tvm.testing.main()