diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index 6d246c0d10aa..a05bb8fd8da8 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -338,13 +338,13 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { return true; } -// Cache the operators that are checked recursively to reduce lookup overhead. -static const auto& expand_dims_op = Op::Get("expand_dims"); -static const auto& reshape_op = Op::Get("reshape"); -static const auto& transpose_op = Op::Get("transpose"); -static const auto& squeeze_op = Op::Get("squeeze"); - bool IsAllPositiveConstant(const Expr& expr) { + // Cache the operators that are checked recursively to reduce lookup overhead. + static const auto& expand_dims_op = Op::Get("expand_dims"); + static const auto& reshape_op = Op::Get("reshape"); + static const auto& transpose_op = Op::Get("transpose"); + static const auto& squeeze_op = Op::Get("squeeze"); + // peel through a few common transform ops. if (const auto* constant = expr.as()) { const auto& tensor = constant->data; diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 0d970050b7d1..bf2788f42f4a 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -29,13 +29,12 @@ #include #include +#include "pass_util.h" + namespace tvm { namespace relay { namespace annotate_target { -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); @@ -66,12 +65,12 @@ class AnnotateTargetRewriter : public ExprRewriter { std::string arg_target = "default"; const CallNode* call = arg.as(); - if (call && call->op == compiler_begin_op) { + if (call && call->op == CompilerBeginOp()) { // Argument is already compiler begin node meaning that this is not the first time // running this pass, so we simply remove it and will add a new one later. CHECK_EQ(call->args.size(), 1U); const CallNode* end = call->args[0].as(); - if (end->op == compiler_end_op) { + if (end->op == CompilerEndOp()) { arg_target = end->attrs.as()->compiler; } compiler_ends.push_back(call->args[0]); @@ -115,12 +114,12 @@ class AnnotateTargetRewriter : public ExprRewriter { auto op_node = pre->op.as(); // This graph has annotations, meaning that this is not the first time running this pass. - if (op_node && pre->op == compiler_begin_op) { + if (op_node && pre->op == CompilerBeginOp()) { // Bypass compiler begin due to lack of target information. It will be processed // when the following op handling arguments. CHECK_EQ(pre->args.size(), 1U); return post.as()->args[0]; - } else if (op_node && pre->op == compiler_end_op) { + } else if (op_node && pre->op == CompilerEndOp()) { // Override compiler end with the new target. CHECK_EQ(pre->args.size(), 1U); auto input_expr = post.as()->args[0]; @@ -131,7 +130,7 @@ class AnnotateTargetRewriter : public ExprRewriter { // Peek the first argument. If it is compiler begin then this node had annotated by // another target before, so we also consider that target as a supported target. const CallNode* first_arg_call = pre->args[0].as(); - if (first_arg_call && first_arg_call->op == compiler_begin_op) { + if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { std::string arg_target = first_arg_call->attrs.as()->compiler; if (arg_target != "default") { supported_targets.push_back(arg_target); diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 6fbd0d513e79..b3a606e7bc4f 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -43,22 +43,18 @@ #include #include "../analysis/annotated_region_set.h" +#include "pass_util.h" namespace tvm { namespace relay { namespace merge_compiler_region { -// Cache compiler_begin and compiler_end annotation ops for equivalence check to -// reduce registry lookup overhead. -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - class RegionMerger : public MixedModeVisitor { public: explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} void VisitExpr_(const CallNode* call) final { - if (call->op == compiler_end_op) { + if (call->op == CompilerEndOp()) { auto region = regions_->GetRegion(GetRef(call)); // Skip this region if it has been merged to the other region. @@ -75,7 +71,7 @@ class RegionMerger : public MixedModeVisitor { // Region inputs must be begin annotation, and the region of // the begin annotation's argument is the parent region. auto begin = Downcast(arg); - CHECK_EQ(begin->op, compiler_begin_op); + CHECK_EQ(begin->op, CompilerBeginOp()); auto parent_region = regions_->GetRegion(begin->args[0]); // Skip this region if it has been merged. @@ -90,7 +86,7 @@ class RegionMerger : public MixedModeVisitor { std::unordered_set mergeable_regions; for (const auto& arg : region->GetInputs()) { auto begin = Downcast(arg); - CHECK_EQ(begin->op, compiler_begin_op); + CHECK_EQ(begin->op, CompilerBeginOp()); auto parent_region = regions_->GetRegion(begin->args[0]); if (parent_region.defined()) { mergeable_regions.insert(parent_region); @@ -147,9 +143,9 @@ class MergeAnnotations : public ExprRewriter { // Merge annotations which are now internal to a region. // This happens if we see a compiler begin next to a // compiler end and they're both in the same region. - if (call->op == compiler_begin_op && call->args[0]->IsInstance()) { + if (call->op == CompilerBeginOp() && call->args[0]->IsInstance()) { auto arg = Downcast(call->args[0]); - if (arg->op == compiler_end_op) { + if (arg->op == CompilerEndOp()) { auto region1 = regions_->GetRegion(GetRef(call)); auto region2 = regions_->GetRegion(arg); if (region1 == region2) { @@ -167,7 +163,7 @@ class MergeAnnotations : public ExprRewriter { Expr MergeCompilerRegions(const Expr& expr) { // Create regions using the annotations. - AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); + AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp()); // Analyze the graph to explore the opportunities of merging regions. RegionMerger merger(regions); diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 0d25e0a02d49..9481e07d494b 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -44,16 +44,12 @@ #include "../analysis/annotated_region_set.h" #include "../backend/utils.h" +#include "pass_util.h" namespace tvm { namespace relay { namespace partitioning { -// Cache compiler_begin and compiler_end annotation ops for equivalence check to -// reduce registry lookup overhead. -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - /*! \brief This struct maintains the required metadata for a region to generate a corresponding * global function and function call. Global function will be passed to the target specific codegen * and function call will be used in the transform Relay graph to invoke the function in runtime. @@ -123,8 +119,7 @@ class Partitioner : public MixedModeMutator { BaseFunc f_func = f.second; // Creating regionset per function in the module. - auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, - partitioning::compiler_end_op); + auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp()); regions_sets_[region_set] = f_func; } } @@ -133,7 +128,7 @@ class Partitioner : public MixedModeMutator { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { return post; - } else if (call->op == compiler_begin_op) { + } else if (call->op == CompilerBeginOp()) { // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); @@ -143,7 +138,7 @@ class Partitioner : public MixedModeMutator { // Backtrace the parent to find the first ancestor node that is not a begin or end op while (const auto* parent_call = parent.as()) { - if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) { + if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) { parent = parent_call->args[0]; } else { break; @@ -174,7 +169,7 @@ class Partitioner : public MixedModeMutator { return std::move(var); } } else { - CHECK_EQ(call->op, compiler_end_op); + CHECK_EQ(call->op, CompilerEndOp()); // The annotation node is inserted on edge so it must have only one // argument. CHECK_EQ(call->args.size(), 1U); @@ -420,7 +415,7 @@ IRModule FlattenTupleOutputs(IRModule module) { TupleOutFlattener() = default; Expr Rewrite_(const CallNode* call, const Expr& post) final { - if (call->op == compiler_end_op) { + if (call->op == CompilerEndOp()) { std::string target = call->attrs.as()->compiler; // Arguments of annotation ops should be 1 CHECK_EQ(call->args.size(), 1U); diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index cbdd4b4a626b..35bbb234dbc1 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -115,6 +115,26 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } +/*! + * \brief Cache the compiler_begin annotation op to reduce registry lookup overhead + * \param void + * \return compiler_begin op + */ +inline const Op& CompilerBeginOp() { + static Op op = Op::Get("annotation.compiler_begin"); + return op; +} + +/*! + * \brief Cache the compiler_end annotation op to reduce registry lookup overhead + * \param void + * \return compiler_end op + */ +inline const Op& CompilerEndOp() { + static Op op = Op::Get("annotation.compiler_end"); + return op; +} + template struct TreeNode { typedef std::shared_ptr> pointer; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index c8e6671d5ee6..e5d5ca9bebf5 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include