Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Move compiler_begin/end_op to local static objects #5622

Merged
merged 1 commit into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
return true;
}

// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");

bool IsAllPositiveConstant(const Expr& expr) {
// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");

Comment on lines +342 to +347
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change these static ops from non-local to local static objects as well.

// peel through a few common transform ops.
if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data;
Expand Down
15 changes: 7 additions & 8 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>

#include "pass_util.h"

namespace tvm {
namespace relay {
namespace annotate_target {

static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
Expand Down Expand Up @@ -66,12 +65,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
std::string arg_target = "default";
const CallNode* call = arg.as<CallNode>();

if (call && call->op == compiler_begin_op) {
if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
CHECK_EQ(call->args.size(), 1U);
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == compiler_end_op) {
if (end->op == CompilerEndOp()) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
Expand Down Expand Up @@ -115,12 +114,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
auto op_node = pre->op.as<OpNode>();

// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && pre->op == compiler_begin_op) {
if (op_node && pre->op == CompilerBeginOp()) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];
} else if (op_node && pre->op == compiler_end_op) {
} else if (op_node && pre->op == CompilerEndOp()) {
// Override compiler end with the new target.
CHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
Expand All @@ -131,7 +130,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == compiler_begin_op) {
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
Expand Down
18 changes: 7 additions & 11 deletions src/relay/transforms/merge_compiler_regions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,18 @@
#include <vector>

#include "../analysis/annotated_region_set.h"
#include "pass_util.h"

namespace tvm {
namespace relay {
namespace merge_compiler_region {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

class RegionMerger : public MixedModeVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}

void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
if (call->op == CompilerEndOp()) {
auto region = regions_->GetRegion(GetRef<Call>(call));

// Skip this region if it has been merged to the other region.
Expand All @@ -75,7 +71,7 @@ class RegionMerger : public MixedModeVisitor {
// Region inputs must be begin annotation, and the region of
// the begin annotation's argument is the parent region.
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
CHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);

// Skip this region if it has been merged.
Expand All @@ -90,7 +86,7 @@ class RegionMerger : public MixedModeVisitor {
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
CHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);
if (parent_region.defined()) {
mergeable_regions.insert(parent_region);
Expand Down Expand Up @@ -147,9 +143,9 @@ class MergeAnnotations : public ExprRewriter {
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
if (call->op == CompilerBeginOp() && call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]);
if (arg->op == compiler_end_op) {
if (arg->op == CompilerEndOp()) {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
Expand All @@ -167,7 +163,7 @@ class MergeAnnotations : public ExprRewriter {

Expr MergeCompilerRegions(const Expr& expr) {
// Create regions using the annotations.
AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp());

// Analyze the graph to explore the opportunities of merging regions.
RegionMerger merger(regions);
Expand Down
17 changes: 6 additions & 11 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,12 @@

#include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"
#include "pass_util.h"

namespace tvm {
namespace relay {
namespace partitioning {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

/*! \brief This struct maintains the required metadata for a region to generate a corresponding
* global function and function call. Global function will be passed to the target specific codegen
* and function call will be used in the transform Relay graph to invoke the function in runtime.
Expand Down Expand Up @@ -123,8 +119,7 @@ class Partitioner : public MixedModeMutator {
BaseFunc f_func = f.second;

// Creating regionset per function in the module.
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp());
regions_sets_[region_set] = f_func;
}
}
Expand All @@ -133,7 +128,7 @@ class Partitioner : public MixedModeMutator {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return post;
} else if (call->op == compiler_begin_op) {
} else if (call->op == CompilerBeginOp()) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);

Expand All @@ -143,7 +138,7 @@ class Partitioner : public MixedModeMutator {

// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) {
parent = parent_call->args[0];
} else {
break;
Expand Down Expand Up @@ -174,7 +169,7 @@ class Partitioner : public MixedModeMutator {
return std::move(var);
}
} else {
CHECK_EQ(call->op, compiler_end_op);
CHECK_EQ(call->op, CompilerEndOp());
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
Expand Down Expand Up @@ -420,7 +415,7 @@ IRModule FlattenTupleOutputs(IRModule module) {
TupleOutFlattener() = default;

Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (call->op == compiler_end_op) {
if (call->op == CompilerEndOp()) {
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Arguments of annotation ops should be 1
CHECK_EQ(call->args.size(), 1U);
Expand Down
20 changes: 20 additions & 0 deletions src/relay/transforms/pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}

/*!
* \brief Cache the compiler_begin annotation op to reduce registry lookup overhead
* \param void
* \return compiler_begin op
*/
inline const Op& CompilerBeginOp() {
static Op op = Op::Get("annotation.compiler_begin");
return op;
}

/*!
* \brief Cache the compiler_end annotation op to reduce registry lookup overhead
* \param void
* \return compiler_end op
*/
inline const Op& CompilerEndOp() {
static Op op = Op::Get("annotation.compiler_end");
return op;
}

template <typename ConditionObjectPtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer;
Expand Down
1 change: 1 addition & 0 deletions src/runtime/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>

#include <iostream>
tqchen marked this conversation as resolved.
Show resolved Hide resolved
#include <mutex>
#include <string>
#include <unordered_map>
Expand Down