Skip to content

Commit

Permalink
[REFACTOR][TE] Inline -> te/schedule/operation_inline.h (apache#5386)
Browse files Browse the repository at this point in the history
Rationale: inline is a transformation used in te to
rewrite its internal expressions. It is not a formal IRModule->IRModule transform pass.

Also removed the python test as the test is covered by stage.compute_inline.
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent ad399ac commit 8908789
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 84 deletions.
16 changes: 0 additions & 16 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,6 @@ Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
*/
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);

/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param f The function reference to be inlined
* \param args The arguments variable of the function.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
PrimExpr body);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
22 changes: 12 additions & 10 deletions src/tir/pass/inline.cc → src/te/schedule/operation_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,31 @@
*/

/*!
* \file inline.cc
* \file operation_inline.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
#include "operation_inline.h"

namespace tvm {
namespace tir {
namespace te {

// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline final : public StmtExprMutator {
class OperationInliner final : public StmtExprMutator {
public:
IRInline(FunctionRef f, Array<Var> args, PrimExpr body)
: f_(f), args_(args), body_(body) {}
OperationInliner(Operation op, Array<Var> args, PrimExpr body)
: operation_(op), args_(args), body_(body) {}

PrimExpr VisitExpr_(const CallNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();

if (op->func == f_) {
if (op->func.same_as(operation_)) {
CHECK_EQ(op->value_index, 0);
expr = body_;
CHECK_EQ(args_.size(), op->args.size());
Expand Down Expand Up @@ -68,20 +70,20 @@ class IRInline final : public StmtExprMutator {
}

private:
FunctionRef f_;
Operation operation_;
Array<Var> args_;
PrimExpr body_;
};

Stmt Inline(Stmt stmt,
FunctionRef f,
Operation f,
Array<Var> args,
PrimExpr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
Stmt ret = IRInline(f, args, body)(std::move(stmt));
Stmt ret = OperationInliner(f, args, body)(std::move(stmt));
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
} // namespace tir
} // namespace te
} // namespace tvm
51 changes: 51 additions & 0 deletions src/te/schedule/operation_inline.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file operation_inline.h
*/
#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_
#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
#include <tvm/te/tensor.h>

namespace tvm {
namespace te {

/*!
* \brief inline all calls of f in stmt.
*
* \param stmt The statement to apply inline optimization.
* \param op The op to be inlined.
* \param args The arguments variable of the function.
* \param body The definition body of the function.
* \return The result stmt
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(Stmt stmt,
Operation op,
Array<Var> args,
PrimExpr body);

} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_OPERATION_INLINE_H_
8 changes: 5 additions & 3 deletions src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "message_passing.h"
#include "operation_inline.h"

#include "../../tir/pass/ir_util.h"
#include "../../arith/compute_expr.h"

Expand Down Expand Up @@ -583,7 +585,7 @@ void InjectInline(ScheduleNode* sch) {
<< "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]),
PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][0]),
stage->op, args, body).as<tir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][0])) {
changed[j] = true;
Expand All @@ -599,7 +601,7 @@ void InjectInline(ScheduleNode* sch) {
}
} else {
for (size_t k = 0; k < new_body[j].size(); ++k) {
PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]),
PrimExpr new_value = Inline(tir::EvaluateNode::make(new_body[j][k]),
stage->op, args, body).as<tir::EvaluateNode>()->value;
if (!new_value.same_as(new_body[j][k])) {
new_body[j].Set(k, new_value);
Expand All @@ -611,7 +613,7 @@ void InjectInline(ScheduleNode* sch) {
if (!new_hybrid_body[j].defined()) {
new_hybrid_body[j] = hybrid->body;
}
Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body);
Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
if (!new_stmt.same_as(new_hybrid_body[j])) {
new_hybrid_body[j] = new_stmt;
hybrid_changed[j] = true;
Expand Down
1 change: 0 additions & 1 deletion src/tir/pass/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")

REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
Expand Down
54 changes: 0 additions & 54 deletions tests/python/unittest/test_tir_pass_inline.py

This file was deleted.

0 comments on commit 8908789

Please sign in to comment.