Skip to content

Commit 0fb8da7

Browse files
authored
[Enhancement] Introduce padding annotation and improve memory access validation (#511)
* Added a new attribute `kPaddingMap` in `builtin.h` for managing padding annotations. * Enhanced `SafeMemorysRewriter` to utilize an annotated padding map for buffer stores, improving memory access safety. * Implemented checks in `layout_inference.cc` to ensure buffers are correctly referenced during layout mapping. * Introduced a new test file for validating the padding annotation functionality in TileLang.
1 parent f09d300 commit 0fb8da7

File tree

5 files changed

+189
-6
lines changed

5 files changed

+189
-6
lines changed

src/op/builtin.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
namespace tvm {
1717
namespace tl {
18+
19+
namespace attr {
20+
static constexpr const char *kPaddingMap = "padding_map";
21+
} // namespace attr
22+
1823
static constexpr const char *kDebugMergeSharedMemoryAllocations =
1924
"tl.debug_merge_shared_memory_allocations";
2025
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";

src/transform/layout_inference.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
482482
auto map =
483483
op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
484484
for (const auto &[var, layout] : map) {
485+
ICHECK(buffer_data_to_buffer_.count(var))
486+
<< "buffer " << var << " is not found in the block";
485487
auto buffer = buffer_data_to_buffer_[var];
486488
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
487489
annotated_layout_map_.Set(buffer, layout);

src/transform/legalize_safe_memory_access.cc

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,33 @@ class SafeMemorysRewriter : public StmtExprMutator {
138138
arith::Analyzer *analyzer_;
139139

140140
public:
141-
explicit SafeMemorysRewriter(arith::Analyzer *analyzer)
142-
: analyzer_(analyzer) {}
141+
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
142+
arith::Analyzer *analyzer)
143+
: annotated_padding_map_(annotated_padding_map), analyzer_(analyzer) {}
143144

144145
private:
145146
Stmt VisitStmt_(const BufferStoreNode *op) final {
146147
// Check if the buffer is in global scope
147148
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
149+
148150
GlobalMemChecker checker(analyzer_);
149151
checker(store);
150152
Array<PrimExpr> conditions = checker.GetConditions();
153+
154+
// Skip boundary check if the store value is an IfThenElse
155+
if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) {
156+
if (conditions.size() > 0) {
157+
LOG(WARNING)
158+
<< "Skipping boundary check for store with IfThenElse value: "
159+
<< store->value
160+
<< "\nAs manual boundary check detected, potential out-of-bounds "
161+
"access may occur."
162+
<< "\nAuto detect boundaries are " << conditions;
163+
return store;
164+
}
165+
return store;
166+
}
167+
151168
if (conditions.size() == 0) {
152169
return store;
153170
}
@@ -164,7 +181,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
164181
for (auto cond : conditions) {
165182
ICHECK(cond.dtype() == DataType::Bool(1))
166183
<< "condition is not a boolean: " << cond;
167-
value = if_then_else(cond, value, make_zero(value->dtype));
184+
value = if_then_else(cond, value, GetPadding(store->buffer));
168185
}
169186
store.CopyOnWrite()->value = value;
170187
return store;
@@ -173,7 +190,7 @@ class SafeMemorysRewriter : public StmtExprMutator {
173190
for (auto cond : conditions) {
174191
ICHECK(cond.dtype() == DataType::Bool(1))
175192
<< "condition is not a boolean: " << cond;
176-
value = if_then_else(cond, value, make_zero(value->dtype));
193+
value = if_then_else(cond, value, GetPadding(store->buffer));
177194
}
178195
store.CopyOnWrite()->value = value;
179196
return store;
@@ -227,6 +244,15 @@ class SafeMemorysRewriter : public StmtExprMutator {
227244
String scope = buffer.scope();
228245
return scope == "global";
229246
}
247+
// Get the padding of the buffer
248+
PrimExpr GetPadding(const Buffer &buffer) {
249+
if (annotated_padding_map_.count(buffer)) {
250+
return annotated_padding_map_[buffer];
251+
}
252+
return make_zero(buffer->dtype);
253+
}
254+
255+
Map<Buffer, PrimExpr> annotated_padding_map_;
230256
};
231257

232258
// Class to legalize safe memory access by transforming them appropriately
@@ -239,6 +265,9 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
239265
SafeMemoryLegalizer substituter(&analyzer);
240266
// Get a mutable copy of the function node
241267
PrimFuncNode *fptr = f.CopyOnWrite();
268+
for (const auto &[_, buffer] : f->buffer_map) {
269+
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
270+
}
242271
// Apply the legalizer to the function body
243272
fptr->body = substituter.VisitStmt(f->body);
244273
return f;
@@ -255,7 +284,7 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
255284
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
256285
auto has_inner_loop = HasInnerLoop(for_node->body);
257286
if (!has_inner_loop) {
258-
SafeMemorysRewriter rewriter(analyzer_);
287+
SafeMemorysRewriter rewriter(annotated_padding_map_, analyzer_);
259288
for_node.CopyOnWrite()->body = rewriter(for_node->body);
260289
// // Detect Buffer Load Node in the loop body, collect the indices and
261290
// buffer size
@@ -279,11 +308,32 @@ class SafeMemoryLegalizer : IRMutatorWithAnalyzer {
279308
return IRMutatorWithAnalyzer::VisitStmt_(op);
280309
}
281310

311+
Stmt VisitStmt_(const BlockNode *op) final {
312+
for (auto buffer : op->alloc_buffers) {
313+
buffer_data_to_buffer_.Set(buffer->data, buffer);
314+
}
315+
if (op->annotations.count(attr::kPaddingMap)) {
316+
auto map = op->annotations.Get(attr::kPaddingMap)
317+
.as<Map<Var, PrimExpr>>()
318+
.value();
319+
for (const auto &[var, padding] : map) {
320+
ICHECK(buffer_data_to_buffer_.count(var))
321+
<< "buffer " << var << " is not found in the block";
322+
auto buffer = buffer_data_to_buffer_[var];
323+
annotated_padding_map_.Set(buffer, padding);
324+
}
325+
}
326+
return IRMutatorWithAnalyzer::VisitStmt_(op);
327+
}
328+
282329
static bool HasInnerLoop(const Stmt &stmt) {
283330
LeafForFinder finder;
284331
finder(stmt);
285332
return finder.leaf_for_nodes.size() > 0;
286333
}
334+
335+
Map<Var, Buffer> buffer_data_to_buffer_;
336+
Map<Buffer, PrimExpr> annotated_padding_map_;
287337
};
288338

289339
// Create a pass that legalizes vectorized loops in the IRModule
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
4+
import tilelang
5+
import tilelang.language as T
6+
import tilelang.testing
7+
import torch
8+
9+
tilelang.disable_cache()
10+
11+
12+
# add decorator @tilelang.jit if you want to return a torch function
13+
# @tilelang.jit
14+
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
15+
16+
@T.prim_func
17+
def main(
18+
A: T.Tensor((M, N), dtype),
19+
B: T.Tensor((M, N), dtype),
20+
):
21+
# Initialize Kernel Context
22+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
23+
A_shared = T.alloc_shared((block_M, block_N), dtype)
24+
25+
T.annotate_padding({A_shared: pad_value})
26+
for i, j in T.Parallel(block_M, block_N):
27+
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
28+
29+
for i, j in T.Parallel(block_M, block_N):
30+
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
31+
32+
return main
33+
34+
35+
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0):
36+
program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
37+
kernel = tilelang.compile(
38+
program,
39+
out_idx=[1],
40+
target="cuda",
41+
pass_configs={
42+
"tl.disable_warp_specialized": True,
43+
"tl.disable_tma_lower": True
44+
})
45+
print(kernel.get_kernel_source())
46+
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
47+
b = kernel(a)
48+
ref_b = torch.zeros_like(a)
49+
for i in range(M):
50+
if i >= 10:
51+
ref_b[i, :] = a[i - 10, :]
52+
else:
53+
ref_b[i, :] = pad_value
54+
torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)
55+
56+
57+
def test_tilelang_copy():
58+
run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, pad_value=10)
59+
60+
61+
if __name__ == "__main__":
62+
tilelang.testing.main()

tilelang/language/__init__.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,76 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
8181
f"tl::{device_func}<{panel_size}>") if enable else None
8282

8383

84-
def annotate_layout(layout_map):
84+
def annotate_layout(layout_map: Dict):
85+
"""Annotate the layout of the buffer
86+
87+
Args:
88+
layout_map (Dict): a dictionary of buffer to layout
89+
90+
Returns:
91+
block_attr: a block attribute
92+
93+
Example:
94+
@T.prim_func
95+
def main(
96+
A: T.Tensor((M, N), dtype),
97+
B: T.Tensor((M, N), dtype),
98+
):
99+
# Initialize Kernel Context
100+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
101+
A_shared = T.alloc_shared((block_M, block_N), dtype)
102+
103+
T.annotate_layout({A_shared: layout})
104+
for i, j in T.Parallel(block_M, block_N):
105+
A_shared[i, j] = A[by * block_M + i, bx * block_N + j]
106+
107+
for i, j in T.Parallel(block_M, block_N):
108+
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
109+
110+
return main
111+
"""
85112
# layout_map is a dictionary of buffer to layout
86113
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()}
87114
return block_attr({"layout_map": layout_map})
88115

89116

117+
def annotate_padding(padding_map: Dict):
118+
"""Annotate the padding of the buffer
119+
120+
Args:
121+
padding_map (dict): a dictionary of buffer to padding value
122+
123+
Returns:
124+
block_attr: a block attribute
125+
126+
Example:
127+
@T.prim_func
128+
def main(
129+
A: T.Tensor((M, N), dtype),
130+
B: T.Tensor((M, N), dtype),
131+
):
132+
# Initialize Kernel Context
133+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
134+
A_shared = T.alloc_shared((block_M, block_N), dtype)
135+
136+
T.annotate_padding({A_shared: pad_value})
137+
for i, j in T.Parallel(block_M, block_N):
138+
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
139+
140+
for i, j in T.Parallel(block_M, block_N):
141+
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
142+
143+
return main
144+
"""
145+
# padding_map is a dictionary of buffer to padding value
146+
_padding_map = {}
147+
for buffer, padding_value in padding_map.items():
148+
# assert not global
149+
assert buffer.scope() != "global", "padding can only be applied to global buffers"
150+
_padding_map[buffer.data] = padding_value
151+
return block_attr({"padding_map": _padding_map})
152+
153+
90154
def import_source(source: Optional[str] = None):
91155
# source is the source code to be imported
92156
return block_attr({"pragma_import_c": source}) if source is not None else None

0 commit comments

Comments
 (0)