Skip to content

Conversation

@kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Sep 9, 2025

  • [Fix] Fix lower bug when buffer store is not guarded by any tile op
  • fix lint error

This pr fix the following error, by adding a T.Parallel(1) wrapper to it

import tilelang
import tilelang.language as T

tilelang.disable_cache()

@tilelang.jit(out_idx=[0], verbose=True)
def _example() -> tilelang.JITKernel:
    @T.prim_func
    def _example_kernel(
        x: T.Tensor((1,), 'bfloat16')
    ) -> None:
        with T.Kernel(128):
            A = T.alloc_fragment(1, 'float32')
            A[0] = 1  # error, A[0] = 1 is BufferStoreNode, won't infer layout because it is not tile op
            # Right Code:
            # for _ in T.Parallel(1):
            #     A[0] = 1

    return _example_kernel

_example()

Summary by CodeRabbit

  • New Features

    • Compiler now wraps standalone buffer-store operations in a minimal parallel loop when they are not already inside a parallel region, enabling additional parallelism and improving performance on parallel backends.
    • This wrapping is applied automatically during the lowering pipeline before further simplifications.
  • Chores

    • Internal pipeline updated to include and invoke the new transform.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 9, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new TIR prim-func transform that wraps standalone BufferStore statements in a single-iteration parallel For when not already in a parallel region, exports it via tilelang.transform, and inserts the pass into the LowerAndLegalize pipeline before InjectAssumes.

Changes

Cohort / File(s) Summary
Transform package export
tilelang/transform/__init__.py
Exposes AddWrapperForSingleBufStore by importing it from tilelang.transform.add_bufstore_wrapper.
New transform pass
tilelang/transform/add_bufstore_wrapper.py
Adds FindVarUse (visitor), AddWrapperForSingleStoreMutator (PyStmtExprMutator) and AddWrapperForSingleBufStore() (prim_func_pass). Mutator tracks inside_pfor and wraps BufferStore nodes not under a parallel For (or annotated) with a 1-iteration ForKind.PARALLEL loop using a synthetic Var("_", "int").
Pipeline insertion
tilelang/engine/phase.py
Inserts AddWrapperForSingleBufStore() into the LowerAndLegalize pipeline immediately after let-inlining and before InjectAssumes.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Pipeline as LowerAndLegalize
    participant Pass as AddWrapperForSingleBufStore
    participant Mut as AddWrapperForSingleStoreMutator
    participant Func as PrimFunc Body

    Pipeline->>Pass: apply(func)   %% new pass inserted before InjectAssumes
    Pass->>Mut: run on func.body
    loop traverse body
        Mut->>Func: visit(For)
        alt For is PARALLEL or annotated
            Note right of Mut #fff4cc: inside_pfor += 1
            Mut-->>Func: visit(body)
            Note right of Mut #fff4cc: inside_pfor -= 1
        else Not parallel
            Mut-->>Func: default traverse
        end
        Mut->>Func: visit(BufferStore)
        alt inside_pfor == 0 and no thread-binding Vars
            Note over Mut #e6ffe6: replace with For(var "_",0,1,PARALLEL, BufferStore)
        else
            Note over Mut #f9e6ff: keep BufferStore unchanged
        end
    end
    Mut-->>Pass: transformed body
    Pass-->>Pipeline: return updated func
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title accurately summarizes the primary change: a fix for lowering when a BufferStore is not guarded by a tile op, which matches the new transform that wraps such stores in a minimal parallel loop and the pipeline insertion described in the changeset. It is specific to the bug being addressed and clearly related to the changeset content.

Poem

I hop where lone stores used to be,
A tiny parallel loop wraps me.
From zero up to one we spin,
Safe in a brace, the stores tuck in.
— cheerfully, the rabbit 🐇

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Sep 9, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @kurisu6912, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on enhancing the robustness of the tilelang lowering process by fixing a specific bug related to buffer store operations. It introduces a new compiler pass to correctly handle buffer stores that are not explicitly guarded by parallel loops, ensuring proper IR transformation and preventing potential issues in subsequent compilation stages.

