diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 9dae0006facd..0567c8613fcd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -172,6 +172,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { scope_.pop_back(); if (op->else_case.defined()) { scope_.push_back(std::vector()); + this->VisitStmt(op->else_case); auto v = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 7fff6a804e4a..ffdf4b5916c4 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -19,6 +19,21 @@ import tvm.testing +def run_passes(inputs, stmt): + func = tvm.te.schedule.SchedulePostProcToPrimFunc(inputs, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + cuda_target = tvm.target.Target("cuda") + + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) + )(mod) + + mod = tvm.tir.transform.SplitHostDevice()(mod) + return tvm.tir.transform.ThreadSync("shared")(mod) + + @tvm.testing.requires_cuda def test_thread_storage_sync(): m = te.size_var("m") @@ -38,23 +53,46 @@ def test_thread_storage_sync(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = run_passes([A, A2], stmt) + f = mod["test_kernel0"] + body_list = tvm.tir.stmt_list(f.body.body.body) + assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) - cuda_target = tvm.target.Target("cuda") - mod = tvm.tir.transform.Apply( - lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) - )(mod._move()) +@tvm.testing.requires_cuda +def test_sync_else_branch(): + def ir(A, B): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - cuda_target = tvm.target.Target("cuda") - f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body) - assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", 1) + + local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local") + shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared") + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + local[i] = Aptr[i] + with ib.else_scope(): + shared[i] = Aptr[i] + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + Bptr[i] = local[i] + with ib.else_scope(): + Bptr[i] = shared[i] + + return ib.get() + + A = tvm.tir.decl_buffer((8,), "float32") + B = tvm.tir.decl_buffer((8,), "float32") + stmt = ir(A, B) + mod = run_passes([A, B], stmt) + assert "@tir.tvm_storage_sync" in str(mod) if __name__ == "__main__": test_thread_storage_sync() + test_sync_else_branch()