-
Notifications
You must be signed in to change notification settings - Fork 331
[Example] Add example #894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…and improve flash attention implementation * Updated the CopyNode Lower method to correctly include the disable_tma flag in the GetCopyInst call. * Refactored the flash attention implementation to selectively disable TMA for specific copy operations while allowing it for others. * Addressed linting issues for improved code quality
|
Caution Review failedThe pull request is closed. WalkthroughAdds a new README file under examples/deepseek_v32 with placeholder text "Comming Soon.". Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~2 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (19)
examples/deepseek_v32/sparse_mla_fwd.py (5)
32-33: Prefer direct truthiness for causal flag.Avoid
== True. Use a direct assert.- assert is_causal == True, "non-casual is not supported" + assert is_causal, "non-causal is not supported"
54-61: Drop unusedG.
G = kv_groupisn’t used.- G = kv_group
183-200: Interface polish: fix spelling, truth check, unused unpack.
- Use
is_causalconsistently (spelled “causal”).- Avoid
== False.- Prefix unused unpacked value.
- is_casual = True - assert return_p_sum == False, "This kernel file is for fwd only" + is_causal = True + assert not return_p_sum, "This kernel file is for fwd only" @@ - _, seq_len_kv, kv_group, _ = kv.shape + _, _seq_len_kv, kv_group, _ = kv.shape @@ - kernel = sparse_attention_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual) + kernel = sparse_attention_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_causal)
203-218: Tidy reference interface: unused arg/var.Prefix unused
is_casualand dropnum_kv_per_index.-def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): +def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, _is_causal=True): @@ - num_kv_per_index = 1
265-280: Silence unused and fix f-strings without placeholders.Rename unused and remove extraneous
fprefixes.- tl_out, tl_lse = sparse_attention_fwd_interface(q, kv, indices) + _tl_out, _tl_lse = sparse_attention_fwd_interface(q, kv, indices) @@ - print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)examples/deepseek_v32/utils.py (3)
2-7: Remove unused imports to pass lint.Drop
torch.nn.functional,triton,triton.language, andcontextlib(unused).-import torch.nn.functional as F -import triton -import triton.language as tl - -import contextlib +# (removed unused: F, triton, tl, contextlib)
197-203: Raise explicitly instead ofassert False.
assertstatements can be stripped with-O. RaiseAssertionErrorfor robustness.- if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') - assert False + if not (0 <= diff <= eps): + print_red_warning(f'{name} Error: {diff}') + raise AssertionError(f"{name} similarity diff {diff} exceeds {eps}")
216-220: Avoid assigning a lambda; define a function.Satisfy E731 and improve readability.
- fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) + def fn(): + return cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len)examples/deepseek_v32/sparse_mla_fwd_pipelined.py (4)
41-45: Prefer direct truthiness for causal flag.Avoid
== True.- assert is_causal == True, 'non-casual is not supported' + assert is_causal, 'non-causal is not supported'
58-66: Drop unusedG.
sparse_attention_fwddoesn’t useG = kv_group.- G = kv_group
449-451: Remove extraneousfprefixes in prints.Fix F541.
- print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + print('fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print('fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
321-324: Consistent naming: useis_causalin public signatures.Interfaces expose
is_casual. Rename tois_causalfor consistency with kernels and reference.examples/deepseek_v32/fp8_mqa_logits.py (7)
10-13: Avoid host sync in assertion and handle zeros robustly.
.item()syncs to CPU; also guardlog2(0)safely.-def ceil_to_ue8m0(x: torch.Tensor): - assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) +def ceil_to_ue8m0(x: torch.Tensor): + x_abs = x.abs().clamp_min(1e-8) + assert torch.any(x_abs > 0) + return (2.0 ** torch.ceil(torch.log2(x_abs)))
71-75: Guardblock_Qderivation against zero/non-divisible heads.
block_Q = 128 // headscan be 0 for largeheads. Ensure at least 1 and preferably a multiple of warp size.- if block_Q is None: - block_Q = 128 // heads + if block_Q is None: + block_Q = max(1, 128 // max(1, heads))
153-160: Optional: assert compatible launch params to avoid empty inner loops.If
block_K % threads != 0, inner serial loop becomes 0. Add a guard or assertion.@tilelang.jit def clean_logits_( threads: int = 512, block_K: int = 4096, ): + assert (block_K % threads) == 0, "block_K must be divisible by threads"
248-256: Makegenerate_random_cu_seqlensrobust if cumulative sum never reaches target.Very rare, but possible (e.g., many zeros) leading to index error on
where(...)[0][0]. Add a fallback.- last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] - cu_seqlens = cu_seqlens[:last_seq_id] + hits = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0] + cu_seqlens = cu_seqlens[: hits[0]] if hits.numel() > 0 else cu_seqlens
257-257: Remove unused variable to fix CI (Ruff F841).- total_seqlen_k = (cu_seqlens // kv_stride).sum()
314-326: Rename unused loop variable to fix CI (Ruff B007).- for i in range(10): + for _ in range(10):
239-239: Drop unused tensorp.Dead code; remove to keep the example lean.
- p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/deepseek_v32/fp8_mqa_logits.py(1 hunks)examples/deepseek_v32/sparse_mla_fwd.py(1 hunks)examples/deepseek_v32/sparse_mla_fwd_pipelined.py(1 hunks)examples/deepseek_v32/utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/deepseek_v32/fp8_mqa_logits.py (7)
examples/deepseek_v32/utils.py (2)
cal_cu_seqlen_ke_for_q(107-128)cal_cu_seqlen_ks_for_q(90-103)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/allocate.py (3)
alloc_shared(21-36)alloc_fragment(53-64)alloc_local(39-50)tilelang/language/builtin.py (1)
no_set_max_nreg(160-163)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/parallel.py (1)
Parallel(8-28)
examples/deepseek_v32/sparse_mla_fwd.py (9)
examples/deepseek_v32/utils.py (3)
print_red_warning(183-184)calc_sim(187-194)assert_similar(197-202)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/math/__init__.py (1)
next_power_of_2(1-2)tilelang/language/__init__.py (1)
symbolic(83-94)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/fill.py (2)
fill(9-21)clear(24-48)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/reduce.py (2)
reduce_max(50-68)reduce_sum(87-109)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (5)
tilelang/engine/callback.py (1)
register_cuda_postproc_callback(28-58)examples/deepseek_v32/sparse_mla_fwd.py (5)
sparse_attention_fwd(15-174)main(75-172)sparse_attention_fwd_interface(177-200)ref_sparse_attention_fwd_interface(203-238)fn(267-268)tilelang/language/allocate.py (3)
alloc_shared(21-36)alloc_fragment(53-64)alloc_barrier(80-89)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/builtin.py (4)
barrier_arrive(280-287)set_max_nreg(118-133)barrier_wait(265-277)cp_async_barrier_noinc(366-369)
🪛 Ruff (0.13.1)
examples/deepseek_v32/utils.py
202-202: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
218-218: Do not assign a lambda expression, use a def
Rewrite fn as a def
(E731)
examples/deepseek_v32/fp8_mqa_logits.py
257-257: Local variable total_seqlen_k is assigned to but never used
Remove assignment to unused variable total_seqlen_k
(F841)
314-314: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
examples/deepseek_v32/sparse_mla_fwd.py
23-23: Unused function argument: CP0
(ARG001)
32-32: Avoid equality comparisons to True; use is_causal: for truth checks
Replace with is_causal
(E712)
54-54: Local variable G is assigned to but never used
Remove assignment to unused variable G
(F841)
184-184: Avoid equality comparisons to False; use not return_p_sum: for false checks
Replace with not return_p_sum
(E712)
187-187: Unpacked variable seq_len_kv is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
203-203: Unused function argument: is_casual
(ARG001)
216-216: Local variable num_kv_per_index is assigned to but never used
Remove assignment to unused variable num_kv_per_index
(F841)
265-265: Unpacked variable tl_out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
265-265: Unpacked variable tl_lse is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
278-278: f-string without any placeholders
Remove extraneous f prefix
(F541)
279-279: f-string without any placeholders
Remove extraneous f prefix
(F541)
🪛 GitHub Actions: CI Test on AMD
examples/deepseek_v32/utils.py
[error] 2-2: F401 [*] torch.nn.functional imported but unused.
[error] 3-3: F401 [*] triton imported but unused.
[error] 4-4: F401 [*] triton.language imported but unused.
[error] 6-6: F401 [*] contextlib imported but unused.
examples/deepseek_v32/fp8_mqa_logits.py
[error] 257-257: F841 Local variable total_seqlen_k is assigned to but never used.
[error] 314-314: B007 Loop control variable i not used within loop body.
examples/deepseek_v32/sparse_mla_fwd.py
[error] 4-4: F401 [*] tilelang.tvm imported but unused.
[error] 5-5: F401 [*] utils.print_red_warning imported but unused.
[error] 5-5: F401 [*] utils.calc_sim imported but unused.
[error] 5-5: F401 [*] utils.assert_similar imported but unused.
[error] 32-32: E712 Avoid equality comparisons to True; use if is_causal: for truth checks.
[error] 54-54: F841 Local variable G is assigned to but never used.
[error] 184-184: E712 Avoid equality comparisons to False; use if not return_p_sum: for truth checks.
[error] 216-216: F841 Local variable num_kv_per_index is assigned to but never used.
[error] 278-279: F541 [*] f-string without any placeholders. Remove extraneous f prefix.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (1)
examples/deepseek_v32/fp8_mqa_logits.py (1)
214-231: Reference path matches kernel math (ReLU and weighting).The FP32 reference mirrors the compiled path’s ReLU, weighting, and masking. No issues spotted.
| from tilelang import language as T | ||
| import torch | ||
| from tilelang import tvm | ||
| from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Fix import to be package-relative to this module.
from utils import ... will fail if examples/deepseek_v32 isn’t on PYTHONPATH. Use an absolute or relative import.
-from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar
+from examples.deepseek_v32.utils import (
+ cal_cu_seqlen_ke_for_q,
+ cal_cu_seqlen_ks_for_q,
+ assert_similar,
+)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar | |
| from examples.deepseek_v32.utils import ( | |
| cal_cu_seqlen_ke_for_q, | |
| cal_cu_seqlen_ks_for_q, | |
| assert_similar, | |
| ) |
🤖 Prompt for AI Agents
In examples/deepseek_v32/fp8_mqa_logits.py around line 6, the import uses a bare
module name ("from utils import ...") which will fail when examples/deepseek_v32
is not on PYTHONPATH; change it to a package-relative import (e.g., "from .utils
import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar") or an
absolute import (e.g., "from examples.deepseek_v32.utils import ..."), and
ensure the package has an __init__.py if needed so the relative import works.
| with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: | ||
|
|
||
| index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) | ||
| index_k_shared = T.alloc_shared([block_N, index_dim], dtype) | ||
| index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) | ||
| s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) | ||
| s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) | ||
| logits = T.alloc_fragment([block_N, block_Q], accum_dtype) | ||
| weights = T.alloc_fragment([block_Q, heads], accum_dtype) | ||
|
|
||
| seq_len_i = bx * block_Q | ||
|
|
||
| cu_k_s_min = T.alloc_local([1], index_dtype) | ||
| cu_k_e_max = T.alloc_local([1], index_dtype) | ||
|
|
||
| T.no_set_max_nreg() | ||
|
|
||
| cu_k_s_min[0] = 2147483647 | ||
| cu_k_e_max[0] = -2147483648 | ||
|
|
||
| for bq_i in T.serial(block_Q): | ||
| cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], | ||
| seq_len_kv)) | ||
| for bq_i in T.serial(block_Q): | ||
| cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], | ||
| seq_len_kv)) | ||
|
|
||
| T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) | ||
| T.copy(Weights[seq_len_i, 0], weights) | ||
|
|
||
| for nbn_i in T.Pipelined( | ||
| T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): | ||
| T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) | ||
| T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) | ||
|
|
||
| T.gemm( | ||
| index_k_shared, | ||
| index_q_shared, | ||
| s, | ||
| transpose_B=True, | ||
| clear_accum=True, | ||
| policy=T.GemmWarpPolicy.FullCol, | ||
| ) | ||
|
|
||
| for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): | ||
| s_reshaped[bn_i, bq_i, | ||
| h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * | ||
| weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] | ||
|
|
||
| T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) | ||
|
|
||
| for bq_i, bn_i in T.Parallel(block_Q, block_N): | ||
| Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( | ||
| logits[bn_i, bq_i]) | ||
|
|
||
| return mqa_attn_return_logits_kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOB reads/writes on tail tiles for Q and K; add dynamic valid extents.
Current code copies, computes, and stores full block_Q/block_N even on the last tiles, risking OOB on:
- Q/Weights copy when
seq_len % block_Q != 0 - K/Scale copy and store when
(cu_k_e_max - cu_k_s_min) % block_N != 0
Fix by computing valid_Q and per-iteration valid_K, slicing copies, and limiting parallel loops and stores.
@@
- with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
+ with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
@@
- seq_len_i = bx * block_Q
+ seq_len_i = bx * block_Q
+ # valid rows of Q in this block
+ valid_Q = T.min(block_Q, seq_len - seq_len_i)
@@
- T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
- T.copy(Weights[seq_len_i, 0], weights)
+ # Copy only valid_Q rows
+ T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared[0:valid_Q * heads, 0])
+ T.copy(Weights[seq_len_i, 0], weights[0:valid_Q, 0])
@@
- for nbn_i in T.Pipelined(
- T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
- T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
- T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
+ for nbn_i in T.Pipelined(
+ T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
+ base_k = cu_k_s_min[0] + nbn_i * block_N
+ valid_K = T.min(block_N, cu_k_e_max[0] - base_k)
+ # Copy only valid_K rows
+ T.copy(IndexK[base_k, 0], index_k_shared[0:valid_K, 0])
+ T.copy(IndexKScale[base_k], index_k_scale_fragment[0:valid_K])
@@
- for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
+ # Compute only on valid tiles
+ for bn_i, bq_i, h_i in T.Parallel(valid_K, valid_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
@@
- T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
+ # Reduce only the valid region
+ T.reduce_sum(s_reshaped[0:valid_K, 0:valid_Q, 0:heads],
+ logits[0:valid_K, 0:valid_Q],
+ dim=-1,
+ clear=True)
@@
- for bq_i, bn_i in T.Parallel(block_Q, block_N):
- Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
- logits[bn_i, bq_i])
+ # Store only valid indices within [cu_k_s_min, cu_k_e_max)
+ for bq_i, bn_i in T.Parallel(valid_Q, valid_K):
+ Logits[seq_len_i + bq_i, base_k + bn_i] = logits[bn_i, bq_i]This keeps GEMM dimensions unchanged but constrains copies, elementwise, reductions, and stores to the valid subtiles to prevent OOB. As per coding guidelines.
🤖 Prompt for AI Agents
In examples/deepseek_v32/fp8_mqa_logits.py around lines 95 to 150, the kernel
currently copies and writes full block_Q/block_N tiles on tail iterations which
can OOB when seq_len % block_Q != 0 or (cu_k_e_max - cu_k_s_min) % block_N != 0;
compute valid_Q = min(block_Q, seq_len - seq_len_i) and inside the nbn_i loop
compute valid_K = min(block_N, cu_k_e_max[0] - cu_k_s_min[0] - nbn_i*block_N)
(clamped to >=0), use those extents to slice the IndexQ and Weights copies (only
copy valid_Q rows), slice IndexK and IndexKScale copies to valid_K rows,
restrict the subsequent elementwise Parallel loops, the reduction (operate over
valid_Q/valid_K extents) and the final Logits store to only write the valid
subtiles; keep the GEMM input shape unchanged (pad/shared buffers may remain
full) but ensure any iteration ranges and
T.copy/T.Parallel/T.reduce_sum/T.Parallel stores use valid_Q/valid_K to avoid
OOB.
| with T.Kernel(seq_len, threads=threads) as bx: | ||
| tx = T.thread_binding(0, threads, thread="threadIdx.x") | ||
| cu_k_s = T.alloc_local([1], indices_dtype) | ||
| cu_k_e = T.alloc_local([1], indices_dtype) | ||
| cu_k_s[0] = CuSeqLenKS[bx] | ||
| cu_k_e[0] = CuSeqLenKE[bx] | ||
|
|
||
| for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): | ||
| for k_i in T.serial(block_K // threads): | ||
| idx = n_i * block_K + k_i * threads + tx | ||
| if idx < cu_k_s[0] or idx >= cu_k_e[0]: | ||
| Logits[bx, idx] = -T.infinity(dtype) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent OOB writes in the logits cleaner on the last K block.
On the tail block, idx can be ≥ seq_len_kv. Add a bounds check.
- idx = n_i * block_K + k_i * threads + tx
- if idx < cu_k_s[0] or idx >= cu_k_e[0]:
- Logits[bx, idx] = -T.infinity(dtype)
+ idx = n_i * block_K + k_i * threads + tx
+ if idx < seq_len_kv and (idx < cu_k_s[0] or idx >= cu_k_e[0]):
+ Logits[bx, idx] = -T.infinity(dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| with T.Kernel(seq_len, threads=threads) as bx: | |
| tx = T.thread_binding(0, threads, thread="threadIdx.x") | |
| cu_k_s = T.alloc_local([1], indices_dtype) | |
| cu_k_e = T.alloc_local([1], indices_dtype) | |
| cu_k_s[0] = CuSeqLenKS[bx] | |
| cu_k_e[0] = CuSeqLenKE[bx] | |
| for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): | |
| for k_i in T.serial(block_K // threads): | |
| idx = n_i * block_K + k_i * threads + tx | |
| if idx < cu_k_s[0] or idx >= cu_k_e[0]: | |
| Logits[bx, idx] = -T.infinity(dtype) | |
| with T.Kernel(seq_len, threads=threads) as bx: | |
| tx = T.thread_binding(0, threads, thread="threadIdx.x") | |
| cu_k_s = T.alloc_local([1], indices_dtype) | |
| cu_k_e = T.alloc_local([1], indices_dtype) | |
| cu_k_s[0] = CuSeqLenKS[bx] | |
| cu_k_e[0] = CuSeqLenKE[bx] | |
| for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): | |
| for k_i in T.serial(block_K // threads): | |
| idx = n_i * block_K + k_i * threads + tx | |
| if idx < seq_len_kv and (idx < cu_k_s[0] or idx >= cu_k_e[0]): | |
| Logits[bx, idx] = -T.infinity(dtype) |
🤖 Prompt for AI Agents
In examples/deepseek_v32/fp8_mqa_logits.py around lines 170 to 182, the logits
cleaner can write out-of-bounds on the final K block because idx may be >=
seq_len_kv; update the conditional that sets Logits[bx, idx] =
-T.infinity(dtype) to include a bounds check ensuring idx < seq_len_kv (i.e.,
only perform the write when idx is within sequence length and within
cu_k_s/cu_k_e), so that the cleaner never writes past the end of the K
dimension.
| from tilelang.engine.callback import register_cuda_postproc_callback | ||
| import argparse | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing import for assert_similar used in tests.
Prevents NameError when --test_correctness is used.
-from tilelang.engine.callback import register_cuda_postproc_callback
+from tilelang.engine.callback import register_cuda_postproc_callback
+from .utils import assert_similar📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from tilelang.engine.callback import register_cuda_postproc_callback | |
| import argparse | |
| from tilelang.engine.callback import register_cuda_postproc_callback | |
| from .utils import assert_similar | |
| import argparse |
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd_pipelined.py around lines 7 to 10, tests
call assert_similar but the symbol is not imported, causing a NameError when
--test_correctness is used; add the missing import (e.g. import assert_similar
from the test utilities module, for example: from tilelang.testing import
assert_similar) at the top of the file alongside the other imports so tests can
reference it.
| for bi_i in T.Parallel(BI): | ||
| mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i | ||
|
|
||
| for bi_i, d_i in T.Parallel(BI, D): | ||
| KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, | ||
| d_i] | ||
| for bi_i, d_i in T.Parallel(BI, D_tail): | ||
| K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, | ||
| D + d_i] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocker: Potential out-of-bounds KV loads when indices use SKV sentinel.
You load KV[b_i, Indices[...], ...] unconditionally. When Indices == SKV (sentinel), this indexes past seq_len_kv - 1. Gate loads by validity (you already computed mask).
- for bi_i, d_i in T.Parallel(BI, D):
- KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
- d_i]
- for bi_i, d_i in T.Parallel(BI, D_tail):
- K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
- D + d_i]
+ for bi_i, d_i in T.Parallel(BI, D):
+ KV_shared[bi_i, d_i] = T.if_then_else(
+ mask[bi_i],
+ KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i],
+ 0)
+ for bi_i, d_i in T.Parallel(BI, D_tail):
+ K_tail_shared[bi_i, d_i] = T.if_then_else(
+ mask[bi_i],
+ KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i],
+ 0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for bi_i in T.Parallel(BI): | |
| mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i | |
| for bi_i, d_i in T.Parallel(BI, D): | |
| KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, | |
| d_i] | |
| for bi_i, d_i in T.Parallel(BI, D_tail): | |
| K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, | |
| D + d_i] | |
| for bi_i, d_i in T.Parallel(BI, D): | |
| KV_shared[bi_i, d_i] = T.if_then_else( | |
| mask[bi_i], | |
| KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i], | |
| 0) | |
| for bi_i, d_i in T.Parallel(BI, D_tail): | |
| K_tail_shared[bi_i, d_i] = T.if_then_else( | |
| mask[bi_i], | |
| KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i], | |
| 0) |
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd.py around lines 122 to 131, the code
unconditionally loads KV using Indices which can equal the SKV sentinel and thus
index past seq_len_kv-1; guard those KV loads with the previously computed mask
to avoid out-of-bounds accesses. Change the KV and K_tail_shared assignments to
only read from KV when mask[bi_i] is true (e.g., use a conditional/masked load
or clamp the index to a safe value and then zero the result when mask is false),
or replace sentinel indices with a safe index (like 0) before loading and
multiply the loaded vector by mask to zero out invalid lanes; apply the same
protection to both the BI×D and BI×D_tail loops.
* [Refactor] Enhance CopyNode Lower method to support disable_tma flag and improve flash attention implementation * Updated the CopyNode Lower method to correctly include the disable_tma flag in the GetCopyInst call. * Refactored the flash attention implementation to selectively disable TMA for specific copy operations while allowing it for others. * Addressed linting issues for improved code quality * sparse mla kernels * Remove deprecated sparse MLA and utility files to streamline the codebase.
Summary by CodeRabbit