diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index f1aed09d47da..3fc2e24fb4f1 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -321,7 +321,7 @@ class VTInjector : public IRMutator { CHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->Mutate(op->then_case); Stmt else_case; - if (else_case.defined()) { + if (op->else_case.defined()) { int temp = max_loop_depth_; max_loop_depth_ = 0; else_case = this->Mutate(op->else_case); diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index 502a55574df0..16f4c4652a3d 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -60,7 +60,26 @@ def get_vthread(name): assert stmt.body.body.body.body.body.body.extents[0].value == 2 assert len(stmt.body.body.body.body.body.body.extents) == 3 +def test_vthread_if_then_else(): + nthread = 2 + tx = tvm.thread_axis("vthread") + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 100) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + B = ib.allocate("float32", 128, name="B", scope="shared") + with ib.if_scope(i == 0): + B[i] = A[i * nthread + tx] + with ib.else_scope(): + B[i] = A[i * nthread + tx] + 1 + with ib.if_scope(i == 0): + B[i] = A[i * nthread + tx] + 2 + stmt = ib.get() + stmt = tvm.ir_pass.InjectVirtualThread(stmt) + assert stmt.body.body.body.first.else_case != None + assert stmt.body.body.body.rest.else_case == None if __name__ == "__main__": test_vthread_extern() test_vthread() + test_vthread_if_then_else()