Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Nov 3, 2025

This pull request introduces significant improvements and new features to the GEMM and MHA kernel implementations, with a particular focus on the TileLang-based CUDA kernels and the WGMMA macro generator. The most important changes include the addition of new latency benchmarking scripts for GEMM and MHA, major enhancements to the WGMMA macro generator to support more flexible instruction sizes, and expanded test coverage for GEMM correctness. These updates improve performance profiling, correctness validation, and hardware compatibility for advanced GPU architectures.

New benchmarking scripts

GEMM and MHA latency profiling:

  • Added a new script latency_gemm.py for profiling GEMM kernel latency, including correctness validation against PyTorch and support for both v1 and v2 GEMM kernels via a command-line flag.
  • Added a new script latency_mha_fwd_bhsd.py for profiling forward MHA kernels with autotuning, flexible configuration, and correctness checks against a PyTorch reference implementation.

WGMMA macro generator enhancements

Flexible instruction sizing and hardware compatibility:

  • Updated wgmma_macro_generator.py to support dynamic calculation of the WGMMA instruction N dimension (wgmma_inst_n), enforcing hardware constraints (multiple of 8, range [8, 256]) and using the greatest common divisor for optimal sizing. [1] [2]
  • Refactored macro generation to loop over the number of instruction N tiles (num_inst_n), updating address calculations and output offsets for both wgmma and wgmma_rs methods to support tiling and improved parallelism. [1] [2] [3] [4] [5]
  • Changed warpgroup wait logic to be configurable via a new wg_wait argument, allowing for more flexible synchronization. [1] [2]

GEMM correctness evaluation

Expanded test coverage:

  • Increased the range of N values tested in GEMM correctness evaluation, improving coverage for larger matrix sizes.
  • Updated the test invocation comment to reflect the correct filename for GEMM correctness evaluation.

Minor improvements

Codebase maintenance:

  • Added missing import for gcd from the math library in wgmma_macro_generator.py.
  • Updated GEMM WGMMA lowering to propagate the new wg_wait argument.

Summary by CodeRabbit

  • New Features

    • Added latency profiling and end-to-end benchmarking for tiled matrix-multiply with ReLU.
    • Added a configurable, autotuned FlashAttention-like forward pass with causal masking and tuning options.
  • Tests

    • Expanded correctness tests to include larger matrix widths (256 and 512).
  • Refactor

    • Improved tensor-core tiling with per-warp-group synchronization and finer-grained instruction tiling.
    • Improved workspace lifetime management to better scope temporary buffers.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 3, 2025

Walkthrough

Adds two new latency benchmarking scripts, expands GEMM test widths, and updates WGMMA generation and GEMM lowering to support per-instruction tiling and a per-warp-group wait parameter (wg_wait), plus workspace lifetime scoping changes in lowering.

Changes

