Skip to content

Commit e2b10c5

Browse files
authored
[Language][UX] Semantic check for parallel fragment access (#1338)
1 parent 2ae4f1b commit e2b10c5

File tree

7 files changed

+277
-3
lines changed

7 files changed

+277
-3
lines changed

src/transform/layout_inference.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
821821
int64_t frag_reg_num = 1;
822822
for (auto i : frag.value()->OutputShape()) {
823823
auto pci = as_const_int(i);
824-
ICHECK(pci != nullptr);
824+
ICHECK(pci != nullptr)
825+
<< "Can not use non-constant range to "
826+
"iterate over a fragment/local "
827+
"buffer. Non-constant shape expr is: "
828+
<< i
829+
<< ". This is possibly because you use symbolic shape when "
830+
"accessing a fragment/local buffer.";
825831
frag_reg_num *= *pci;
826832
}
827833
reg_num += frag_reg_num;
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import pytest
4+
5+
6+
@tilelang.jit
7+
def simple_invalid_loop(dtype: str = "bfloat16",
8+
accum_dtype: str = "float32",
9+
num_threads: int = 128):
10+
A = T.dynamic("A")
11+
12+
@T.prim_func
13+
def main(
14+
data: T.Tensor((128, A), dtype), # type: ignore
15+
):
16+
with T.Kernel(128, threads=num_threads) as (tid,):
17+
data_frag = T.alloc_fragment([128], accum_dtype)
18+
19+
for i in T.Parallel(128):
20+
if i < A:
21+
data_frag[i] = data[tid, i]
22+
23+
for i in T.Parallel(A):
24+
data_frag[i] = 0
25+
26+
return main
27+
28+
29+
@tilelang.jit
30+
def nested_invalid_loop(dtype: str = "bfloat16",
31+
accum_dtype: str = "float32",
32+
num_threads: int = 128):
33+
A = T.dynamic("A")
34+
35+
@T.prim_func
36+
def main(
37+
data: T.Tensor((128, A), dtype), # type: ignore
38+
):
39+
with T.Kernel(128, threads=num_threads) as (tid,):
40+
data_frag = T.alloc_fragment([128], accum_dtype)
41+
42+
for i in T.Parallel(128):
43+
if i < A:
44+
data_frag[i] = data[tid, i]
45+
46+
for i in T.Parallel(A // 64):
47+
for j in T.Parallel(64):
48+
data_frag[i * 64 + j] = 0
49+
50+
return main
51+
52+
53+
@tilelang.jit
54+
def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
55+
accum_dtype: str = "float32",
56+
num_threads: int = 128):
57+
A = T.dynamic("A")
58+
59+
@T.prim_func
60+
def main(
61+
data: T.Tensor((128, A), dtype), # type: ignore
62+
):
63+
with T.Kernel(128, threads=num_threads) as (tid,):
64+
data_frag = T.alloc_fragment([128], accum_dtype)
65+
66+
for i in T.Parallel(128):
67+
if i < A:
68+
data_frag[i] = data[tid, i]
69+
70+
for i in T.Parallel(A):
71+
data_frag[64 // 2 + i % 64] = 0
72+
73+
return main
74+
75+
76+
@tilelang.jit
77+
def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
78+
accum_dtype: str = "float32",
79+
num_threads: int = 128):
80+
A = T.dynamic("A")
81+
82+
@T.prim_func
83+
def main(
84+
data: T.Tensor((128, A), dtype), # type: ignore
85+
):
86+
with T.Kernel(128, threads=num_threads) as (tid,):
87+
data_frag = T.alloc_fragment([128], accum_dtype)
88+
89+
for i in T.Parallel(128):
90+
if i < A:
91+
data_frag[i] = data[tid, i]
92+
93+
for i in T.Parallel(A): # noqa: B007
94+
for j in T.Parallel(64):
95+
data_frag[j] = 0 # This is valid because we don't use i
96+
97+
return main
98+
99+
100+
@tilelang.jit
101+
def valid_loop_not_frag(dtype: str = "bfloat16",
102+
accum_dtype: str = "float32",
103+
num_threads: int = 128):
104+
A = T.dynamic("A")
105+
106+
@T.prim_func
107+
def main(
108+
data: T.Tensor((128, A), dtype), # type: ignore
109+
):
110+
with T.Kernel(128, threads=num_threads) as (tid,):
111+
data_shared = T.alloc_shared([128], accum_dtype)
112+
113+
for i in T.Parallel(128):
114+
if i < A:
115+
data_shared[i] = data[tid, i]
116+
117+
for i in T.Parallel(A):
118+
data_shared[i] = 0 # Valid because this is shared memory
119+
120+
return main
121+
122+
123+
@tilelang.jit
124+
def valid_loop_serial(dtype: str = "bfloat16",
125+
accum_dtype: str = "float32",
126+
num_threads: int = 128):
127+
A = T.dynamic("A")
128+
129+
@T.prim_func
130+
def main(
131+
data: T.Tensor((128, A), dtype), # type: ignore
132+
):
133+
with T.Kernel(128, threads=num_threads) as (tid,):
134+
data_shared = T.alloc_shared([128], accum_dtype)
135+
136+
for i in T.Parallel(128):
137+
if i < A:
138+
data_shared[i] = data[tid, i]
139+
140+
for i in T.serial(A):
141+
data_shared[i] = 0 # Valid because this is serial
142+
143+
return main
144+
145+
146+
def test_invalid_loop():
147+
with pytest.raises(ValueError):
148+
simple_invalid_loop()
149+
with pytest.raises(ValueError):
150+
nested_invalid_loop()
151+
with pytest.raises(ValueError):
152+
invalid_loop_with_complex_dataflow()
153+
154+
155+
def test_valid_loop():
156+
valid_loop_not_use_loop_var()
157+
valid_loop_not_frag()
158+
valid_loop_serial()
159+
160+
161+
if __name__ == "__main__":
162+
tilelang.testing.main()

testing/python/language/test_tilelang_language_nested_loop.py renamed to testing/python/analysis/test_tilelang_nested_loop_checker.py

File renamed without changes.

tilelang/analysis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from .ast_printer import ASTPrinter # noqa: F401
44
from .nested_loop_checker import NestedLoopChecker # noqa: F401
5+
from .fragment_loop_checker import FragmentLoopChecker # noqa: F401
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
from tvm import tir
3+
from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm)
4+
from tvm.tir.transform import prim_func_pass
5+
from tvm.tir.stmt_functor import post_order_visit
6+
7+
8+
@tir.functor.visitor
9+
class _LoopVarUseAnalyzer(PyStmtExprVisitor):
10+
"""Analyze whether a loop variable is used in the given expr."""
11+
12+
def __init__(self, var: Var) -> None:
13+
super().__init__()
14+
self.var = var
15+
self.used = False
16+
17+
def visit_var_(self, op: Var) -> None:
18+
if op == self.var:
19+
self.used = True
20+
# Don't recursively visit children to avoid infinite recursion
21+
22+
23+
def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
24+
"""
25+
Collect local buffer accesses in the loop body.
26+
27+
Args:
28+
statement: The TIR statement to analyze
29+
30+
Returns:
31+
Tuple of buffer accesses in the loop body.
32+
"""
33+
34+
buffer_accesses = []
35+
36+
def visit_buffer_access(node):
37+
if isinstance(node, (BufferLoad, BufferStore)) and node.buffer.scope().startswith("local"):
38+
buffer_accesses.append(node)
39+
40+
post_order_visit(statement, visit_buffer_access)
41+
42+
return buffer_accesses
43+
44+
45+
@tir.functor.visitor
46+
class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
47+
48+
def __init__(self) -> None:
49+
super().__init__()
50+
51+
def visit_for_(self, op: For) -> None:
52+
if op.kind == tir.ForKind.PARALLEL:
53+
# Fuse consecutive parallel loops
54+
# Other nested cases are all invalid in TileLang.
55+
loops = [op]
56+
child = op.body
57+
while isinstance(child, For) and child.kind == tir.ForKind.PARALLEL:
58+
loops.append(child)
59+
child = child.body
60+
61+
loops_with_symbolic_ranges = []
62+
for loop in loops:
63+
if not (isinstance(loop.min, IntImm) and isinstance(loop.extent, IntImm)):
64+
loops_with_symbolic_ranges.append(loop)
65+
66+
if len(loops_with_symbolic_ranges) > 0:
67+
buffer_accesses = collect_local_buffer_accesses(child)
68+
for loop in loops_with_symbolic_ranges:
69+
for buffer_access in buffer_accesses:
70+
indices = buffer_access.indices
71+
analyzer = _LoopVarUseAnalyzer(loop.loop_var)
72+
for index in indices:
73+
analyzer.visit_expr(index)
74+
if analyzer.used:
75+
raise ValueError(
76+
"[Tilelang Semantic Check] "
77+
f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index "
78+
"a local/fragment buffer, which is not allowed in Tilelang.")
79+
80+
return
81+
82+
self.visit_stmt(op.body)
83+
84+
85+
def FragmentLoopChecker():
86+
"""
87+
When using T.Parallel over a local/fragment buffer, there are several restrictions:
88+
to ensure that the parallelization is valid.
89+
90+
1. The range of loop can not be symbolic.
91+
92+
Returns:
93+
A prim_func_pass that applies the transformation
94+
"""
95+
96+
def pass_fn(func: PrimFunc, mod, ctx):
97+
_FragmentLoopCheckVisitor().visit_stmt(func.body)
98+
return func
99+
100+
return prim_func_pass(pass_fn, opt_level=0)

tilelang/analysis/nested_loop_checker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,17 @@ def visit_for_(self, op: For) -> None:
3535

3636
# Otherwise
3737
if self.in_parallel_context:
38-
raise ValueError("Nested parallel loops are not allowed. "
38+
raise ValueError("[Tilelang Semantic Check] "
39+
"Nested parallel loops are not allowed. "
3940
"Please check your loop structure.")
4041
self.in_parallel_context = True
4142
self.visit_stmt(child)
4243
self.in_parallel_context = False
4344
return
4445
elif is_pipelined_for(op):
4546
if self.in_parallel_context:
46-
raise ValueError("Pipelined loop cannot be nested inside a parallel loop. "
47+
raise ValueError("[Tilelang Semantic Check] "
48+
"Pipelined loop cannot be nested inside a parallel loop. "
4749
"Please check your loop structure.")
4850

4951
self.visit_stmt(op.body)

tilelang/engine/phase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
8080
# Check if there are any invalid nested loops.
8181
tilelang.analysis.NestedLoopChecker()(mod)
8282

83+
# Check if there are any invalid symbolic T.Parallel + fragment access.
84+
tilelang.analysis.FragmentLoopChecker()(mod)
85+
8386

8487
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
8588
# Bind the target device information to the module

0 commit comments

Comments
 (0)