Skip to content

Commit b39aaf5

Browse files
authored
[Bugfix][WS] Consider loop min extent when computing phase id (#754)
* Update test parameters and remove debug print statement - Adjusted test cases in `test_tilelang_dynamic_symbolic_bench.py` to use smaller matrix sizes (1024x1024) for improved performance and quicker execution. - Removed a debug print statement from `phase.py` to clean up the code and enhance clarity. * Refactor loop stack management in warp_specialized_rewriter - Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability. - Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability. - Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations.
1 parent fd199a4 commit b39aaf5

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

src/transform/warp_specialized_rewriter.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ using namespace tir;
2424
using namespace runtime;
2525
using arith::IRVisitorWithAnalyzer;
2626

27+
struct LoopInfo {
28+
Var loop_var;
29+
PrimExpr extent;
30+
PrimExpr min;
31+
};
32+
2733
enum class Role { kConsumer, kProducer, kBoth };
2834

2935
class ProducerBufferDetector : public StmtExprVisitor {
@@ -838,7 +844,7 @@ class WSCodeEmitter : public StmtMutator {
838844
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
839845
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
840846
}
841-
loop_stack_.emplace_back(op->loop_var, op->extent);
847+
loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});
842848

843849
Array<Array<Integer>> group_info_array;
844850
Array<Integer> order_info_array;
@@ -871,15 +877,14 @@ class WSCodeEmitter : public StmtMutator {
871877

872878
num_stages_ = num_stages;
873879
pipeline_info_ = pipeline_info;
874-
PrimExpr linear_index = loop_stack_[0].first;
880+
PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
875881
for (size_t i = 1; i < loop_stack_.size(); ++i) {
876-
linear_index =
877-
linear_index * loop_stack_[i].second + loop_stack_[i].first;
882+
linear_index = linear_index * loop_stack_[i].extent +
883+
(loop_stack_[i].loop_var - loop_stack_[i].min);
878884
}
879885
stage_ = FloorMod(linear_index, num_stages);
880886
parity_ = FloorMod(
881887
parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
882-
883888
auto result = FilterByRole(op);
884889

885890
Stmt grouped_for_node;
@@ -1137,7 +1142,7 @@ class WSCodeEmitter : public StmtMutator {
11371142
PrimExpr parity_ = 0;
11381143
PrimExpr stage_ = 0;
11391144
int num_stages_ = 1;
1140-
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
1145+
std::vector<LoopInfo> loop_stack_;
11411146
Var thread_var_;
11421147
bool mbarrier_only_ = false;
11431148
PipelineInfo pipeline_info_;

testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,10 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
550550

551551

552552
def test_all():
553-
run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
554-
run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
555-
run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
556-
run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
553+
run_assert_tl_matmul_block_static(1024, 1024, 1024, 128, 128, 32)
554+
run_assert_tl_matmul_block_dynamic_m(1024, 1024, 1024, 128, 128, 32)
555+
run_assert_tl_matmul_block_dynamic_mn(1024, 1024, 1024, 128, 128, 32)
556+
run_assert_tl_matmul_block_dynamic_mnk(1024, 1024, 1024, 128, 128, 32)
557557

558558

559559
if __name__ == "__main__":

tilelang/engine/phase.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
165165
mod = tilelang.transform.MergeSharedMemoryAllocations(
166166
enable_aggressive_merge=enable_aggressive_merge)(
167167
mod)
168-
print("mod \n", mod)
169168
mod = tilelang.transform.ThreadSync("shared")(mod)
170169
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
171170
# Inject PTX async copy must behind the thread sync pass

0 commit comments

Comments
 (0)