Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][refactor] Cache Op::Get in passes to reduce lookup overhead #4594

Merged
merged 2 commits into from
Dec 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value);
}


/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "primitive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* matches the form of primitive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
Expand Down
15 changes: 9 additions & 6 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"

#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
Expand All @@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
Expand All @@ -38,7 +41,6 @@
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -102,7 +104,7 @@ class ScheduleGetter :
public ExprFunctor<Array<Tensor>(const Expr&)> {
public:
explicit ScheduleGetter(Target target)
: target_(target) {}
: target_(target), device_copy_op_(Op::Get("device_copy")) {}

std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Expand Down Expand Up @@ -250,11 +252,9 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
// Check if the op is a device copy op.
bool is_copy_op = op.same_as(Op::Get("device_copy"));
Array<Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (is_copy_op) {
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0));
Expand Down Expand Up @@ -282,7 +282,7 @@ class ScheduleGetter :
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (is_copy_op) {
if (op == device_copy_op_) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
Expand Down Expand Up @@ -332,6 +332,9 @@ class ScheduleGetter :
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
Array<Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};

// Creates shape function from functor.
Expand Down
23 changes: 14 additions & 9 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,12 @@ class Interpreter :
public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> {
public:
Interpreter(Module mod,
DLContext context,
Target target)
: mod_(mod), context_(context), target_(target) {
Interpreter(Module mod, DLContext context, Target target)
: mod_(mod),
context_(context),
target_(target),
debug_op_(Op::Get("debug")),
shape_of_op_(Op::Get("shape_of")) {
engine_ = CompileEngine::Global();
}

Expand All @@ -263,7 +265,7 @@ class Interpreter :
stack_.current_frame().locals.Set(id, v);
}

inline Value Lookup(const Var& local) {
Value Lookup(const Var& local) {
return stack_.Lookup(local);
}

Expand Down Expand Up @@ -307,7 +309,7 @@ class Interpreter :
return TupleValueNode::make(values);
}

inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);

Expand Down Expand Up @@ -454,9 +456,9 @@ class Interpreter :

Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();
const auto* call_node = func->body.as<CallNode>();

if (call_node && call_node->op == Op::Get("debug")) {
if (call_node && call_node->op == debug_op_) {
auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]);

Expand Down Expand Up @@ -540,7 +542,7 @@ class Interpreter :
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) {
if (call_node->op == shape_of_op_) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
Expand Down Expand Up @@ -782,6 +784,9 @@ class Interpreter :
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
const Op& shape_of_op_;
};


Expand Down
12 changes: 8 additions & 4 deletions src/relay/pass/canonicalize_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ namespace relay {
// \endcode
class CastCanonicalizer : public ExprMutator {
public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}

Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");

Expand Down Expand Up @@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {

private:
std::unordered_map<const Node*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;


Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor

static auto& cast = Op::Get("cast");
Expr new_expr = this->VisitExpr(e);

if (const CallNode* call = e.as<CallNode>()) {
if (call->op.same_as(cast)) {
if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
CHECK(from_type);
Expand All @@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
CHECK(new_call->op.same_as(cast));
CHECK(new_call->op == cast_op_);
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
new_call->type_args);
}
Expand Down
10 changes: 8 additions & 2 deletions src/relay/pass/canonicalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
Expand All @@ -33,10 +34,11 @@ namespace relay {

class BiasAddSimplifier : public ExprMutator {
public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}

Expr VisitExpr_(const CallNode* n) {
static const Op& bias_add = Op::Get("nn.bias_add");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(bias_add)) {
if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
Expand All @@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
}
return new_n;
}

private:
// Cache the bias_add for equivalence checking.
const Op& bias_add_op_;
};

Expr CanonicalizeOps(const Expr& e) {
Expand Down
24 changes: 12 additions & 12 deletions src/relay/pass/combine_parallel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,38 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
#include "expr_subst.h"
#include "pattern_util.h"
#include "combine_parallel_op.h"


namespace tvm {
namespace relay {

BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
BranchGroupFinder::BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name),
: cached_op_(op),
fis_supported_op_(fis_supported_op),
fare_compatible_ops_(fare_compatible_ops) {
}

std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);

this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : op_roots_) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue;
if (child->op != cached_op_) continue;

auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
Expand Down Expand Up @@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
}

void BranchGroupFinder::VisitExpr_(const CallNode* n) {
const Op& op = Op::Get(op_name_);
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(op) && fis_supported_op_(n)) {
if (n->op == cached_op_ && fis_supported_op_(n)) {
op_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
Expand All @@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
}

ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
: op_name_(op_name),
: cached_op_(Op::Get(op_name)),
min_num_branches_(min_num_branches) {
}

Expr ParallelOpCombiner::Combine(const Expr& expr) {
auto groups = BranchGroupFinder(op_name_,
auto groups = BranchGroupFinder(cached_op_,
[&](const CallNode* n) {
return IsSupportedOp(n);
},
Expand Down
12 changes: 6 additions & 6 deletions src/relay/pass/combine_parallel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
public:
/*
* \brief Constructor
* \param op_name name of op to start each group
* \param op The op that indicates the start of each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
BranchGroupFinder(const std::string& op_name,
BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);

Expand All @@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
std::vector<Group> Find(const Expr& expr);

private:
/* \brief name of op to find parallel branches for */
std::string op_name_;
/* \brief Cache the op for finding parallel branches */
const Op& cached_op_;

/* \brief function to return true if op is eligible to be combined,
* false otherwise
Expand Down Expand Up @@ -205,8 +205,8 @@ class ParallelOpCombiner {
ExprSubstMap* subst_map) = 0;

private:
/* \brief name of op to be combined */
std::string op_name_;
/* \brief Cache the op to be combined */
const Op& cached_op_;

/* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_;
Expand Down
Loading