Cohort / File(s) Summary
Test Parameterization
maint/gemm_v2/correctness_evaluation.py
Swapped initial pytest target to correctness_evaluation.py and expanded N_VALUES to include 256 and 512.
Latency & FlashAttention Scripts
maint/gemm_v2/latency_gemm.py, maint/gemm_v2/latency_mha_fwd_bhsd.py
Added latency_gemm.py (tiled GEMM + ReLU kernel, compile/test/profile flow) and latency_mha_fwd_bhsd.py (TileLang FlashAttention-like forward kernel with autotuning, CLI, and PyTorch reference).
WGMMA Intrinsics Generator
tilelang/intrinsics/wgmma_macro_generator.py
Compute per-instruction tiling (wgmma_inst_m, wgmma_inst_n) using gcd, enforce multiple-of-8 constraint, restructure loops into num_inst_n/num_inst_m/ki ordering, propagate wg_wait, and update A/B/C offset calculations.
GEMM WGMMA Integration
tilelang/tileop/gemm/gemm_wgmma.py
Add local wg_wait binding and pass wg_wait into mma_emitter.wgmma() calls in both S-S and R-S lowering paths.
Lowering Workspace Scoping
src/transform/lower_tile_op.cc
Replace single workspace list with workspace_stack_ (vector of frames), push/pop per-Block frames, record workspaces into current frame, and attach frame buffers to block alloc_buffers on exit.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Host as Host Script
  participant Compiler as TileLang Compiler
  participant Kernel as GPU Kernel
  participant WGMMA as WGMMA Intrinsic

  Host->>Compiler: request compile (latency_gemm / flashattn)
  Compiler->>Kernel: emit kernel with wgmma_inst_n and wg_wait
  Kernel->>WGMMA: per-iteration: compute A/B/C offsets (inst_m/inst_n), call wgmma(..., wg_wait)
  alt wg_wait > 0
    WGMMA-->>WGMMA: perform warp-group wait/sync
  end
  WGMMA-->>Kernel: complete MMA ops
  Kernel-->>Host: return results and latency profile
  Note right of Kernel: workspace_stack_ frames scope alloc_buffers per Block
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Files requiring extra attention:
    • tilelang/intrinsics/wgmma_macro_generator.py — correctness of gcd-based tiling, offsets, and loop bounds across all code paths.
    • tilelang/tileop/gemm/gemm_wgmma.py — ensure wg_wait propagation is complete and consistent for all lowering variants.
    • src/transform/lower_tile_op.cc — verify workspace_stack_ push/pop semantics, fallback frame behavior, and no workspace leaks.
    • maint/gemm_v2/latency_mha_fwd_bhsd.py — validate kernel correctness vs. PyTorch reference and autotune configurations.

Possibly related PRs

Suggested reviewers

  • chengyupku
  • xysmlx

Poem

🐰 I hopped through loops of warp and lane,

I counted inst_n with a gcd brain,
I nudged a gentle wg_wait to cue,
Tuned kernels, ReLU, and Flash in view,
A tiny rabbit cheering code anew.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Title check ⚠️ Warning The pull request title '[Langauge] Support n>256 for v2' is vague and does not clearly reflect the actual scope of changes. While the PR does include WGMMA macro generator enhancements that support larger N dimensions (by dynamically calculating wgmma_inst_n with hardware constraints), the title is misleading because it appears to promise N > 256 support when the implementation constrains inst_n to the range [8, 256]. Additionally, the title does not acknowledge the substantial changes including new benchmarking scripts (latency_gemm.py, latency_mha_fwd_bhsd.py), GEMM correctness evaluation updates, and workspace lifetime tracking improvements. There is also a typo: 'Langauge' should be 'Language'. The title only partially captures the main intent—improved WGMMA generator and benchmarking infrastructure—but oversimplifies and misrepresents the technical details. Revise the title to accurately reflect the full scope of the PR. A better title might be: '[Enhancement] WGMMA macro generator improvements with benchmarking utilities' or '[Feature] Add GEMM/MHA benchmarking scripts and improve WGMMA tiling'. Ensure the title reflects that the enhancement supports better parallelism via instruction-level tiling (within hardware constraints), not unbounded N > 256 support. Fix the typo 'Langauge' to 'Language' or remove it entirely.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Nov 3, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)

291-296: Allow disabling wait via wg_wait = -1 explicitly

You handle wg_wait >= 0, but the docs mention -1 to skip waiting. Consider documenting/validating that value, or treat negatives uniformly.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ba39075 and 1d7a098.

📒 Files selected for processing (5)
  • maint/gemm_v2/correctness_evaluation.py (2 hunks)
  • maint/gemm_v2/latency_gemm.py (1 hunks)
  • maint/gemm_v2/latency_mha_fwd_bhsd.py (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (9 hunks)
  • tilelang/tileop/gemm/gemm_wgmma.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
maint/gemm_v2/latency_gemm.py (9)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/allocate.py (2)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/gemm.py (2)
  • gemm_v2 (215-434)
  • gemm_v1 (10-211)
tilelang/language/parallel.py (1)
  • Parallel (9-29)
tilelang/jit/kernel.py (1)
  • get_profiler (367-383)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
maint/gemm_v2/latency_mha_fwd_bhsd.py (8)
tilelang/autotuner/tuner.py (1)
  • autotune (723-816)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-144)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/fill.py (2)
  • clear (24-48)
  • fill (9-21)
