Skip to content

Commit

Permalink
Fix storage_access not visiting else branch (apache#8525)
Browse files Browse the repository at this point in the history
* Fix storage_access not visiting else branch

* fix conflict with apache#8516 in the test

* update thread sync test following apache#8516 update
  • Loading branch information
masahi authored and ylc committed Sep 29, 2021
1 parent 8d884f1 commit a4ddd26
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/tir/transforms/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
scope_.pop_back();
if (op->else_case.defined()) {
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->else_case);
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
Expand Down
64 changes: 51 additions & 13 deletions tests/python/unittest/test_tir_transform_thread_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@
import tvm.testing


def run_passes(inputs, stmt):
func = tvm.te.schedule.SchedulePostProcToPrimFunc(inputs, stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)

cuda_target = tvm.target.Target("cuda")

mod = tvm.tir.transform.Apply(
lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
)(mod)

mod = tvm.tir.transform.SplitHostDevice()(mod)
return tvm.tir.transform.ThreadSync("shared")(mod)


@tvm.testing.requires_cuda
def test_thread_storage_sync():
m = te.size_var("m")
Expand All @@ -38,23 +53,46 @@ def test_thread_storage_sync():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)

func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
mod = run_passes([A, A2], stmt)
f = mod["test_kernel0"]
body_list = tvm.tir.stmt_list(f.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))

cuda_target = tvm.target.Target("cuda")

mod = tvm.tir.transform.Apply(
lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target})
)(mod._move())
@tvm.testing.requires_cuda
def test_sync_else_branch():
def ir(A, B):
ib = tvm.tir.ir_builder.create()
Aptr = ib.buffer_ptr(A)
Bptr = ib.buffer_ptr(B)

fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
cuda_target = tvm.target.Target("cuda")
f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"]
body_list = tvm.tir.stmt_list(f.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", 1)

local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local")
shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared")

with ib.for_range(0, 8) as i:
with ib.if_scope(Aptr[i] < 0):
local[i] = Aptr[i]
with ib.else_scope():
shared[i] = Aptr[i]

with ib.for_range(0, 8) as i:
with ib.if_scope(Aptr[i] < 0):
Bptr[i] = local[i]
with ib.else_scope():
Bptr[i] = shared[i]

return ib.get()

A = tvm.tir.decl_buffer((8,), "float32")
B = tvm.tir.decl_buffer((8,), "float32")
stmt = ir(A, B)
mod = run_passes([A, B], stmt)
assert "@tir.tvm_storage_sync" in str(mod)


if __name__ == "__main__":
test_thread_storage_sync()
test_sync_else_branch()

0 comments on commit a4ddd26

Please sign in to comment.