Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 29, 2025

Summary by CodeRabbit

  • Documentation
    • Added a placeholder README for an upcoming DeepSeek v32 example, indicating content is “coming soon.”
    • Improves discoverability by signaling planned coverage and providing a visible entry point for future guidance.
    • No changes to functionality or user workflows; existing features remain unaffected.

…and improve flash attention implementation

* Updated the CopyNode Lower method to correctly include the disable_tma flag in the GetCopyInst call.
* Refactored the flash attention implementation to selectively disable TMA for specific copy operations while allowing it for others.
* Addressed linting issues for improved code quality
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 29, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Adds a new README file under examples/deepseek_v32 with placeholder text "Comming Soon.".

Changes

Cohort / File(s) Summary
Docs: New README placeholder
examples/deepseek_v32/README.md
Added README with placeholder content: "Comming Soon."

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~2 minutes

Poem

A nibble of text, a promise in seed,
“Comming Soon.” is all that we read.
I twitch my whiskers, hop with delight—
Docs will bloom, perhaps overnight!
Until then, I’ll guard this sprout 🌱
Thump, thump—can’t wait to check it out! 🐇

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning “[Example] Add example” is overly generic and does not reference any of the substantive changes in the PR, such as new FP8 MQA logits, sparse MLA forward kernels, or utility functions. A title should concisely summarize the primary change so that teammates scanning the project history can quickly understand its purpose. Because this title fails to convey those details, it does not meet the clarity criteria for PR titles. Please revise the title to clearly describe the main additions, for example: “Add FP8 MQA logits and sparse MLA forward example modules.” A descriptive title will improve readability in the project history and help reviewers immediately grasp the PR’s scope. This clarity reduces overhead in reviews and eases long-term maintenance.
Docstring Coverage ⚠️ Warning Docstring coverage is 5.26% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8f00a9e and ff9e598.

📒 Files selected for processing (1)
  • examples/deepseek_v32/README.md (1 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@LeiWang1999 LeiWang1999 merged commit 4424fa9 into tile-ai:main Sep 29, 2025
3 checks passed
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (19)
examples/deepseek_v32/sparse_mla_fwd.py (5)

32-33: Prefer direct truthiness for causal flag.

Avoid == True. Use a direct assert.

-    assert is_causal == True, "non-casual is not supported"
+    assert is_causal, "non-causal is not supported"

54-61: Drop unused G.

G = kv_group isn’t used.

-    G = kv_group

183-200: Interface polish: fix spelling, truth check, unused unpack.

  • Use is_causal consistently (spelled “causal”).
  • Avoid == False.
  • Prefix unused unpacked value.
-    is_casual = True
-    assert return_p_sum == False, "This kernel file is for fwd only"
+    is_causal = True
+    assert not return_p_sum, "This kernel file is for fwd only"
@@
-    _, seq_len_kv, kv_group, _ = kv.shape
+    _, _seq_len_kv, kv_group, _ = kv.shape
@@
-    kernel = sparse_attention_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual)
+    kernel = sparse_attention_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_causal)

203-218: Tidy reference interface: unused arg/var.

Prefix unused is_casual and drop num_kv_per_index.

-def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
+def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, _is_causal=True):
@@
-    num_kv_per_index = 1

265-280: Silence unused and fix f-strings without placeholders.

Rename unused and remove extraneous f prefixes.

-    tl_out, tl_lse = sparse_attention_fwd_interface(q, kv, indices)
+    _tl_out, _tl_lse = sparse_attention_fwd_interface(q, kv, indices)
@@
-    print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
-    print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
+    print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
+    print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
examples/deepseek_v32/utils.py (3)

2-7: Remove unused imports to pass lint.

Drop torch.nn.functional, triton, triton.language, and contextlib (unused).

-import torch.nn.functional as F
-import triton
-import triton.language as tl
-
-import contextlib
+# (removed unused: F, triton, tl, contextlib)

197-203: Raise explicitly instead of assert False.

assert statements can be stripped with -O. Raise AssertionError for robustness.

-    if not (0 <= diff <= eps):
-        print_red_warning(f'{name} Error: {diff}')
-        assert False
+    if not (0 <= diff <= eps):
+        print_red_warning(f'{name} Error: {diff}')
+        raise AssertionError(f"{name} similarity diff {diff} exceeds {eps}")

216-220: Avoid assigning a lambda; define a function.

Satisfy E731 and improve readability.

-    fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len)
+    def fn():
+        return cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (4)

41-45: Prefer direct truthiness for causal flag.