tilelang/language/gemm.py (2)
  • gemm_v2 (215-434)
  • gemm_v1 (10-211)
tilelang/language/allocate.py (2)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
tilelang/profiler/__init__.py (1)
  • assert_allclose (77-146)
tilelang/tileop/gemm/gemm_wgmma.py (2)
tilelang/tileop/gemm/gemm_base.py (2)
  • wg_wait (115-116)
  • clear_accum (107-108)
tilelang/intrinsics/wgmma_macro_generator.py (1)
  • wgmma (160-295)
tilelang/intrinsics/wgmma_macro_generator.py (4)
tilelang/tileop/gemm/gemm_base.py (3)
  • clear_accum (107-108)
  • wg_wait (115-116)
  • accum_dtype (59-60)
tilelang/utils/language.py (1)
  • is_fragment (81-91)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1065-1104)
  • ptx_wgmma_rs (1107-1142)
tilelang/language/builtin.py (2)
  • warpgroup_commit_batch (277-283)
  • warpgroup_wait (286-296)
🪛 Ruff (0.14.2)
maint/gemm_v2/latency_mha_fwd_bhsd.py

4-4: from tilelang.autotuner import * used; unable to detect undefined names

(F403)


29-29: autotune may be undefined, or defined from star imports

(F405)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (2)
maint/gemm_v2/correctness_evaluation.py (1)

386-388: Double-check block coverage for new N values

Adding 256 and 512 increases the N sweep; just confirm the block configuration (block_N = 128) still divides all tested shapes (it does). No code changes needed.

tilelang/tileop/gemm/gemm_wgmma.py (1)

90-122: Propagation of wg_wait looks correct

Glad to see the new wait parameter issued to both SS and RS kernels. Nothing else required.

Comment on lines +46 to +57
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)

# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)

# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

use_v2 captured at call-site breaks autotuned matmul reuse

use_v2 is read from the CLI once and closed over by the nested matmul_relu_kernel. The compiled kernel therefore hardcodes whatever flag was set during compilation. When you later call matmul again with a different use_v2 expectation (e.g., toggling between v1/v2 from the CLI or a tuning sweep), the already-compiled kernel silently keeps the previous path. This leads to confusing latency numbers and invalidates correctness checks when you intend to benchmark both variants in one process. Please pass the flag as an explicit parameter to matmul and thread it through the kernel invocation rather than capturing a global.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 46-57, the nested
matmul_relu_kernel currently closes over the CLI variable use_v2 which hardcodes
the chosen path into the compiled kernel; instead make use_v2 an explicit
argument to matmul and add it as a parameter to the kernel signature so the
kernel reads the flag at call time. Update all invocations to pass the boolean
through, modify the kernel body to branch on the kernel parameter (not a
closed-over variable), and ensure any autotuning/compilation cache keys include
this parameter so separate v1/v2 kernels are compiled and reused correctly.

Comment on lines +75 to +88
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)

# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)

print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)

# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid allocating 16k×16k fp16 tensors on import

Instantiating 16 384² tensors at module import consumes ~1 GiB per tensor and triggers GPU allocation before __main__. This makes latency_gemm.py unusable as a library and can crash on machines without that much free memory. Move the heavy allocations under if __name__ == "__main__": (or into a main function) so importing the module for reuse or running unit tests doesn’t instantly OOM the GPU.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 75–88 the script allocates large
fp16 CUDA tensors (a, b, c) and runs the kernel at module import time which can
OOM GPUs and prevents safe import; move the heavy allocations, kernel call,
reference computation and assertion inside an if __name__ == "__main__": block
(or a main() function) so importing the module does not perform GPU allocations
or execute the test — i.e., wrap lines that create a, b, c, call
matmul_relu_kernel, print, compute ref_c, and assert_close into a guarded main
section and leave only function/class definitions at top-level.

Comment on lines +95 to +99
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

latency = profiler.do_bench()

print(f"Latency: {latency} ms")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Profiler benchmark should benchmark the TileLang kernel, not the reference

