From 6d10ac9c477e0e844a3f5ec075d7f540295eefc4 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 4 Aug 2020 02:20:11 +0530 Subject: [PATCH] [TIR][Transform] HoistIfThenElse added (#6066) * [TIR][Transform] HoistIfThenElse added * lint error resolved * Pass position changed * pylint error resolved * CI issues resolved * Frontend tflite test case failure resolved * [1] Review comment handled * [2] Review comment handled * [3] Review comment handled * Lint error resolved --- include/tvm/tir/transform.h | 8 + python/tvm/driver/build_module.py | 1 + python/tvm/tir/transform/transform.py | 9 + src/tir/transforms/hoist_if_then_else.cc | 365 ++++++++++++++++++ tests/python/unittest/test_te_build_lower.py | 2 +- .../unittest/test_tir_transform_hoist_if.py | 268 +++++++++++++ 6 files changed, 652 insertions(+), 1 deletion(-) create mode 100644 src/tir/transforms/hoist_if_then_else.cc create mode 100644 tests/python/unittest/test_tir_transform_hoist_if.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5e04838f7cd3a..f31e515c7913c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -338,6 +338,14 @@ TVM_DLL Pass BF16Legalize(); */ TVM_DLL Pass PointerValueTypeRewrite(); +/*! + * \brief Hoist loop-invariant IfThenElse nodes to + * outside the elligible loops. + * + * \return The pass. + */ +TVM_DLL Pass HoistIfThenElse(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index b10700042260d..663a17a72a16a 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -179,6 +179,7 @@ def lower(sch, tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), + tvm.tir.transform.HoistIfThenElse(), ] pass_list += lower_phase1 diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 86e7a33ad8cb0..d2f5acd199e61 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -499,3 +499,12 @@ def VerifyMemory(): The result pass """ return _ffi_api.VerifyMemory() + +def HoistIfThenElse(): + """Hoist loop-invariant IfThenElse nodes to outside the elligible loops. + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.HoistIfThenElse() diff --git a/src/tir/transforms/hoist_if_then_else.cc b/src/tir/transforms/hoist_if_then_else.cc new file mode 100644 index 0000000000000..f58eb965584d7 --- /dev/null +++ b/src/tir/transforms/hoist_if_then_else.cc @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file hoist_if_then_else.cc + */ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../arith/interval_set.h" +#include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" + +namespace tvm { +namespace tir { + +using VarForMap = std::unordered_map; +using HoistForIfTuple = std::tuple; + +/* + * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. + * For example, given the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. + * Then we hoist IfThenElse stmt by one For stmt each step: + * + * Step 1: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Step 2: + * for (i = 0; i < 3; i++) + * if (likely(i*2 < 4)) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * In this pass, we only continue detecting possible hoisting chance when visiting For, + * IfThenElse or AttrStmt Node. For example, for the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Only the For with k variable will be considered and the resulting stmt would be: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following + * block won't be optimized: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * if (likely(j > 2)) + * A[i+j+k] = B[i+j+k] + * + */ + +// Select potential candidate IRs that can be hoisted. +class HoistCandidateSelector final : public StmtExprVisitor { + public: + HoistCandidateSelector() { InitRecorder(); } + + void VisitStmt_(const ForNode* op) final { + // If already recording complete, + // then stop tracing + if (RecordingComplete()) { + return; + } + + // Check if it is first for loop, then start the recorder + StartOrAddRecord(op); + StmtExprVisitor::VisitStmt_(op); + RemoveRecord(op); + } + + void VisitStmt_(const SeqStmtNode* op) final { + // If SeqStmt is encountered in the middle of recording + // then need to purge all, as it can not be hoisted + if (IsRecordingOn()) { + ResetRecorder(); + } + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode* op) final { + // Maintain list of all vars in AttrStmt + // To stop hoisting if any of the block variables are used. + // + // NOTE: If in future + // hoisting is required for any specific case, + // then add exception to only those case + // rather than allowing for all. + UpdateAttrVarList(op); + StmtExprVisitor::VisitStmt_(op); + RemoveAttrVarList(op); + } + + void VisitStmt_(const IfThenElseNode* op) final { + if (!IsRecordingOn()) { + StmtExprVisitor::VisitStmt_(op); + return; + } + + is_if_cond_ = true; + StmtExprVisitor::VisitExpr(op->condition); + is_if_cond_ = false; + + if (CheckValidIf()) { + // Check corresponding for loop + bool match_found = false; + size_t match_for_loop_pos = 0; + for (auto var : if_var_list_) { + for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) { + if (ordered_for_list_[i] == var_for_map_[var]) { + if (match_for_loop_pos < i) { + match_for_loop_pos = i; + } + match_found = true; + break; + } + } + } + // If none of the for loop has the matching loop variable as if condition, + // then the if node need to be hoisted on top of all, provided no parent loop exists. + int target_for_pos = match_found ? match_for_loop_pos + 1 : 0; + + // Check if target for loop is not the parent of current if node + if (!IsParentForLoop(target_for_pos)) { + StopAndAddRecord(ordered_for_list_[target_for_pos], op); + if_var_list_.clear(); + return; + } + } + + if_var_list_.clear(); + StmtExprVisitor::VisitStmt_(op); + StopRecording(); + } + + void VisitExpr_(const VarNode* op) final { + if (is_if_cond_) { + if_var_list_.emplace_back(op); + } + } + + HoistForIfTuple hoist_for_if_recorder; + + void ResetRecorder() { + if (is_recorder_on_) { + CHECK_GT(ordered_for_list_.size(), 0); + is_recorder_on_ = false; + } + ordered_for_list_.clear(); + var_for_map_.clear(); + hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); + } + + bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); } + + const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); } + + const IfThenElseNode* GetTargetIfNode() { return std::get<2>(hoist_for_if_recorder); } + + private: + bool CheckValidIf() { + // If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop + // hoisting + return ((!if_var_list_.empty()) && (!CheckAttrVar())); + } + + bool IsParentForLoop(int loop_pos) { + // Check if the loop position is higher than the parent loop position + for (auto var : if_var_list_) { + if (GetParentLoopPos(var_for_map_[var]) >= loop_pos) { + return true; + } + } + return false; + } + + int GetParentLoopPos(const Object* node) { + for (size_t i = 0; i < ordered_for_list_.size(); ++i) { + if (ordered_for_list_[i] == node) { + return i; + } + } + return -1; + } + + void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); } + + void StopRecording() { is_recorder_on_ = false; } + + bool IsRecordingOn() { return is_recorder_on_; } + + void StartOrAddRecord(const ForNode* op) { + is_recorder_on_ = true; + if (!var_for_map_.count(op->loop_var.get())) { + var_for_map_.insert({op->loop_var.get(), op}); + } + ordered_for_list_.emplace_back(op); + } + + void RemoveRecord(const ForNode* op) { + StopRecording(); + var_for_map_.erase(op->loop_var.get()); + if (ordered_for_list_.size() > 0) ordered_for_list_.pop_back(); + } + + void StopAndAddRecord(const ForNode* for_node, const IfThenElseNode* if_node) { + hoist_for_if_recorder = std::make_tuple(true, for_node, if_node); + StopRecording(); + } + + void UpdateAttrVarList(const AttrStmtNode* op) { + if (const auto* iv = op->node.as()) { + attr_var_list_.insert(iv->var.get()); + } else if (const auto* iv = op->node.as()) { + attr_var_list_.insert(iv); + } + } + + void RemoveAttrVarList(const AttrStmtNode* op) { + if (const auto* iv = op->node.as()) { + attr_var_list_.erase(iv->var.get()); + } else if (const auto* iv = op->node.as()) { + attr_var_list_.erase(iv); + } + } + + bool CheckAttrVar() { + for (auto var : if_var_list_) { + if (attr_var_list_.count(var)) { + return true; + } + } + return false; + } + + std::vector ordered_for_list_; + std::vector if_var_list_; + std::unordered_set attr_var_list_; + VarForMap var_for_map_; + + bool is_if_cond_{false}; + bool is_recorder_on_{false}; +}; + +class IfThenElseHoister : public StmtMutator { + public: + IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {} + + Stmt VisitAndMutate(Stmt stmt) { + hoist_selector_(stmt); + Stmt stmt_copy = std::move(stmt); + + while (hoist_selector_.RecordingComplete()) { + target_for_ = hoist_selector_.GetTargetForNode(); + target_if_ = hoist_selector_.GetTargetIfNode(); + + stmt_copy = operator()(stmt_copy); + + hoist_selector_.ResetRecorder(); + hoist_selector_(stmt_copy); + } + + // Support SSA Form + stmt_copy = ConvertSSA(stmt_copy); + return stmt_copy; + } + + Stmt VisitStmt_(const ForNode* op) final { + if ((!is_updating_) && (target_for_ == op)) { + is_updating_ = true; + is_then_case_ = true; + Stmt then_case = StmtMutator::VisitStmt_(op); + is_then_case_ = false; + Stmt else_case = Stmt(); + if (target_if_->else_case.defined()) { + else_case = StmtMutator::VisitStmt_(op); + } + is_updating_ = false; + return IfThenElse(target_if_->condition, then_case, else_case); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const IfThenElseNode* op) final { + if (is_updating_ && (target_if_ == op)) { + if (is_then_case_) { + return StmtMutator::VisitStmt(op->then_case); + } else if (op->else_case.defined()) { + return StmtMutator::VisitStmt(op->else_case); + } + } + return StmtMutator::VisitStmt_(op); + } + + private: + bool is_updating_{false}; + bool is_then_case_{false}; + HoistCandidateSelector hoist_selector_; + const ForNode* target_for_; + const IfThenElseNode* target_if_; +}; + +Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); } + +namespace transform { + +Pass HoistIfThenElse() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = HoistIfThenElse(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.HoistIfThenElse", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index b1d754605a464..1fc2fcdb08cd6 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -49,7 +49,7 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c])["main"].body - assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) + assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py new file mode 100644 index 0000000000000..4ca952af00d40 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +var_list = [] + +def verify_structure(stmt, expected_struct): + node_dict = {} + struct = {} + def _extract_vars(op): + global var_list + if isinstance(op, tvm.tir.Var): + var_list.append(op.name) + + def _visit(op): + key = op + if isinstance(op, tvm.tir.IfThenElse): + global var_list + tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))] + var_list.clear() + elif isinstance(op, tvm.tir.For): + val = [(op.body,), ("tir.For", op.loop_var.name)] + elif isinstance(op, tvm.tir.AttrStmt): + val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))] + else: + return + node_dict[key] = val + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + for key, val in node_dict.items(): + struct[val[1]] = tuple(node_dict[child][1] if child in node_dict + else None for child in val[0]) + + assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ + % (expected_struct, struct) + var_list.clear() + +def test_hoist_top_for(): + ib = tvm.tir.ir_builder.create() + l = te.var('l') + m = te.var('m') + n = te.var('n') + data = ib.pointer("float32", name="data") + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.tir.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.tir.Evaluate(n)) + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_hoist_multi_var_if(): + ib = tvm.tir.ir_builder.create() + l = te.var('l') + m = te.var('m') + n = te.var('n') + data = ib.pointer("float32", name="data") + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i + j < 2)): + ib.emit(tvm.tir.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.tir.Evaluate(n)) + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'k'): (None,), + ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')), + ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), + ('tir.For', 'i'): (('tir.For', 'j'),)} + verify_structure(new_stmt, expected_struct) + +def test_hoist_no_match_for(): + ib = tvm.tir.ir_builder.create() + l = te.var('l') + m = te.var('m') + n = te.var('n') + data = ib.pointer("float32", name="data") + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.tir.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.tir.Evaluate(n)) + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'k'): (None,), + ('tir.IfThenElse', ('i', )): (('tir.For', 'k'), ('tir.For', 'k')), + ('tir.For', 'j'): (None,), + ('tir.For', 'i'): (('tir.For', 'j'),)} + verify_structure(new_stmt, expected_struct) + +def test_no_else(): + ib = tvm.tir.ir_builder.create() + l = te.var('l') + m = te.var('m') + n = te.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.tir.Evaluate(m)) + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_attr_stmt(): + ib = tvm.tir.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = te.var('l') + m = te.var('m') + n = te.var('n') + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')), + ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),), + ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),), + ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)} + verify_structure(new_stmt, expected_struct) + +def test_nested_for(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'l'): (None,), ('tir.For', 'k'): (('tir.For', 'l'),), + ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')), + ('tir.For', 'j'): (None,), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_if_block(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + n = te.var("n") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + with ib.if_scope(j <5): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 15, "k") as k: + with ib.if_scope(n >= 3): + data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None), + ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),), + ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)} + verify_structure(new_stmt, expected_struct) + + +def test_multi_if(): + ib = tvm.tir.ir_builder.create() + data = ib.pointer("float32", name="data") + + with ib.for_range(0, 10, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 10, "k") as k: + with ib.if_scope(i >= 3): + with ib.if_scope(j >= 3): + data[i * 100 + j * 10 + k] = data[i * 100 + j * 10 + k] + 0.5 + + stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + new_stmt = tvm.tir.transform.HoistIfThenElse()(mod)["main"].body + expected_struct = {('tir.For', 'k'): (None,), + ('tir.IfThenElse', ('j',)): (('tir.For', 'k'), None), + ('tir.For', 'j'): (('tir.IfThenElse', ('j',)),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + + +if __name__ == "__main__": + test_hoist_top_for() + test_hoist_multi_var_if() + test_hoist_no_match_for() + test_no_else() + test_attr_stmt() + test_nested_for() + test_if_block() + test_multi_if() +