Skip to content

Commit

Permalink
Refactor to use IsOp utility
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Dec 28, 2019
1 parent a55d119 commit 266b823
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 102 deletions.
5 changes: 2 additions & 3 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value);
}


/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "primitive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* matches the form of primitive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
Expand Down
15 changes: 9 additions & 6 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"

#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
Expand All @@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
Expand All @@ -38,7 +41,6 @@
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -102,7 +104,7 @@ class ScheduleGetter :
public ExprFunctor<Array<Tensor>(const Expr&)> {
public:
explicit ScheduleGetter(Target target)
: target_(target) {}
: target_(target), device_copy_op_(Op::Get("device_copy")) {}

std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Expand Down Expand Up @@ -250,11 +252,9 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
// Check if the op is a device copy op.
bool is_copy_op = op.same_as(Op::Get("device_copy"));
Array<Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (is_copy_op) {
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0));
Expand Down Expand Up @@ -282,7 +282,7 @@ class ScheduleGetter :
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (is_copy_op) {
if (op == device_copy_op_) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
Expand Down Expand Up @@ -332,6 +332,9 @@ class ScheduleGetter :
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
Array<Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};

// Creates shape function from functor.
Expand Down
23 changes: 14 additions & 9 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,12 @@ class Interpreter :
public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> {
public:
Interpreter(Module mod,
DLContext context,
Target target)
: mod_(mod), context_(context), target_(target) {
Interpreter(Module mod, DLContext context, Target target)
: mod_(mod),
context_(context),
target_(target),
debug_op_(Op::Get("debug")),
shape_of_op_(Op::Get("shape_of")) {
engine_ = CompileEngine::Global();
}

Expand All @@ -263,7 +265,7 @@ class Interpreter :
stack_.current_frame().locals.Set(id, v);
}

inline Value Lookup(const Var& local) {
Value Lookup(const Var& local) {
return stack_.Lookup(local);
}

Expand Down Expand Up @@ -307,7 +309,7 @@ class Interpreter :
return TupleValueNode::make(values);
}

inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);

Expand Down Expand Up @@ -454,9 +456,9 @@ class Interpreter :

Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();
const auto* call_node = func->body.as<CallNode>();

if (call_node && call_node->op == Op::Get("debug")) {
if (call_node && call_node->op == debug_op_) {
auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]);

Expand Down Expand Up @@ -540,7 +542,7 @@ class Interpreter :
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) {
if (call_node->op == shape_of_op_) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
Expand Down Expand Up @@ -782,6 +784,9 @@ class Interpreter :
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
const Op& shape_of_op_;
};


Expand Down
12 changes: 8 additions & 4 deletions src/relay/pass/canonicalize_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ namespace relay {
// \endcode
class CastCanonicalizer : public ExprMutator {
public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}

Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");

Expand Down Expand Up @@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {

private:
std::unordered_map<const Node*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;


Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor

static auto& cast = Op::Get("cast");
Expr new_expr = this->VisitExpr(e);

if (const CallNode* call = e.as<CallNode>()) {
if (call->op.same_as(cast)) {
if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
CHECK(from_type);
Expand All @@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
CHECK(new_call->op.same_as(cast));
CHECK(new_call->op == cast_op_);
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
new_call->type_args);
}
Expand Down
10 changes: 8 additions & 2 deletions src/relay/pass/canonicalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
Expand All @@ -33,10 +34,11 @@ namespace relay {

class BiasAddSimplifier : public ExprMutator {
public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}

Expr VisitExpr_(const CallNode* n) {
static const Op& bias_add = Op::Get("nn.bias_add");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(bias_add)) {
if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
Expand All @@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
}
return new_n;
}

private:
// Cache the bias_add for equivalence checking.
const Op& bias_add_op_;
};

Expr CanonicalizeOps(const Expr& e) {
Expand Down
24 changes: 12 additions & 12 deletions src/relay/pass/combine_parallel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,38 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
#include "expr_subst.h"
#include "pattern_util.h"
#include "combine_parallel_op.h"


namespace tvm {
namespace relay {

BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
BranchGroupFinder::BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name),
: cached_op_(op),
fis_supported_op_(fis_supported_op),
fare_compatible_ops_(fare_compatible_ops) {
}

std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);

this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : op_roots_) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue;
if (child->op != cached_op_) continue;

auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
Expand Down Expand Up @@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
}

void BranchGroupFinder::VisitExpr_(const CallNode* n) {
const Op& op = Op::Get(op_name_);
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(op) && fis_supported_op_(n)) {
if (n->op == cached_op_ && fis_supported_op_(n)) {
op_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
Expand All @@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
}

ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
: op_name_(op_name),
: cached_op_(Op::Get(op_name)),
min_num_branches_(min_num_branches) {
}

Expr ParallelOpCombiner::Combine(const Expr& expr) {
auto groups = BranchGroupFinder(op_name_,
auto groups = BranchGroupFinder(cached_op_,
[&](const CallNode* n) {
return IsSupportedOp(n);
},
Expand Down
12 changes: 6 additions & 6 deletions src/relay/pass/combine_parallel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
public:
/*
* \brief Constructor
* \param op_name name of op to start each group
* \param op The op that indicates the start of each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
BranchGroupFinder(const std::string& op_name,
BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);

Expand All @@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
std::vector<Group> Find(const Expr& expr);

private:
/* \brief name of op to find parallel branches for */
std::string op_name_;
/* \brief Cache the op for finding parallel branches */
const Op& cached_op_;

/* \brief function to return true if op is eligible to be combined,
* false otherwise
Expand Down Expand Up @@ -205,8 +205,8 @@ class ParallelOpCombiner {
ExprSubstMap* subst_map) = 0;

private:
/* \brief name of op to be combined */
std::string op_name_;
/* \brief Cache the op to be combined */
const Op& cached_op_;

/* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_;
Expand Down
Loading

0 comments on commit 266b823

Please sign in to comment.