profiler.do_bench(ref_program_processed, warmup=500) times the PyTorch reference instead of the compiled kernel. That double-counts the reference execution and skews the reported latency. Drop the reference callable when benchmarking the kernel; reserve it for assert_allclose.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 95 to 99, the current
profiler.do_bench call is timing the PyTorch reference (ref_program_processed)
instead of the compiled TileLang kernel; remove the reference callable from the
do_bench invocation so the profiler only benchmarks the compiled kernel (e.g.,
call profiler.do_bench(warmup=500) or equivalent), and keep the reference
callable solely for correctness checks like assert_allclose after benchmarking.

Comment on lines +24 to +27
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

get_configs builds configs incorrectly

dict(zip(iter_params, values)) zips dictionary keys with value tuples, leaving only the last key/value pair. You end up with configs like {'threads': 256} and silently drop block_M, block_N, num_stages. Use itertools.product(*iter_params.values()) with proper key association (e.g., {k: v for k, v in zip(iter_params.keys(), values)}) to retain all knobs.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 24–27, get_configs
currently uses dict(zip(iter_params, values)) which zips dictionary keys (not
key list) with values and ends up keeping only the last pair; replace that
construction with a proper key-to-value association such as
dict(zip(iter_params.keys(), values)) or a dict comprehension {k: v for k, v in
zip(iter_params.keys(), values)} so every knob (block_M, block_N, num_stages,
threads) is preserved for each product combination.

Comment on lines 74 to 78
if use_v2:
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Causal mask should leave untouched scores when non-causal

When is_causal is false you T.clear(acc_s) but skip it otherwise. For the non-causal path the previous tile’s data remains; T.clear should be executed regardless, with causal logic overwriting masked entries afterward.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 74-78, the current code
only clears acc_s for the causal path which leaves previous tile data in the
non-causal path; always call T.clear(acc_s) before performing the GEMM (move the
clear before the use_v2 conditional or add it in both branches), then keep the
causal masking logic as-is to overwrite masked entries after the GEMM so
non-causal scores remain untouched.

Comment on lines +115 to +125
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)

for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Log-sum-exp accumulation can underflow to zero

scores_scale multiplies logsum before the first tile has written anything. If logsum[i] stays at 0 for early iterations, the subsequent division acc_o[i, j] /= logsum[i] risks division by zero. Initialize logsum to a tiny epsilon (or branch to skip scaling when the previous sum is zero).

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 115 to 125, logsum is
multiplied by scores_scale before any tile may have written to it, allowing
logsum to remain zero and later cause division-by-zero when acc_o is divided by
logsum; fix by initializing logsum to a small positive epsilon (e.g. 1e-12)
before the loop or add a conditional/branch that skips scaling and division when
the previous logsum is zero, ensuring any updates use max(logsum, epsilon) or
guard the divide with a check to avoid division by zero.

Comment on lines 210 to 233
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)

profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Benchmark still times PyTorch reference

Same as GEMM: calling profiler.do_bench(ref_program_processed, ...) benchmarks the reference. Remove the argument so you’re timing the TileLang kernel.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 210-233, the call
profiler.do_bench(ref_program_processed, warmup=500) is benchmarking the PyTorch
reference instead of the TileLang kernel; remove the ref_program_processed
argument so profiler.do_bench(warmup=500) runs the TileLang kernel, keep the
reference run only where you intend to benchmark the reference (or rename the
printed labels if you swap the order), and ensure assert_allclose still uses
ref_program_processed.

Comment on lines +211 to +222
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Heavy allocations executed on import

Like the GEMM script, this file allocates large tensors at module import. Wrap the execution logic in a main() guarded by if __name__ == "__main__": to avoid OOM when imported.

🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 211 to 222, the script
performs heavy tensor allocations and prints the kernel source at import; wrap
this execution in a main() function and protect it with if __name__ ==
"__main__": so imports don't allocate memory. Move the kernel creation,
large-tensor allocations, and print(kernel.get_kernel_source()) into the new
main(), keep any helper/definition code at module scope, and call main() only
inside the if __name__ == "__main__": guard.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
maint/gemm_v2/latency_mha_fwd_bhsd.py (4)

25-28: Single configuration defeats autotuning purpose.