Avoid == True.

-    assert is_causal == True, 'non-casual is not supported'
+    assert is_causal, 'non-causal is not supported'

58-66: Drop unused G.

sparse_attention_fwd doesn’t use G = kv_group.

-    G = kv_group

449-451: Remove extraneous f prefixes in prints.

Fix F541.

-    print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
-    print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
+    print('fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
+    print('fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)

321-324: Consistent naming: use is_causal in public signatures.

Interfaces expose is_casual. Rename to is_causal for consistency with kernels and reference.

examples/deepseek_v32/fp8_mqa_logits.py (7)

10-13: Avoid host sync in assertion and handle zeros robustly.

.item() syncs to CPU; also guard log2(0) safely.

-def ceil_to_ue8m0(x: torch.Tensor):
-    assert x.view(-1).amax().item() > 0
-    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
+def ceil_to_ue8m0(x: torch.Tensor):
+    x_abs = x.abs().clamp_min(1e-8)
+    assert torch.any(x_abs > 0)
+    return (2.0 ** torch.ceil(torch.log2(x_abs)))

71-75: Guard block_Q derivation against zero/non-divisible heads.

block_Q = 128 // heads can be 0 for large heads. Ensure at least 1 and preferably a multiple of warp size.

-    if block_Q is None:
-        block_Q = 128 // heads
+    if block_Q is None:
+        block_Q = max(1, 128 // max(1, heads))

153-160: Optional: assert compatible launch params to avoid empty inner loops.

If block_K % threads != 0, inner serial loop becomes 0. Add a guard or assertion.

 @tilelang.jit
 def clean_logits_(
     threads: int = 512,
     block_K: int = 4096,
 ):
+    assert (block_K % threads) == 0, "block_K must be divisible by threads"

248-256: Make generate_random_cu_seqlens robust if cumulative sum never reaches target.

Very rare, but possible (e.g., many zeros) leading to index error on where(...)[0][0]. Add a fallback.

-        last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0]
-        cu_seqlens = cu_seqlens[:last_seq_id]
+        hits = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0]
+        cu_seqlens = cu_seqlens[: hits[0]] if hits.numel() > 0 else cu_seqlens

257-257: Remove unused variable to fix CI (Ruff F841).

-        total_seqlen_k = (cu_seqlens // kv_stride).sum()

314-326: Rename unused loop variable to fix CI (Ruff B007).

-    for i in range(10):
+    for _ in range(10):

239-239: Drop unused tensor p.

Dead code; remove to keep the example lean.

-    p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6c67a77 and 8f00a9e.

📒 Files selected for processing (4)
  • examples/deepseek_v32/fp8_mqa_logits.py (1 hunks)
  • examples/deepseek_v32/sparse_mla_fwd.py (1 hunks)
  • examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1 hunks)
  • examples/deepseek_v32/utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/deepseek_v32/fp8_mqa_logits.py (7)
examples/deepseek_v32/utils.py (2)
  • cal_cu_seqlen_ke_for_q (107-128)
  • cal_cu_seqlen_ks_for_q (90-103)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/allocate.py (3)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
  • alloc_local (39-50)
tilelang/language/builtin.py (1)
  • no_set_max_nreg (160-163)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
examples/deepseek_v32/sparse_mla_fwd.py (9)
examples/deepseek_v32/utils.py (3)
  • print_red_warning (183-184)
  • calc_sim (187-194)
  • assert_similar (197-202)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/math/__init__.py (1)
  • next_power_of_2 (1-2)
tilelang/language/__init__.py (1)
  • symbolic (83-94)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/fill.py (2)
  • fill (9-21)
  • clear (24-48)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/reduce.py (2)
  • reduce_max (50-68)
  • reduce_sum (87-109)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (5)
tilelang/engine/callback.py (1)
  • register_cuda_postproc_callback (28-58)
examples/deepseek_v32/sparse_mla_fwd.py (5)
  • sparse_attention_fwd (15-174)
  • main (75-172)
  • sparse_attention_fwd_interface (177-200)
  • ref_sparse_attention_fwd_interface (203-238)
  • fn (267-268)
tilelang/language/allocate.py (3)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
  • alloc_barrier (80-89)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/builtin.py (4)
  • barrier_arrive (280-287)
  • set_max_nreg (118-133)
  • barrier_wait (265-277)
  • cp_async_barrier_noinc (366-369)
🪛 Ruff (0.13.1)
examples/deepseek_v32/utils.py

202-202: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)


218-218: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)