Highlights

  • Bug Fix for Buffer Stores: Addresses a bug where buffer store operations were not correctly handled during the lowering phase if they were not guarded by a tile operation, specifically by wrapping them in a T.Parallel(1) loop.
  • New Transformation Pass: Introduces a new transformation pass, AddWrapperForSingleBufStore, which identifies and wraps single buffer store operations that are not already within a parallel loop.
  • Integration into Lowering Pipeline: The new AddWrapperForSingleBufStore pass has been integrated into the LowerAndLegalize function within the tilelang engine, ensuring it runs as part of the standard lowering process.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@kurisu6912 kurisu6912 changed the title kurisu desugar bufstore patch 1 [Fix] Fix lower bug when buffer store is not guarded by any tile op Sep 9, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new transformation pass, AddWrapperForSingleBufStore, to fix a lowering bug where a buffer store is not guarded by a tile operation. The new pass wraps such buffer stores in a T.Parallel(1) loop. The implementation is mostly correct, but I've pointed out a potential robustness issue in the new mutator class where an exception during tree traversal could leave it in an inconsistent state. I've suggested a fix using a try...finally block to ensure the internal state is managed correctly.

Comment on lines 12 to 17
def visit_for_(self, op: For):
pfor = op.kind == ForKind.PARALLEL
self.inside_pfor += pfor
res = super().visit_for_(op)
self.inside_pfor -= pfor
return res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of visit_for_ is not robust against exceptions that might occur during the recursive visit. If super().visit_for_(op) raises an exception, self.inside_pfor will not be decremented, leaving the mutator in an inconsistent state for the remainder of the AST traversal. This could lead to incorrect transformations.

Using a try...finally block ensures that the counter is always decremented, making the mutator more robust.

Additionally, using a more descriptive variable name like is_parallel instead of pfor and being explicit with if is_parallel: can improve readability.

    def visit_for_(self, op: For):
        is_parallel = op.kind == ForKind.PARALLEL
        if is_parallel:
            self.inside_pfor += 1
        try:
            res = super().visit_for_(op)
        finally:
            if is_parallel:
                self.inside_pfor -= 1
        return res

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tilelang/transform/add_bufstore_wrapper.py (2)

12-17: Optional: avoid bool-as-int arithmetic for clarity

Incrementing/decrementing with booleans works but is cryptic. Use explicit ints.

-        pfor = op.kind == ForKind.PARALLEL
-        self.inside_pfor += pfor
+        pfor = op.kind == ForKind.PARALLEL
+        if pfor:
+            self.inside_pfor += 1
         res = super().visit_for_(op)
-        self.inside_pfor -= pfor
+        if pfor:
+            self.inside_pfor -= 1

26-33: Silence Ruff ARG001 by marking unused params

mod and ctx are intentionally unused; prefix with underscores.

-    def pass_fn(func: PrimFunc, mod, ctx):
+    def pass_fn(func: PrimFunc, _mod, _ctx):
         mut = AddWrapperForSingleStoreMutator()
         new_body = mut.visit_stmt(func.body)
         return func.with_body(new_body)
tilelang/engine/phase.py (1)

90-91: Config-gate the new pass (optional)

If you foresee edge cases, consider a PassContext flag (e.g., tl.enable_wrap_single_bufstore, default True) to toggle this quickly.

tilelang/transform/__init__.py (1)

83-84: Expose in all (optional)

If this module maintains an __all__, add AddWrapperForSingleBufStore for IDE discoverability.

+# near other public exports
+__all__ = [*globals().get("__all__", []), "AddWrapperForSingleBufStore"]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bcfc834 and ab60f6e.

📒 Files selected for processing (3)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/transform/__init__.py (2 hunks)
  • tilelang/transform/add_bufstore_wrapper.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/engine/phase.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
  • AddWrapperForSingleBufStore (26-33)
tilelang/transform/__init__.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
  • AddWrapperForSingleBufStore (26-33)
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py

28-28: Unused function argument: mod

(ARG001)


28-28: Unused function argument: ctx

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (2)
tilelang/engine/phase.py (1)

90-91: Pass ordering looks reasonable; confirm Simplify doesn't drop Parallel(1)

Good spot to insert the pass. Risk: subsequent Simplify may fold away ForKind.PARALLEL with extent 1. Please confirm it’s preserved until the downstream pass that depended on this guard runs.

Add a minimal TIR test ensuring a lone BufferStore becomes a T.parallel(1) and remains after Simplify:

# pseudo-test outline (place under tests/)
# 1) build PrimFunc with A[0] = 1 outside any pfor
# 2) run FrontendLegalize -> AddWrapperForSingleBufStore -> InjectAssumes -> Simplify
# 3) assert there exists a For with kind=PARALLEL and extent==1 around the store
tilelang/transform/__init__.py (1)

9-9: LGTM: scoped import

Importing the module to expose a public alias keeps coupling minimal.

Comment on lines 19 to 23
def visit_buffer_store_(self, op: BufferStore):
if not self.inside_pfor:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
else:
return super().visit_buffer_store_(op)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix: use int32 dtype and don't skip mutating BufferStore children

  • Var dtype should be "int32" (not "int") in TVM TIR.
  • Current branch returns the original op, skipping recursive mutation of indices/value. Wrap the mutated node instead.

Apply:

-    def visit_buffer_store_(self, op: BufferStore):
-        if not self.inside_pfor:
-            return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
-        else:
-            return super().visit_buffer_store_(op)
+    def visit_buffer_store_(self, op: BufferStore):
+        new_op = super().visit_buffer_store_(op)
+        if not self.inside_pfor:
+            return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op)
+        return new_op
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def visit_buffer_store_(self, op: BufferStore):
if not self.inside_pfor:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
else:
return super().visit_buffer_store_(op)
def visit_buffer_store_(self, op: BufferStore):
new_op = super().visit_buffer_store_(op)
if not self.inside_pfor:
return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op)
return new_op
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 19 to 23, the code
creates Var with dtype "int" and returns the original op (skipping recursive
mutation) when not inside_pfor; change Var's dtype to "int32" and call
super().visit_buffer_store_(op) to get the mutated BufferStore, then wrap THAT
mutated node in the For and return it (i.e., compute mutated =
super().visit_buffer_store_(op) and return For(Var("_", "int32"), 0, 1,
ForKind.PARALLEL, mutated)).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)

12-17: Fix: always decrement via try/finally, return the mutated For, and guard annotations access

Current code can leave inside_pfor inconsistent on exceptions and discards mutations by returning op. Also be explicit instead of bool arithmetic.

 def visit_for_(self, op: For):
-        pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
-        self.inside_pfor += pfor
-        super().visit_for_(op)
-        self.inside_pfor -= pfor
-        return op
+        is_parallel = (op.kind == ForKind.PARALLEL) or (
+            'num_stages' in getattr(op, "annotations", {})
+        )
+        if is_parallel:
+            self.inside_pfor += 1
+        try:
+            res = super().visit_for_(op)
+        finally:
+            if is_parallel:
+                self.inside_pfor -= 1
+        return res

19-24: Fix: wrap the mutated BufferStore and use int32 loop var; don’t skip child mutation

Return the result of super(), and ensure Var dtype is int32.

 def visit_buffer_store_(self, op: BufferStore):
-        if not self.inside_pfor:
-            return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
-        else:
-            super().visit_buffer_store_(op)
-            return op
+        new_op = super().visit_buffer_store_(op)
+        if not self.inside_pfor:
+            return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op)
+        return new_op
🧹 Nitpick comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)

29-29: Nit: silence unused args flagged by Ruff (ARG001)

Rename to underscores to satisfy lints without changing signature.

-    def pass_fn(func: PrimFunc, mod, ctx):
+    def pass_fn(func: PrimFunc, _mod, _ctx):

34-34: Optional: name the pass for better diagnostics

Helpful in IRModule/PassContext logs.

-    return prim_func_pass(pass_fn, opt_level=0)
+    return prim_func_pass(pass_fn, opt_level=0, name="AddWrapperForSingleBufStore")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ab60f6e and 2498524.

📒 Files selected for processing (1)
  • tilelang/transform/add_bufstore_wrapper.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py

29-29: Unused function argument: mod

(ARG001)


29-29: Unused function argument: ctx

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/transform/__init__.py (1)

9-9: Remove unused noqa to satisfy Ruff RUF100.

Ruff flags this as an unused noqa (non-enabled: F401). Drop the directive.

-from .add_bufstore_wrapper import AddWrapperForSingleBufStore  # noqa: F401
+from .add_bufstore_wrapper import AddWrapperForSingleBufStore
tilelang/engine/phase.py (1)