The get_configs implementation is correct (the past review comment incorrectly identified a bug), but each parameter list contains only one value, yielding exactly one configuration: {'block_M': 128, 'block_N': 128, 'num_stages': 2, 'threads': 256}. This defeats the purpose of autotuning—consider expanding the search space with multiple values per parameter.

For example:

 def get_configs():
-    iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
+    iter_params = dict(
+        block_M=[64, 128],
+        block_N=[64, 128],
+        num_stages=[0, 2],
+        threads=[128, 256]
+    )
     return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]

54-77: LGTM: Causal masking correctly initializes before GEMM.

The causal and non-causal paths both correctly initialize acc_s before the GEMM operation:

  • Causal: sets entries to 0 (valid positions) or -inf (masked positions) so GEMM can accumulate on the initialized values
  • Non-causal: clears to 0 for clean accumulation

This matches the FlashAttention pattern. The past review comment suggesting both paths should clear first was incorrect.


195-246: LGTM: Benchmarking and entry point are correctly implemented.

The main function properly:

  • Calculates theoretical FLOPS (including 0.5x adjustment for causal attention)
  • Validates correctness against PyTorch with assert_allclose
  • Benchmarks both reference (line 227) and TileLang (line 230) separately with correct labels
  • Supports both manual and autotuned modes
  • Is properly guarded by if __name__ == "__main__": to prevent heavy allocations on import

The past review comments claiming lines 227 and 210-221 were incorrect have been resolved.


95-124: Potential division by zero in causal attention edge case.

The online softmax correctly updates logsum, but in causal attention when a query position has no valid keys (all masked to -inf), the following sequence occurs:

  1. All acc_s[i, :] are -inf
  2. scores_max[i] = -inf
  3. After exp2, all scores become 0
  4. scores_sum[i] = 0
  5. logsum[i] = prev * scores_scale + 0 eventually becomes 0
  6. Line 173 then divides by zero: acc_o[i, j] /= logsum[i]

This edge case can occur in causal attention when past_len + bx * block_M positions have no valid keys to attend to.

The commented-out Check_inf logic (lines 108-112) hints at this issue. Consider adding a guard:

         for i, j in T.Parallel(block_M, dim):
-            acc_o[i, j] /= logsum[i]
+            acc_o[i, j] = T.if_then_else(
+                logsum[i] > 0,
+                acc_o[i, j] / logsum[i],
+                0  # or keep acc_o[i, j] unchanged
+            )

Or initialize logsum to a small epsilon rather than zero (line 157).

🧹 Nitpick comments (1)
maint/gemm_v2/latency_mha_fwd_bhsd.py (1)

4-4: Consider replacing star import for clarity.

The star import from tilelang.autotuner makes it unclear which names are imported and triggers static analysis warnings (F403, F405). Since only autotune is used, consider an explicit import.

Apply this diff:

-from tilelang.autotuner import *
+from tilelang.autotuner import autotune
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1d7a098 and 050e3fe.

📒 Files selected for processing (1)
  • maint/gemm_v2/latency_mha_fwd_bhsd.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
maint/gemm_v2/latency_mha_fwd_bhsd.py (6)
tilelang/autotuner/tuner.py (1)
  • autotune (723-816)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/fill.py (2)
  • clear (24-48)
  • fill (9-21)
tilelang/language/gemm.py (2)
  • gemm_v2 (215-434)
  • gemm_v1 (10-211)
tilelang/language/reduce.py (2)
  • reduce_max (50-68)
  • reduce_sum (87-109)
tilelang/language/allocate.py (2)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
🪛 Ruff (0.14.2)
maint/gemm_v2/latency_mha_fwd_bhsd.py

4-4: from tilelang.autotuner import * used; unable to detect undefined names

(F403)


30-30: autotune may be undefined, or defined from star imports

(F405)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (3)
maint/gemm_v2/latency_mha_fwd_bhsd.py (3)

30-52: LGTM: Decorator stack and setup are well-configured.

The autotune and JIT decorators are correctly stacked, the output index is properly specified, and the scale factor cleverly combines 1/sqrt(dim) with log2(e) for the exp2 optimization used later in the Softmax macro.


78-93: LGTM: MMA1, Rescale, and main kernel correctly implement FlashAttention.

