From 6d6b09b7a65b54068e5e709a02dfbc337cc178af Mon Sep 17 00:00:00 2001 From: kun-zh Date: Fri, 26 Oct 2018 00:38:32 +0800 Subject: [PATCH 1/6] add a pass for the specific hardware accelarator when it is not binded --- include/tvm/ir_pass.h | 9 +++++++++ src/api/api_pass.cc | 1 + src/pass/detect_device.cc | 30 ++++++++++++++++++++++++++++++ src/pass/split_host_device.cc | 4 +++- 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 src/pass/detect_device.cc diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 9403a2e6151b..1477795e992d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -326,6 +326,15 @@ Stmt RewriteUnsafeSelect(Stmt stmt); */ Stmt LowerStorageAccessInfo(Stmt stmt); +/*! + * \brief insert the mark of device for the hardware accelarator when + * it is not binded with thread or block. + * + * \param stmt The stmt to be trasnformed + * \return Transformed stmt. + */ +Stmt DeviceMark(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 1e571ca0dc41..2f9bc4953d2f 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -154,5 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyGPUCode); +REGISTER_PASS1(DeviceMark); } // namespace ir } // namespace tvm diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc new file mode 100644 index 000000000000..f3b483ad677b --- /dev/null +++ b/src/pass/detect_device.cc @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file detect_device.cc + */ + +#include +#include +#include "../pass/ir_util.h" + +namespace tvm { +namespace ir { + +class DetectDevice : public IRMutator { + public: + DetectDevice() {} + Stmt Detect(Stmt stmt) { + Stmt body = AttrStmt::make(make_zero(Int(32)), + ir::attr::pragma_scope_prefix, + StringImm::make("device"), + stmt); + return body; + } +}; + +Stmt DeviceMark(Stmt stmt) { + return DetectDevice().Detect(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 112c2c173df1..5cc6a1565086 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -153,7 +153,9 @@ class HostDeviceSplitter : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope) { + op->attr_key == attr::pipeline_exec_scope || + (op->attr_key == attr::pragma_scope_prefix && + op->value.as()->value == "device")) { return SplitDeviceFunc(s); } return IRMutator::Mutate_(op, s); From d9fe37eca5221187cf3692b045bc9605ebe6346b Mon Sep 17 00:00:00 2001 From: kun-zh Date: Fri, 26 Oct 2018 23:28:01 +0800 Subject: [PATCH 2/6] add a unit case --- .../unittest/test_pass_detect_device.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/python/unittest/test_pass_detect_device.py diff --git a/tests/python/unittest/test_pass_detect_device.py b/tests/python/unittest/test_pass_detect_device.py new file mode 100644 index 000000000000..619e6b74b7a7 --- /dev/null +++ b/tests/python/unittest/test_pass_detect_device.py @@ -0,0 +1,23 @@ +import tvm + +def test_detect_device(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((m, l), name='A') + + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.create_schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], factor=8) + s[A1].compute_at(s[A2], xo) + s[A1].set_scope("shared") + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt = tvm.ir_pass.DeviceMark(stmt) + print stmt + +if __name__ == "__main__": + test_detect_device() + From ac5a74fc67a2a052db4bcdf55fa3b2fa848be74b Mon Sep 17 00:00:00 2001 From: kun-zh Date: Mon, 29 Oct 2018 00:29:30 +0800 Subject: [PATCH 3/6] update test case --- tests/python/unittest/test_pass_detect_device.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_pass_detect_device.py b/tests/python/unittest/test_pass_detect_device.py index 619e6b74b7a7..7ca1907ccc02 100644 --- a/tests/python/unittest/test_pass_detect_device.py +++ b/tests/python/unittest/test_pass_detect_device.py @@ -15,8 +15,11 @@ def test_detect_device(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) - stmt = tvm.ir_pass.DeviceMark(stmt) - print stmt + stmt1 = tvm.ir_pass.Simplify(stmt) + stmt2 = tvm.ir_pass.DeviceMark(stmt1) + assert isinstance(stmt2, tvm.stmt.AttrStmt) + assert stmt2.attr_key == "pragma_" + assert stmt1 == stmt2.body if __name__ == "__main__": test_detect_device() From 817d826bf63b89180d6bee233dd7a97cb7e2127a Mon Sep 17 00:00:00 2001 From: kun-zh Date: Mon, 29 Oct 2018 01:35:23 +0800 Subject: [PATCH 4/6] update code per reviewer comments --- include/tvm/ir.h | 5 +++++ include/tvm/ir_pass.h | 2 +- src/api/api_pass.cc | 2 +- src/pass/detect_device.cc | 6 +++--- src/pass/split_host_device.cc | 3 +-- tests/python/unittest/test_pass_detect_device.py | 4 ++-- 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 14e60146567f..212234303c61 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -237,6 +237,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; */ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; +/*! + * \brief Mark that it is in the device scope. + */ +constexpr const char* device_scope = "device_scope"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 1477795e992d..bc7c47167bca 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -333,7 +333,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt); * \param stmt The stmt to be trasnformed * \return Transformed stmt. */ -Stmt DeviceMark(Stmt stmt); +Stmt MarkDevice(Stmt stmt); /*! * \brief Make an user callable API LoweredFunc. diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 2f9bc4953d2f..7c7c5ff6dd80 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -154,6 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyGPUCode); -REGISTER_PASS1(DeviceMark); +REGISTER_PASS1(MarkDevice); } // namespace ir } // namespace tvm diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc index f3b483ad677b..80ab2d572586 100644 --- a/src/pass/detect_device.cc +++ b/src/pass/detect_device.cc @@ -15,14 +15,14 @@ class DetectDevice : public IRMutator { DetectDevice() {} Stmt Detect(Stmt stmt) { Stmt body = AttrStmt::make(make_zero(Int(32)), - ir::attr::pragma_scope_prefix, - StringImm::make("device"), + ir::attr::device_scope, + 0, stmt); return body; } }; -Stmt DeviceMark(Stmt stmt) { +Stmt MarkDevice(Stmt stmt) { return DetectDevice().Detect(stmt); } diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 5cc6a1565086..4cfbc7c90d8c 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -154,8 +154,7 @@ class HostDeviceSplitter : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || - (op->attr_key == attr::pragma_scope_prefix && - op->value.as()->value == "device")) { + op->attr_key == attr::device_scope) { return SplitDeviceFunc(s); } return IRMutator::Mutate_(op, s); diff --git a/tests/python/unittest/test_pass_detect_device.py b/tests/python/unittest/test_pass_detect_device.py index 7ca1907ccc02..c5dbaae7a593 100644 --- a/tests/python/unittest/test_pass_detect_device.py +++ b/tests/python/unittest/test_pass_detect_device.py @@ -16,9 +16,9 @@ def test_detect_device(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt1 = tvm.ir_pass.Simplify(stmt) - stmt2 = tvm.ir_pass.DeviceMark(stmt1) + stmt2 = tvm.ir_pass.MarkDevice(stmt1) assert isinstance(stmt2, tvm.stmt.AttrStmt) - assert stmt2.attr_key == "pragma_" + assert stmt2.attr_key == "device_scope" assert stmt1 == stmt2.body if __name__ == "__main__": From 073aa0c7ae02527d27ce39375bdc65f19931853e Mon Sep 17 00:00:00 2001 From: kun-zh Date: Mon, 29 Oct 2018 22:24:36 +0800 Subject: [PATCH 5/6] change per comments --- include/tvm/ir_pass.h | 6 +++--- src/api/api_pass.cc | 2 +- src/pass/detect_device.cc | 21 ++++++------------- ....py => test_pass_decorate_device_scope.py} | 8 +++---- 4 files changed, 14 insertions(+), 23 deletions(-) rename tests/python/unittest/{test_pass_detect_device.py => test_pass_decorate_device_scope.py} (85%) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index bc7c47167bca..332becb7aa38 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -327,13 +327,13 @@ Stmt RewriteUnsafeSelect(Stmt stmt); Stmt LowerStorageAccessInfo(Stmt stmt); /*! - * \brief insert the mark of device for the hardware accelarator when - * it is not binded with thread or block. + * \brief Decorate the stmt with a device scope, this is helpful for + * hardware accelerator without thread blocks. * * \param stmt The stmt to be trasnformed * \return Transformed stmt. */ -Stmt MarkDevice(Stmt stmt); +Stmt DecorateDeviceScope(Stmt stmt); /*! * \brief Make an user callable API LoweredFunc. diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 7c7c5ff6dd80..575535f26e81 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -154,6 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyGPUCode); -REGISTER_PASS1(MarkDevice); +REGISTER_PASS1(DecorateDeviceScope); } // namespace ir } // namespace tvm diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc index 80ab2d572586..c5fb0dd1b8f3 100644 --- a/src/pass/detect_device.cc +++ b/src/pass/detect_device.cc @@ -9,21 +9,12 @@ namespace tvm { namespace ir { - -class DetectDevice : public IRMutator { - public: - DetectDevice() {} - Stmt Detect(Stmt stmt) { - Stmt body = AttrStmt::make(make_zero(Int(32)), - ir::attr::device_scope, - 0, - stmt); - return body; - } -}; - -Stmt MarkDevice(Stmt stmt) { - return DetectDevice().Detect(stmt); +Stmt DecorateDeviceScope(Stmt stmt) { + Stmt body = AttrStmt::make(make_zero(Int(32)), + ir::attr::device_scope, + 0, + stmt); + return body; } } // namespace ir diff --git a/tests/python/unittest/test_pass_detect_device.py b/tests/python/unittest/test_pass_decorate_device_scope.py similarity index 85% rename from tests/python/unittest/test_pass_detect_device.py rename to tests/python/unittest/test_pass_decorate_device_scope.py index c5dbaae7a593..1d9eb899a642 100644 --- a/tests/python/unittest/test_pass_detect_device.py +++ b/tests/python/unittest/test_pass_decorate_device_scope.py @@ -1,6 +1,6 @@ import tvm -def test_detect_device(): +def test_decorate_device(): m = tvm.var('m') l = tvm.var('l') A = tvm.placeholder((m, l), name='A') @@ -16,11 +16,11 @@ def test_detect_device(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) stmt1 = tvm.ir_pass.Simplify(stmt) - stmt2 = tvm.ir_pass.MarkDevice(stmt1) + stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1) assert isinstance(stmt2, tvm.stmt.AttrStmt) assert stmt2.attr_key == "device_scope" assert stmt1 == stmt2.body - + if __name__ == "__main__": - test_detect_device() + test_decorate_device() From 190f1d4f836c2bb57de53584ebfcbd7701b9ece5 Mon Sep 17 00:00:00 2001 From: kun-zh Date: Tue, 30 Oct 2018 22:30:30 +0800 Subject: [PATCH 6/6] fix a bug in inject_virtual_thread --- src/pass/inject_virtual_thread.cc | 2 +- .../unittest/test_pass_inject_vthread.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index f1aed09d47da..3fc2e24fb4f1 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -321,7 +321,7 @@ class VTInjector : public IRMutator { CHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->Mutate(op->then_case); Stmt else_case; - if (else_case.defined()) { + if (op->else_case.defined()) { int temp = max_loop_depth_; max_loop_depth_ = 0; else_case = this->Mutate(op->else_case); diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index 502a55574df0..16f4c4652a3d 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -60,7 +60,26 @@ def get_vthread(name): assert stmt.body.body.body.body.body.body.extents[0].value == 2 assert len(stmt.body.body.body.body.body.body.extents) == 3 +def test_vthread_if_then_else(): + nthread = 2 + tx = tvm.thread_axis("vthread") + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + with ib.for_range(0, 100) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + B = ib.allocate("float32", 128, name="B", scope="shared") + with ib.if_scope(i == 0): + B[i] = A[i * nthread + tx] + with ib.else_scope(): + B[i] = A[i * nthread + tx] + 1 + with ib.if_scope(i == 0): + B[i] = A[i * nthread + tx] + 2 + stmt = ib.get() + stmt = tvm.ir_pass.InjectVirtualThread(stmt) + assert stmt.body.body.body.first.else_case != None + assert stmt.body.body.body.rest.else_case == None if __name__ == "__main__": test_vthread_extern() test_vthread() + test_vthread_if_then_else()