Skip to content

Commit

Permalink
[REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. (#4161)
Browse files Browse the repository at this point in the history
* [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol.

This PR removes the original node system, and make node as a subclass of Object.
This is a major refactor towards a better unified runtime object system.

List of changes in the refactor:

- We now hide data_ field, use Downcast explicitly to get a sub-class object.
- Removed the node system FFI in python.
- Removed the node C API, instead use PackedFunc for list and get attrs.
- Change relay::Op::set_attr_type_key(attr_key_name) to relay::Op::set_attr_type<AttrType>().
  - This change was necessary because of the new Object registration mechanism.
  - Subsequent changes to the op registrations
  - The change revealed a few previous problems that is now fixed.
- Patched up a few missing node type registration.
  - Now we will raise an error if we register object that is not registered.
- The original node.h and container.h are kept in the same location.
- Calling convention: kObjectHandle now equals the old kNodeHandle, kNodeHandle is removed.
- IRFunctor now dispatches on ObjectRef.
- Update to the new type checking API: is_type, derived_from are replaced by IsInstance.
- Removed .hash member function, instead use C++ convention hasher functors.

* Address review comments
  • Loading branch information
tqchen authored Oct 21, 2019
1 parent 97ea31c commit 7895adb
Show file tree
Hide file tree
Showing 185 changed files with 1,442 additions and 2,387 deletions.
4 changes: 2 additions & 2 deletions golang/src/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ var KTVMType = int32(C.kTVMType)
var KTVMContext = int32(C.kTVMContext)
// KArrayHandle is golang type code for TVM kArrayHandle.
var KArrayHandle = int32(C.kArrayHandle)
// KNodeHandle is golang type code for TVM kNodeHandle.
var KNodeHandle = int32(C.kNodeHandle)
// KObjectHandle is golang type code for TVM kObjectHandle.
var KObjectHandle = int32(C.kObjectHandle)
// KModuleHandle is gonag type code for TVM kModuleHandle.
var KModuleHandle = int32(C.kModuleHandle)
// KFuncHandle is gonalg type code for TVM kFuncHandle.
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class EnvFunc : public NodeRef {
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
return static_cast<const EnvFuncNode*>(get());
}
/*!
* \brief Invoke the function.
Expand Down Expand Up @@ -124,19 +124,19 @@ class TypedEnvFunc<R(Args...)> : public NodeRef {
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
* \return reference to self.
*/
TSelf& operator=(const EnvFunc& other) {
this->node_ = other.node_;
ObjectRef::operator=(other);
return *this;
}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
return static_cast<const EnvFuncNode*>(get());
}
/*!
* \brief Invoke the function.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class IntSet : public NodeRef {
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
explicit IntSet(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -692,7 +692,7 @@ Array<Expr> DetectClipBound(const Expr& e,

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
return static_cast<const IntSetNode*>(get());
}
} // namespace arith
} // namespace tvm
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class AttrsEqual {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

protected:
friend class AttrsEqualHandler;
Expand Down Expand Up @@ -203,7 +203,7 @@ class AttrsHash {
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const NodeRef& value) const;
TVM_DLL size_t operator()(const ObjectRef& value) const;

private:
friend class AttrsHashHandler;
Expand Down Expand Up @@ -260,7 +260,7 @@ class BaseAttrsNode : public Node {
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(
const Node* other, AttrsEqual equal) const = 0;
const Object* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
Expand Down Expand Up @@ -290,7 +290,7 @@ class Attrs : public NodeRef {
private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(node_.get());
return static_cast<const BaseAttrsNode*>(get());
}
};

Expand All @@ -315,7 +315,7 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other, AttrsEqual equal) const final;
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
Expand Down Expand Up @@ -369,7 +369,7 @@ class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
Expand All @@ -387,8 +387,8 @@ class AttrsEqualVisitor {
}

private:
const Node* lhs_;
const Node* rhs_;
const Object* lhs_;
const Object* rhs_;
const AttrsEqual& equal_;
};

Expand Down Expand Up @@ -488,7 +488,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->type_key();
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
}
}
Expand Down Expand Up @@ -521,7 +521,7 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->type_key();
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
}
}
}
Expand Down Expand Up @@ -827,7 +827,7 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}

bool ContentEqual(const Node* other, AttrsEqual equal) const final {
bool ContentEqual(const Object* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
Expand All @@ -839,7 +839,7 @@ class AttrsNode : public BaseAttrsNode {

size_t ContentHash(AttrsHash hasher) const final {
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key());
visitor.result_ = this->GetTypeKeyHash();
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
Expand Down
16 changes: 9 additions & 7 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ using ::tvm::AttrVisitor;
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
Expand All @@ -75,12 +76,12 @@ using ::tvm::AttrVisitor;
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(node_.get()); \
return static_cast<NodeName*>(data_.get()); \
}

/*! \brief Macro to make it easy to define node ref type given node */
Expand Down Expand Up @@ -160,7 +161,7 @@ std::string SaveJSON(const NodeRef& node);
*
* \return The shared_ptr of the Node.
*/
NodePtr<Node> LoadJSON_(std::string json_str);
ObjectPtr<Object> LoadJSON_(std::string json_str);

/*!
* \brief Load the node from json string.
Expand Down Expand Up @@ -233,6 +234,7 @@ struct NodeFactoryReg {
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ enum BufferType : int {
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(NodePtr<Node> n) : NodeRef(n) {}
explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
Expand Down Expand Up @@ -171,7 +171,7 @@ class BufferNode : public Node {
};

inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
return static_cast<const BufferNode*>(get());
}

/*!
Expand Down
16 changes: 8 additions & 8 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TargetNode : public Node {
class Target : public NodeRef {
public:
Target() {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {}
explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
Expand All @@ -110,7 +110,7 @@ class Target : public NodeRef {
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);

const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
return static_cast<const TargetNode*>(get());
}

using ContainerType = TargetNode;
Expand Down Expand Up @@ -256,12 +256,12 @@ class BuildConfigNode : public Node {
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
return static_cast<const BuildConfigNode*>(get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
return static_cast<BuildConfigNode*>(get_mutable());
}
/*!
* \brief Construct a BuildConfig containing a empty build config node.
Expand Down Expand Up @@ -371,7 +371,7 @@ class GenericFuncNode;
class GenericFunc : public NodeRef {
public:
GenericFunc() {}
explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}

/*!
* \brief Set the default function implementaiton.
Expand Down Expand Up @@ -478,10 +478,10 @@ class GenericFuncNode : public Node {
};

inline GenericFuncNode* GenericFunc::operator->() {
return static_cast<GenericFuncNode*>(node_.get());
return static_cast<GenericFuncNode*>(get_mutable());
}

#define TVM_GENERIC_FUNC_REG_VAR_DEF \
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM

/*!
Expand Down
98 changes: 0 additions & 98 deletions include/tvm/c_dsl_api.h

This file was deleted.

4 changes: 2 additions & 2 deletions include/tvm/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Channel : public NodeRef {
public:
/*! \brief default constructor */
Channel() {}
explicit Channel(NodePtr<Node> n) : NodeRef(n) {}
explicit Channel(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -67,7 +67,7 @@ struct ChannelNode : public Node {

// Inline implementations
inline const ChannelNode* Channel::operator->() const {
return static_cast<const ChannelNode*>(node_.get());
return static_cast<const ChannelNode*>(get());
}
} // namespace tvm
#endif // TVM_CHANNEL_H_
Loading

0 comments on commit 7895adb

Please sign in to comment.