Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NODE] Macro to define NodeRef methods, constructor style example #3224

Merged
merged 1 commit into from
May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ namespace arith {

// Forward declare Analyzer
class Analyzer;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound;

/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
Expand All @@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node {
v->Visit("max_value", &max_value);
}

TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);

/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
Expand All @@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
};

TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound : public NodeRef {
public:
/*!
* \brief constructor by fields.
* \param min_value The mininum value.
* \param max_value The maximum value.
*/
TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);

static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode);
};

/*!
* \brief Analyzer to get constant integer bound over expression.
Expand Down Expand Up @@ -133,11 +143,6 @@ class ConstIntBoundAnalyzer {
Impl* impl_;
};

/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet;
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
Expand All @@ -162,13 +167,20 @@ class ModularSetNode : public Node {
v->Visit("base", &base);
}

TVM_DLL static ModularSet make(int64_t coeff, int64_t base);

static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
};

TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);
/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet : public NodeRef {
public:
TVM_DLL ModularSet(int64_t coeff, int64_t base);

TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode);
};

/*!
* \brief Analyzer to get modular information over expression.
Expand Down
61 changes: 36 additions & 25 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,24 @@ using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;

/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
}; \
/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;

/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
Expand All @@ -70,25 +73,33 @@ using ::tvm::AttrVisitor;
*
* \endcode
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
inline NodeName* CopyOnWrite() { \
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
} \
return static_cast<NodeName*>(node_.get()); \
} \
using ContainerType = NodeName; \
};
}

/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add BaseType to this macro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I think we will move to avoid use TVM_DEFINE_NODE_REF and directly use class declarations plus TVM_DEFINE_NODE_REF_METHODS.

class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \

/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};

/*!
* \brief save the node as well as all the node it depends on as json.
Expand Down
17 changes: 12 additions & 5 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -58,7 +58,6 @@ TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_API("arith.DomainTouched")
.set_body_typed(DomainTouched);


TVM_REGISTER_API("_IntervalSetGetMin")
.set_body_method(&IntSet::min);

Expand All @@ -71,11 +70,19 @@ TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_API("_IntSetIsEverything")
.set_body_method(&IntSet::is_everything);

ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}

TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body_typed(ConstIntBoundNode::make);
.set_body_typed(MakeConstIntBound);

ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}

TVM_REGISTER_API("arith._make_ModularSet")
.set_body_typed(ModularSetNode::make);
.set_body_typed(MakeModularSet);

TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
10 changes: 5 additions & 5 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ using namespace ir;

TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);

ConstIntBound ConstIntBoundNode::make(
ConstIntBound::ConstIntBound(
int64_t min_value, int64_t max_value) {
auto node = make_node<ConstIntBoundNode>();
node->min_value = min_value;
node->max_value = max_value;
return ConstIntBound(node);
node_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
Expand Down Expand Up @@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl :
std::vector<BoundInfo> additional_info_;
// constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
static const constexpr int64_t kPosInf = ConstIntBound::kPosInf;
static_assert(-kNegInf == kPosInf, "invariant of inf");
// internal helper functions
/*!
Expand Down Expand Up @@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl :

ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr);
return ConstIntBoundNode::make(ret.min_value, ret.max_value);
return ConstIntBound(ret.min_value, ret.max_value);
}

void ConstIntBoundAnalyzer::Update(const Var& var,
Expand Down
9 changes: 5 additions & 4 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ using namespace ir;

TVM_REGISTER_NODE_TYPE(ModularSetNode);

ModularSet ModularSetNode::make(int64_t coeff, int64_t base) {
ModularSet::ModularSet(int64_t coeff, int64_t base) {
auto node = make_node<ModularSetNode>();
node->coeff = coeff;
node->base = base;
return ModularSet(node);
// finish construction.
node_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
Expand Down Expand Up @@ -366,13 +367,13 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent.
*/
static Entry Nothing() {
return Entry(0, 1);
return Entry(0, 1);
}
};

ModularSet ModularSetAnalyzer::operator()(const Expr& expr) {
Entry ret = impl_->VisitExpr(expr);
return ModularSetNode::make(ret.coeff, ret.base);
return ModularSet(ret.coeff, ret.base);
}

void ModularSetAnalyzer::Update(const Var& var,
Expand Down