Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 31 additions & 52 deletions tilelang/transform/add_bufstore_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,46 @@
from tvm.tir import PyStmtExprMutator, PyStmtExprVisitor, BufferStore, For, AttrStmt, Block, ForKind, IterVar, Var, PrimFunc
from tvm.tir.functor import mutator, visitor
from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc
from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass


@visitor
class FindVarUse(PyStmtExprVisitor):

def __init__(self):
self.used_var = set()

def visit_var_(self, op: Var):
self.used_var.add(op)
super().visit_var_(op)

def AddWrapperForSingleBufStore():

@mutator
class AddWrapperForSingleStoreMutator(PyStmtExprMutator):
'''
Add a dummy parallel for loop to wrap the single buffer store
Condition:
1. not inside a parallel for loop
2. no custom thread binding, i.e. threadIdx.x, blockIdx.x
'''
def pass_fn(func: PrimFunc, mod, ctx):
pfor = 0
thread_binding_var = set()

def __init__(self):
self.inside_pfor = 0
self.thread_binding_var = set()
def get_used_var(op):
used_var = set()

def visit_block_(self, op: Block):
super().visit_block_(op)
return op
def visit_fn(x):
if isinstance(x, Var):
used_var.add(x)

def visit_attr_stmt_(self, op: AttrStmt):
if op.attr_key == 'thread_extent':
iter_var: IterVar = op.node
self.thread_binding_var.add(iter_var.var)
super().visit_attr_stmt_(op)
return op
post_order_visit(op, visit_fn)
return used_var

def visit_for_(self, op: For):
pfor = op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
self.inside_pfor += pfor
super().visit_for_(op)
self.inside_pfor -= pfor
return op
def is_tile_op_for(op: For):
return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations

def visit_buffer_store_(self, op: BufferStore):
# This pass runs after LetInline, we find var inside the stmt
fv = FindVarUse()
fv.visit_stmt(op)
used_binding = fv.used_var.intersection(self.thread_binding_var)
if not self.inside_pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, op)
else:
super().visit_buffer_store_(op)
return op
def pre_visit(stmt):
nonlocal pfor
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
thread_binding_var.add(stmt.node.var)
if isinstance(stmt, For):
pfor += is_tile_op_for(stmt)

def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
if isinstance(stmt, BufferStore):
Comment on lines +32 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Pop thread-bound var on exit to maintain correct scope.

Without this, scope leaks; while often harmless, it’s easy to fix and future-proof.

Apply this diff:

         def post_visit(stmt):
             nonlocal pfor
             if isinstance(stmt, For):
                 pfor -= is_tile_op_for(stmt)
+            if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
+                if thread_binding_stack:
+                    var = thread_binding_stack.pop()
+                    thread_binding_var.discard(var)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
if isinstance(stmt, BufferStore):
def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
if thread_binding_stack:
var = thread_binding_stack.pop()
thread_binding_var.discard(var)
if isinstance(stmt, BufferStore):
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 30 to 34, the
post_visit handler currently decrements the thread-bound pfor counter when
exiting For nodes but does nothing when exiting BufferStore nodes, causing scope
leakage; update post_visit so that when isinstance(stmt, BufferStore) you also
decrement/pop the thread-bound pfor (same logic as for For, e.g. subtract the
result of is_tile_op_for(stmt) or otherwise pop the bound) to restore the
correct scope on exit.

used_var = get_used_var(stmt)
used_binding = used_var.intersection(thread_binding_var)
if not pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)

def AddWrapperForSingleBufStore():
new_body = ir_transform(func.body, pre_visit, post_visit)

def pass_fn(func: PrimFunc, mod, ctx):
mut = AddWrapperForSingleStoreMutator()
new_body = mut.visit_stmt(func.body)
return func.with_body(new_body)

return prim_func_pass(pass_fn, opt_level=0)
Loading