Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. #4161

Merged
merged 2 commits into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions golang/src/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ var KTVMType = int32(C.kTVMType)
var KTVMContext = int32(C.kTVMContext)
// KArrayHandle is golang type code for TVM kArrayHandle.
var KArrayHandle = int32(C.kArrayHandle)
// KNodeHandle is golang type code for TVM kNodeHandle.
var KNodeHandle = int32(C.kNodeHandle)
// KObjectHandle is golang type code for TVM kObjectHandle.
var KObjectHandle = int32(C.kObjectHandle)
// KModuleHandle is gonag type code for TVM kModuleHandle.
var KModuleHandle = int32(C.kModuleHandle)
// KFuncHandle is gonalg type code for TVM kFuncHandle.
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class EnvFunc : public NodeRef {
explicit EnvFunc(NodePtr<Node> n) : NodeRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
return static_cast<const EnvFuncNode*>(get());
}
/*!
* \brief Invoke the function.
Expand Down Expand Up @@ -124,19 +124,19 @@ class TypedEnvFunc<R(Args...)> : public NodeRef {
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(NodePtr<Node> n) : NodeRef(n) {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
* \return reference to self.
*/
TSelf& operator=(const EnvFunc& other) {
this->node_ = other.node_;
ObjectRef::operator=(other);
return *this;
}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
return static_cast<const EnvFuncNode*>(get());
}
/*!
* \brief Invoke the function.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class IntSet : public NodeRef {
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(NodePtr<Node> n) : NodeRef(n) {}
explicit IntSet(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -692,7 +692,7 @@ Array<Expr> DetectClipBound(const Expr& e,

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
return static_cast<const IntSetNode*>(get());
}
} // namespace arith
} // namespace tvm
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class AttrsEqual {
return lhs == rhs;
}
// node comparator
TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const;
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

protected:
friend class AttrsEqualHandler;
Expand Down Expand Up @@ -203,7 +203,7 @@ class AttrsHash {
(static_cast<int>(value.bits()) << 8) |
(static_cast<int>(value.lanes()) << 16));
}
TVM_DLL size_t operator()(const NodeRef& value) const;
TVM_DLL size_t operator()(const ObjectRef& value) const;

private:
friend class AttrsHashHandler;
Expand Down Expand Up @@ -260,7 +260,7 @@ class BaseAttrsNode : public Node {
* \return The comparison result.
*/
TVM_DLL virtual bool ContentEqual(
const Node* other, AttrsEqual equal) const = 0;
const Object* other, AttrsEqual equal) const = 0;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
Expand Down Expand Up @@ -290,7 +290,7 @@ class Attrs : public NodeRef {
private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(node_.get());
return static_cast<const BaseAttrsNode*>(get());
}
};

Expand All @@ -315,7 +315,7 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other, AttrsEqual equal) const final;
bool ContentEqual(const Object* other, AttrsEqual equal) const final;
size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
Expand Down Expand Up @@ -369,7 +369,7 @@ class AttrsEqualVisitor {
public:
bool result_{true};
// constructor
AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal)
AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
: lhs_(lhs), rhs_(rhs), equal_(equal) {
}
template<typename T>
Expand All @@ -387,8 +387,8 @@ class AttrsEqualVisitor {
}

private:
const Node* lhs_;
const Node* rhs_;
const Object* lhs_;
const Object* rhs_;
const AttrsEqual& equal_;
};

Expand Down Expand Up @@ -488,7 +488,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->type_key();
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
}
}
Expand Down Expand Up @@ -521,7 +521,7 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
} else if (const ir::UIntImm* op = expr.as<ir::UIntImm>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->type_key();
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
}
}
}
Expand Down Expand Up @@ -827,7 +827,7 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}

