diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc deleted file mode 100644 index d1e24b94a32f..000000000000 --- a/src/tir/pass/hoist_if_then_else.cc +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file hoist_if_then_else.cc - */ -#include -#include -#include -#include - -#include -#include -#include - -#include "../../arith/interval_set.h" -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace tir { - -using HoistMap = std::unordered_map>; -using VarMap = std::unordered_map>; - -/* - * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. - * For example, given the following block: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * if (likely(i*2 < 4)) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. - * Then we hoist IfThenElse stmt by one For stmt each step: - * - * Step 1: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * if (likely(i*2 < 4)) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * Step 2: - * for (i = 0; i < 3; i++) - * if (likely(i*2 < 4)) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * In this pass, we only continue detecting possible hoisting chance when visiting For, - * IfThenElse or AttrStmt Node. For example, for the following block: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * A[i + j] = A[i + j] - 1 - * for (k = 0; k < 5; k++) - * if (likely(i*2 < 4)) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * Only the For with k variable will be considered and the resulting stmt would be: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * A[i + j] = A[i + j] - 1 - * if (likely(i*2 < 4)) - * for (k = 0; k < 5; k++) - * A[3*i+2j+k] = B[7*i+3j+k] - * - * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following - * block won't be optimized: - * for (i = 0; i < 3; i++) - * for (j = 0; j < 4; j++) - * for (k = 0; k < 5; k++) - * if (likely(i*2 < 4)) - * A[3*i+2j+k] = B[7*i+3j+k] - * if (likely(j > 2)) - * A[i+j+k] = B[i+j+k] - * - */ -class IfThenElseHoist { - public: - Stmt VisitAndMutate(const Stmt& stmt) { - SelectCandidates(stmt); - LocateTopFor(); - return PostOrderMutate(stmt); - } - - private: - void SelectCandidates(const Stmt& stmt); - void LocateTopFor(); - Stmt PostOrderMutate(const Stmt& stmt); - size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); - Stmt HoistIf(const Stmt& if_stmt); - - // Map of all For nodes to all child IfThenElse nodes. - HoistMap for2if_map_; - // Map of all IfThenElse nodes to all For nodes which are loop invariant. - HoistMap if2for_map_; - // Map of highest loop invariant For to child IfThenElse. - HoistMap top_for_var_map_; - // Map of original For to list of update For nodes. - HoistMap for_tracking_map_; - // Map of all IfThenElse nodes to condition variable nodes. - VarMap cond_var_map_; - // List of For nodes added in post order DFS visiting. - std::vector ordered_for_list_; -}; - -// Check whether a given IfThenElse stmt is the first one appearing -// in a For stmt. -bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { - std::vector if_node_list; - const ForNode* for_node = for_stmt.as(); - CHECK(for_node); - CHECK(if_stmt.as()); - - PostOrderVisit(for_node->body, [&](const ObjectRef& node) { - if (node.as()) { - if_node_list.push_back(node.get()); - } - }); - return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back(); -} - -// Update upper level For node when current For node is modified. -// With this function we only need to visit and mutate top level For node -// in the main VisitAndMutate function. -Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { - const Object* top_for_node; - const ForNode* parent_for_node = parent_for_stmt.as(); - CHECK(parent_for_node); - CHECK(new_if_stmt.as()); - - PostOrderVisit(parent_for_node->body, [&](const ObjectRef& node) { - if (node.as()) { - top_for_node = node.get(); - } - }); - - PackedFunc replace_target_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { - const ObjectRef& current_for = args[0]; - if (current_for.get() == top_for_node) { - *ret = new_if_stmt; - } - }); - - return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"tir.For"}); -} - -// Remove IfThenElse node from a For node. -// A pair of For nodes will be generated. -std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { - Stmt then_for; - Stmt else_for; - CHECK(if_stmt.as()); - - PackedFunc replace_then_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->then_case; - } - }); - - PackedFunc replace_else_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->else_case; - } - }); - - then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"tir.IfThenElse"}); - if (if_stmt.as()->else_case.defined()) { - else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array{"tir.IfThenElse"}); - } - - return std::make_pair(then_for, else_for); -} - -// Locate all For nodes and capture child IfThenElse nodes. -void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { - PostOrderVisit(stmt, [&](const ObjectRef& node) { - const ForNode* for_node = node.as(); - if (!for_node) return; - - std::queue tracker; - tracker.push(for_node->body); - Stmt for_stmt = Downcast(node); - for2if_map_.insert({for_stmt.get(), std::vector()}); - while (!tracker.empty()) { - Stmt head = tracker.front(); - tracker.pop(); - if (head->IsInstance()) { - for (const auto& if_stmt : for2if_map_.at(head.get())) { - for2if_map_[for_stmt.get()].push_back(if_stmt); - } - } else if (head->IsInstance()) { - const AttrStmtNode* attr_node = head.as(); - tracker.push(attr_node->body); - } else if (head->IsInstance()) { - for2if_map_[for_stmt.get()].push_back(head); - const IfThenElseNode* if_node = head.as(); - tracker.push(if_node->then_case); - if (if_node->else_case.defined()) { - tracker.push(if_node->else_case); - } - - // Record condition variables. - if (!cond_var_map_.count(head.get())) { - std::unordered_set new_var_set; - cond_var_map_.insert({head.get(), new_var_set}); - PostOrderVisit(if_node->condition, [&](const ObjectRef& cond_node) { - if (cond_node.as()) { - cond_var_map_[head.get()].insert(cond_node.get()); - } - }); - } - } else { - continue; - } - } - ordered_for_list_.emplace_back(Downcast(node)); - }); -} - -// For each IfThenElse node, find the highest For node which -// meets loop invariant condition. -void IfThenElseHoist::LocateTopFor() { - std::unordered_map if_position_map; - std::unordered_set top_for_var_set; - - // Create IfThenElse -> For map. - for (const Stmt& for_stmt : ordered_for_list_) { - std::vector if_list = for2if_map_[for_stmt.get()]; - const ForNode* for_node = for_stmt.as(); - CHECK(for_node); - top_for_var_map_.insert({for_node->loop_var.get(), if_list}); - for (const Stmt& if_stmt : if_list) { - const Object* if_node = if_stmt.get(); - if2for_map_[if_node].push_back(for_stmt); - } - } - - // Locate the highest For node which is loop invariant. - for (const auto& item : if2for_map_) { - Stmt top_for; - const Object* if_stmt = item.first; - std::vector for_list = item.second; - for (size_t i = 0; i < for_list.size(); ++i) { - const Stmt& for_stmt = for_list.at(i); - const ForNode* for_node = for_stmt.as(); - CHECK(for_node); - std::vector new_for_list{for_stmt}; - for_tracking_map_.insert({for_stmt.get(), new_for_list}); - if (cond_var_map_[if_stmt].count(for_node->loop_var.get())) { - std::vector updated_for_list(for_list.begin(), for_list.begin() + i); - if2for_map_[if_stmt] = updated_for_list; - break; - } else { - top_for = for_stmt; - } - } - if (top_for.as()) { - if_position_map.insert({if_stmt, top_for}); - } - } - - for (const auto& item : if_position_map) { - top_for_var_set.insert(item.second.as()->loop_var.get()); - } - - std::vector removed_for_var_list; - for (const auto& item : top_for_var_map_) { - const Object* top_for_var = item.first; - std::vector if_list = item.second; - if (!top_for_var_set.count(top_for_var)) { - removed_for_var_list.push_back(top_for_var); - } else { - std::vector actual_if_list; - for (const Stmt& if_stmt : if_list) { - if (if_position_map.count(if_stmt.get())) { - actual_if_list.push_back(if_stmt); - } - } - top_for_var_map_[top_for_var] = actual_if_list; - } - } - for (const Object* top_for_var : removed_for_var_list) { - top_for_var_map_.erase(top_for_var); - } -} - -// When we try to mutate a For node, some child For nodes can have already -// been mutated. This function is to get the updated For node and further -// hoisting can be done based on this new node. -// We keep all For nodes tracing in for_tracking_map_. When we get a -// hoisted IfThenElse, we match it with tracing For nodes to pick -// the updated one. -size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) { - std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; - size_t updated_for_idx = 0; - for (size_t i = 0; i < tracked_for_list.size(); ++i) { - const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i); - if (is_first_if(current_for, if_stmt)) { - updated_for_idx = tracked_for_list.size() - 1 - i; - break; - } - } - return updated_for_idx; -} - -// Hoist an IfThenElse node as high as possible. -// This function iterates on all candidate For nodes. For each For node, -// it first removes IfThenElse nodes. Then it generates a new IfThenElse -// node using mutated For nodes. -Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { - Stmt new_if = if_stmt; - - for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { - const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); - size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); - const Stmt& updated_for_node = for_tracking_map_[for_stmt.get()].at(updated_for_idx); - auto generated_for_pair = RemoveIf(updated_for_node, new_if); - const Stmt& then_for = generated_for_pair.first; - const Stmt& else_for = generated_for_pair.second; - - for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; - - if (else_for.get()) { - for_tracking_map_[for_stmt.get()].push_back(else_for); - } - - const IfThenElseNode* new_if_node = new_if.as(); - CHECK(new_if_node); - new_if = IfThenElse(new_if_node->condition, then_for, else_for); - if (i < if2for_map_[if_stmt.get()].size() - 1) { - const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); - const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx); - Stmt update_for_stmt = update_for(actual_next_for, new_if); - - for_tracking_map_[original_next_for.get()].at(updated_for_idx) = update_for_stmt; - } - } - return new_if; -} - -// Mutate For nodes in post order DFS manner. -Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { - PackedFunc replace_top_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { - const ObjectRef& current_for = args[0]; - const ForNode* for_node = current_for.as(); - if (!for_node) return; - - if (top_for_var_map_.count(for_node->loop_var.get())) { - std::vector new_if_list; - for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(HoistIf(if_stmt)); - } - - const IfThenElseNode* next_if_node; - const IfThenElseNode* current_if_node = new_if_list.back().as(); - Stmt new_for = Stmt(); - for (size_t i = new_if_list.size() - 1; i > 0; --i) { - CHECK(current_if_node); - const Stmt current_if_stmt = IfThenElse( - current_if_node->condition, current_if_node->then_case, current_if_node->else_case); - next_if_node = new_if_list[i - 1].as(); - CHECK(next_if_node); - new_for = IfThenElse(next_if_node->condition, current_if_stmt, next_if_node->else_case); - current_if_node = new_for.as(); - } - - if (!new_for.get()) { - const IfThenElseNode* first_if_node = new_if_list[0].as(); - CHECK(first_if_node); - new_for = IfThenElse(first_if_node->condition, first_if_node->then_case, - first_if_node->else_case); - } - *ret = new_for; - } - }); - return IRTransform(stmt, nullptr, replace_top_for, Array{"tir.For"}); -} - -Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } - -TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse); - -} // namespace tir -} // namespace tvm diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_pass_hoist_if.py deleted file mode 100644 index 80e93a706ee7..000000000000 --- a/tests/python/unittest/test_tir_pass_hoist_if.py +++ /dev/null @@ -1,186 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - - -var_list = [] - -def verify_structure(stmt, expected_struct): - node_dict = {} - struct = {} - def _extract_vars(op): - global var_list - if isinstance(op, tvm.tir.Var): - var_list.append(op.name) - - def _visit(op): - key = op - if isinstance(op, tvm.tir.IfThenElse): - global var_list - tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) - val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))] - var_list.clear() - elif isinstance(op, tvm.tir.For): - val = [(op.body,), ("tir.For", op.loop_var.name)] - elif isinstance(op, tvm.tir.AttrStmt): - val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))] - else: - return - node_dict[key] = val - - tvm.tir.stmt_functor.post_order_visit(stmt, _visit) - for key, val in node_dict.items(): - struct[val[1]] = tuple(node_dict[child][1] if child in node_dict - else None for child in val[0]) - - assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ - % (expected_struct, struct) - var_list.clear() - -def test_basic(): - ib = tvm.tir.ir_builder.create() - l = te.var('l') - m = te.var('m') - n = te.var('n') - - with ib.for_range(0, l, "i") as i: - with ib.for_range(0, m, "j") as j: - with ib.for_range(0, n, "k") as k: - with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.tir.Evaluate(m)) - with ib.else_scope(): - ib.emit(tvm.tir.Evaluate(n)) - - stmt = ib.get() - new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), - ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')), - ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} - verify_structure(new_stmt, expected_struct) - -def test_no_else(): - ib = tvm.tir.ir_builder.create() - l = te.var('l') - m = te.var('m') - n = te.var('n') - - with ib.for_range(0, l, "i") as i: - with ib.for_range(0, m, "j") as j: - with ib.for_range(0, n, "k") as k: - with ib.if_scope(ib.likely(i < 2)): - ib.emit(tvm.tir.Evaluate(m)) - - stmt = ib.get() - new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), - ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), - ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} - verify_structure(new_stmt, expected_struct) - -def test_attr_stmt(): - ib = tvm.tir.ir_builder.create() - dshape = (32, 64) - data = ib.pointer("float32", name="data") - l = te.var('l') - m = te.var('m') - n = te.var('n') - - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", dshape[0]) - ib.scope_attr(bx, "thread_extent", dshape[1]) - with ib.for_range(0, l, "i") as i: - with ib.for_range(0, m, "j") as j: - with ib.for_range(0, n, "k") as k: - with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): - data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 - with ib.else_scope(): - data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 - - stmt = ib.get() - new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')), - ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),), - ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),), - ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)} - verify_structure(new_stmt, expected_struct) - -def test_nested_for(): - ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") - - - with ib.for_range(0, 5, "i") as i: - with ib.for_range(0, 10, "j") as j: - with ib.if_scope(i >= 3): - data[i * 3 + j] = data[i * 3 + j] + 0.5 - with ib.for_range(0, 15, "k") as k: - with ib.for_range(0, 20, "l") as l: - with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): - data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 - with ib.else_scope(): - data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 - - stmt = ib.get() - new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.For', 'l'): (('tir.IfThenElse', ('i', 'j')),), - ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For', 'j'): (None,), ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), - ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} - verify_structure(new_stmt, expected_struct) - -def test_if_block(): - ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") - n = te.var("n") - - - with ib.for_range(0, 5, "i") as i: - with ib.for_range(0, 10, "j") as j: - with ib.if_scope(i >= 3): - data[i * 3 + j] = data[i * 3 + j] + 0.5 - with ib.for_range(0, 15, "k") as k: - with ib.for_range(0, 20, "l") as l: - with ib.if_scope(tvm.tir.any(i < 4, j >= 8)): - data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 - with ib.else_scope(): - data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 - with ib.if_scope(j <5): - data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 - - - with ib.for_range(0, 5, "i") as i: - with ib.for_range(0, 10, "j") as j: - with ib.for_range(0, 15, "k") as k: - with ib.if_scope(n >= 3): - data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 - - stmt = ib.get() - new_stmt = tvm.testing.HoistIfThenElse(stmt) - expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None), - ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),), - ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),), - ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)} - verify_structure(new_stmt, expected_struct) - - -if __name__ == "__main__": - test_basic() - test_no_else() - test_attr_stmt() - test_nested_for() - test_if_block()