Skip to content

Commit 9b2e04d

Browse files
authored
[Refactor] Enhance buffer store transformation in TIR pass (tile-ai#851)
- Updated the `AddWrapperForSingleBufStore` function to improve the handling of buffer stores by adding detailed checks for fragment buffer accesses and ensuring only index 0 is used. - Introduced new helper functions for collecting buffer accesses and indices, enhancing code readability and maintainability. - Refined the logic for determining tile operations and thread bindings to ensure accurate transformations without affecting existing parallel structures.
1 parent 171c4dd commit 9b2e04d

File tree

2 files changed

+140
-34
lines changed

2 files changed

+140
-34
lines changed

src/transform/storage_rewrite.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ class StoragePlanRewriter : public StmtExprMutator {
674674
bool IsSpecialTaggedMemory(const StorageScope &scope) {
675675
return !scope.tag.empty() && scope.tag != ".dyn" &&
676676
scope.tag != ".barrier" && scope.tag != ".workspace" &&
677-
scope.tag != ".vtcm";
677+
scope.tag != ".vtcm" && scope.tag != ".var";
678678
}
679679

680680
// Allocate entry of node.

tilelang/transform/add_bufstore_wrapper.py

Lines changed: 139 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,149 @@
1-
from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc
1+
from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm)
22
from tvm.tir.stmt_functor import ir_transform, post_order_visit
33
from tvm.tir.transform import prim_func_pass
44

55

66
def AddWrapperForSingleBufStore():
7+
"""
8+
Creates a TVM pass that wraps single buffer stores with parallel loops.
9+
10+
This transformation adds T.Parallel wrappers around buffer stores that:
11+
1. Access fragment buffers with index 0
12+
2. Are not inside existing tile operations or thread bindings
13+
3. Don't access fragment buffers with non-zero indices
14+
15+
Returns:
16+
A prim_func_pass that applies the transformation
17+
"""
718

819
def pass_fn(func: PrimFunc, mod, ctx):
9-
pfor = 0
10-
thread_binding_var = set()
11-
12-
def get_used_var(op):
13-
used_var = set()
14-
15-
def visit_fn(x):
16-
if isinstance(x, Var):
17-
used_var.add(x)
18-
19-
post_order_visit(op, visit_fn)
20-
return used_var
21-
22-
def is_tile_op_for(op: For):
23-
return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations
24-
25-
def pre_visit(stmt):
26-
nonlocal pfor
27-
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
28-
thread_binding_var.add(stmt.node.var)
29-
if isinstance(stmt, For):
30-
pfor += is_tile_op_for(stmt)
31-
32-
def post_visit(stmt):
33-
nonlocal pfor
34-
if isinstance(stmt, For):
35-
pfor -= is_tile_op_for(stmt)
36-
if isinstance(stmt, BufferStore):
37-
used_var = get_used_var(stmt)
38-
used_binding = used_var.intersection(thread_binding_var)
39-
if not pfor and len(used_binding) == 0:
40-
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)
20+
# Counter for tracking nested tile operations
21+
tile_operation_depth = 0
22+
# Set of variables bound to threads
23+
thread_binding_vars = set()
24+
25+
def get_used_variables(operation) -> set:
26+
"""
27+
Collects all variables used in the given operation.
28+
29+
Args:
30+
operation: The TIR operation to analyze
31+
32+
Returns:
33+
Set of variables used in the operation
34+
"""
35+
used_variables = set()
36+
37+
def visit_variable(node):
38+
if isinstance(node, Var):
39+
used_variables.add(node)
40+
41+
post_order_visit(operation, visit_variable)
42+
return used_variables
43+
44+
def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]:
45+
"""
46+
Categorizes buffers accessed in the statement by their scope.
47+
48+
Args:
49+
statement: The TIR statement to analyze
50+
51+
Returns:
52+
Tuple of (local_buffers, fragment_buffers)
53+
"""
54+
accessed_buffers = set()
55+
56+
def visit_buffer_access(node):
57+
if isinstance(node, (BufferLoad, BufferStore)):
58+
accessed_buffers.add(node.buffer)
59+
60+
post_order_visit(statement, visit_buffer_access)
61+
62+
local_buffers = []
63+
fragment_buffers = []
64+
for buffer in accessed_buffers:
65+
if buffer.scope() == "local.fragment":
66+
fragment_buffers.append(buffer)
67+
elif buffer.scope().startswith("local"):
68+
local_buffers.append(buffer)
69+
return local_buffers, fragment_buffers
70+
71+
def collect_buffer_indices(statement) -> dict[Buffer, list[int]]:
72+
"""
73+
Maps each buffer to its access indices.
74+
75+
Args:
76+
statement: The TIR statement to analyze
77+
78+
Returns:
79+
Dictionary mapping buffers to their access indices
80+
"""
81+
buffer_to_indices = {}
82+
83+
def visit_buffer_access(node):
84+
if isinstance(node, (BufferLoad, BufferStore)):
85+
buffer_to_indices[node.buffer] = node.indices
86+
87+
post_order_visit(statement, visit_buffer_access)
88+
return buffer_to_indices
89+
90+
def is_tile_operation_loop(loop: For) -> bool:
91+
"""
92+
Determines if a For loop is a tile operation.
93+
94+
Args:
95+
loop: The For loop to check
96+
97+
Returns:
98+
True if the loop is a tile operation (parallel or has num_stages annotation)
99+
"""
100+
return loop.kind == ForKind.PARALLEL or 'num_stages' in loop.annotations
101+
102+
def pre_visit(statement):
103+
"""
104+
Pre-order visitor that tracks thread bindings and tile operation depth.
105+
"""
106+
nonlocal tile_operation_depth
107+
108+
if isinstance(statement, AttrStmt) and statement.attr_key == 'thread_extent':
109+
thread_binding_vars.add(statement.node.var)
110+
elif isinstance(statement, For) and is_tile_operation_loop(statement):
111+
tile_operation_depth += 1
112+
113+
def post_visit(statement):
114+
"""
115+
Post-order visitor that applies transformations and updates counters.
116+
"""
117+
nonlocal tile_operation_depth
118+
119+
if isinstance(statement, For) and is_tile_operation_loop(statement):
120+
tile_operation_depth -= 1
121+
122+
elif isinstance(statement, BufferStore):
123+
used_variables = get_used_variables(statement)
124+
thread_bound_variables = used_variables.intersection(thread_binding_vars)
125+
126+
# Only transform if not inside tile operations and no thread bindings
127+
if tile_operation_depth == 0 and len(thread_bound_variables) == 0:
128+
# Skip if no fragment buffers are accessed
129+
_, fragment_buffers = collect_buffer_accesses(statement)
130+
if len(fragment_buffers) == 0:
131+
return statement
132+
133+
# Validate fragment buffer indices - only index 0 is supported
134+
buffer_indices = collect_buffer_indices(statement)
135+
for buffer, indices in buffer_indices.items():
136+
if buffer.scope() == "local.fragment":
137+
for index in indices:
138+
if isinstance(index, IntImm) and index != 0:
139+
raise ValueError(
140+
f"Fragment buffer access with non-zero index [{index}] is not supported. "
141+
"Only fragment[0] access is allowed.")
142+
143+
# Wrap fragment[0] access with T.Parallel loop
144+
return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement)
145+
146+
return statement
41147

42148
new_body = ir_transform(func.body, pre_visit, post_visit)
43149

0 commit comments

Comments
 (0)