bool ContentEqual(const Node* other, AttrsEqual equal) const final {
bool ContentEqual(const Object* other, AttrsEqual equal) const final {
DerivedType* pself = self();
if (pself == other) return true;
if (other == nullptr) return false;
Expand All @@ -839,7 +839,7 @@ class AttrsNode : public BaseAttrsNode {

size_t ContentHash(AttrsHash hasher) const final {
::tvm::detail::AttrsHashVisitor visitor(hasher);
visitor.result_ = std::hash<std::string>()(this->type_key());
visitor.result_ = this->GetTypeKeyHash();
self()->__VisitAttrs__(visitor);
return visitor.result_;
}
Expand Down
16 changes: 9 additions & 7 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ using ::tvm::AttrVisitor;
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
return static_cast<const NodeName*>(data_.get()); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get()

} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;
Expand All @@ -75,12 +76,12 @@ using ::tvm::AttrVisitor;
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(node_ != nullptr); \
if (!node_.unique()) { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
NodePtr<Node>(std::move(n)).swap(node_); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(node_.get()); \
return static_cast<NodeName*>(data_.get()); \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_mutable()

}

/*! \brief Macro to make it easy to define node ref type given node */
Expand Down Expand Up @@ -160,7 +161,7 @@ std::string SaveJSON(const NodeRef& node);
*
* \return The shared_ptr of the Node.
*/
NodePtr<Node> LoadJSON_(std::string json_str);
ObjectPtr<Object> LoadJSON_(std::string json_str);

/*!
* \brief Load the node from json string.
Expand Down Expand Up @@ -233,6 +234,7 @@ struct NodeFactoryReg {
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ enum BufferType : int {
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(NodePtr<Node> n) : NodeRef(n) {}
explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
Expand Down Expand Up @@ -171,7 +171,7 @@ class BufferNode : public Node {
};

inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
return static_cast<const BufferNode*>(get());
}

/*!
Expand Down
16 changes: 8 additions & 8 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TargetNode : public Node {
class Target : public NodeRef {
public:
Target() {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {}
explicit Target(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
Expand All @@ -110,7 +110,7 @@ class Target : public NodeRef {
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);

const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
return static_cast<const TargetNode*>(get());
}

using ContainerType = TargetNode;
Expand Down Expand Up @@ -256,12 +256,12 @@ class BuildConfigNode : public Node {
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}
explicit BuildConfig(ObjectPtr<Object> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
return static_cast<const BuildConfigNode*>(get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
return static_cast<BuildConfigNode*>(get_mutable());
}
/*!
* \brief Construct a BuildConfig containing a empty build config node.
Expand Down Expand Up @@ -371,7 +371,7 @@ class GenericFuncNode;
class GenericFunc : public NodeRef {
public:
GenericFunc() {}
explicit GenericFunc(NodePtr<Node> n) : NodeRef(n) {}
explicit GenericFunc(ObjectPtr<Object> n) : NodeRef(n) {}

/*!
* \brief Set the default function implementaiton.
Expand Down Expand Up @@ -478,10 +478,10 @@ class GenericFuncNode : public Node {
};

inline GenericFuncNode* GenericFunc::operator->() {
return static_cast<GenericFuncNode*>(node_.get());
return static_cast<GenericFuncNode*>(get_mutable());
}

#define TVM_GENERIC_FUNC_REG_VAR_DEF \
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM

/*!
Expand Down
98 changes: 0 additions & 98 deletions include/tvm/c_dsl_api.h

This file was deleted.

4 changes: 2 additions & 2 deletions include/tvm/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Channel : public NodeRef {
public:
/*! \brief default constructor */
Channel() {}
explicit Channel(NodePtr<Node> n) : NodeRef(n) {}
explicit Channel(ObjectPtr<Object> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -67,7 +67,7 @@ struct ChannelNode : public Node {

// Inline implementations
inline const ChannelNode* Channel::operator->() const {
return static_cast<const ChannelNode*>(node_.get());
return static_cast<const ChannelNode*>(get());
}
} // namespace tvm
#endif // TVM_CHANNEL_H_
Loading