The kernel orchestration properly:

  • Allocates fragment and shared memory buffers
  • Implements tiled online softmax with rescaling
  • Computes the causal loop range correctly: min(ceildiv(seq_kv, block_N), ceildiv((bx+1)*block_M + past_len, block_N))
  • Pipelines the operations for efficiency
  • Stores results through shared memory

The only concern is the division by logsum at line 173, already flagged in a previous comment.

Also applies to: 126-177


180-192: LGTM: Reference implementation is clear and correct.

The PyTorch reference correctly implements attention with optional causal masking. The torch.tril diagonal offset of seq_kv - seq_q properly handles the past_len case, matching the TileLang kernel's causal logic.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)

260-268: Address unused unpacked variable tx.

The variable tx (lane ID) is unpacked at line 268 but never used in the subsequent code. This is flagged by static analysis (Ruff).

Apply this diff to prefix the unused variable with an underscore:

-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)

Based on coding guidelines.


366-374: Address unused unpacked variables tx and warp_m.

At line 374, both tx (lane ID) and warp_m are unpacked but never used in the wgmma_rs implementation. This is flagged by static analysis (Ruff).

Apply this diff to prefix the unused variables with underscores:

-            tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
+            _tx, warp_n, _warp_m = self.extract_thread_binding(thread_binding)

Based on coding guidelines.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 050e3fe and 9b8b6b4.