examples/deepseek_v32/fp8_mqa_logits.py

257-257: Local variable total_seqlen_k is assigned to but never used

Remove assignment to unused variable total_seqlen_k

(F841)


314-314: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

examples/deepseek_v32/sparse_mla_fwd.py

23-23: Unused function argument: CP0

(ARG001)


32-32: Avoid equality comparisons to True; use is_causal: for truth checks

Replace with is_causal

(E712)


54-54: Local variable G is assigned to but never used

Remove assignment to unused variable G

(F841)


184-184: Avoid equality comparisons to False; use not return_p_sum: for false checks

Replace with not return_p_sum

(E712)


187-187: Unpacked variable seq_len_kv is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


203-203: Unused function argument: is_casual

(ARG001)


216-216: Local variable num_kv_per_index is assigned to but never used

Remove assignment to unused variable num_kv_per_index

(F841)


265-265: Unpacked variable tl_out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


265-265: Unpacked variable tl_lse is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


278-278: f-string without any placeholders

Remove extraneous f prefix

(F541)


279-279: f-string without any placeholders

Remove extraneous f prefix

(F541)

🪛 GitHub Actions: CI Test on AMD
examples/deepseek_v32/utils.py

[error] 2-2: F401 [*] torch.nn.functional imported but unused.


[error] 3-3: F401 [*] triton imported but unused.


[error] 4-4: F401 [*] triton.language imported but unused.


[error] 6-6: F401 [*] contextlib imported but unused.

examples/deepseek_v32/fp8_mqa_logits.py

[error] 257-257: F841 Local variable total_seqlen_k is assigned to but never used.


[error] 314-314: B007 Loop control variable i not used within loop body.

examples/deepseek_v32/sparse_mla_fwd.py

[error] 4-4: F401 [*] tilelang.tvm imported but unused.


[error] 5-5: F401 [*] utils.print_red_warning imported but unused.


[error] 5-5: F401 [*] utils.calc_sim imported but unused.


[error] 5-5: F401 [*] utils.assert_similar imported but unused.


[error] 32-32: E712 Avoid equality comparisons to True; use if is_causal: for truth checks.


[error] 54-54: F841 Local variable G is assigned to but never used.


[error] 184-184: E712 Avoid equality comparisons to False; use if not return_p_sum: for truth checks.


[error] 216-216: F841 Local variable num_kv_per_index is assigned to but never used.


[error] 278-279: F541 [*] f-string without any placeholders. Remove extraneous f prefix.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: bot-task
  • GitHub Check: format-check
🔇 Additional comments (1)
examples/deepseek_v32/fp8_mqa_logits.py (1)

214-231: Reference path matches kernel math (ReLU and weighting).

The FP32 reference mirrors the compiled path’s ReLU, weighting, and masking. No issues spotted.

from tilelang import language as T
import torch
from tilelang import tvm
from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix import to be package-relative to this module.

from utils import ... will fail if examples/deepseek_v32 isn’t on PYTHONPATH. Use an absolute or relative import.

-from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar
+from examples.deepseek_v32.utils import (
+    cal_cu_seqlen_ke_for_q,
+    cal_cu_seqlen_ks_for_q,
+    assert_similar,
+)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar
from examples.deepseek_v32.utils import (
cal_cu_seqlen_ke_for_q,
cal_cu_seqlen_ks_for_q,
assert_similar,
)
🤖 Prompt for AI Agents
In examples/deepseek_v32/fp8_mqa_logits.py around line 6, the import uses a bare
module name ("from utils import ...") which will fail when examples/deepseek_v32
is not on PYTHONPATH; change it to a package-relative import (e.g., "from .utils
import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar") or an
absolute import (e.g., "from examples.deepseek_v32.utils import ..."), and
ensure the package has an __init__.py if needed so the relative import works.

Comment on lines 95 to 150
with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:

index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype)

seq_len_i = bx * block_Q

cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype)

T.no_set_max_nreg()

cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648

for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i],
seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i],
seq_len_kv))

T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
T.copy(Weights[seq_len_i, 0], weights)

for nbn_i in T.Pipelined(
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)

T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)

for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]

T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)

for bq_i, bn_i in T.Parallel(block_Q, block_N):
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
logits[bn_i, bq_i])

return mqa_attn_return_logits_kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

OOB reads/writes on tail tiles for Q and K; add dynamic valid extents.

Current code copies, computes, and stores full block_Q/block_N even on the last tiles, risking OOB on:

  • Q/Weights copy when seq_len % block_Q != 0
  • K/Scale copy and store when (cu_k_e_max - cu_k_s_min) % block_N != 0

