From 5ff99b66cbd695fddde03469c21c4948b91097db Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 23 May 2019 10:17:03 -0700 Subject: [PATCH] [NODE] Macro to define NodeRef methods, constructor style example (#3224) --- include/tvm/arithmetic.h | 44 ++++++++++++++-------- include/tvm/base.h | 61 ++++++++++++++++++------------- src/api/api_arith.cc | 17 ++++++--- src/arithmetic/const_int_bound.cc | 10 ++--- src/arithmetic/modular_set.cc | 9 +++-- 5 files changed, 86 insertions(+), 55 deletions(-) diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 9a8d9d372956..6eec767611e0 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -48,11 +48,7 @@ namespace arith { // Forward declare Analyzer class Analyzer; -/*! - * \brief reference class to ConstIntBoundNode - * \sa ConstIntBoundNode - */ -class ConstIntBound; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. @@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node { v->Visit("max_value", &max_value); } - TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value); - /*! \brief Number to represent +inf */ static const constexpr int64_t kPosInf = std::numeric_limits::max(); /*! @@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); }; -TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode); +/*! + * \brief reference class to ConstIntBoundNode + * \sa ConstIntBoundNode + */ +class ConstIntBound : public NodeRef { + public: + /*! + * \brief constructor by fields. + * \param min_value The mininum value. + * \param max_value The maximum value. + */ + TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value); + + static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; + TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode); +}; /*! * \brief Analyzer to get constant integer bound over expression. @@ -133,11 +143,6 @@ class ConstIntBoundAnalyzer { Impl* impl_; }; -/*! - * \brief reference of ModularSetNode - * \sa ModularSetNode - */ -class ModularSet; /*! * \brief Range of a linear integer function. * Use to do specify the possible index values. @@ -162,13 +167,20 @@ class ModularSetNode : public Node { v->Visit("base", &base); } - TVM_DLL static ModularSet make(int64_t coeff, int64_t base); - static constexpr const char* _type_key = "arith.ModularSet"; TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); }; -TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode); +/*! + * \brief reference of ModularSetNode + * \sa ModularSetNode + */ +class ModularSet : public NodeRef { + public: + TVM_DLL ModularSet(int64_t coeff, int64_t base); + + TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode); +}; /*! * \brief Analyzer to get modular information over expression. diff --git a/include/tvm/base.h b/include/tvm/base.h index ae2d91ff8523..049a427ffce8 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -39,21 +39,24 @@ using ::tvm::Node; using ::tvm::NodeRef; using ::tvm::AttrVisitor; -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ - }; \ +/*! + * \brief Macro to define common node ref methods. + * \param TypeName The name of the NodeRef. + * \param BaseTypeName The Base type. + * \param NodeName The node container type. + */ +#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ + TypeName() {} \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() const { return this->defined(); } \ + using ContainerType = NodeName; /*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. + * \brief Macro to define CopyOnWrite function in a NodeRef. + * \param NodeName The Type of the Node. * * CopyOnWrite will generate a unique copy of the internal node. * The node will be copied if it is referenced by multiple places. @@ -70,25 +73,33 @@ using ::tvm::AttrVisitor; * * \endcode */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - inline NodeName* CopyOnWrite() { \ +#define TVM_DEFINE_NODE_REF_COW(NodeName) \ + NodeName* CopyOnWrite() { \ CHECK(node_ != nullptr); \ if (!node_.unique()) { \ NodePtr n = make_node(*(operator->())); \ NodePtr(std::move(n)).swap(node_); \ } \ return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ - }; + } +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ::tvm::NodeRef { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ + }; \ + +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_NODE_REF_COW(NodeName); \ + }; /*! * \brief save the node as well as all the node it depends on as json. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index fce73aabf6a7..55a706420f06 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound") TVM_REGISTER_API("arith.DomainTouched") .set_body_typed(DomainTouched); - TVM_REGISTER_API("_IntervalSetGetMin") .set_body_method(&IntSet::min); @@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing") TVM_REGISTER_API("_IntSetIsEverything") .set_body_method(&IntSet::is_everything); +ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { + return ConstIntBound(min_value, max_value); +} + TVM_REGISTER_API("arith._make_ConstIntBound") -.set_body_typed(ConstIntBoundNode::make); +.set_body_typed(MakeConstIntBound); + +ModularSet MakeModularSet(int64_t coeff, int64_t base) { + return ModularSet(coeff, base); +} TVM_REGISTER_API("arith._make_ModularSet") -.set_body_typed(ModularSetNode::make); +.set_body_typed(MakeModularSet); TVM_REGISTER_API("arith._CreateAnalyzer") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index bfd06c8ba255..72b85084d59d 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -34,12 +34,12 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); -ConstIntBound ConstIntBoundNode::make( +ConstIntBound::ConstIntBound( int64_t min_value, int64_t max_value) { auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - return ConstIntBound(node); + node_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl : std::vector additional_info_; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. - static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; - static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; + static const constexpr int64_t kPosInf = ConstIntBound::kPosInf; static_assert(-kNegInf == kPosInf, "invariant of inf"); // internal helper functions /*! @@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl : ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { Entry ret = impl_->VisitExpr(expr); - return ConstIntBoundNode::make(ret.min_value, ret.max_value); + return ConstIntBound(ret.min_value, ret.max_value); } void ConstIntBoundAnalyzer::Update(const Var& var, diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 7701e04844fa..57e82943b84c 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -35,11 +35,12 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ModularSetNode); -ModularSet ModularSetNode::make(int64_t coeff, int64_t base) { +ModularSet::ModularSet(int64_t coeff, int64_t base) { auto node = make_node(); node->coeff = coeff; node->base = base; - return ModularSet(node); + // finish construction. + node_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl : * \return Bound that represent everything dtype can represent. */ static Entry Nothing() { - return Entry(0, 1); + return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { Entry ret = impl_->VisitExpr(expr); - return ModularSetNode::make(ret.coeff, ret.base); + return ModularSet(ret.coeff, ret.base); } void ModularSetAnalyzer::Update(const Var& var,