Skip to content

Commit

Permalink
HG: Commit message of changeset 6281661. (#5622)
Browse files Browse the repository at this point in the history
[Relay] Move compiler_begin/end_op to local static objects
  • Loading branch information
hlu1 authored May 22, 2020
1 parent e55f9ff commit dbb8be7
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 36 deletions.
12 changes: 6 additions & 6 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
return true;
}

// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");

bool IsAllPositiveConstant(const Expr& expr) {
// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");

// peel through a few common transform ops.
if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data;
Expand Down
15 changes: 7 additions & 8 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>

#include "pass_util.h"

namespace tvm {
namespace relay {
namespace annotate_target {

static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

const PackedFunc* make_begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
Expand Down Expand Up @@ -66,12 +65,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
std::string arg_target = "default";
const CallNode* call = arg.as<CallNode>();

if (call && call->op == compiler_begin_op) {
if (call && call->op == CompilerBeginOp()) {
// Argument is already compiler begin node meaning that this is not the first time
// running this pass, so we simply remove it and will add a new one later.
CHECK_EQ(call->args.size(), 1U);
const CallNode* end = call->args[0].as<CallNode>();
if (end->op == compiler_end_op) {
if (end->op == CompilerEndOp()) {
arg_target = end->attrs.as<CompilerAttrs>()->compiler;
}
compiler_ends.push_back(call->args[0]);
Expand Down Expand Up @@ -115,12 +114,12 @@ class AnnotateTargetRewriter : public ExprRewriter {
auto op_node = pre->op.as<OpNode>();

// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && pre->op == compiler_begin_op) {
if (op_node && pre->op == CompilerBeginOp()) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];
} else if (op_node && pre->op == compiler_end_op) {
} else if (op_node && pre->op == CompilerEndOp()) {
// Override compiler end with the new target.
CHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
Expand All @@ -131,7 +130,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == compiler_begin_op) {
if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
supported_targets.push_back(arg_target);
Expand Down
18 changes: 7 additions & 11 deletions src/relay/transforms/merge_compiler_regions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,18 @@
#include <vector>

#include "../analysis/annotated_region_set.h"
#include "pass_util.h"

namespace tvm {
namespace relay {
namespace merge_compiler_region {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

class RegionMerger : public MixedModeVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}

void VisitExpr_(const CallNode* call) final {
if (call->op == compiler_end_op) {
if (call->op == CompilerEndOp()) {
auto region = regions_->GetRegion(GetRef<Call>(call));

// Skip this region if it has been merged to the other region.
Expand All @@ -75,7 +71,7 @@ class RegionMerger : public MixedModeVisitor {
// Region inputs must be begin annotation, and the region of
// the begin annotation's argument is the parent region.
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
CHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);

// Skip this region if it has been merged.
Expand All @@ -90,7 +86,7 @@ class RegionMerger : public MixedModeVisitor {
std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
for (const auto& arg : region->GetInputs()) {
auto begin = Downcast<Call>(arg);
CHECK_EQ(begin->op, compiler_begin_op);
CHECK_EQ(begin->op, CompilerBeginOp());
auto parent_region = regions_->GetRegion(begin->args[0]);
if (parent_region.defined()) {
mergeable_regions.insert(parent_region);
Expand Down Expand Up @@ -147,9 +143,9 @@ class MergeAnnotations : public ExprRewriter {
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
if (call->op == CompilerBeginOp() && call->args[0]->IsInstance<CallNode>()) {
auto arg = Downcast<Call>(call->args[0]);
if (arg->op == compiler_end_op) {
if (arg->op == CompilerEndOp()) {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
Expand All @@ -167,7 +163,7 @@ class MergeAnnotations : public ExprRewriter {

Expr MergeCompilerRegions(const Expr& expr) {
// Create regions using the annotations.
AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp());

// Analyze the graph to explore the opportunities of merging regions.
RegionMerger merger(regions);
Expand Down
17 changes: 6 additions & 11 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,12 @@

#include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"
#include "pass_util.h"

namespace tvm {
namespace relay {
namespace partitioning {

// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

/*! \brief This struct maintains the required metadata for a region to generate a corresponding
* global function and function call. Global function will be passed to the target specific codegen
* and function call will be used in the transform Relay graph to invoke the function in runtime.
Expand Down Expand Up @@ -123,8 +119,7 @@ class Partitioner : public MixedModeMutator {
BaseFunc f_func = f.second;

// Creating regionset per function in the module.
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp());
regions_sets_[region_set] = f_func;
}
}
Expand All @@ -133,7 +128,7 @@ class Partitioner : public MixedModeMutator {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return post;
} else if (call->op == compiler_begin_op) {
} else if (call->op == CompilerBeginOp()) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);

Expand All @@ -143,7 +138,7 @@ class Partitioner : public MixedModeMutator {

// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) {
parent = parent_call->args[0];
} else {
break;
Expand Down Expand Up @@ -174,7 +169,7 @@ class Partitioner : public MixedModeMutator {
return std::move(var);
}
} else {
CHECK_EQ(call->op, compiler_end_op);
CHECK_EQ(call->op, CompilerEndOp());
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
Expand Down Expand Up @@ -420,7 +415,7 @@ IRModule FlattenTupleOutputs(IRModule module) {
TupleOutFlattener() = default;

Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (call->op == compiler_end_op) {
if (call->op == CompilerEndOp()) {
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
// Arguments of annotation ops should be 1
CHECK_EQ(call->args.size(), 1U);
Expand Down
20 changes: 20 additions & 0 deletions src/relay/transforms/pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ inline bool IsAtomic(const Expr& e) {
return e.as<VarNode>() || e.as<OpNode>() || e.as<ConstructorNode>() || e.as<GlobalVarNode>();
}

/*!
* \brief Cache the compiler_begin annotation op to reduce registry lookup overhead
* \param void
* \return compiler_begin op
*/
inline const Op& CompilerBeginOp() {
static Op op = Op::Get("annotation.compiler_begin");
return op;
}

/*!
* \brief Cache the compiler_end annotation op to reduce registry lookup overhead
* \param void
* \return compiler_end op
*/
inline const Op& CompilerEndOp() {
static Op op = Op::Get("annotation.compiler_end");
return op;
}

template <typename ConditionObjectPtr>
struct TreeNode {
typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer;
Expand Down
1 change: 1 addition & 0 deletions src/runtime/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>

#include <iostream>
#include <mutex>
#include <string>
#include <unordered_map>
Expand Down

0 comments on commit dbb8be7

Please sign in to comment.