90-92: Placement looks right; please confirm Simplify won’t eliminate Parallel(1) before LayoutInference.

Good to insert right after LetInline and before InjectAssumes/LayoutInference. Please verify that tir.transform.Simplify does not fold away a unit-extent parallel loop in your TVM version; if it does, consider moving this pass to just before LayoutInference or marking the loop to keep. Also add a regression covering the A[0] = 1 case.

I can draft a small IR-based regression to ensure the wrapper is preserved through LayoutInference—want me to add it?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 50a59f3 and 4d175e0.

📒 Files selected for processing (2)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/transform/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/engine/phase.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
  • AddWrapperForSingleBufStore (27-34)
tilelang/transform/__init__.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
  • AddWrapperForSingleBufStore (27-34)
🪛 Ruff (0.12.2)
tilelang/transform/__init__.py

9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: bot-task

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)

41-47: Make inside_pfor accounting exception-safe and return mutated For

  • If recursion raises, inside_pfor stays incremented.
  • Returning op drops mutations from super().
  • Guard annotations for None.

Apply:

 def visit_for_(self, op: For):
-    pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
-    self.inside_pfor += pfor
-    super().visit_for_(op)
-    self.inside_pfor -= pfor
-    return op
+    is_parallel = (op.kind == ForKind.PARALLEL) or (
+        getattr(op, "annotations", None) is not None and "num_stages" in op.annotations
+    )
+    if is_parallel:
+        self.inside_pfor += 1
+    try:
+        return super().visit_for_(op)
+    finally:
+        if is_parallel:
+            self.inside_pfor -= 1

48-57: Wrap the mutated BufferStore and use int32 loop var; don’t skip child mutation

  • Current branch wraps the original op and returns op in the else-branch, skipping recursive mutation.
  • TIR loop variable dtype should be "int32", not "int".

Apply:

 def visit_buffer_store_(self, op: BufferStore):
-    # This pass runs after LetInline, we find var inside the stmt
-    fv = FindVarUse()
-    fv.visit_stmt(op)
-    used_binding = fv.used_var.intersection(self.thread_binding_var)
-    if not self.inside_pfor and len(used_binding) == 0:
-        return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
-    else:
-        super().visit_buffer_store_(op)
-        return op
+    # First mutate children, then analyze/wrap the result.
+    new_op = super().visit_buffer_store_(op)
+    fv = FindVarUse()
+    fv.visit_stmt(new_op)
+    used_binding = fv.used_var.intersection(self.thread_binding_var)
+    if not self.inside_pfor and not used_binding:
+        return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op)
+    return new_op
🧹 Nitpick comments (3)
tilelang/transform/add_bufstore_wrapper.py (3)

62-67: Silence Ruff ARG001 for unused mod, ctx

Use throwaway names to avoid lints without changing the pass ABI.

-    def pass_fn(func: PrimFunc, mod, ctx):
+    def pass_fn(func: PrimFunc, _mod, _ctx):

26-29: Optional: make thread-binding tracking robust to duplicates

If the same thread var is nested, a simple set undercounts scope depth. A small counter avoids edge cases.

Sketch:

# in __init__
self._tb_counts: dict[Var, int] = {}

# on enter thread_extent
self._tb_counts[var] = self._tb_counts.get(var, 0) + 1

# on exit
self._tb_counts[var] -= 1
if self._tb_counts[var] == 0:
    del self._tb_counts[var]

# when computing used bindings
active_tbind = set(self._tb_counts.keys())
used_binding = fv.used_var.intersection(active_tbind)

I can submit a follow-up patch if you prefer.

Also applies to: 34-40, 48-57


60-67: Add minimal tests for the pass behavior

  • Wraps BufferStore outside any parallel region.
  • No-op when already inside ForKind.PARALLEL.
  • No-op when BufferStore uses thread-bound vars from thread_extent.

I can draft concise TIR snippets plus structural checks (e.g., counting inserted ForKind.PARALLEL with extent=1).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4d175e0 and 751ea5d.

📒 Files selected for processing (1)
  • tilelang/transform/add_bufstore_wrapper.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py

62-62: Unused function argument: mod

(ARG001)


62-62: Unused function argument: ctx

(ARG001)

Comment on lines +30 to +33
def visit_block_(self, op: Block):
super().visit_block_(op)
return op

Copy link
Contributor

@coderabbitai coderabbitai bot Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: returning original Block discards recursive mutations

You call super().visit_block_(op) but return op, dropping any mutations to the block’s body.

Apply:

 def visit_block_(self, op: Block):
-    super().visit_block_(op)
-    return op
+    return super().visit_block_(op)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def visit_block_(self, op: Block):
super().visit_block_(op)
return op
def visit_block_(self, op: Block):
return super().visit_block_(op)
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 30 to 33, the method
calls super().visit_block_(op) but then returns the original op, discarding any
recursive mutations; change the method to capture the return value from
super().visit_block_(op) and return that result (or update op with the mutated
block returned) so any transformations applied by the visitor are preserved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return op is right, but return super().visit_block_(op) is wrong. I don't know why.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

In your case, the visit_block_ method appears to be designed to:

  1. Call super().visit_block_(op) to recursively visit the block's contents and potentially modify the mutator's internal state (like inside_pfor counter)
  2. Return the original op because the block itself doesn't need transformation - only its contents may be transformed

This pattern is common in TVM where the mutator needs to traverse the tree to build up context or modify state, but not every node type requires actual mutation. The block serves as a container, and the actual mutations happen at deeper levels (like the BufferStore nodes).

Thank you for the correction - I've learned something important about TVM's mutator patterns!


✏️ Learnings added
Learnt from: kurisu6912
PR: tile-ai/tilelang#794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.446Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Comment on lines +34 to +40
def visit_attr_stmt_(self, op: AttrStmt):
if op.attr_key == 'thread_extent':
iter_var: IterVar = op.node
self.thread_binding_var.add(iter_var.var)
super().visit_attr_stmt_(op)
return op

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Scope leak and lost mutations in thread_extent handling

  • Vars added to self.thread_binding_var are never removed, leaking across sibling scopes.
  • Returning op discards mutations from super().visit_attr_stmt_(op).

Apply:

 def visit_attr_stmt_(self, op: AttrStmt):
-    if op.attr_key == 'thread_extent':
-        iter_var: IterVar = op.node
-        self.thread_binding_var.add(iter_var.var)
-    super().visit_attr_stmt_(op)
-    return op
+    if op.attr_key == "thread_extent":
+        iter_var: IterVar = op.node
+        self.thread_binding_var.add(iter_var.var)
+        try:
+            return super().visit_attr_stmt_(op)
+        finally:
+            # Use discard to avoid KeyError if encountered twice.
+            self.thread_binding_var.discard(iter_var.var)
+    return super().visit_attr_stmt_(op)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def visit_attr_stmt_(self, op: AttrStmt):
if op.attr_key == 'thread_extent':
iter_var: IterVar = op.node
self.thread_binding_var.add(iter_var.var)
super().visit_attr_stmt_(op)
return op
def visit_attr_stmt_(self, op: AttrStmt):
if op.attr_key == "thread_extent":
iter_var: IterVar = op.node
self.thread_binding_var.add(iter_var.var)
try:
return super().visit_attr_stmt_(op)
finally:
# Use discard to avoid KeyError if encountered twice.
self.thread_binding_var.discard(iter_var.var)
return super().visit_attr_stmt_(op)
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 34 to 40, vars added
to self.thread_binding_var when encountering AttrStmt with attr_key
'thread_extent' are never removed (causing scope leak across siblings) and the
function returns the original op which discards mutations from super(). To fix:
when op.attr_key == 'thread_extent' push the iter_var.var into
self.thread_binding_var immediately before calling super().visit_attr_stmt_(op),
call super and capture its return value, then remove/pop the iter_var.var in a
finally block to ensure it is always removed after visiting the subtree; finally
return the value returned by super().visit_attr_stmt_(op) instead of the
original op.

@LeiWang1999
Copy link
Member

LGTM, Merged:)

RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…ile-ai#794)

* [Fix] Fix lower bug when buffer store is not guarded by any tile op

* fix lint error

* Fix typo in  pass

* fix lint error

* Ignore custom thread binding
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…guarded by any tile op (tile-ai#817)

* [Refactor] Rewrite AddWrapper pass by ir_transform
PyStmtExprVisitor and PyStmtExprMutator seem buggy

* fix lint error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants