-
Notifications
You must be signed in to change notification settings - Fork 333
[Langauge] Support n>256 for v2 #1182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds two new latency benchmarking scripts, expands GEMM test widths, and updates WGMMA generation and GEMM lowering to support per-instruction tiling and a per-warp-group wait parameter ( Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Host as Host Script
participant Compiler as TileLang Compiler
participant Kernel as GPU Kernel
participant WGMMA as WGMMA Intrinsic
Host->>Compiler: request compile (latency_gemm / flashattn)
Compiler->>Kernel: emit kernel with wgmma_inst_n and wg_wait
Kernel->>WGMMA: per-iteration: compute A/B/C offsets (inst_m/inst_n), call wgmma(..., wg_wait)
alt wg_wait > 0
WGMMA-->>WGMMA: perform warp-group wait/sync
end
WGMMA-->>Kernel: complete MMA ops
Kernel-->>Host: return results and latency profile
Note right of Kernel: workspace_stack_ frames scope alloc_buffers per Block
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
…HSD latency script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🧹 Nitpick comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)
291-296: Allow disabling wait viawg_wait = -1explicitlyYou handle
wg_wait >= 0, but the docs mention-1to skip waiting. Consider documenting/validating that value, or treat negatives uniformly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
maint/gemm_v2/correctness_evaluation.py(2 hunks)maint/gemm_v2/latency_gemm.py(1 hunks)maint/gemm_v2/latency_mha_fwd_bhsd.py(1 hunks)tilelang/intrinsics/wgmma_macro_generator.py(9 hunks)tilelang/tileop/gemm/gemm_wgmma.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
maint/gemm_v2/latency_gemm.py (9)
tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/allocate.py (2)
alloc_shared(27-42)alloc_fragment(59-70)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/gemm.py (2)
gemm_v2(215-434)gemm_v1(10-211)tilelang/language/parallel.py (1)
Parallel(9-29)tilelang/jit/kernel.py (1)
get_profiler(367-383)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)
maint/gemm_v2/latency_mha_fwd_bhsd.py (8)
tilelang/autotuner/tuner.py (1)
autotune(723-816)tilelang/jit/__init__.py (1)
jit(233-306)tilelang/transform/pass_config.py (1)
PassConfigKey(6-144)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/fill.py (2)
clear(24-48)fill(9-21)tilelang/language/gemm.py (2)
gemm_v2(215-434)gemm_v1(10-211)tilelang/language/allocate.py (2)
alloc_shared(27-42)alloc_fragment(59-70)tilelang/profiler/__init__.py (1)
assert_allclose(77-146)
tilelang/tileop/gemm/gemm_wgmma.py (2)
tilelang/tileop/gemm/gemm_base.py (2)
wg_wait(115-116)clear_accum(107-108)tilelang/intrinsics/wgmma_macro_generator.py (1)
wgmma(160-295)
tilelang/intrinsics/wgmma_macro_generator.py (4)
tilelang/tileop/gemm/gemm_base.py (3)
clear_accum(107-108)wg_wait(115-116)accum_dtype(59-60)tilelang/utils/language.py (1)
is_fragment(81-91)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1065-1104)ptx_wgmma_rs(1107-1142)tilelang/language/builtin.py (2)
warpgroup_commit_batch(277-283)warpgroup_wait(286-296)
🪛 Ruff (0.14.2)
maint/gemm_v2/latency_mha_fwd_bhsd.py
4-4: from tilelang.autotuner import * used; unable to detect undefined names
(F403)
29-29: autotune may be undefined, or defined from star imports
(F405)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (2)
maint/gemm_v2/correctness_evaluation.py (1)
386-388: Double-check block coverage for new N valuesAdding
256and512increases the N sweep; just confirm the block configuration (block_N = 128) still divides all tested shapes (it does). No code changes needed.tilelang/tileop/gemm/gemm_wgmma.py (1)
90-122: Propagation ofwg_waitlooks correctGlad to see the new wait parameter issued to both SS and RS kernels. Nothing else required.
| if use_v2: | ||
| T.gemm_v2(A_shared, B_shared, C_local) | ||
| else: | ||
| T.gemm_v1(A_shared, B_shared, C_local) | ||
|
|
||
| # relu | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| C_local[i, j] = T.max(C_local[i, j], 0) | ||
|
|
||
| # Copy result back to global memory | ||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_v2 captured at call-site breaks autotuned matmul reuse
use_v2 is read from the CLI once and closed over by the nested matmul_relu_kernel. The compiled kernel therefore hardcodes whatever flag was set during compilation. When you later call matmul again with a different use_v2 expectation (e.g., toggling between v1/v2 from the CLI or a tuning sweep), the already-compiled kernel silently keeps the previous path. This leads to confusing latency numbers and invalidates correctness checks when you intend to benchmark both variants in one process. Please pass the flag as an explicit parameter to matmul and thread it through the kernel invocation rather than capturing a global.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 46-57, the nested
matmul_relu_kernel currently closes over the CLI variable use_v2 which hardcodes
the chosen path into the compiled kernel; instead make use_v2 an explicit
argument to matmul and add it as a parameter to the kernel signature so the
kernel reads the flag at call time. Update all invocations to pass the boolean
through, modify the kernel body to branch on the kernel parameter (not a
closed-over variable), and ensure any autotuning/compilation cache keys include
this parameter so separate v1/v2 kernels are compiled and reused correctly.
| a = torch.randn(M, K, device="cuda", dtype=torch.float16) | ||
| b = torch.randn(K, N, device="cuda", dtype=torch.float16) | ||
| c = torch.empty(M, N, device="cuda", dtype=torch.float16) | ||
|
|
||
| # Run the kernel through the Profiler | ||
| matmul_relu_kernel(a, b, c) | ||
|
|
||
| print(c) | ||
| # Reference multiplication using PyTorch | ||
| ref_c = torch.relu(a @ b) | ||
|
|
||
| # Validate correctness | ||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||
| print("Kernel output matches PyTorch reference.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid allocating 16k×16k fp16 tensors on import
Instantiating 16 384² tensors at module import consumes ~1 GiB per tensor and triggers GPU allocation before __main__. This makes latency_gemm.py unusable as a library and can crash on machines without that much free memory. Move the heavy allocations under if __name__ == "__main__": (or into a main function) so importing the module for reuse or running unit tests doesn’t instantly OOM the GPU.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 75–88 the script allocates large
fp16 CUDA tensors (a, b, c) and runs the kernel at module import time which can
OOM GPUs and prevents safe import; move the heavy allocations, kernel call,
reference computation and assertion inside an if __name__ == "__main__": block
(or a main() function) so importing the module does not perform GPU allocations
or execute the test — i.e., wrap lines that create a, b, c, call
matmul_relu_kernel, print, compute ref_c, and assert_close into a guarded main
section and leave only function/class definitions at top-level.
| profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) | ||
|
|
||
| latency = profiler.do_bench() | ||
|
|
||
| print(f"Latency: {latency} ms") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Profiler benchmark should benchmark the TileLang kernel, not the reference
profiler.do_bench(ref_program_processed, warmup=500) times the PyTorch reference instead of the compiled kernel. That double-counts the reference execution and skews the reported latency. Drop the reference callable when benchmarking the kernel; reserve it for assert_allclose.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_gemm.py around lines 95 to 99, the current
profiler.do_bench call is timing the PyTorch reference (ref_program_processed)
instead of the compiled TileLang kernel; remove the reference callable from the
do_bench invocation so the profiler only benchmarks the compiled kernel (e.g.,
call profiler.do_bench(warmup=500) or equivalent), and keep the reference
callable solely for correctness checks like assert_allclose after benchmarking.
| def get_configs(): | ||
| iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) | ||
| return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_configs builds configs incorrectly
dict(zip(iter_params, values)) zips dictionary keys with value tuples, leaving only the last key/value pair. You end up with configs like {'threads': 256} and silently drop block_M, block_N, num_stages. Use itertools.product(*iter_params.values()) with proper key association (e.g., {k: v for k, v in zip(iter_params.keys(), values)}) to retain all knobs.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 24–27, get_configs
currently uses dict(zip(iter_params, values)) which zips dictionary keys (not
key list) with values and ends up keeping only the last pair; replace that
construction with a proper key-to-value association such as
dict(zip(iter_params.keys(), values)) or a dict comprehension {k: v for k, v in
zip(iter_params.keys(), values)} so every knob (block_M, block_N, num_stages,
threads) is preserved for each product combination.
| if use_v2: | ||
| T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
| else: | ||
| T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Causal mask should leave untouched scores when non-causal
When is_causal is false you T.clear(acc_s) but skip it otherwise. For the non-causal path the previous tile’s data remains; T.clear should be executed regardless, with causal logic overwriting masked entries afterward.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 74-78, the current code
only clears acc_s for the causal path which leaves previous tile data in the
non-causal path; always call T.clear(acc_s) before performing the GEMM (move the
clear before the use_v2 conditional or add it in both branches), then keep the
causal masking logic as-is to overwrite masked entries after the GEMM so
non-causal scores remain untouched.
| scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) | ||
|
|
||
| for i, j in T.Parallel(block_M, block_N): | ||
| # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - | ||
| # max * log_2(e)) This allows the compiler to use the ffma | ||
| # instruction instead of fadd and fmul separately. | ||
| acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) | ||
| T.reduce_sum(acc_s, scores_sum, dim=1) | ||
| for i in T.Parallel(block_M): | ||
| logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
| T.copy(acc_s, acc_s_cast) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Log-sum-exp accumulation can underflow to zero
scores_scale multiplies logsum before the first tile has written anything. If logsum[i] stays at 0 for early iterations, the subsequent division acc_o[i, j] /= logsum[i] risks division by zero. Initialize logsum to a tiny epsilon (or branch to skip scaling when the previous sum is zero).
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 115 to 125, logsum is
multiplied by scores_scale before any tile may have written to it, allowing
logsum to remain zero and later cause division-by-zero when acc_o is divided by
logsum; fix by initializing logsum to a small positive epsilon (e.g. 1e-12)
before the loop or add a conditional/branch that skips scaling and division when
the previous logsum is zero, ensuring any updates use max(logsum, epsilon) or
guard the divide with a check to avoid division by zero.
| if (not tune): | ||
| kernel = flashattn( | ||
| batch, | ||
| heads, | ||
| seq_q, | ||
| seq_kv, | ||
| dim, | ||
| is_causal, | ||
| block_M=64, | ||
| block_N=64, | ||
| num_stages=0, | ||
| threads=128) | ||
| print(kernel.get_kernel_source()) | ||
| ref_program_processed = partial(ref_program, is_causal=is_causal) | ||
|
|
||
| profiler = kernel.get_profiler() | ||
| profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) | ||
| print("All checks pass.") | ||
| latency = profiler.do_bench(ref_program_processed, warmup=500) | ||
| print("Ref: {:.2f} ms".format(latency)) | ||
| print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) | ||
| latency = profiler.do_bench(warmup=500) | ||
| print("Tile-lang: {:.2f} ms".format(latency)) | ||
| print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark still times PyTorch reference
Same as GEMM: calling profiler.do_bench(ref_program_processed, ...) benchmarks the reference. Remove the argument so you’re timing the TileLang kernel.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 210-233, the call
profiler.do_bench(ref_program_processed, warmup=500) is benchmarking the PyTorch
reference instead of the TileLang kernel; remove the ref_program_processed
argument so profiler.do_bench(warmup=500) runs the TileLang kernel, keep the
reference run only where you intend to benchmark the reference (or rename the
printed labels if you swap the order), and ensure assert_allclose still uses
ref_program_processed.
| kernel = flashattn( | ||
| batch, | ||
| heads, | ||
| seq_q, | ||
| seq_kv, | ||
| dim, | ||
| is_causal, | ||
| block_M=64, | ||
| block_N=64, | ||
| num_stages=0, | ||
| threads=128) | ||
| print(kernel.get_kernel_source()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heavy allocations executed on import
Like the GEMM script, this file allocates large tensors at module import. Wrap the execution logic in a main() guarded by if __name__ == "__main__": to avoid OOM when imported.
🤖 Prompt for AI Agents
In maint/gemm_v2/latency_mha_fwd_bhsd.py around lines 211 to 222, the script
performs heavy tensor allocations and prints the kernel source at import; wrap
this execution in a main() function and protect it with if __name__ ==
"__main__": so imports don't allocate memory. Move the kernel creation,
large-tensor allocations, and print(kernel.get_kernel_source()) into the new
main(), keep any helper/definition code at module scope, and call main() only
inside the if __name__ == "__main__": guard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
maint/gemm_v2/latency_mha_fwd_bhsd.py (4)
25-28: Single configuration defeats autotuning purpose.The
get_configsimplementation is correct (the past review comment incorrectly identified a bug), but each parameter list contains only one value, yielding exactly one configuration:{'block_M': 128, 'block_N': 128, 'num_stages': 2, 'threads': 256}. This defeats the purpose of autotuning—consider expanding the search space with multiple values per parameter.For example:
def get_configs(): - iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + iter_params = dict( + block_M=[64, 128], + block_N=[64, 128], + num_stages=[0, 2], + threads=[128, 256] + ) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
54-77: LGTM: Causal masking correctly initializes before GEMM.The causal and non-causal paths both correctly initialize
acc_sbefore the GEMM operation:
- Causal: sets entries to
0(valid positions) or-inf(masked positions) so GEMM can accumulate on the initialized values- Non-causal: clears to
0for clean accumulationThis matches the FlashAttention pattern. The past review comment suggesting both paths should clear first was incorrect.
195-246: LGTM: Benchmarking and entry point are correctly implemented.The
mainfunction properly:
- Calculates theoretical FLOPS (including 0.5x adjustment for causal attention)
- Validates correctness against PyTorch with
assert_allclose- Benchmarks both reference (line 227) and TileLang (line 230) separately with correct labels
- Supports both manual and autotuned modes
- Is properly guarded by
if __name__ == "__main__":to prevent heavy allocations on importThe past review comments claiming lines 227 and 210-221 were incorrect have been resolved.
95-124: Potential division by zero in causal attention edge case.The online softmax correctly updates
logsum, but in causal attention when a query position has no valid keys (all masked to-inf), the following sequence occurs:
- All
acc_s[i, :]are-infscores_max[i] = -inf- After
exp2, all scores become0scores_sum[i] = 0logsum[i] = prev * scores_scale + 0eventually becomes0- Line 173 then divides by zero:
acc_o[i, j] /= logsum[i]This edge case can occur in causal attention when
past_len + bx * block_Mpositions have no valid keys to attend to.The commented-out Check_inf logic (lines 108-112) hints at this issue. Consider adding a guard:
for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] + acc_o[i, j] = T.if_then_else( + logsum[i] > 0, + acc_o[i, j] / logsum[i], + 0 # or keep acc_o[i, j] unchanged + )Or initialize
logsumto a small epsilon rather than zero (line 157).
🧹 Nitpick comments (1)
maint/gemm_v2/latency_mha_fwd_bhsd.py (1)
4-4: Consider replacing star import for clarity.The star import from
tilelang.autotunermakes it unclear which names are imported and triggers static analysis warnings (F403, F405). Since onlyautotuneis used, consider an explicit import.Apply this diff:
-from tilelang.autotuner import * +from tilelang.autotuner import autotune
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
maint/gemm_v2/latency_mha_fwd_bhsd.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
maint/gemm_v2/latency_mha_fwd_bhsd.py (6)
tilelang/autotuner/tuner.py (1)
autotune(723-816)tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/fill.py (2)
clear(24-48)fill(9-21)tilelang/language/gemm.py (2)
gemm_v2(215-434)gemm_v1(10-211)tilelang/language/reduce.py (2)
reduce_max(50-68)reduce_sum(87-109)tilelang/language/allocate.py (2)
alloc_shared(27-42)alloc_fragment(59-70)
🪛 Ruff (0.14.2)
maint/gemm_v2/latency_mha_fwd_bhsd.py
4-4: from tilelang.autotuner import * used; unable to detect undefined names
(F403)
30-30: autotune may be undefined, or defined from star imports
(F405)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
maint/gemm_v2/latency_mha_fwd_bhsd.py (3)
30-52: LGTM: Decorator stack and setup are well-configured.The autotune and JIT decorators are correctly stacked, the output index is properly specified, and the
scalefactor cleverly combines1/sqrt(dim)withlog2(e)for the exp2 optimization used later in the Softmax macro.
78-93: LGTM: MMA1, Rescale, and main kernel correctly implement FlashAttention.The kernel orchestration properly:
- Allocates fragment and shared memory buffers
- Implements tiled online softmax with rescaling
- Computes the causal loop range correctly:
min(ceildiv(seq_kv, block_N), ceildiv((bx+1)*block_M + past_len, block_N))- Pipelines the operations for efficiency
- Stores results through shared memory
The only concern is the division by
logsumat line 173, already flagged in a previous comment.Also applies to: 126-177
180-192: LGTM: Reference implementation is clear and correct.The PyTorch reference correctly implements attention with optional causal masking. The
torch.trildiagonal offset ofseq_kv - seq_qproperly handles thepast_lencase, matching the TileLang kernel's causal logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/intrinsics/wgmma_macro_generator.py (2)
260-268: Address unused unpacked variabletx.The variable
tx(lane ID) is unpacked at line 268 but never used in the subsequent code. This is flagged by static analysis (Ruff).Apply this diff to prefix the unused variable with an underscore:
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)Based on coding guidelines.
366-374: Address unused unpacked variablestxandwarp_m.At line 374, both
tx(lane ID) andwarp_mare unpacked but never used in thewgmma_rsimplementation. This is flagged by static analysis (Ruff).Apply this diff to prefix the unused variables with underscores:
- tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + _tx, warp_n, _warp_m = self.extract_thread_binding(thread_binding)Based on coding guidelines.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/lower_tile_op.cc(5 hunks)tilelang/intrinsics/wgmma_macro_generator.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/intrinsics/wgmma_macro_generator.py (6)
tilelang/tileop/gemm/gemm_base.py (2)
wg_wait(115-116)accum_dtype(59-60)tilelang/intrinsics/mma_macro_generator.py (2)
get_thread_binding(158-164)extract_thread_binding(174-204)tilelang/intrinsics/mfma_macro_generator.py (2)
get_thread_binding(217-223)extract_thread_binding(225-252)tilelang/language/kernel.py (2)
get_thread_binding(171-176)get_thread_binding(306-310)tilelang/language/tir/op.py (2)
ptx_wgmma_ss(1065-1104)ptx_wgmma_rs(1107-1142)tilelang/language/builtin.py (4)
warpgroup_commit_batch(277-283)warpgroup_wait(286-296)initialize_wgmma_descriptor(601-628)warpgroup_fence_operand(433-490)
🪛 Ruff (0.14.3)
tilelang/intrinsics/wgmma_macro_generator.py
268-268: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
374-374: Unpacked variable tx is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
374-374: Unpacked variable warp_m is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (11)
src/transform/lower_tile_op.cc (5)
13-13: LGTM: Include added for workspace stack.The
#include <vector>is correctly added to support the newstd::vector<Array<Buffer>> workspace_stack_member.
305-306: LGTM: Workspace frame initialization on block entry.The frame is correctly pushed when entering a block scope, establishing the collection point for workspace buffers created during block processing.
316-322: LGTM: Workspace frame cleanup on block exit.The workspaces are correctly attached to the block's
alloc_buffersand the frame is popped, balancing the push at line 306. The empty check is defensive and ensures robustness.
726-727: LGTM: Workspace stack member declaration.The
workspace_stack_member is correctly declared asstd::vector<Array<Buffer>>with clear documentation explaining its purpose for per-block workspace tracking.
667-680: Verify when the fallback case can occur and document or remove it.The execution context shows that the callback is passed to
tile_op->Lower()(line 696) and should be invoked while a workspace frame is active. Specifically, the frame is pushed at BlockNode::VisitStmt_ line 306 before children are visited, and popped at line 321 after visiting completes. During normal execution within this block scope,workspace_stack_should not be empty when the callback executes.However, if the fallback is triggered:
- A frame is created via
emplace_back()at line 677- This frame is never popped (the pop at line 321 only occurs in the BlockNode handler)
- The workspace buffer added to this orphaned frame is never attached to any block's
alloc_buffersThis suggests the callback is being invoked from an unexpected execution path outside the normal block context. Either:
- Remove the fallback and add
ICHECK(!workspace_stack_.empty())to catch and diagnose the unexpected case, or- Document precisely when this fallback is legitimately triggered and how workspace lifetime is managed in that scenario.
tilelang/intrinsics/wgmma_macro_generator.py (6)
9-9: LGTM: gcd import for dynamic inst_n calculation.The
gcdimport from themathmodule is correctly added to support the new dynamic calculation ofwgmma_inst_n.
74-77: LGTM: New WGMMA instruction dimension attributes.The
wgmma_inst_mandwgmma_inst_nattributes are well-documented and store the per-instruction M and N dimensions for tiling calculations.
112-125: LGTM: Dynamic inst_n calculation addresses previous concerns.The use of
gcd(self.warp_col_tiles, 256)forinst_nis mathematically sound:
- By definition,
gcd(a, b)divides bothaandb, soinst_ndivideswarp_col_tiles- This ensures
num_inst_n = warp_col_tiles // wgmma_inst_nis always ≥ 1 (sincewarp_col_tiles ≥ 8per line 132)- The assertions correctly validate that
inst_nis a multiple of 8 and within [8, 256]This design addresses the previous review concern about
num_inst_n = 0causing division-by-zero errors.
167-168: LGTM: wg_wait parameter added for configurable synchronization.The
wg_waitparameter with default value 0 enables configurable warpgroup wait synchronization. The signature is consistent withwgmma_rs(lines 311-312).
384-416: wgmma_rs offset calculations are correctly structured and consistent with wgmma method.The restructured
wgmma_rsloop implementation is sound:
- A_offset (line 389): Correctly handles register tiling with
ki * warp_rows * local_size_a + i * local_size_a, properly advancing through K dimension chunks and warp instances- B_offset (lines 390-394): Consistent with
wgmmausing swizzle-aware descriptor offsets- C_offset (line 395): Safe division by
num_inst_nmatches the 4-warp unit organization- Fence operations (lines 395, 417): Properly bracket the warpgroup computation
- wg_wait (lines 414-415): Correctly applied with conditional check
All variables (
local_size_a,warp_rows,warp_cols,local_size_out) are properly extracted fromselfand used consistently.
280-303: Now I need to examine the actual wgmma_macro_generator.py file to understand the context better:Based on my verification research, I cannot find definitive validation or test cases that explicitly verify the offset calculation formulas shown. To make a conclusive assessment, I would need to examine the full
wgmma_macro_generator.pyfile context—specifically the definitions ofwarp_m,warp_n,num_inst_m, andnum_inst_n, as well as any validation or unit tests for these calculations.However, my research confirms:
- A warpgroup consists of four contiguous warps (128 threads), and the wgmma.mma_async instruction is executed collectively by all 128 threads in a warpgroup
- The layout-dependent offset calculations (K-major vs MN-major) align with the canonical GMMA layout atoms and tile_to_shape approach in CUTLASS for accurately calculating offsets and swizzling modes
- The test suite includes integration tests for WGMMA-based kernels (FlashAttention, attention sink) that would catch offset errors
The review comment cannot be fully resolved without access to:
- Full function signatures defining valid ranges for
warp_m,warp_n,num_inst_m,num_inst_n- Whether the
(warp_m // 4)formula and the asymmetricalwarp_j = warp_n * num_inst_n + jare both intentional or if one requires correction- Unit or integration tests that validate these specific offset calculations
Verify the restructured loop and offset calculations. The changes significantly alter the iteration pattern and memory access paths for WGMMA operands:
- Loop order changed to N-outer (j), M-middle (i), K-inner (ki)—verify cache efficiency vs previous implementation
- Warp index calculations show asymmetry:
warp_i = (warp_m // 4) * num_inst_m + ivswarp_j = warp_n * num_inst_n + j—confirm both formulas match your hardware tiling strategy- Offset calculations for A and B depend on memory layout (K-major vs MN-major)—ensure swizzle atom sizes and
ak_atom_size/bk_atom_sizeare compatible with the instruction shapes
* fix * lint fix * fix * lint fix * fix * upd * support n>256 * Remove unnecessary pass configurations for fast math in MHA forward BHSD latency script. * lint fix * lint fix
This pull request introduces significant improvements and new features to the GEMM and MHA kernel implementations, with a particular focus on the TileLang-based CUDA kernels and the WGMMA macro generator. The most important changes include the addition of new latency benchmarking scripts for GEMM and MHA, major enhancements to the WGMMA macro generator to support more flexible instruction sizes, and expanded test coverage for GEMM correctness. These updates improve performance profiling, correctness validation, and hardware compatibility for advanced GPU architectures.
New benchmarking scripts
GEMM and MHA latency profiling:
latency_gemm.pyfor profiling GEMM kernel latency, including correctness validation against PyTorch and support for both v1 and v2 GEMM kernels via a command-line flag.latency_mha_fwd_bhsd.pyfor profiling forward MHA kernels with autotuning, flexible configuration, and correctness checks against a PyTorch reference implementation.WGMMA macro generator enhancements
Flexible instruction sizing and hardware compatibility:
wgmma_macro_generator.pyto support dynamic calculation of the WGMMA instruction N dimension (wgmma_inst_n), enforcing hardware constraints (multiple of 8, range [8, 256]) and using the greatest common divisor for optimal sizing. [1] [2]num_inst_n), updating address calculations and output offsets for bothwgmmaandwgmma_rsmethods to support tiling and improved parallelism. [1] [2] [3] [4] [5]wg_waitargument, allowing for more flexible synchronization. [1] [2]GEMM correctness evaluation
Expanded test coverage:
Minor improvements
Codebase maintenance:
gcdfrom the math library inwgmma_macro_generator.py.wg_waitargument.Summary by CodeRabbit
New Features
Tests
Refactor