Skip to content
11 changes: 1 addition & 10 deletions testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
Expand Down
11 changes: 1 addition & 10 deletions testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
Expand Down
11 changes: 1 addition & 10 deletions testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tilelang.language as T
from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify
from tilelang.utils.tensor import map_torch_type
from typing import Optional

tilelang.testing.set_random_seed(0)
Expand Down Expand Up @@ -131,16 +132,6 @@ def evaluate_gemv_simt(

kernel = JITKernel(program, target="cuda")

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
Expand Down
11 changes: 1 addition & 10 deletions testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type

tilelang.testing.set_random_seed(0)

Expand Down Expand Up @@ -186,16 +187,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
Expand Down
11 changes: 1 addition & 10 deletions testing/python/kernel/test_tilelang_kernel_gemv_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tilelang.language as T
from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify
from tilelang.utils.tensor import map_torch_type
from typing import Optional

tilelang.testing.set_random_seed(0)
Expand Down Expand Up @@ -131,16 +132,6 @@ def evaluate_gemv_simt(

kernel = JITKernel(program, target="cuda")

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import tilelang.testing

tilelang.testing.set_random_seed(42)


def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
Expand Down Expand Up @@ -302,10 +304,10 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None

assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)


def test_mha_bwd():
Expand Down
59 changes: 54 additions & 5 deletions tilelang/jit/adapter/cython/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
from tilelang.contrib.cc import get_cplus_compiler

import torch
import sys
import sysconfig
import hashlib
Expand Down Expand Up @@ -89,7 +90,7 @@ def get_cached_lib(source_code: str) -> Tuple[Optional[ctypes.CDLL], Path]:
logger.debug("Cython jit adapter is up to date, no need to compile...")
need_compile = False
else:
logger.info("Cython jit adapter is out of date, need to compile...")
logger.info("Cython jit adapter is out of date, need to recompile...")
else:
logger.info("No cached version found for cython jit adapter, need to compile...")

Expand Down Expand Up @@ -135,6 +136,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
# Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
# Maps buffer variables to their corresponding static shapes
# {
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None

def __init__(self,
rt_mod,
Expand Down Expand Up @@ -163,6 +171,8 @@ def __init__(self,
self.ir_module = func_or_mod

self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype()
self.static_shape_map = self._process_static_shape()

self.target = Target.canon_target(determine_target(target))
self.verbose = verbose
Expand All @@ -182,12 +192,14 @@ def __init__(self,
raise Exception(
f"Failed to initialize the compiled library for {self.target}: {e}") from e

self.cython_wrapper = CythonKernelWrapper(self.dynamic_symbolic_map, self.result_idx,
self.params, self.lib)
self.cython_wrapper = CythonKernelWrapper(self.result_idx, self.params, self.lib)
self.cython_wrapper.set_dynamic_symbolic_map(self.dynamic_symbolic_map)
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map)

self._post_init()

def _process_dynamic_symbolic(self):
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function.

Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
Expand All @@ -205,6 +217,43 @@ def _process_dynamic_symbolic(self):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map

def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
"""Extract information about buffer dtypes from the TIR function.

Maps buffer variables to their corresponding dtypes.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
buffer_dtype_map = {}
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name, dtype = buffer.name, buffer.dtype
buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map

def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]:
"""Extract information about static shapes from the TIR function.

Maps buffer variables to their corresponding static shapes.
"""
func = self.prim_func
params = func.params
buffer_map = func.buffer_map
static_shape_map = {}
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
name = buffer.name
shape = buffer.shape
static_shape = []
for j, s in enumerate(shape):
if isinstance(s, tir.IntImm):
static_shape.append((j, s.value))
static_shape_map[name] = (i, static_shape)
return static_shape_map

def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
"""Low-level function to call the compiled CUDA kernel.

Expand Down
28 changes: 26 additions & 2 deletions tilelang/jit/adapter/cython/cython_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,29 @@ cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference
cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes
list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel

def __cinit__(self, dynamic_symbolic_map, result_idx, params, lib):
def __cinit__(self, result_idx, params, lib):
# Initialize wrapper with kernel configuration
self.dynamic_symbolic_map = dynamic_symbolic_map
self.result_idx = result_idx
self.params = params
self.lib = lib

def set_dynamic_symbolic_map(self, dynamic_symbolic_map):
self.dynamic_symbolic_map = dynamic_symbolic_map
return self

def set_buffer_dtype_map(self, buffer_dtype_map):
self.buffer_dtype_map = buffer_dtype_map
return self

def set_static_shape_map(self, static_shape_map):
self.static_shape_map = static_shape_map
return self

cpdef forward(self, list inputs, int64_t stream = -1):
# Validate input dimensions and prepare for kernel execution
Expand Down Expand Up @@ -69,6 +82,17 @@ cdef class CythonKernelWrapper:
else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")

# Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")

# Check static shape map
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape}, got {tensor_list[buffer_idx].shape}")

# Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
Expand Down
2 changes: 1 addition & 1 deletion tilelang/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TensorSupplyType(Enum):
Auto = 7


def map_torch_type(intype):
def map_torch_type(intype: str) -> torch.dtype:
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
Expand Down
Loading