Skip to content
4 changes: 2 additions & 2 deletions maint/gemm_v2/correctness_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pytest gemm_ss_wgmma.py -n 32
# pytest correctness_evaluation.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
Expand Down Expand Up @@ -384,7 +384,7 @@ def run_gemm_rr(


M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128]
N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
Expand Down
99 changes: 99 additions & 0 deletions maint/gemm_v2/latency_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import tilelang
import tilelang.language as T
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()

use_v2 = args.use_v2


# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)

# Clear local accumulation
T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)

# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)

# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)

# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

Comment on lines +46 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

use_v2 captured at call-site breaks autotuned matmul reuse

use_v2 is read from the CLI once and closed over by the nested matmul_relu_kernel. The compiled kernel therefore hardcodes whatever flag was set during compilation. When you later call matmul again with a different use_v2 expectation (e.g., toggling between v1/v2 from the CLI or a tuning sweep), the already-compiled kernel silently keeps the previous path. This leads to confusing latency numbers and invalidates correctness checks when you intend to benchmark both variants in one process. Please pass the flag as an explicit parameter to matmul and thread it through the kernel invocation rather than capturing a global.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 46-57, the nested
matmul_relu_kernel currently closes over the CLI variable use_v2 which hardcodes
the chosen path into the compiled kernel; instead make use_v2 an explicit
argument to matmul and add it as a parameter to the kernel signature so the
kernel reads the flag at call time. Update all invocations to pass the boolean
through, modify the kernel body to branch on the kernel parameter (not a
closed-over variable), and ensure any autotuning/compilation cache keys include
this parameter so separate v1/v2 kernels are compiled and reused correctly.

return matmul_relu_kernel


M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64

# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)

# 3. Test the kernel in Python with PyTorch data
import torch

# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)

# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)

print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
Comment on lines +75 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid allocating 16k×16k fp16 tensors on import

Instantiating 16 384² tensors at module import consumes ~1 GiB per tensor and triggers GPU allocation before __main__. This makes latency_gemm.py unusable as a library and can crash on machines without that much free memory. Move the heavy allocations under if __name__ == "__main__": (or into a main function) so importing the module for reuse or running unit tests doesn’t instantly OOM the GPU.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 75–88 the script allocates large
fp16 CUDA tensors (a, b, c) and runs the kernel at module import time which can
OOM GPUs and prevents safe import; move the heavy allocations, kernel call,
reference computation and assertion inside an if __name__ == "__main__": block
(or a main() function) so importing the module does not perform GPU allocations
or execute the test — i.e., wrap lines that create a, b, c, call
matmul_relu_kernel, print, compute ref_c, and assert_close into a guarded main
section and leave only function/class definitions at top-level.


# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)

# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

latency = profiler.do_bench()

print(f"Latency: {latency} ms")
Comment on lines +95 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Profiler benchmark should benchmark the TileLang kernel, not the reference

profiler.do_bench(ref_program_processed, warmup=500) times the PyTorch reference instead of the compiled kernel. That double-counts the reference execution and skews the reported latency. Drop the reference callable when benchmarking the kernel; reserve it for assert_allclose.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 95 to 99, the current
profiler.do_bench call is timing the PyTorch reference (ref_program_processed)
instead of the compiled TileLang kernel; remove the reference callable from the
do_bench invocation so the profiler only benchmarks the compiled kernel (e.g.,
call profiler.do_bench(warmup=500) or equivalent), and keep the reference
callable solely for correctness checks like assert_allclose after benchmarking.

246 changes: 246 additions & 0 deletions maint/gemm_v2/latency_mha_fwd_bhsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial

parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=512, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true")

args = parser.parse_args()

use_v2 = args.use_v2


def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]

Comment on lines +25 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

get_configs builds configs incorrectly

dict(zip(iter_params, values)) zips dictionary keys with value tuples, leaving only the last key/value pair. You end up with configs like {'threads': 256} and silently drop block_M, block_N, num_stages. Use itertools.product(*iter_params.values()) with proper key association (e.g., {k: v for k, v in zip(iter_params.keys(), values)}) to retain all knobs.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 24–27, get_configs
currently uses dict(zip(iter_params, values)) which zips dictionary keys (not
key list) with values and ends up keeping only the last pair; replace that
construction with a proper key-to-value association such as
dict(zip(iter_params.keys(), values)) or a dict comprehension {k: v for k, v in
zip(iter_params.keys(), values)} so every knob (block_M, block_N, num_stages,
threads) is preserved for each product combination.


@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"

past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"

@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
if use_v2:
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)

for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
Comment on lines +114 to +124
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Log-sum-exp accumulation can underflow to zero

scores_scale multiplies logsum before the first tile has written anything. If logsum[i] stays at 0 for early iterations, the subsequent division acc_o[i, j] /= logsum[i] risks division by zero. Initialize logsum to a tiny epsilon (or branch to skip scaling when the previous sum is zero).

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 115 to 125, logsum is
multiplied by scores_scale before any tile may have written to it, allowing
logsum to remain zero and later cause division-by-zero when acc_o is divided by
logsum; fix by initializing logsum to a small positive epsilon (e.g. 1e-12)
before the loop or add a conditional/branch that skips scaling and division when
the previous logsum is zero, ensuring any updates use max(logsum, epsilon) or
guard the divide with a check to avoid division by zero.


@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]

@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))

for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])

return main


def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output


def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5

if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
Comment on lines +210 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Heavy allocations executed on import

Like the GEMM script, this file allocates large tensors at module import. Wrap the execution logic in a main() guarded by if __name__ == "__main__": to avoid OOM when imported.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 211 to 222, the script
performs heavy tensor allocations and prints the kernel source at import; wrap
this execution in a main() function and protect it with if __name__ ==
"__main__": so imports don't allocate memory. Move the kernel creation,
large-tensor allocations, and print(kernel.get_kernel_source()) into the new
main(), keep any helper/definition code at module scope, and call main() only
inside the if __name__ == "__main__": guard.

ref_program_processed = partial(ref_program, is_causal=is_causal)

profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print(f"Ref: {latency:.2f} ms")
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=500)
print(f"Tile-lang: {latency:.2f} ms")
print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops")
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")


if __name__ == "__main__":
tilelang.disable_cache()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
Loading
Loading