📒 Files selected for processing (2)
  • src/transform/lower_tile_op.cc (5 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/tileop/gemm/gemm_base.py (2)
  • wg_wait (115-116)
  • accum_dtype (59-60)
tilelang/intrinsics/mma_macro_generator.py (2)
  • get_thread_binding (158-164)
  • extract_thread_binding (174-204)
tilelang/intrinsics/mfma_macro_generator.py (2)
  • get_thread_binding (217-223)
  • extract_thread_binding (225-252)
tilelang/language/kernel.py (2)
  • get_thread_binding (171-176)
  • get_thread_binding (306-310)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_ss (1065-1104)
  • ptx_wgmma_rs (1107-1142)
tilelang/language/builtin.py (4)
  • warpgroup_commit_batch (277-283)
  • warpgroup_wait (286-296)
  • initialize_wgmma_descriptor (601-628)
  • warpgroup_fence_operand (433-490)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py

268-268: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


374-374: Unpacked variable tx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


374-374: Unpacked variable warp_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (11)
src/transform/lower_tile_op.cc (5)

13-13: LGTM: Include added for workspace stack.

The #include <vector> is correctly added to support the new std::vector<Array<Buffer>> workspace_stack_ member.


305-306: LGTM: Workspace frame initialization on block entry.

The frame is correctly pushed when entering a block scope, establishing the collection point for workspace buffers created during block processing.


316-322: LGTM: Workspace frame cleanup on block exit.

The workspaces are correctly attached to the block's alloc_buffers and the frame is popped, balancing the push at line 306. The empty check is defensive and ensures robustness.


726-727: LGTM: Workspace stack member declaration.

The workspace_stack_ member is correctly declared as std::vector<Array<Buffer>> with clear documentation explaining its purpose for per-block workspace tracking.


667-680: Verify when the fallback case can occur and document or remove it.

The execution context shows that the callback is passed to tile_op->Lower() (line 696) and should be invoked while a workspace frame is active. Specifically, the frame is pushed at BlockNode::VisitStmt_ line 306 before children are visited, and popped at line 321 after visiting completes. During normal execution within this block scope, workspace_stack_ should not be empty when the callback executes.

However, if the fallback is triggered:

  • A frame is created via emplace_back() at line 677
  • This frame is never popped (the pop at line 321 only occurs in the BlockNode handler)
  • The workspace buffer added to this orphaned frame is never attached to any block's alloc_buffers

This suggests the callback is being invoked from an unexpected execution path outside the normal block context. Either:

  1. Remove the fallback and add ICHECK(!workspace_stack_.empty()) to catch and diagnose the unexpected case, or
  2. Document precisely when this fallback is legitimately triggered and how workspace lifetime is managed in that scenario.
tilelang/intrinsics/wgmma_macro_generator.py (6)

9-9: LGTM: gcd import for dynamic inst_n calculation.

The gcd import from the math module is correctly added to support the new dynamic calculation of wgmma_inst_n.


74-77: LGTM: New WGMMA instruction dimension attributes.

The wgmma_inst_m and wgmma_inst_n attributes are well-documented and store the per-instruction M and N dimensions for tiling calculations.


112-125: LGTM: Dynamic inst_n calculation addresses previous concerns.

The use of gcd(self.warp_col_tiles, 256) for inst_n is mathematically sound:

  1. By definition, gcd(a, b) divides both a and b, so inst_n divides warp_col_tiles
  2. This ensures num_inst_n = warp_col_tiles // wgmma_inst_n is always ≥ 1 (since warp_col_tiles ≥ 8 per line 132)
  3. The assertions correctly validate that inst_n is a multiple of 8 and within [8, 256]

This design addresses the previous review concern about num_inst_n = 0 causing division-by-zero errors.


167-168: LGTM: wg_wait parameter added for configurable synchronization.

The wg_wait parameter with default value 0 enables configurable warpgroup wait synchronization. The signature is consistent with wgmma_rs (lines 311-312).


384-416: wgmma_rs offset calculations are correctly structured and consistent with wgmma method.

The restructured wgmma_rs loop implementation is sound:

  • A_offset (line 389): Correctly handles register tiling with ki * warp_rows * local_size_a + i * local_size_a, properly advancing through K dimension chunks and warp instances
  • B_offset (lines 390-394): Consistent with wgmma using swizzle-aware descriptor offsets
  • C_offset (line 395): Safe division by num_inst_n matches the 4-warp unit organization
  • Fence operations (lines 395, 417): Properly bracket the warpgroup computation
  • wg_wait (lines 414-415): Correctly applied with conditional check

All variables (local_size_a, warp_rows, warp_cols, local_size_out) are properly extracted from self and used consistently.


280-303: Now I need to examine the actual wgmma_macro_generator.py file to understand the context better:

Based on my verification research, I cannot find definitive validation or test cases that explicitly verify the offset calculation formulas shown. To make a conclusive assessment, I would need to examine the full wgmma_macro_generator.py file context—specifically the definitions of warp_m, warp_n, num_inst_m, and num_inst_n, as well as any validation or unit tests for these calculations.

However, my research confirms:

  1. A warpgroup consists of four contiguous warps (128 threads), and the wgmma.mma_async instruction is executed collectively by all 128 threads in a warpgroup
  2. The layout-dependent offset calculations (K-major vs MN-major) align with the canonical GMMA layout atoms and tile_to_shape approach in CUTLASS for accurately calculating offsets and swizzling modes
  3. The test suite includes integration tests for WGMMA-based kernels (FlashAttention, attention sink) that would catch offset errors

The review comment cannot be fully resolved without access to:

  • Full function signatures defining valid ranges for warp_m, warp_n, num_inst_m, num_inst_n
  • Whether the (warp_m // 4) formula and the asymmetrical warp_j = warp_n * num_inst_n + j are both intentional or if one requires correction
  • Unit or integration tests that validate these specific offset calculations

Verify the restructured loop and offset calculations. The changes significantly alter the iteration pattern and memory access paths for WGMMA operands:

  1. Loop order changed to N-outer (j), M-middle (i), K-inner (ki)—verify cache efficiency vs previous implementation
  2. Warp index calculations show asymmetry: warp_i = (warp_m // 4) * num_inst_m + i vs warp_j = warp_n * num_inst_n + j—confirm both formulas match your hardware tiling strategy
  3. Offset calculations for A and B depend on memory layout (K-major vs MN-major)—ensure swizzle atom sizes and ak_atom_size/bk_atom_size are compatible with the instruction shapes

@LeiWang1999 LeiWang1999 merged commit b66a93c into tile-ai:main Nov 5, 2025
5 of 6 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* fix

* lint fix

* fix

* lint fix

* fix

* upd

* support n>256

* Remove unnecessary pass configurations for fast math in MHA forward BHSD latency script.

* lint fix

* lint fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant