Skip to content

Commit da544d1

Browse files
committed
Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.
1 parent 92417ed commit da544d1

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

src/target/codegen_cuda.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,11 +1290,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
12901290
scope = GetPtrStorageScope(buffer->data);
12911291
}
12921292
if (scope == "local.var" || scope.find("local.descriptor") == 0) {
1293-
if (scope == "local.descriptor.tcgen05_instr") {
1294-
os << vid << ".desc_";
1295-
} else {
1296-
os << vid;
1297-
}
1293+
os << vid;
12981294
return os.str();
12991295
}
13001296
std::string index_str = PrintExpr(index);

testing/python/language/test_tilelang_language_get_warp_info.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,4 @@ def test_shuffle_elect_block_leader():
208208

209209

210210
if __name__ == "__main__":
211-
# tilelang.testing.main()
212-
test_get_lane_idx_custom()
211+
tilelang.testing.main()

tilelang/language/tir/ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def unroll(start: PrimExpr,
104104
res : frame.ForFrame
105105
The ForFrame.
106106
"""
107+
# Ensure annotations has {"pragma_unroll_explicit": True} by default
108+
if annotations is None:
109+
annotations = {"pragma_unroll_explicit": False}
110+
else:
111+
# Add "pragma_unroll_explicit": True if not already present
112+
annotations = dict(annotations)
113+
annotations.setdefault("pragma_unroll_explicit", False)
107114
return _ir.unroll(start=start, stop=stop, annotations=annotations)
108115

109116

0 commit comments

Comments
 (0)