-
Notifications
You must be signed in to change notification settings - Fork 333
[Fix] Fix lower bug when buffer store is not guarded by any tile op #794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Fix lower bug when buffer store is not guarded by any tile op #794
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a new TIR prim-func transform that wraps standalone BufferStore statements in a single-iteration parallel For when not already in a parallel region, exports it via Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Pipeline as LowerAndLegalize
participant Pass as AddWrapperForSingleBufStore
participant Mut as AddWrapperForSingleStoreMutator
participant Func as PrimFunc Body
Pipeline->>Pass: apply(func) %% new pass inserted before InjectAssumes
Pass->>Mut: run on func.body
loop traverse body
Mut->>Func: visit(For)
alt For is PARALLEL or annotated
Note right of Mut #fff4cc: inside_pfor += 1
Mut-->>Func: visit(body)
Note right of Mut #fff4cc: inside_pfor -= 1
else Not parallel
Mut-->>Func: default traverse
end
Mut->>Func: visit(BufferStore)
alt inside_pfor == 0 and no thread-binding Vars
Note over Mut #e6ffe6: replace with For(var "_",0,1,PARALLEL, BufferStore)
else
Note over Mut #f9e6ff: keep BufferStore unchanged
end
end
Mut-->>Pass: transformed body
Pass-->>Pipeline: return updated func
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @kurisu6912, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request primarily focuses on enhancing the robustness of the tilelang lowering process by fixing a specific bug related to buffer store operations. It introduces a new compiler pass to correctly handle buffer stores that are not explicitly guarded by parallel loops, ensuring proper IR transformation and preventing potential issues in subsequent compilation stages.
Highlights
- Bug Fix for Buffer Stores: Addresses a bug where buffer store operations were not correctly handled during the lowering phase if they were not guarded by a tile operation, specifically by wrapping them in a T.Parallel(1) loop.
- New Transformation Pass: Introduces a new transformation pass, AddWrapperForSingleBufStore, which identifies and wraps single buffer store operations that are not already within a parallel loop.
- Integration into Lowering Pipeline: The new AddWrapperForSingleBufStore pass has been integrated into the LowerAndLegalize function within the tilelang engine, ensuring it runs as part of the standard lowering process.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new transformation pass, AddWrapperForSingleBufStore, to fix a lowering bug where a buffer store is not guarded by a tile operation. The new pass wraps such buffer stores in a T.Parallel(1) loop. The implementation is mostly correct, but I've pointed out a potential robustness issue in the new mutator class where an exception during tree traversal could leave it in an inconsistent state. I've suggested a fix using a try...finally block to ensure the internal state is managed correctly.
| def visit_for_(self, op: For): | ||
| pfor = op.kind == ForKind.PARALLEL | ||
| self.inside_pfor += pfor | ||
| res = super().visit_for_(op) | ||
| self.inside_pfor -= pfor | ||
| return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of visit_for_ is not robust against exceptions that might occur during the recursive visit. If super().visit_for_(op) raises an exception, self.inside_pfor will not be decremented, leaving the mutator in an inconsistent state for the remainder of the AST traversal. This could lead to incorrect transformations.
Using a try...finally block ensures that the counter is always decremented, making the mutator more robust.
Additionally, using a more descriptive variable name like is_parallel instead of pfor and being explicit with if is_parallel: can improve readability.
def visit_for_(self, op: For):
is_parallel = op.kind == ForKind.PARALLEL
if is_parallel:
self.inside_pfor += 1
try:
res = super().visit_for_(op)
finally:
if is_parallel:
self.inside_pfor -= 1
return resThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tilelang/transform/add_bufstore_wrapper.py (2)
12-17: Optional: avoid bool-as-int arithmetic for clarityIncrementing/decrementing with booleans works but is cryptic. Use explicit ints.
- pfor = op.kind == ForKind.PARALLEL - self.inside_pfor += pfor + pfor = op.kind == ForKind.PARALLEL + if pfor: + self.inside_pfor += 1 res = super().visit_for_(op) - self.inside_pfor -= pfor + if pfor: + self.inside_pfor -= 1
26-33: Silence Ruff ARG001 by marking unused params
modandctxare intentionally unused; prefix with underscores.- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx): mut = AddWrapperForSingleStoreMutator() new_body = mut.visit_stmt(func.body) return func.with_body(new_body)tilelang/engine/phase.py (1)
90-91: Config-gate the new pass (optional)If you foresee edge cases, consider a PassContext flag (e.g., tl.enable_wrap_single_bufstore, default True) to toggle this quickly.
tilelang/transform/__init__.py (1)
83-84: Expose in all (optional)If this module maintains an
__all__, addAddWrapperForSingleBufStorefor IDE discoverability.+# near other public exports +__all__ = [*globals().get("__all__", []), "AddWrapperForSingleBufStore"]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/engine/phase.py(1 hunks)tilelang/transform/__init__.py(2 hunks)tilelang/transform/add_bufstore_wrapper.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/engine/phase.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
AddWrapperForSingleBufStore(26-33)
tilelang/transform/__init__.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
AddWrapperForSingleBufStore(26-33)
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
28-28: Unused function argument: mod
(ARG001)
28-28: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (2)
tilelang/engine/phase.py (1)
90-91: Pass ordering looks reasonable; confirm Simplify doesn't drop Parallel(1)Good spot to insert the pass. Risk: subsequent Simplify may fold away
ForKind.PARALLELwith extent 1. Please confirm it’s preserved until the downstream pass that depended on this guard runs.Add a minimal TIR test ensuring a lone
BufferStorebecomes aT.parallel(1)and remains afterSimplify:# pseudo-test outline (place under tests/) # 1) build PrimFunc with A[0] = 1 outside any pfor # 2) run FrontendLegalize -> AddWrapperForSingleBufStore -> InjectAssumes -> Simplify # 3) assert there exists a For with kind=PARALLEL and extent==1 around the storetilelang/transform/__init__.py (1)
9-9: LGTM: scoped importImporting the module to expose a public alias keeps coupling minimal.
| def visit_buffer_store_(self, op: BufferStore): | ||
| if not self.inside_pfor: | ||
| return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op) | ||
| else: | ||
| return super().visit_buffer_store_(op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix: use int32 dtype and don't skip mutating BufferStore children
- Var dtype should be "int32" (not "int") in TVM TIR.
- Current branch returns the original
op, skipping recursive mutation of indices/value. Wrap the mutated node instead.
Apply:
- def visit_buffer_store_(self, op: BufferStore):
- if not self.inside_pfor:
- return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
- else:
- return super().visit_buffer_store_(op)
+ def visit_buffer_store_(self, op: BufferStore):
+ new_op = super().visit_buffer_store_(op)
+ if not self.inside_pfor:
+ return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op)
+ return new_op📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def visit_buffer_store_(self, op: BufferStore): | |
| if not self.inside_pfor: | |
| return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op) | |
| else: | |
| return super().visit_buffer_store_(op) | |
| def visit_buffer_store_(self, op: BufferStore): | |
| new_op = super().visit_buffer_store_(op) | |
| if not self.inside_pfor: | |
| return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op) | |
| return new_op |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 19 to 23, the code
creates Var with dtype "int" and returns the original op (skipping recursive
mutation) when not inside_pfor; change Var's dtype to "int32" and call
super().visit_buffer_store_(op) to get the mutated BufferStore, then wrap THAT
mutated node in the For and return it (i.e., compute mutated =
super().visit_buffer_store_(op) and return For(Var("_", "int32"), 0, 1,
ForKind.PARALLEL, mutated)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)
12-17: Fix: always decrement via try/finally, return the mutated For, and guard annotations accessCurrent code can leave inside_pfor inconsistent on exceptions and discards mutations by returning op. Also be explicit instead of bool arithmetic.
def visit_for_(self, op: For): - pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations - self.inside_pfor += pfor - super().visit_for_(op) - self.inside_pfor -= pfor - return op + is_parallel = (op.kind == ForKind.PARALLEL) or ( + 'num_stages' in getattr(op, "annotations", {}) + ) + if is_parallel: + self.inside_pfor += 1 + try: + res = super().visit_for_(op) + finally: + if is_parallel: + self.inside_pfor -= 1 + return res
19-24: Fix: wrap the mutated BufferStore and use int32 loop var; don’t skip child mutationReturn the result of super(), and ensure Var dtype is int32.
def visit_buffer_store_(self, op: BufferStore): - if not self.inside_pfor: - return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op) - else: - super().visit_buffer_store_(op) - return op + new_op = super().visit_buffer_store_(op) + if not self.inside_pfor: + return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op) + return new_op
🧹 Nitpick comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)
29-29: Nit: silence unused args flagged by Ruff (ARG001)Rename to underscores to satisfy lints without changing signature.
- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx):
34-34: Optional: name the pass for better diagnosticsHelpful in IRModule/PassContext logs.
- return prim_func_pass(pass_fn, opt_level=0) + return prim_func_pass(pass_fn, opt_level=0, name="AddWrapperForSingleBufStore")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/transform/add_bufstore_wrapper.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
29-29: Unused function argument: mod
(ARG001)
29-29: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/transform/__init__.py (1)
9-9: Remove unused noqa to satisfy Ruff RUF100.Ruff flags this as an unused noqa (non-enabled: F401). Drop the directive.
-from .add_bufstore_wrapper import AddWrapperForSingleBufStore # noqa: F401 +from .add_bufstore_wrapper import AddWrapperForSingleBufStoretilelang/engine/phase.py (1)
90-92: Placement looks right; please confirm Simplify won’t eliminate Parallel(1) before LayoutInference.Good to insert right after LetInline and before InjectAssumes/LayoutInference. Please verify that tir.transform.Simplify does not fold away a unit-extent parallel loop in your TVM version; if it does, consider moving this pass to just before LayoutInference or marking the loop to keep. Also add a regression covering the A[0] = 1 case.
I can draft a small IR-based regression to ensure the wrapper is preserved through LayoutInference—want me to add it?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/engine/phase.py(1 hunks)tilelang/transform/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/engine/phase.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
AddWrapperForSingleBufStore(27-34)
tilelang/transform/__init__.py (1)
tilelang/transform/add_bufstore_wrapper.py (1)
AddWrapperForSingleBufStore(27-34)
🪛 Ruff (0.12.2)
tilelang/transform/__init__.py
9-9: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)
41-47: Make inside_pfor accounting exception-safe and return mutated For
- If recursion raises,
inside_pforstays incremented.- Returning
opdrops mutations fromsuper().- Guard
annotationsfor None.Apply:
def visit_for_(self, op: For): - pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations - self.inside_pfor += pfor - super().visit_for_(op) - self.inside_pfor -= pfor - return op + is_parallel = (op.kind == ForKind.PARALLEL) or ( + getattr(op, "annotations", None) is not None and "num_stages" in op.annotations + ) + if is_parallel: + self.inside_pfor += 1 + try: + return super().visit_for_(op) + finally: + if is_parallel: + self.inside_pfor -= 1
48-57: Wrap the mutated BufferStore and use int32 loop var; don’t skip child mutation
- Current branch wraps the original
opand returnsopin the else-branch, skipping recursive mutation.- TIR loop variable dtype should be
"int32", not"int".Apply:
def visit_buffer_store_(self, op: BufferStore): - # This pass runs after LetInline, we find var inside the stmt - fv = FindVarUse() - fv.visit_stmt(op) - used_binding = fv.used_var.intersection(self.thread_binding_var) - if not self.inside_pfor and len(used_binding) == 0: - return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op) - else: - super().visit_buffer_store_(op) - return op + # First mutate children, then analyze/wrap the result. + new_op = super().visit_buffer_store_(op) + fv = FindVarUse() + fv.visit_stmt(new_op) + used_binding = fv.used_var.intersection(self.thread_binding_var) + if not self.inside_pfor and not used_binding: + return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, new_op) + return new_op
🧹 Nitpick comments (3)
tilelang/transform/add_bufstore_wrapper.py (3)
62-67: Silence Ruff ARG001 for unusedmod,ctxUse throwaway names to avoid lints without changing the pass ABI.
- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx):
26-29: Optional: make thread-binding tracking robust to duplicatesIf the same thread var is nested, a simple set undercounts scope depth. A small counter avoids edge cases.
Sketch:
# in __init__ self._tb_counts: dict[Var, int] = {} # on enter thread_extent self._tb_counts[var] = self._tb_counts.get(var, 0) + 1 # on exit self._tb_counts[var] -= 1 if self._tb_counts[var] == 0: del self._tb_counts[var] # when computing used bindings active_tbind = set(self._tb_counts.keys()) used_binding = fv.used_var.intersection(active_tbind)I can submit a follow-up patch if you prefer.
Also applies to: 34-40, 48-57
60-67: Add minimal tests for the pass behavior
- Wraps BufferStore outside any parallel region.
- No-op when already inside
ForKind.PARALLEL.- No-op when BufferStore uses thread-bound vars from
thread_extent.I can draft concise TIR snippets plus structural checks (e.g., counting inserted
ForKind.PARALLELwith extent=1).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/transform/add_bufstore_wrapper.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
62-62: Unused function argument: mod
(ARG001)
62-62: Unused function argument: ctx
(ARG001)
| def visit_block_(self, op: Block): | ||
| super().visit_block_(op) | ||
| return op | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: returning original Block discards recursive mutations
You call super().visit_block_(op) but return op, dropping any mutations to the block’s body.
Apply:
def visit_block_(self, op: Block):
- super().visit_block_(op)
- return op
+ return super().visit_block_(op)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def visit_block_(self, op: Block): | |
| super().visit_block_(op) | |
| return op | |
| def visit_block_(self, op: Block): | |
| return super().visit_block_(op) |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 30 to 33, the method
calls super().visit_block_(op) but then returns the original op, discarding any
recursive mutations; change the method to capture the return value from
super().visit_block_(op) and return that result (or update op with the mutated
block returned) so any transformations applied by the visitor are preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return op is right, but return super().visit_block_(op) is wrong. I don't know why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
In your case, the visit_block_ method appears to be designed to:
- Call
super().visit_block_(op)to recursively visit the block's contents and potentially modify the mutator's internal state (likeinside_pforcounter) - Return the original
opbecause the block itself doesn't need transformation - only its contents may be transformed
This pattern is common in TVM where the mutator needs to traverse the tree to build up context or modify state, but not every node type requires actual mutation. The block serves as a container, and the actual mutations happen at deeper levels (like the BufferStore nodes).
Thank you for the correction - I've learned something important about TVM's mutator patterns!
✏️ Learnings added
Learnt from: kurisu6912
PR: tile-ai/tilelang#794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.446Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
| def visit_attr_stmt_(self, op: AttrStmt): | ||
| if op.attr_key == 'thread_extent': | ||
| iter_var: IterVar = op.node | ||
| self.thread_binding_var.add(iter_var.var) | ||
| super().visit_attr_stmt_(op) | ||
| return op | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scope leak and lost mutations in thread_extent handling
- Vars added to
self.thread_binding_varare never removed, leaking across sibling scopes. - Returning
opdiscards mutations fromsuper().visit_attr_stmt_(op).
Apply:
def visit_attr_stmt_(self, op: AttrStmt):
- if op.attr_key == 'thread_extent':
- iter_var: IterVar = op.node
- self.thread_binding_var.add(iter_var.var)
- super().visit_attr_stmt_(op)
- return op
+ if op.attr_key == "thread_extent":
+ iter_var: IterVar = op.node
+ self.thread_binding_var.add(iter_var.var)
+ try:
+ return super().visit_attr_stmt_(op)
+ finally:
+ # Use discard to avoid KeyError if encountered twice.
+ self.thread_binding_var.discard(iter_var.var)
+ return super().visit_attr_stmt_(op)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def visit_attr_stmt_(self, op: AttrStmt): | |
| if op.attr_key == 'thread_extent': | |
| iter_var: IterVar = op.node | |
| self.thread_binding_var.add(iter_var.var) | |
| super().visit_attr_stmt_(op) | |
| return op | |
| def visit_attr_stmt_(self, op: AttrStmt): | |
| if op.attr_key == "thread_extent": | |
| iter_var: IterVar = op.node | |
| self.thread_binding_var.add(iter_var.var) | |
| try: | |
| return super().visit_attr_stmt_(op) | |
| finally: | |
| # Use discard to avoid KeyError if encountered twice. | |
| self.thread_binding_var.discard(iter_var.var) | |
| return super().visit_attr_stmt_(op) |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 34 to 40, vars added
to self.thread_binding_var when encountering AttrStmt with attr_key
'thread_extent' are never removed (causing scope leak across siblings) and the
function returns the original op which discards mutations from super(). To fix:
when op.attr_key == 'thread_extent' push the iter_var.var into
self.thread_binding_var immediately before calling super().visit_attr_stmt_(op),
call super and capture its return value, then remove/pop the iter_var.var in a
finally block to ensure it is always removed after visiting the subtree; finally
return the value returned by super().visit_attr_stmt_(op) instead of the
original op.
|
LGTM, Merged:) |
… by any tile op (#817) * [Refactor] Rewrite AddWrapper pass by ir_transform PyStmtExprVisitor and PyStmtExprMutator seem buggy * fix lint error
…ile-ai#794) * [Fix] Fix lower bug when buffer store is not guarded by any tile op * fix lint error * Fix typo in pass * fix lint error * Ignore custom thread binding
…guarded by any tile op (tile-ai#817) * [Refactor] Rewrite AddWrapper pass by ir_transform PyStmtExprVisitor and PyStmtExprMutator seem buggy * fix lint error
This pr fix the following error, by adding a T.Parallel(1) wrapper to it
Summary by CodeRabbit
New Features
Chores