diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 592a79fb86d1c..3425910fe0138 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -148,22 +148,6 @@ Stmt Substitute(Stmt stmt, const Map& value_map); */ PrimExpr Substitute(PrimExpr expr, const Map& value_map); -/*! - * \brief inline all calls of f in stmt. - * - * \param stmt The statement to apply inline optimization. - * \param f The function reference to be inlined - * \param args The arguments variable of the function. - * \param body The definition body of the function. - * \return The result stmt - * - * \note All the passes in this file uses SSA form and outputs SSA form. - */ -Stmt Inline(Stmt stmt, - FunctionRef f, - Array args, - PrimExpr body); - /*! * \brief Verify if there is any argument bound to compact buffer. * diff --git a/src/tir/pass/inline.cc b/src/te/schedule/operation_inline.cc similarity index 84% rename from src/tir/pass/inline.cc rename to src/te/schedule/operation_inline.cc index 1b322964b8733..dfa9f604e9ac4 100644 --- a/src/tir/pass/inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -18,29 +18,30 @@ */ /*! - * \file inline.cc + * \file operation_inline.cc */ #include #include #include #include +#include "operation_inline.h" namespace tvm { -namespace tir { +namespace te { // inliner to inline a function // the result may not be SSA, // ConvertSSA need to be applied after this pass -class IRInline final : public StmtExprMutator { +class OperationInliner final : public StmtExprMutator { public: - IRInline(FunctionRef f, Array args, PrimExpr body) - : f_(f), args_(args), body_(body) {} + OperationInliner(Operation op, Array args, PrimExpr body) + : operation_(op), args_(args), body_(body) {} PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - if (op->func == f_) { + if (op->func.same_as(operation_)) { CHECK_EQ(op->value_index, 0); expr = body_; CHECK_EQ(args_.size(), op->args.size()); @@ -68,20 +69,20 @@ class IRInline final : public StmtExprMutator { } private: - FunctionRef f_; + Operation operation_; Array args_; PrimExpr body_; }; Stmt Inline(Stmt stmt, - FunctionRef f, + Operation f, Array args, PrimExpr body) { CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; - Stmt ret = IRInline(f, args, body)(std::move(stmt)); + Stmt ret = OperationInliner(f, args, body)(std::move(stmt)); if (ret.same_as(stmt)) return ret; return ConvertSSA(ret); } -} // namespace tir +} // namespace te } // namespace tvm diff --git a/src/te/schedule/operation_inline.h b/src/te/schedule/operation_inline.h new file mode 100644 index 0000000000000..d7d55cc660272 --- /dev/null +++ b/src/te/schedule/operation_inline.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file operation_inline.h + */ +#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_ +#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace te { + +/*! + * \brief inline all calls of f in stmt. + * + * \param stmt The statement to apply inline optimization. + * \param op The op to be inlined. + * \param args The arguments variable of the function. + * \param body The definition body of the function. + * \return The result stmt + * + * \note All the passes in this file uses SSA form and outputs SSA form. + */ +Stmt Inline(Stmt stmt, + Operation op, + Array args, + PrimExpr body); + +} // namespace te +} // namespace tvm +#endif // TVM_TE_SCHEDULE_OPERATION_INLINE_H_ diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 99f2fb9efd877..48a27d17b7008 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -26,6 +26,8 @@ #include #include #include "message_passing.h" +#include "operation_inline.h" + #include "../../tir/pass/ir_util.h" #include "../../arith/compute_expr.h" @@ -583,7 +585,7 @@ void InjectInline(ScheduleNode* sch) { << "The Reduce inputs of ComputeOp should " << "have the same attribute except value_index"; } - PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]), + PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body).as()->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; @@ -599,7 +601,7 @@ void InjectInline(ScheduleNode* sch) { } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]), + PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body).as()->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); @@ -611,7 +613,7 @@ void InjectInline(ScheduleNode* sch) { if (!new_hybrid_body[j].defined()) { new_hybrid_body[j] = hybrid->body; } - Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body); + Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body); if (!new_stmt.same_as(new_hybrid_body[j])) { new_hybrid_body[j] = new_stmt; hybrid_changed[j] = true; diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 60b5bd9a7f9cb..2d970e0203e5c 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -97,7 +97,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") REGISTER_PASS(ConvertSSA); REGISTER_PASS(VerifySSA); -REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); diff --git a/tests/python/unittest/test_tir_pass_inline.py b/tests/python/unittest/test_tir_pass_inline.py deleted file mode 100644 index ad0591d3a7c16..0000000000000 --- a/tests/python/unittest/test_tir_pass_inline.py +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_inline(): - m = te.size_var('m') - A = te.placeholder((m,), name='A') - T = te.compute((m,), lambda i,: A[i] + 10, name='T') - stmt = tvm.tir.Evaluate(T[10] + 11 * T[100]) - stmt = tvm.tir.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) - print(stmt) - assert(tvm.tir.ir_pass.VerifySSA(stmt)) - - try: - # pass in int array(wrong argument type) - # must raise an error - stmt = tvm.tir.ir_pass.Inline( - T.op, [1,2,3], T.op.body, stmt) - assert False - except tvm.error.TVMError: - pass - -def test_inline2(): - m = te.size_var('m') - A = te.placeholder((m,), name='A') - T = te.compute((m,), lambda i,: A[i] + 10, name='T') - stmt = tvm.tir.Evaluate(te.exp(T[10]) + 11 * T[100]) - stmt = tvm.tir.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) - def check(op): - if isinstance(op, tvm.tir.Call): - assert op.func != T.op - tvm.tir.ir_pass.PostOrderVisit(stmt, check) - - -if __name__ == "__main__": - test_inline2() - test_inline()