Fix by computing valid_Q and per-iteration valid_K, slicing copies, and limiting parallel loops and stores.

@@
-        with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
+        with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
@@
-            seq_len_i = bx * block_Q
+            seq_len_i = bx * block_Q
+            # valid rows of Q in this block
+            valid_Q = T.min(block_Q, seq_len - seq_len_i)
@@
-            T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
-            T.copy(Weights[seq_len_i, 0], weights)
+            # Copy only valid_Q rows
+            T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared[0:valid_Q * heads, 0])
+            T.copy(Weights[seq_len_i, 0], weights[0:valid_Q, 0])
@@
-            for nbn_i in T.Pipelined(
-                    T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
-                T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
-                T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
+            for nbn_i in T.Pipelined(
+                    T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
+                base_k = cu_k_s_min[0] + nbn_i * block_N
+                valid_K = T.min(block_N, cu_k_e_max[0] - base_k)
+                # Copy only valid_K rows
+                T.copy(IndexK[base_k, 0], index_k_shared[0:valid_K, 0])
+                T.copy(IndexKScale[base_k], index_k_scale_fragment[0:valid_K])
@@
-                for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
+                # Compute only on valid tiles
+                for bn_i, bq_i, h_i in T.Parallel(valid_K, valid_Q, heads):
                     s_reshaped[bn_i, bq_i,
                                h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
                                        weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
@@
-                T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
+                # Reduce only the valid region
+                T.reduce_sum(s_reshaped[0:valid_K, 0:valid_Q, 0:heads],
+                             logits[0:valid_K, 0:valid_Q],
+                             dim=-1,
+                             clear=True)
@@
-                for bq_i, bn_i in T.Parallel(block_Q, block_N):
-                    Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
-                        logits[bn_i, bq_i])
+                # Store only valid indices within [cu_k_s_min, cu_k_e_max)
+                for bq_i, bn_i in T.Parallel(valid_Q, valid_K):
+                    Logits[seq_len_i + bq_i, base_k + bn_i] = logits[bn_i, bq_i]

This keeps GEMM dimensions unchanged but constrains copies, elementwise, reductions, and stores to the valid subtiles to prevent OOB. As per coding guidelines.

🤖 Prompt for AI Agents
In examples/deepseek_v32/fp8_mqa_logits.py around lines 95 to 150, the kernel
currently copies and writes full block_Q/block_N tiles on tail iterations which
can OOB when seq_len % block_Q != 0 or (cu_k_e_max - cu_k_s_min) % block_N != 0;
compute valid_Q = min(block_Q, seq_len - seq_len_i) and inside the nbn_i loop
compute valid_K = min(block_N, cu_k_e_max[0] - cu_k_s_min[0] - nbn_i*block_N)
(clamped to >=0), use those extents to slice the IndexQ and Weights copies (only
copy valid_Q rows), slice IndexK and IndexKScale copies to valid_K rows,
restrict the subsequent elementwise Parallel loops, the reduction (operate over
valid_Q/valid_K extents) and the final Logits store to only write the valid
subtiles; keep the GEMM input shape unchanged (pad/shared buffers may remain
full) but ensure any iteration ranges and
T.copy/T.Parallel/T.reduce_sum/T.Parallel stores use valid_Q/valid_K to avoid
OOB.

Comment on lines 170 to 182
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
cu_k_s = T.alloc_local([1], indices_dtype)
cu_k_e = T.alloc_local([1], indices_dtype)
cu_k_s[0] = CuSeqLenKS[bx]
cu_k_e[0] = CuSeqLenKE[bx]

for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
for k_i in T.serial(block_K // threads):
idx = n_i * block_K + k_i * threads + tx
if idx < cu_k_s[0] or idx >= cu_k_e[0]:
Logits[bx, idx] = -T.infinity(dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Prevent OOB writes in the logits cleaner on the last K block.

On the tail block, idx can be ≥ seq_len_kv. Add a bounds check.

-                    idx = n_i * block_K + k_i * threads + tx
-                    if idx < cu_k_s[0] or idx >= cu_k_e[0]:
-                        Logits[bx, idx] = -T.infinity(dtype)
+                    idx = n_i * block_K + k_i * threads + tx
+                    if idx < seq_len_kv and (idx < cu_k_s[0] or idx >= cu_k_e[0]):
+                        Logits[bx, idx] = -T.infinity(dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
cu_k_s = T.alloc_local([1], indices_dtype)
cu_k_e = T.alloc_local([1], indices_dtype)
cu_k_s[0] = CuSeqLenKS[bx]
cu_k_e[0] = CuSeqLenKE[bx]
for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
for k_i in T.serial(block_K // threads):
idx = n_i * block_K + k_i * threads + tx
if idx < cu_k_s[0] or idx >= cu_k_e[0]:
Logits[bx, idx] = -T.infinity(dtype)
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
cu_k_s = T.alloc_local([1], indices_dtype)
cu_k_e = T.alloc_local([1], indices_dtype)
cu_k_s[0] = CuSeqLenKS[bx]
cu_k_e[0] = CuSeqLenKE[bx]
for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
for k_i in T.serial(block_K // threads):
idx = n_i * block_K + k_i * threads + tx
if idx < seq_len_kv and (idx < cu_k_s[0] or idx >= cu_k_e[0]):
Logits[bx, idx] = -T.infinity(dtype)
🤖 Prompt for AI Agents
In examples/deepseek_v32/fp8_mqa_logits.py around lines 170 to 182, the logits
cleaner can write out-of-bounds on the final K block because idx may be >=
seq_len_kv; update the conditional that sets Logits[bx, idx] =
-T.infinity(dtype) to include a bounds check ensuring idx < seq_len_kv (i.e.,
only perform the write when idx is within sequence length and within
cu_k_s/cu_k_e), so that the cleaner never writes past the end of the K
dimension.

Comment on lines 7 to 10
from tilelang.engine.callback import register_cuda_postproc_callback
import argparse


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add missing import for assert_similar used in tests.

Prevents NameError when --test_correctness is used.

-from tilelang.engine.callback import register_cuda_postproc_callback
+from tilelang.engine.callback import register_cuda_postproc_callback
+from .utils import assert_similar
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from tilelang.engine.callback import register_cuda_postproc_callback
import argparse
from tilelang.engine.callback import register_cuda_postproc_callback
from .utils import assert_similar
import argparse
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd_pipelined.py around lines 7 to 10, tests
call assert_similar but the symbol is not imported, causing a NameError when
--test_correctness is used; add the missing import (e.g. import assert_similar
from the test utilities module, for example: from tilelang.testing import
assert_similar) at the top of the file alongside the other imports so tests can
reference it.

Comment on lines 122 to 131
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i

for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
D + d_i]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Blocker: Potential out-of-bounds KV loads when indices use SKV sentinel.

You load KV[b_i, Indices[...], ...] unconditionally. When Indices == SKV (sentinel), this indexes past seq_len_kv - 1. Gate loads by validity (you already computed mask).

-                for bi_i, d_i in T.Parallel(BI, D):
-                    KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
-                                              d_i]
-                for bi_i, d_i in T.Parallel(BI, D_tail):
-                    K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
-                                                  D + d_i]
+                for bi_i, d_i in T.Parallel(BI, D):
+                    KV_shared[bi_i, d_i] = T.if_then_else(
+                        mask[bi_i],
+                        KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i],
+                        0)
+                for bi_i, d_i in T.Parallel(BI, D_tail):
+                    K_tail_shared[bi_i, d_i] = T.if_then_else(
+                        mask[bi_i],
+                        KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i],
+                        0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
D + d_i]
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = T.if_then_else(
mask[bi_i],
KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i],
0)
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = T.if_then_else(
mask[bi_i],
KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i],
0)
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd.py around lines 122 to 131, the code
unconditionally loads KV using Indices which can equal the SKV sentinel and thus
index past seq_len_kv-1; guard those KV loads with the previously computed mask
to avoid out-of-bounds accesses. Change the KV and K_tail_shared assignments to
only read from KV when mask[bi_i] is true (e.g., use a conditional/masked load
or clamp the index to a safe value and then zero the result when mask is false),
or replace sentinel indices with a safe index (like 0) before loading and
multiply the loaded vector by mask to zero out invalid lanes; apply the same
protection to both the BI×D and BI×D_tail loops.

@LeiWang1999 LeiWang1999 deleted the flashmla branch September 29, 2025 11:53
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Refactor] Enhance CopyNode Lower method to support disable_tma flag and improve flash attention implementation

* Updated the CopyNode Lower method to correctly include the disable_tma flag in the GetCopyInst call.
* Refactored the flash attention implementation to selectively disable TMA for specific copy operations while allowing it for others.
* Addressed linting issues for improved code quality

* sparse mla kernels

* Remove deprecated sparse MLA and utility files to streamline the codebase.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant