-
Notifications
You must be signed in to change notification settings - Fork 333
[Refactor] Reopen #794 Fix lower bug when buffer store is not guarded by any tile op #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Reopen #794 Fix lower bug when buffer store is not guarded by any tile op #817
Conversation
PyStmtExprVisitor and PyStmtExprMutator seem buggy
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughRefactors AddWrapperForSingleBufStore to use an ir_transform-based pre/post visitor instead of class-based mutator/visitor. Tracks tile-loop depth and thread-bound Vars, collects Vars used by BufferStore, and conditionally wraps eligible BufferStore nodes with a dummy parallel For. API signature unchanged. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant AddWrapper as AddWrapperForSingleBufStore
participant IR as ir_transform
participant Node as IR Node
participant BS as BufferStore
Caller->>AddWrapper: invoke(func)
AddWrapper->>IR: ir_transform(func.body, pre_visit, post_visit)
rect rgba(220,235,255,0.4)
note over IR: Traversal with state (pfor, thread_binding_var)
IR->>Node: pre_visit(node) -- update state on For/AttrStmt
IR->>Node: traverse children
IR->>Node: post_visit(node) -- inspect/transform node
end
alt node is BufferStore and eligible
IR->>BS: collect used Vars (post-order)
IR->>IR: wrap BufferStore with dummy parallel For
else
IR-->>Node: return node unchanged
end
IR-->>AddWrapper: new_body
AddWrapper-->>Caller: updated func with transformed body
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @kurisu6912, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical bug related to buffer store guarding within tile operations by undertaking a significant refactor of the AddWrapper pass. The previous implementation was prone to segment faults due to unreliable PyStmtExprVisitor and PyStmtExprMutator classes. The changes introduce a more stable and maintainable solution by re-implementing the pass using ir_transform, ensuring correct behavior and preventing runtime crashes.
Highlights
- Refactor AddWrapper Pass: The
AddWrapperpass has been completely rewritten to address stability issues and improve robustness. - Bug Fix: This pull request reopens and fixes a bug (originally from #794) where buffer stores were not properly guarded by tile operations, which previously led to lower-level issues and potential segment faults.
- Implementation Change: The previous implementation, which relied on
PyStmtExprVisitorandPyStmtExprMutator, was found to be buggy and caused segment faults. The new approach utilizestvm.tir.stmt_functor.ir_transformfor a more stable and maintainable solution.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the AddWrapperForSingleBufStore pass to use ir_transform instead of the buggy PyStmtExprVisitor and PyStmtExprMutator. This is a solid improvement for stability and code clarity. The new implementation correctly mirrors the logic of the old one. I have one suggestion to improve the code structure and add back documentation that was removed.
| def AddWrapperForSingleBufStore(): | ||
|
|
||
| def pass_fn(func: PrimFunc, mod, ctx): | ||
| mut = AddWrapperForSingleStoreMutator() | ||
| new_body = mut.visit_stmt(func.body) | ||
| pfor = 0 | ||
| thread_binding_var = set() | ||
|
|
||
| def get_used_var(op): | ||
| used_var = set() | ||
| def visit_fn(x): | ||
| if isinstance(x, Var): | ||
| used_var.add(x) | ||
| post_order_visit(op, visit_fn) | ||
| return used_var | ||
|
|
||
| def is_tile_op_for(op: For): | ||
| return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations | ||
|
|
||
| def pre_visit(stmt): | ||
| nonlocal pfor | ||
| if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': | ||
| thread_binding_var.add(stmt.node.var) | ||
| if isinstance(stmt, For): | ||
| pfor += is_tile_op_for(stmt) | ||
|
|
||
| def post_visit(stmt): | ||
| nonlocal pfor | ||
| if isinstance(stmt, For): | ||
| pfor -= is_tile_op_for(stmt) | ||
| if isinstance(stmt, BufferStore): | ||
| used_var = get_used_var(stmt) | ||
| used_binding = used_var.intersection(thread_binding_var) | ||
| if not pfor and len(used_binding) == 0: | ||
| return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt) | ||
|
|
||
| new_body = ir_transform(func.body, pre_visit, post_visit) | ||
|
|
||
| return func.with_body(new_body) | ||
|
|
||
| return prim_func_pass(pass_fn, opt_level=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refactoring to use ir_transform is a great improvement. To further enhance readability and modularity, I suggest a few changes:
- Restore Documentation: Add a docstring to
AddWrapperForSingleBufStoreto explain its purpose. The previous implementation had a docstring that is useful to retain. - Restructure Helper Functions: The helper functions
get_used_varandis_tile_op_forare currently defined insidepass_fn. Since they are pure functions and don't depend on the state withinpass_fn, they can be moved up a level to be insideAddWrapperForSingleBufStorebut outsidepass_fn. This improves separation of concerns. - Explicit Return: It's good practice for
post_visitfunctions inir_transformto explicitlyreturn Nonewhen no change is made to the statement. This makes the code's intent clearer.
def AddWrapperForSingleBufStore():
"""Add a dummy parallel for loop to wrap a single buffer store.
This pass is to handle the case where a buffer store is not guarded by any tile op,
which may cause problems in the downstream analysis.
Condition to wrap a BufferStore:
1. not inside a parallel for loop
2. no custom thread binding, i.e. threadIdx.x, blockIdx.x
"""
def _get_used_var(op):
used_var = set()
def visit_fn(x):
if isinstance(x, Var):
used_var.add(x)
post_order_visit(op, visit_fn)
return used_var
def _is_tile_op_for(op: For):
return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
def pass_fn(func: PrimFunc, mod, ctx):
pfor = 0
thread_binding_var = set()
def pre_visit(stmt):
nonlocal pfor
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
thread_binding_var.add(stmt.node.var)
if isinstance(stmt, For):
pfor += _is_tile_op_for(stmt)
def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= _is_tile_op_for(stmt)
if isinstance(stmt, BufferStore):
used_var = _get_used_var(stmt)
used_binding = used_var.intersection(thread_binding_var)
if not pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)
return None
new_body = ir_transform(func.body, pre_visit, post_visit)
return func.with_body(new_body)
return prim_func_pass(pass_fn, opt_level=0)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (4)
tilelang/transform/add_bufstore_wrapper.py (4)
8-8: Silence lint: unusedmodandctx.Align with Ruff ARG001 by prefixing underscores.
Apply this diff:
- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx):
9-11: Track thread binding scope with a stack (prepares proper pop in post).You’re collecting thread-bound vars but not tracking scope. Add a small stack to enable precise removal on exit (see post_visit suggestion).
Apply this diff:
- pfor = 0 - thread_binding_var = set() + pfor = 0 + thread_binding_var = set() + thread_binding_stack = []
12-19: Var collection OK; minor nit: restrict to PrimExpr subtree only if needed.Current walker is fine. If false positives ever show up (e.g., buffer data vars), consider walking only indices/value of BufferStore.
20-22: Broaden tile-op detection keys.Depending on TVM/tilelang version, software pipeline annotations may use different keys. Safer to check a small set.
Apply this diff:
- def is_tile_op_for(op: For): - return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations + def is_tile_op_for(op: For): + anns = getattr(op, "annotations", {}) + return ( + op.kind == ForKind.PARALLEL + or any(k in anns for k in ("num_stages", "software_pipeline_stage", "software_pipeline_order", "software_pipeline_kinds")) + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/transform/add_bufstore_wrapper.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
8-8: Unused function argument: mod
(ARG001)
8-8: Unused function argument: ctx
(ARG001)
🪛 GitHub Actions: CI Test on AMD
tilelang/transform/add_bufstore_wrapper.py
[error] 1-1: clang-format reformatted the file. Please review and stage the changes.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)
1-2: Good move to ir_transform/post_order_visit.This refactor removes the custom mutator/visitor complexity; imports look correct.
Note: CI flagged formatting. Please run the repo’s formatter to satisfy “clang-format reformatted the file” in CI.
40-41: ir_transform usage LGTM.Pre/post semantics look correct; returning a replacement only in post keeps traversal sane.
| def pre_visit(stmt): | ||
| nonlocal pfor | ||
| if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': | ||
| thread_binding_var.add(stmt.node.var) | ||
| if isinstance(stmt, For): | ||
| pfor += is_tile_op_for(stmt) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thread_extent node access can crash; handle Var vs IterVar and push scope.
stmt.node.var raises if node is already a Var. Normalize and push to a stack for scoped removal later.
Apply this diff:
def pre_visit(stmt):
nonlocal pfor
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
- thread_binding_var.add(stmt.node.var)
+ node = stmt.node
+ var = node.var if hasattr(node, "var") else (node if isinstance(node, Var) else None)
+ if var is not None:
+ thread_binding_var.add(var)
+ thread_binding_stack.append(var)
if isinstance(stmt, For):
pfor += is_tile_op_for(stmt)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def pre_visit(stmt): | |
| nonlocal pfor | |
| if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': | |
| thread_binding_var.add(stmt.node.var) | |
| if isinstance(stmt, For): | |
| pfor += is_tile_op_for(stmt) | |
| def pre_visit(stmt): | |
| nonlocal pfor | |
| if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': | |
| node = stmt.node | |
| var = node.var if hasattr(node, "var") else (node if isinstance(node, Var) else None) | |
| if var is not None: | |
| thread_binding_var.add(var) | |
| thread_binding_stack.append(var) | |
| if isinstance(stmt, For): | |
| pfor += is_tile_op_for(stmt) |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 23 to 29, accessing
stmt.node.var can crash when stmt.node is already a Var; normalize by checking
the node type (if it's an IterVar extract its .var, otherwise use the Var
directly), then push that Var onto the thread_binding_var stack/collection so it
can be removed when leaving the scope (ensure the corresponding pop/removal
happens in the post_visit or scope-exit handler).
| nonlocal pfor | ||
| if isinstance(stmt, For): | ||
| pfor -= is_tile_op_for(stmt) | ||
| if isinstance(stmt, BufferStore): | ||
| used_var = get_used_var(stmt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pop thread-bound var on exit to maintain correct scope.
Without this, scope leaks; while often harmless, it’s easy to fix and future-proof.
Apply this diff:
def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
+ if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
+ if thread_binding_stack:
+ var = thread_binding_stack.pop()
+ thread_binding_var.discard(var)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def post_visit(stmt): | |
| nonlocal pfor | |
| if isinstance(stmt, For): | |
| pfor -= is_tile_op_for(stmt) | |
| if isinstance(stmt, BufferStore): | |
| def post_visit(stmt): | |
| nonlocal pfor | |
| if isinstance(stmt, For): | |
| pfor -= is_tile_op_for(stmt) | |
| if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': | |
| if thread_binding_stack: | |
| var = thread_binding_stack.pop() | |
| thread_binding_var.discard(var) | |
| if isinstance(stmt, BufferStore): |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 30 to 34, the
post_visit handler currently decrements the thread-bound pfor counter when
exiting For nodes but does nothing when exiting BufferStore nodes, causing scope
leakage; update post_visit so that when isinstance(stmt, BufferStore) you also
decrement/pop the thread-bound pfor (same logic as for For, e.g. subtract the
result of is_tile_op_for(stmt) or otherwise pop the bound) to restore the
correct scope on exit.
| used_binding = used_var.intersection(thread_binding_var) | ||
| if not pfor and len(used_binding) == 0: | ||
| return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt) | ||
|
|
||
| new_body = ir_transform(func.body, pre_visit, post_visit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid Var dtype in For; use int32 (and simplify condition).
Var("_", "int") is invalid; use int32 (or target-appropriate). Also, if not pfor and ... is fine, but explicit comparison reads clearer.
Apply this diff:
if isinstance(stmt, BufferStore):
used_var = get_used_var(stmt)
used_binding = used_var.intersection(thread_binding_var)
- if not pfor and len(used_binding) == 0:
- return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)
+ if pfor == 0 and not used_binding:
+ return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, stmt)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| used_var = get_used_var(stmt) | |
| used_binding = used_var.intersection(thread_binding_var) | |
| if not pfor and len(used_binding) == 0: | |
| return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt) | |
| used_var = get_used_var(stmt) | |
| used_binding = used_var.intersection(thread_binding_var) | |
| if pfor == 0 and not used_binding: | |
| return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, stmt) |
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 35 to 39, replace the
invalid Var("_", "int") with Var("_", "int32") (or the target-appropriate 32-bit
int) and make the condition explicit by using an explicit boolean comparison;
e.g. change the if to use "pfor == False" instead of "not pfor" while keeping
the len(used_binding) == 0 check so the branch returns For(Var("_", "int32"), 0,
1, ForKind.PARALLEL, stmt).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
tilelang/transform/add_bufstore_wrapper.py (5)
6-7: Restore a clear docstring for this pass.Carry over the previous docstring for intent/conditions; keeps future readers aligned.
def AddWrapperForSingleBufStore(): + """Wrap a standalone BufferStore with a dummy parallel For to avoid downstream issues. + + Wrap when: + 1) Not inside a tile-like loop (parallel or annotated with 'num_stages'). + 2) The BufferStore does not use any thread-bound vars (thread_extent). + """
32-42: Make the no-op explicit.Return
Noneexplicitly when no rewrite occurs; improves readability.def post_visit(stmt): @@ if isinstance(stmt, BufferStore): used_var = get_used_var(stmt) used_binding = used_var.intersection(thread_binding_var) if not pfor and len(used_binding) == 0: return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt) + return None
12-21: Hoist pure helpers out of pass_fn for separation of concerns.Move
get_used_var(and optionallyis_tile_op_for) one level up withinAddWrapperForSingleBufStoreto reduce nesting.
9-11: Fix thread_extent handling: normalize Var vs IterVar and manage scope with a stack.Current
stmt.node.varcan raise whennodeis already a Var; also leaks bindings across scopes. Normalize and push/pop per scope.pfor = 0 - thread_binding_var = set() + thread_binding_var = set() + thread_binding_stack = [] @@ def pre_visit(stmt): nonlocal pfor if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': - thread_binding_var.add(stmt.node.var) + node = stmt.node + var = node.var if hasattr(node, "var") else (node if isinstance(node, Var) else None) + if var is not None: + thread_binding_var.add(var) + thread_binding_stack.append(var) if isinstance(stmt, For): - pfor += is_tile_op_for(stmt) + if is_tile_op_for(stmt): + pfor += 1 @@ def post_visit(stmt): nonlocal pfor - if isinstance(stmt, For): - pfor -= is_tile_op_for(stmt) + if isinstance(stmt, For) and is_tile_op_for(stmt): + pfor -= 1 + if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': + if thread_binding_stack: + var = thread_binding_stack.pop() + thread_binding_var.discard(var)Also applies to: 25-31, 32-36
39-41: Use a valid dtype for loop var and clarify condition.
Var("_", "int")is invalid in TIR; useint32. Also simplify the condition.- if not pfor and len(used_binding) == 0: - return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt) + if pfor == 0 and not used_binding: + return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, stmt)
🧹 Nitpick comments (2)
tilelang/transform/add_bufstore_wrapper.py (2)
8-8: Silence lint: prefix unused pass args.Matches prim_func_pass signature while fixing Ruff ARG001.
- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx):
22-23: Defensive check: annotations may be None.Avoid potential TypeError on membership test.
- def is_tile_op_for(op: For): - return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations + def is_tile_op_for(op: For): + ann = getattr(op, "annotations", None) + return op.kind == ForKind.PARALLEL or (ann is not None and 'num_stages' in ann)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/transform/add_bufstore_wrapper.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
8-8: Unused function argument: mod
(ARG001)
8-8: Unused function argument: ctx
(ARG001)
…guarded by any tile op (tile-ai#817) * [Refactor] Rewrite AddWrapper pass by ir_transform PyStmtExprVisitor and PyStmtExprMutator seem buggy * fix lint error
This is re-open of #794 , the implementation is buggy. This is because
PyStmtExprVisitorandPyStmtExprMutatorare buggy, sometime causes segment fault. I re-implement the pass usingir_transform.Summary by CodeRabbit
Refactor
Chores
Notes