diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index e12d841519ca7..dbd0972935932 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -79,7 +79,7 @@ class EnvFunc : public NodeRef { explicit EnvFunc(NodePtr n) : NodeRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. @@ -124,19 +124,19 @@ class TypedEnvFunc : public NodeRef { /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc; TypedEnvFunc() {} - explicit TypedEnvFunc(NodePtr n) : NodeRef(n) {} + explicit TypedEnvFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. * \return reference to self. */ TSelf& operator=(const EnvFunc& other) { - this->node_ = other.node_; + ObjectRef::operator=(other); return *this; } /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 8be1c36048133..e81fa0afd2546 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -362,7 +362,7 @@ class IntSet : public NodeRef { /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(NodePtr n) : NodeRef(n) {} + explicit IntSet(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -692,7 +692,7 @@ Array DetectClipBound(const Expr& e, // implementation inline const IntSetNode* IntSet::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace arith } // namespace tvm diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 3b64d1f961e26..fb8927a756132 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -163,7 +163,7 @@ class AttrsEqual { return lhs == rhs; } // node comparator - TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const; + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; protected: friend class AttrsEqualHandler; @@ -203,7 +203,7 @@ class AttrsHash { (static_cast(value.bits()) << 8) | (static_cast(value.lanes()) << 16)); } - TVM_DLL size_t operator()(const NodeRef& value) const; + TVM_DLL size_t operator()(const ObjectRef& value) const; private: friend class AttrsHashHandler; @@ -260,7 +260,7 @@ class BaseAttrsNode : public Node { * \return The comparison result. */ TVM_DLL virtual bool ContentEqual( - const Node* other, AttrsEqual equal) const = 0; + const Object* other, AttrsEqual equal) const = 0; /*! * \brief Content aware hash. * \param hasher The hasher to run the hash. @@ -290,7 +290,7 @@ class Attrs : public NodeRef { private: /*! \return the internal attribute node */ const BaseAttrsNode* ptr() const { - return static_cast(node_.get()); + return static_cast(get()); } }; @@ -315,7 +315,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; - bool ContentEqual(const Node* other, AttrsEqual equal) const final; + bool ContentEqual(const Object* other, AttrsEqual equal) const final; size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; @@ -369,7 +369,7 @@ class AttrsEqualVisitor { public: bool result_{true}; // constructor - AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal) + AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal) : lhs_(lhs), rhs_(rhs), equal_(equal) { } template @@ -387,8 +387,8 @@ class AttrsEqualVisitor { } private: - const Node* lhs_; - const Node* rhs_; + const Object* lhs_; + const Object* rhs_; const AttrsEqual& equal_; }; @@ -488,7 +488,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect int value, but get " << expr->type_key(); + LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } } } @@ -521,7 +521,7 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect float value, but get " << expr->type_key(); + LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } } } @@ -827,7 +827,7 @@ class AttrsNode : public BaseAttrsNode { return visitor.fields_; } - bool ContentEqual(const Node* other, AttrsEqual equal) const final { + bool ContentEqual(const Object* other, AttrsEqual equal) const final { DerivedType* pself = self(); if (pself == other) return true; if (other == nullptr) return false; @@ -839,7 +839,7 @@ class AttrsNode : public BaseAttrsNode { size_t ContentHash(AttrsHash hasher) const final { ::tvm::detail::AttrsHashVisitor visitor(hasher); - visitor.result_ = std::hash()(this->type_key()); + visitor.result_ = this->GetTypeKeyHash(); self()->__VisitAttrs__(visitor); return visitor.result_; } diff --git a/include/tvm/base.h b/include/tvm/base.h index f358f7f5d447c..a42de10abef22 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -47,9 +47,10 @@ using ::tvm::AttrVisitor; */ #define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : BaseTypeName(n) {} \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(data_.get()); \ } \ operator bool() const { return this->defined(); } \ using ContainerType = NodeName; @@ -75,12 +76,12 @@ using ::tvm::AttrVisitor; */ #define TVM_DEFINE_NODE_REF_COW(NodeName) \ NodeName* CopyOnWrite() { \ - CHECK(node_ != nullptr); \ - if (!node_.unique()) { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ NodePtr n = make_node(*(operator->())); \ - NodePtr(std::move(n)).swap(node_); \ + ObjectPtr(std::move(n)).swap(data_); \ } \ - return static_cast(node_.get()); \ + return static_cast(data_.get()); \ } /*! \brief Macro to make it easy to define node ref type given node */ @@ -160,7 +161,7 @@ std::string SaveJSON(const NodeRef& node); * * \return The shared_ptr of the Node. */ -NodePtr LoadJSON_(std::string json_str); +ObjectPtr LoadJSON_(std::string json_str); /*! * \brief Load the node from json string. @@ -233,6 +234,7 @@ struct NodeFactoryReg { * \note This is necessary to enable serialization of the Node. */ #define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ .set_creator([](const std::string&) { return ::tvm::make_node(); }) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 1233e9b0b89b8..f18ed9206db36 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -51,7 +51,7 @@ enum BufferType : int { class Buffer : public NodeRef { public: Buffer() {} - explicit Buffer(NodePtr n) : NodeRef(n) {} + explicit Buffer(ObjectPtr n) : NodeRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. @@ -171,7 +171,7 @@ class BufferNode : public Node { }; inline const BufferNode* Buffer::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 1d57d82e66c6d..c985fbe175460 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -93,7 +93,7 @@ class TargetNode : public Node { class Target : public NodeRef { public: Target() {} - explicit Target(NodePtr n) : NodeRef(n) {} + explicit Target(ObjectPtr n) : NodeRef(n) {} /*! * \brief Create a Target given a string * \param target_str the string to parse @@ -110,7 +110,7 @@ class Target : public NodeRef { TVM_DLL static tvm::Target Current(bool allow_not_defined = true); const TargetNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = TargetNode; @@ -256,12 +256,12 @@ class BuildConfigNode : public Node { class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} - explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} + explicit BuildConfig(ObjectPtr n) : NodeRef(n) {} const BuildConfigNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } BuildConfigNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! * \brief Construct a BuildConfig containing a empty build config node. @@ -371,7 +371,7 @@ class GenericFuncNode; class GenericFunc : public NodeRef { public: GenericFunc() {} - explicit GenericFunc(NodePtr n) : NodeRef(n) {} + explicit GenericFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Set the default function implementaiton. @@ -478,10 +478,10 @@ class GenericFuncNode : public Node { }; inline GenericFuncNode* GenericFunc::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ +#define TVM_GENERIC_FUNC_REG_VAR_DEF \ static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM /*! diff --git a/include/tvm/c_dsl_api.h b/include/tvm/c_dsl_api.h deleted file mode 100644 index bbbb84926e8ec..0000000000000 --- a/include/tvm/c_dsl_api.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/c_dsl_api.h - * - * \brief TVM DSL Node C API, used to interact to DSL compilation. - * - * These are only a few functions needed for DSL construction time. - * These function are only available when link libtvm. - * If only TVM runtime is linked, calling these function will trigger error. - * - * \note Most API functions are registerd as PackedFunc and - * can be grabbed via TVMFuncGetGlobal - */ -#ifndef TVM_C_DSL_API_H_ -#define TVM_C_DSL_API_H_ - -#include "runtime/c_runtime_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief handle to node */ -typedef void* NodeHandle; - -/*! - * \brief free the node handle - * \param handle The node handle to be freed. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeFree(NodeHandle handle); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeTypeKey2Index(const char* type_key, - int* out_index); - -/*! - * \brief Get runtime type index of the node. - * \param handle the node handle. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index); - -/*! - * \brief get attributes given key - * \param handle The node handle - * \param key The attribute name - * \param out_value The attribute value - * \param out_type_code The type code of the attribute. - * \param out_success Whether get is successful. - * \return 0 when success, -1 when failure happens - * \note API calls always exchanges with type bits=64, lanes=1 - */ -TVM_DLL int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success); - -/*! - * \brief get attributes names in the node. - * \param handle The node handle - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array); -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif -#endif // TVM_C_DSL_API_H_ diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 143d4295f3e39..346291a6b06a5 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -35,7 +35,7 @@ class Channel : public NodeRef { public: /*! \brief default constructor */ Channel() {} - explicit Channel(NodePtr n) : NodeRef(n) {} + explicit Channel(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -67,7 +67,7 @@ struct ChannelNode : public Node { // Inline implementations inline const ChannelNode* Channel::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_CHANNEL_H_ diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index c2ae572de8183..ad3da6b347af6 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -127,7 +127,7 @@ class LayoutNode : public Node { */ class Layout : public NodeRef { public: - explicit Layout(NodePtr n) : NodeRef(n) {} + explicit Layout(ObjectPtr n) : NodeRef(n) {} /*! \brief default constructor */ Layout() = default; @@ -152,7 +152,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ const LayoutNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! @@ -160,7 +160,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ LayoutNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! @@ -369,7 +369,7 @@ class BijectiveLayout : public NodeRef { }; inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 201a2b485aa6e..71d1c32911e90 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -49,7 +49,7 @@ class ExprNode : public Node { class Expr : public NodeRef { public: Expr() {} - explicit Expr(NodePtr ptr) : NodeRef(ptr) {} + explicit Expr(ObjectPtr ptr) : NodeRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -122,7 +122,7 @@ class Variable : public ExprNode { /*! \brief a named variable in TVM */ class Var : public Expr { public: - explicit Var(NodePtr n) : Expr(n) {} + explicit Var(ObjectPtr n) : Expr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", Type t = Int(32)); /*! @@ -145,7 +145,7 @@ class Var : public Expr { * \return the corresponding Variable. */ const Variable* get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = Variable; @@ -187,7 +187,7 @@ class Integer : public Expr { /*! * \brief constructor from node. */ - explicit Integer(NodePtr node) : Expr(node) {} + explicit Integer(ObjectPtr node) : Expr(node) {} /*! * \brief Construct integer from int value. */ @@ -197,7 +197,7 @@ class Integer : public Expr { * \param other another expression. */ Integer& operator=(const Integer& other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -205,13 +205,13 @@ class Integer : public Expr { * \return the content of the integer. */ const IntImm* operator->() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief convert to int64_t */ operator int64_t() const { - CHECK(node_ != nullptr) + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } @@ -346,7 +346,7 @@ class IterVar : public NodeRef { // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(NodePtr n) : NodeRef(n) {} + explicit IterVar(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -423,7 +423,7 @@ class IterVarNode : public Node { // inline implementations inline const IterVarNode* IterVar::operator->() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline IterVar::operator Expr() const { @@ -481,11 +481,11 @@ class IRPrinter { : stream(stream) {} /*! \brief The node to be printed. */ - TVM_DLL void Print(const NodeRef& node); + TVM_DLL void Print(const ObjectRef& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. - using FType = IRFunctor; + using FType = IRFunctor; TVM_DLL static FType& vtable(); }; @@ -498,10 +498,7 @@ inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT namespace std { template <> -struct hash<::tvm::IterVar> { - std::size_t operator()(const ::tvm::IterVar& k) const { - return k.hash(); - } +struct hash<::tvm::IterVar> : public ::tvm::NodeHash { }; } #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 079f05f5a7f27..032410d568e30 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -664,10 +664,10 @@ class CommReducerNode : public Node { }; inline const CommReducerNode* CommReducer::get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline const CommReducerNode* CommReducer::operator->() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! \brief Reduction operator operator */ @@ -1576,7 +1576,7 @@ namespace std { template <> struct hash<::tvm::ir::TensorKey> { std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = k.f.hash(); + size_t lhs = ::tvm::NodeHash()(k.f); size_t rhs = static_cast(k.value_index); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index a7d91eacf8516..3f3517190d3a0 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -88,15 +88,15 @@ class StmtFunctor; #define IR_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ #define IR_STMT_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), \ std::forward(args)...); \ }); \ @@ -104,7 +104,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -164,7 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -213,7 +213,7 @@ template class StmtFunctor { private: using TSelf = StmtFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -255,7 +255,7 @@ class StmtFunctor { virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index b82a19d4689c4..c910a48620c82 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -65,9 +65,9 @@ class TVM_DLL IRMutator { /*! \brief destructor */ virtual ~IRMutator() {} /*! \brief functor type of expr mutation */ - using FMutateExpr = IRFunctor; + using FMutateExpr = IRFunctor; /*! \brief functor type of stmt mutation */ - using FMutateStmt = IRFunctor; + using FMutateStmt = IRFunctor; /*! \return internal vtable of expr */ static FMutateExpr& vtable_expr(); // NOLINT(*) /*! \return internal stmt of expr */ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index f20b913685871..bebf94585ed6d 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -49,7 +49,7 @@ namespace ir { * // The use case is to count number of Variables in the ir tree. * class MyCounter : public IRVisitor { * public: - * int Count(const NodeRef& n) { + * int Count(const ObjectRef& n) { * ret_ = 0; * this->Visit(n); * return ret_; @@ -94,7 +94,7 @@ class TVM_DLL IRVisitor { /*! \brief destructor */ virtual ~IRVisitor() {} /*! \brief functor type of visitor */ - using FVisit = IRFunctor; + using FVisit = IRFunctor; /*! \return internal vtable*/ static FVisit& vtable(); // overloadable visit function. diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 4da93b80c2ab6..e2147d036587b 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -44,7 +44,7 @@ class LoweredFuncNode; class LoweredFunc : public ir::FunctionRef { public: LoweredFunc() {} - explicit LoweredFunc(NodePtr n) : FunctionRef(n) {} + explicit LoweredFunc(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -136,17 +136,14 @@ class LoweredFuncNode : public ir::FunctionBaseNode { // Implementations of inline functions inline const LoweredFuncNode* LoweredFunc::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm namespace std { template <> -struct hash<::tvm::LoweredFunc> { - std::size_t operator()(const ::tvm::LoweredFunc& k) const { - return k.hash(); - } +struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { }; } diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index c2c639e374f55..2e1a978f4806b 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -38,14 +38,14 @@ namespace tvm { class ArrayNode : public Node { public: /*! \brief the data content */ - std::vector > data; + std::vector data; void VisitAttrs(AttrVisitor* visitor) final { // Visitor to array have no effect. } static constexpr const char* _type_key = "Array"; - TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); }; /*! \brief map node content */ @@ -54,32 +54,17 @@ class MapNode : public Node { void VisitAttrs(AttrVisitor* visitor) final { // Visitor to map have no effect. } - // hash function - struct Hash { - size_t operator()(const NodePtr& n) const { - return std::hash()(n.get()); - } - }; - // comparator - struct Equal { - bool operator()( - const NodePtr& a, - const NodePtr& b) const { - return a.get() == b.get(); - } - }; - /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map< - NodePtr, - NodePtr, - Hash, Equal>; + ObjectRef, + ObjectRef, + ObjectHash, ObjectEqual>; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); }; @@ -90,15 +75,13 @@ class StrMapNode : public Node { // Visitor to map have no effect. } /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - std::string, - NodePtr >; + using ContainerType = std::unordered_map; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); }; /*! @@ -111,9 +94,9 @@ template::difference_type; - using value_type = typename std::iterator_traits::value_type; - using pointer = typename std::iterator_traits::pointer; - using reference = typename std::iterator_traits::reference; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -138,7 +121,7 @@ class IterAdapter { inline bool operator!=(IterAdapter other) const { return !(*this == other); } - inline const typename Converter::ResultType operator*() const { + inline const value_type operator*() const { return Converter::convert(*iter_); } @@ -162,26 +145,27 @@ class Array : public NodeRef { * \brief default constructor */ Array() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Array(Array && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Array(const Array &other) : NodeRef(other.node_) { // NOLINT(*) + Array(const Array &other) { // NOLINT(*) + data_ = std::move(other.data_); } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Array(NodePtr n) : NodeRef(n) {} + explicit Array(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -214,9 +198,9 @@ class Array : public NodeRef { explicit Array(size_t n, const T& val) { auto tmp_node = make_node(); for (size_t i = 0; i < n; ++i) { - tmp_node->data.push_back(val.node_); + tmp_node->data.push_back(val); } - node_ = std::move(tmp_node); + data_ = std::move(tmp_node); } /*! * \brief move assign operator @@ -224,7 +208,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(Array && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -233,7 +217,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(const Array & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -246,9 +230,9 @@ class Array : public NodeRef { void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType it = begin; it != end; ++it) { - n->data.push_back((*it).node_); + n->data.push_back(T(*it)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read i-th element from array. @@ -256,12 +240,13 @@ class Array : public NodeRef { * \return the i-th element. */ inline const T operator[](size_t i) const { - return T(static_cast(node_.get())->data[i]); + return DowncastNoCheck( + static_cast(data_.get())->data[i]); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! * \brief copy on write semantics @@ -272,12 +257,12 @@ class Array : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline ArrayNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief push a new item to the back of the list @@ -285,7 +270,7 @@ class Array : public NodeRef { */ inline void push_back(const T& item) { ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item.node_); + n->data.push_back(item); } /*! * \brief set i-th element of the array. @@ -294,7 +279,7 @@ class Array : public NodeRef { */ inline void Set(size_t i, const T& value) { ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value.node_; + n->data[i] = value; } /*! \return whether array is empty */ inline bool empty() const { @@ -303,34 +288,34 @@ class Array : public NodeRef { /*! \brief specify container node */ using ContainerType = ArrayNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = T; - static inline T convert(const NodePtr& n) { - return T(n); + static inline T convert(const ObjectRef& n) { + return DowncastNoCheck(n); } }; - using iterator = IterAdapter >::const_iterator>; + using iterator = IterAdapter::const_iterator>; using reverse_iterator = IterAdapter< - Ptr2NodeRef, - std::vector >::const_reverse_iterator>; + ValueConverter, + std::vector::const_reverse_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return rbegin iterator */ inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(node_.get())->data.rbegin()); + return reverse_iterator(static_cast(data_.get())->data.rbegin()); } /*! \return rend iterator */ inline reverse_iterator rend() const { - return reverse_iterator(static_cast(node_.get())->data.rend()); + return reverse_iterator(static_cast(data_.get())->data.rend()); } }; @@ -355,26 +340,26 @@ class Map : public NodeRef { * \brief default constructor */ Map() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -406,7 +391,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -415,7 +400,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -428,10 +413,9 @@ class Map : public NodeRef { void assign(IterType begin, IterType end) { NodePtr n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first.node_, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read element from map. @@ -439,7 +423,8 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -447,17 +432,18 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! \return The number of elements of the key */ inline size_t count(const K& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key.node_); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } /*! * \brief copy on write semantics @@ -468,12 +454,12 @@ class Map : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief set the Map. @@ -482,7 +468,7 @@ class Map : public NodeRef { */ inline void Set(const K& key, const V& value) { MapNode* n = this->CopyOnWrite(); - n->data[key.node_] = value.node_; + n->data[key] = value; } /*! \return whether array is empty */ @@ -492,29 +478,31 @@ class Map : public NodeRef { /*! \brief specify container node */ using ContainerType = MapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - NodePtr, - NodePtr >& n) { - return std::make_pair(K(n.first), V(n.second)); + ObjectRef, + ObjectRef>& n) { + return std::make_pair(DowncastNoCheck(n.first), + DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, MapNode::ContainerType::const_iterator>; + ValueConverter, MapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator(static_cast(node_.get())->data.find(key.node_)); + return iterator( + static_cast(data_.get())->data.find(key)); } }; @@ -524,14 +512,14 @@ class Map : public NodeRef { public: // for code reuse Map() { - node_ = make_node(); + data_ = make_node(); } Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} template Map(IterType begin, IterType end) { assign(begin, end); @@ -545,76 +533,77 @@ class Map : public NodeRef { assign(init.begin(), init.end()); } Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } template void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } inline const V operator[](const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline const V at(const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } inline size_t count(const std::string& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } inline StrMapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } inline void Set(const std::string& key, const V& value) { StrMapNode* n = this->CopyOnWrite(); - n->data[key] = value.node_; + n->data[key] = value; } inline bool empty() const { return size() == 0; } using ContainerType = StrMapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - std::string, - NodePtr >& n) { - return std::make_pair(n.first, V(n.second)); + std::string, + ObjectRef>& n) { + return std::make_pair(n.first, DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; + ValueConverter, StrMapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const std::string& key) const { - return iterator(static_cast(node_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h index 23c5a3fafdabb..e902e8fb6d44a 100644 --- a/include/tvm/node/ir_functor.h +++ b/include/tvm/node/ir_functor.h @@ -34,10 +34,10 @@ namespace tvm { /*! - * \brief A dynamically dispatched functor on NodeRef in the first argument. + * \brief A dynamically dispatched functor on ObjectRef in the first argument. * * \code - * IRFunctor tostr; + * IRFunctor tostr; * tostr.set_dispatch([](const Add* op, std::string prefix) { * return prefix + "Add"; * }); @@ -60,10 +60,10 @@ template class IRFunctor; template -class IRFunctor { +class IRFunctor { private: - using Function = std::function; - using TSelf = IRFunctor; + using Function = std::function; + using TSelf = IRFunctor; /*! \brief internal function table */ std::vector func_; @@ -75,8 +75,8 @@ class IRFunctor { * \param n The node to be dispatched * \return Whether dispatching function is registered for n's type. */ - inline bool can_dispatch(const NodeRef& n) const { - uint32_t type_index = n.type_index(); + inline bool can_dispatch(const ObjectRef& n) const { + uint32_t type_index = n->type_index(); return type_index < func_.size() && func_[type_index] != nullptr; } /*! @@ -85,12 +85,12 @@ class IRFunctor { * \param args The additional arguments * \return The result. */ - inline R operator()(const NodeRef& n, Args... args) const { - uint32_t type_index = n.type_index(); + inline R operator()(const ObjectRef& n, Args... args) const { + uint32_t type_index = n->type_index(); CHECK(type_index < func_.size() && func_[type_index] != nullptr) << "IRFunctor calls un-registered function on type " - << Node::TypeIndex2Key(type_index); + << n->GetTypeKey(); return func_[type_index](n, std::forward(args)...); } /*! @@ -101,19 +101,19 @@ class IRFunctor { */ template inline TSelf& set_dispatch(Function f) { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } CHECK(func_[tindex] == nullptr) - << "Dispatch for " << Node::TypeIndex2Key(tindex) + << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! * \brief set the dispacher for type TNode - * This allows f to used detailed const Node pointer to replace NodeRef + * This allows f to used detailed const Node pointer to replace ObjectRef * * \param f The function to be set. * \tparam TNode the type of Node to be dispatched. @@ -121,8 +121,8 @@ class IRFunctor { */ template inline TSelf& set_dispatch(std::function f) { // NOLINT(*) - Function fun = [f](const NodeRef& n, Args... args) { - return f(static_cast(n.node_.get()), + Function fun = [f](const ObjectRef& n, Args... args) { + return f(static_cast(n.get()), std::forward(args)...); }; return this->set_dispatch(fun); @@ -135,7 +135,7 @@ class IRFunctor { */ template inline TSelf& clear_dispatch() { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; @@ -172,7 +172,7 @@ class IRFunctor { * f(e, this); * } * - * using FType = IRFunctor; + * using FType = IRFunctor; * // function to return global function table * static FType& vtable(); * }; @@ -232,15 +232,15 @@ template class IRFunctorStaticRegistry; template -class IRFunctorStaticRegistry { +class IRFunctorStaticRegistry { private: - IRFunctor *irf_; + IRFunctor *irf_; std::shared_ptr free_list; - using TSelf = IRFunctorStaticRegistry; + using TSelf = IRFunctorStaticRegistry; public: - IRFunctorStaticRegistry(IRFunctor *irf) { + IRFunctorStaticRegistry(IRFunctor *irf) { irf_ = irf; free_list = std::make_shared(); } @@ -261,12 +261,12 @@ class IRFunctorStaticRegistry { * the compiler to deduce the template types. */ template -IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( - IRFunctor *irf) { - return IRFunctorStaticRegistry(irf); +IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( + IRFunctor *irf) { + return IRFunctorStaticRegistry(irf); } -#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ +#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName /*! diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h deleted file mode 100644 index 1bba57144e194..0000000000000 --- a/include/tvm/node/memory.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/memory.h - * \brief Node memory management. - */ -#ifndef TVM_NODE_MEMORY_H_ -#define TVM_NODE_MEMORY_H_ - -#include -#include "node.h" - -namespace tvm { -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - */ -template -inline NodePtr make_node(Args&&... args); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. -// -template -class SimpleNodeAllocator { - public: - template - static T* New(Args&&... args) { - return new T(std::forward(args)...); - } - static NodeBase::FDeleter Deleter() { - return Deleter_; - } - - private: - static void Deleter_(NodeBase* ptr) { - delete static_cast(ptr); - } -}; - -template -inline NodePtr make_node(Args&&... args) { - using Allocator = SimpleNodeAllocator; - static_assert(std::is_base_of::value, - "make_node can only be used to create NodeBase"); - T* node = Allocator::New(std::forward(args)...); - node->deleter_ = Allocator::Deleter(); - return NodePtr(node); -} - -} // namespace tvm -#endif // TVM_NODE_MEMORY_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index cb18e46e9a5c7..8203ee69f686f 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -25,7 +25,9 @@ #include #include -#include +#include +#include +#include #include #include #include @@ -38,13 +40,6 @@ class DataType; class Node; class NodeRef; -namespace runtime { -// forward declaration -class NDArray; -// forward declaration -class ObjectRef; -} // namespace runtime - /*! * \brief Visitor class to each node content. * The content is going to be called for each field. @@ -74,15 +69,17 @@ class TVM_DLL AttrVisitor { //! \endcond }; +/*! \brief Reuse the type index in he runtime. */ +using TypeIndex = runtime::TypeIndex; + /*! * \brief base class of node container in DSL AST. */ -class TVM_DLL Node : public NodeBase { +class Node : public runtime::Object { public: /*! \brief virtual destructor */ virtual ~Node() {} - /*! \return The unique type key of the node */ - virtual const char* type_key() const = 0; + /*! * \brief Apply visitor to each field of the Node * Visitor could mutate the content of the node. @@ -90,272 +87,79 @@ class TVM_DLL Node : public NodeBase { * \param visitor The visitor */ virtual void VisitAttrs(AttrVisitor* visitor) {} - /*! \return the type index of the node */ - virtual uint32_t type_index() const = 0; - /*! - * \brief Whether this node derives from node with type_index=tid. - * Implemented by TVM_DECLARE_NODE_TYPE_INFO - * - * \param tid The type index. - * \return the check result. - */ - virtual bool _DerivedFrom(uint32_t tid) const; - /*! - * \brief get a runtime unique type index given a type key - * \param type_key Type key of a type. - * \return the corresponding type index. - */ - static uint32_t TypeKey2Index(const char* type_key); - /*! - * \brief get type key from type index. - * \param index The type index - * \return the corresponding type key. - */ - static const char* TypeIndex2Key(uint32_t index); - /*! - * \return whether the type is derived from - */ - template - inline bool derived_from() const; - /*! - * \return whether the node is of type T - * \tparam The type to be checked. - */ - template - inline bool is_type() const; - /*! - * \brief Get a NodePtr that holds reference to this Node. - * \return the NodePtr - */ - inline NodePtr GetNodePtr() const; - // node ref can see this - friend class NodeRef; + static constexpr const char* _type_key = "Node"; + static constexpr uint32_t _type_index = TypeIndex::kDynamic; + + TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object); }; -/*! \brief Base class of all node reference object */ -class NodeRef { + +/*! + * \brief Base class of all node reference object + * NodeRef is just a alias of ObjectRef. + */ +class NodeRef : public runtime::ObjectRef { public: /*! \brief type indicate the container type */ using ContainerType = Node; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator==(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool same_as(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator<(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator!=(const NodeRef& other) const; - /*! \return the hash function for NodeRef */ - inline size_t hash() const; - /*! \return whether the expression is null */ - inline bool defined() const; - /*! \return the internal type index of IRNode */ - inline uint32_t type_index() const; + /*! \return the internal node pointer */ - inline const Node* get() const; + const Node* get() const { + return static_cast(ObjectRef::get()); + } /*! \return the internal node pointer */ - inline const Node* operator->() const; - /*! - * \brief Downcast this ir node to its actual type (e.g. Add, or - * Select). This returns nullptr if the node is not of the requested - * type. Example usage: - * - * if (const Add *add = node->as()) { - * // This is an add node - * } - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as() const; + const Node* operator->() const { + return get(); + } /*! * \brief A more powerful version of as that also works with * intermediate base types. * \tparam T the target type, must be subtype of IRNode */ template - inline const T *as_derived() const; + const T *as_derived() const { + return as(); + } /*! \brief default constructor */ NodeRef() = default; - explicit NodeRef(NodePtr node) : node_(node) {} - /*! \brief the internal node object, do not touch */ - NodePtr node_; + explicit NodeRef(runtime::ObjectPtr ptr) : ObjectRef(ptr) {} }; -/*! - * \brief Get a reference type from a Node ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the node alive beyond the scope of the function. - * - * \param ptr The node pointer - * \tparam RefType The reference type - * \tparam NodeType The node type - * \return The corresponding RefType - */ -template -inline RefType GetRef(const NodeType* ptr); - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The inptut reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template -inline SubRef Downcast(BaseRef ref); - /*! * \brief helper macro to declare type information in a base node. */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - bool _DerivedFrom(uint32_t tid) const override { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } +#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) /*! * \brief helper macro to declare type information in a terminal node */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - const char* type_key() const final { \ - return TypeName::_type_key; \ - } \ - uint32_t type_index() const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - return tidx; \ - } \ - bool _DerivedFrom(uint32_t tid) const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } - -// implementations of inline functions after this -template -inline bool Node::derived_from() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return this->_DerivedFrom(type_id); -} - - -template -inline bool Node::is_type() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return type_id == this->type_index(); -} +#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ + TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); -inline NodePtr Node::GetNodePtr() const { - return NodePtr(const_cast(this)); -} +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::GetRef; +using runtime::Downcast; +using runtime::make_object; +using runtime::ObjectHash; +using runtime::ObjectEqual; -template -inline RefType GetRef(const NodeType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return RefType(ptr->GetNodePtr()); -} - -template -inline SubRef Downcast(BaseRef ref) { - CHECK(ref->template is_type() || - ref->template derived_from()) - << "Downcast from " << ref->type_key() << " to " - << SubRef::ContainerType::_type_key << " failed."; - return SubRef(std::move(ref.node_)); -} - -inline const Node* NodeRef::get() const { - return node_.get(); -} - -inline const Node* NodeRef::operator->() const { - return node_.get(); -} - -inline bool NodeRef::defined() const { - return node_.get() != nullptr; -} - -inline bool NodeRef::operator==(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} +using NodeHash = ObjectHash; +using NodeEqual = ObjectEqual; -inline bool NodeRef::same_as(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::operator<(const NodeRef& other) const { - return node_.get() < other.node_.get(); -} - -inline bool NodeRef::operator!=(const NodeRef& other) const { - return node_.get() != other.node_.get(); -} - -inline size_t NodeRef::hash() const { - return std::hash()(node_.get()); -} - -inline uint32_t NodeRef::type_index() const { - CHECK(node_.get() != nullptr) - << "null type"; - return get()->type_index(); -} - -template -inline const T* NodeRef::as() const { - const Node* ptr = static_cast(get()); - if (ptr && ptr->is_type()) { - return static_cast(ptr); - } - return nullptr; -} - -template -inline const T* NodeRef::as_derived() const { - const Node* ptr = static_cast(get()); - if (ptr && (ptr->is_type() || ptr->derived_from())) { - return static_cast(ptr); - } - return nullptr; +/*! + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + */ +template +inline NodePtr make_node(Args&&... args) { + return runtime::make_object(std::forward(args)...); } - -/*! \brief The hash function for nodes */ -struct NodeHash { - size_t operator()(const NodeRef& a) const { - return a.hash(); - } -}; - -/*! \brief The equal comparator for nodes */ -struct NodeEqual { - bool operator()(const NodeRef& a, const NodeRef& b) const { - return a.get() == b.get(); - } -}; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index b950aa952f04d..b942464d49071 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -651,7 +651,7 @@ inline Tensor compute(Array shape, // inline function. inline const OperationNode* Operation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_OPERATION_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 5951594b873c7..48d46fdf2fc66 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -37,6 +37,7 @@ #include "runtime/packed_func.h" namespace tvm { + using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; @@ -47,86 +48,82 @@ namespace runtime { * \tparam T the type to be checked. */ template -struct NodeTypeChecker { - static inline bool Check(Node* sptr) { - // This is the only place in the project where RTTI is used - // It can be turned off, but will make non strict checking. - // TODO(tqchen) possibly find alternative to turn of RTTI +struct ObjectTypeChecker { + static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - // always allow nullptr. - if (sptr == nullptr) return true; - return sptr->derived_from(); + if (ptr == nullptr) return true; + return ptr->IsInstance(); } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + static void PrintName(std::ostream& os) { // NOLINT(*) using ContainerType = typename T::ContainerType; os << ContainerType::_type_key; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - ArrayNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const ArrayNode* n = static_cast(ptr); for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) { + if (!ObjectTypeChecker::Check(p.get())) { return false; } } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "array<"; - NodeTypeChecker::PrintName(os); - os << ">"; + static void PrintName(std::ostream& os) { // NOLINT(*) + os << "List["; + ObjectTypeChecker::PrintName(os); + os << "]"; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - StrMapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const StrMapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - MapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const MapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.first.get())) return false; - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map<"; - NodeTypeChecker::PrintName(os); + static void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "Map["; + ObjectTypeChecker::PrintName(os); os << ','; - NodeTypeChecker::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -inline std::string NodeTypeName() { +inline std::string ObjectTypeName() { std::ostringstream os; - NodeTypeChecker::PrintName(os); + ObjectTypeChecker::PrintName(os); return os.str(); } @@ -138,12 +135,12 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { std::is_base_of::value, "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(NodePtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TNodeRef(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Expr() const { @@ -156,18 +153,20 @@ inline TVMArgValue::operator tvm::Expr() const { if (type_code_ == kDLFloat) { return Expr(static_cast(value_.v_float64)); } - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - if (sptr->is_type()) { - return IterVar(sptr)->var; + + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + + if (ptr->IsInstance()) { + return IterVar(ObjectPtr(ptr))->var; } - if (sptr->is_type()) { - return Tensor(sptr)(); + if (ptr->IsInstance()) { + return Tensor(ObjectPtr(ptr))(); } - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Expr(sptr); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Expr(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Integer() const { @@ -177,68 +176,36 @@ inline TVMArgValue::operator tvm::Integer() const { CHECK_GE(value_.v_int64, std::numeric_limits::min()); return Integer(static_cast(value_.v_int64)); } - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Integer(sptr); -} - -inline NodePtr& TVMArgValue::node_sptr() { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return *ptr >(); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Integer(ObjectPtr(ptr)); } - template -inline bool TVMArgValue::IsNodeType() const { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = - *ptr >(); - return NodeTypeChecker::Check(sptr.get()); +inline bool TVMPODValue_::IsObjectRef() const { + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + return ObjectTypeChecker::Check(ptr); } // extensions for TVMRetValue -inline TVMRetValue& TVMRetValue::operator=( - const NodePtr& other) { - if (other.get() == nullptr) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other); - } - return *this; -} - -inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { - if (!other.defined()) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other.node_); - } - return *this; -} - template inline TNodeRef TVMRetValue::AsNodeRef() const { static_assert( std::is_base_of::value, "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); -} + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); -inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) - if (other.defined()) { - values_[i].v_handle = const_cast*>(&(other.node_)); - type_codes_[i] = kNodeHandle; - } else { - type_codes_[i] = kNull; - } + Object* ptr = static_cast(value_.v_handle); + + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TNodeRef(ObjectPtr(ptr)); } // type related stuffs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 4329c438e8a0f..e54d88d5a393f 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -52,7 +52,7 @@ class PatternNode : public RelayNode { class Pattern : public NodeRef { public: Pattern() {} - explicit Pattern(NodePtr p) : NodeRef(p) {} + explicit Pattern(ObjectPtr p) : NodeRef(p) {} using ContainerType = PatternNode; }; diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f94ba5e26068e..15330b00e9619 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -83,10 +83,12 @@ using NodeEqual = ::tvm::NodeEqual; #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ class TypeName : public NodeRefBase { \ public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \ + TypeName() {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : NodeRefBase(n) { \ + } \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(get()); \ } \ operator bool() { return this->defined(); } \ using ContainerType = NodeName; \ @@ -127,7 +129,7 @@ class SourceName : public NodeRef { * \return the pointer to the internal node container */ inline const SourceNameNode* operator->() const { - return static_cast(this->node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b1b8d6a7154e2..281b99297e780 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -541,10 +541,11 @@ RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); // implementataions inline const Type& ExprNode::checked_type() const { - CHECK(checked_type_.defined()) << "internal error: the type checker has " - "not populated the checked_type " - "field for " - << GetRef(this); + CHECK(checked_type_.defined()) + << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " + << GetRef(this); return this->checked_type_; } @@ -557,7 +558,7 @@ inline const TTypeNode* ExprNode::type_as() const { const TTypeNode* node = checked_type_.as(); CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->type_key(); + << ", but get " << checked_type_->GetTypeKey(); return node; } diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e0d940c5d1a59..8bc87a27f66fe 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -57,8 +57,8 @@ class ExprFunctor; #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -117,7 +117,7 @@ class ExprFunctor { virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index d05099f781acd..a0422fa7f4462 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -78,9 +78,9 @@ class ValueNode : public RelayNode { class Value : public NodeRef { public: Value() {} - explicit Value(NodePtr n) : NodeRef(n) {} + explicit Value(ObjectPtr n) : NodeRef(n) {} const ValueNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = ValueNode; diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8b17020a1132d..10d72349d0f51 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -281,10 +281,10 @@ class ModuleNode : public RelayNode { struct Module : public NodeRef { Module() {} - explicit Module(NodePtr p) : NodeRef(p) {} + explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} - inline ModuleNode* operator->() const { - return static_cast(node_.get()); + ModuleNode* operator->() const { + return static_cast(get_mutable()); } using ContainerType = ModuleNode; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0a6d3725655f3..572c194bc2693 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -138,7 +138,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(NodePtr n) : Expr(n) {} + explicit Op(ObjectPtr n) : Expr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -221,11 +221,12 @@ class OpRegistry { const Attrs&, const TypeReporter&)> type_rel_func); /*! - * \brief Set the type key of attributes. - * \param type_key The type of of the attrs field. + * \brief Set the the attrs type key and index to be AttrsType. + * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - inline OpRegistry& set_attrs_type_key(const std::string& type_key); + template + inline OpRegistry& set_attrs_type(); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -397,7 +398,7 @@ class OpMap { // implementations inline const OpNode* Op::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } template @@ -496,10 +497,10 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) - const std::string& type_key) { - get()->attrs_type_key = type_key; - get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str()); +template +inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) + get()->attrs_type_key = AttrsType::_type_key; + get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); return *this; } diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 7f1c47e03592c..c15523cb25de4 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -57,8 +57,8 @@ class PatternFunctor; #define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class PatternFunctor { private: using TSelf = PatternFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -103,7 +103,7 @@ class PatternFunctor { virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a2119c90f750c..08ea3075cb835 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -134,16 +134,16 @@ class PassContext : public NodeRef { * \return const access pointer. */ const PassContextNode* operator->() const { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get()); } /*! * \brief mutable accessor. * \return mutable access pointer. */ PassContextNode* operator->() { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get_mutable()); } /*! * \brief Construct a PassContext containing the default configurations. diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 16e36785c5338..a5cc3c83383e1 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -58,7 +58,7 @@ class TypeNode : public RelayNode { class Type : public NodeRef { public: Type() {} - explicit Type(NodePtr p) : NodeRef(p) {} + explicit Type(ObjectPtr p) : NodeRef(p) {} using ContainerType = TypeNode; }; @@ -430,10 +430,11 @@ class TypeReporterNode : public Node { class TypeReporter : public NodeRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { } TypeReporterNode* operator->() const { - return static_cast(node_.get()); + return const_cast( + static_cast(get())); } using ContainerType = TypeReporterNode; }; diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b058fd63a2f5f..267504beb11ae 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -98,13 +98,12 @@ typedef enum { kTVMType = 5U, kTVMContext = 6U, kArrayHandle = 7U, - kNodeHandle = 8U, + kObjectHandle = 8U, kModuleHandle = 9U, kFuncHandle = 10U, kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kObjectHandle = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 6b4f01e4ac9b1..01c08d324fcbc 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -69,7 +69,7 @@ class ObjAllocatorBase { "make_node can only be used to create NodeBase"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); - ptr->type_index_ = T::type_index(); + ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); } diff --git a/include/tvm/runtime/node_base.h b/include/tvm/runtime/node_base.h deleted file mode 100644 index 8b47c18a09a7a..0000000000000 --- a/include/tvm/runtime/node_base.h +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/node_base.h - * \brief Base data structure for Node. - * - * \note Node is not a runtime feature. - * This file only exposes the signature of NodePtr for PackedFunc. - */ -#ifndef TVM_RUNTIME_NODE_BASE_H_ -#define TVM_RUNTIME_NODE_BASE_H_ - -#include -#include - -namespace tvm { - -// forward declarations -template -class NodePtr; -class Node; -class NodeRef; - -/*! - * \brief Base class of Node for runtime destructor purposes. - * - * Node is a reference counted object which is used to construct AST. - * Each node is backed by a custom deleter, which deletes the object. - * Do not call create raw Node pointer, always use tvm::make_node. - * - * \note In most cases, please inheritate tvm::Node. - * \sa Node, NodePtr, make_node - */ -class NodeBase { - public: - /*! - * \brief type of NodeBase deleter - * \param self pointer to the NodeBase. - */ - typedef void (*FDeleter)(NodeBase* self); - - protected: - // default constructor and copy constructor - NodeBase() {} - // override the copy and assign constructors to do nothing. - // This is to make sure only contents, but not deleter and ref_counter - // are copied when a child class copies itself. - NodeBase(const NodeBase& other) { // NOLINT(*) - } - NodeBase(NodeBase&& other) { // NOLINT(*) - } - NodeBase& operator=(const NodeBase& other) { //NOLINT(*) - return *this; - } - NodeBase& operator=(NodeBase&& other) { //NOLINT(*) - return *this; - } - - private: - /*! \brief Internal reference counter */ - std::atomic ref_counter_{0}; - /*! - * \brief deleter of this object to enable customized allocation. - * If the deleter is nullptr, no deletion will be performed. - * The creator of the Node must always set the deleter field properly. - */ - FDeleter deleter_ = nullptr; - // reference counting functions - void IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); - } - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } - } - int use_count() const { - return ref_counter_.load(std::memory_order_relaxed); - } - // friend declaration - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; - -/*! - * \brief Smart pointer for Node containers, - * must be subclass of NodeBase - * \tparam T the content data type. - */ -template -class NodePtr { - public: - /*! \brief default constructor */ - NodePtr() {} - /*! \brief default constructor */ - NodePtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - } - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~NodePtr() { - this->reset(); - } - /*! - * \brief Swap this array with another NDArray - * \param other The other NDArray - */ - void swap(NodePtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { - return static_cast(data_); - } - /*! - * \return The pointer - */ - T* operator->() const { - return get(); - } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(const NodePtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - NodePtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(NodePtr&& other) { // NOLINT(*) - // copy-and-swap idiom - NodePtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } - /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } - /*! \return Whether two NodePtr do not equals each other */ - bool operator==(const NodePtr& other) const { - return data_ == other.data_; - } - /*! \return Whether two NodePtr equals each other */ - bool operator!=(const NodePtr& other) const { - return data_ != other.data_; - } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } - - private: - /*! \brief internal pointer field */ - NodeBase* data_{nullptr}; - /*! - * \brief constructor from NodeBase - * \param data The node base pointer - */ - explicit NodePtr(NodeBase* data) - : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } - } - // friend declaration - friend class Node; - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; -} // namespace tvm - -#endif // TVM_RUNTIME_NODE_BASE_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0693b1f47b3c7..c3cb6f3f2fc56 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -65,7 +65,7 @@ enum TypeIndex { * - _type_index: * Static type index of the object, if assigned to TypeIndex::kDynamic * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::type_index(); + * Runtime type index can be accessed by ObjectType::TypeIndex(); * - _type_key: * The unique string identifier of tyep type. * - _type_final: @@ -147,10 +147,23 @@ class Object { * \param self pointer to the Object. */ typedef void (*FDeleter)(Object* self); - /*! \return The internal type index of the object. */ + /*! \return The internal runtime type index of the object. */ uint32_t type_index() const { return type_index_; } + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + return TypeIndex2Key(type_index_); + } + /*! + * \return A hash value of the return of GetTypeKey. + */ + size_t GetTypeKeyHash() const { + return TypeIndex2KeyHash(type_index_); + } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. @@ -159,6 +172,25 @@ class Object { template inline bool IsInstance() const; + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); + /*! + * \brief Get the type key hash of the corresponding index from runtime. + * \param tindex The type index. + * \return the related key-hash. + */ + TVM_DLL static size_t TypeIndex2KeyHash(uint32_t tindex); + /*! + * \brief Get the type index of the corresponding key from runtime. + * \param key The type key. + * \return the result. + */ + TVM_DLL static uint32_t TypeKey2Index(const char* key); + #if TVM_OBJECT_ATOMIC_REF_COUNTER using RefCounterType = std::atomic; #else @@ -170,9 +202,30 @@ class Object { static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; - static const uint32_t _GetOrAllocRuntimeTypeIndex() { + static uint32_t _GetOrAllocRuntimeTypeIndex() { return 0; } + static uint32_t RuntimeTypeIndex() { + return 0; + } + + // Default constructor and copy constructor + Object() {} + // Override the copy and assign constructors to do nothing. + // This is to make sure only contents, but not deleter and ref_counter + // are copied when a child class copies itself. + // This will enable us to use make_object(*obj_ptr) + // to copy an existing object. + Object(const Object& other) { // NOLINT(*) + } + Object(Object&& other) { // NOLINT(*) + } + Object& operator=(const Object& other) { //NOLINT(*) + return *this; + } + Object& operator=(Object&& other) { //NOLINT(*) + return *this; + } protected: // The fields of the base object cell. @@ -215,18 +268,6 @@ class Object { uint32_t type_child_slots, bool type_child_slots_can_overflow); - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - */ - TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); - - /*! - * \brief Get the type index of the corresponding key from runtime. - * \param key The type key. - */ - TVM_DLL static uint32_t TypeKey2Index(const char* key); - private: // reference counter related operations /*! \brief developer function, increases reference counter. */ @@ -256,6 +297,32 @@ class Object { friend class TVMObjectCAPI; }; +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam ObjectType The node type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const ObjectType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + /*! * \brief A custom smart pointer for Object. * \tparam T the content data type. @@ -389,7 +456,7 @@ class ObjectPtr { /*! \brief internal pointer field */ Object* data_{nullptr}; /*! - * \brief constructor from NodeBase + * \brief constructor from Object * \param data The data pointer */ explicit ObjectPtr(Object* data) : data_(data) { @@ -400,6 +467,7 @@ class ObjectPtr { // friend classes friend class Object; friend class ObjectRef; + friend struct ObjectHash; template friend class ObjectPtr; template @@ -407,6 +475,9 @@ class ObjectPtr { friend class TVMPODValue_; friend class TVMArgsSetter; friend class TVMRetValue; + friend class TVMArgValue; + template + friend RefType GetRef(const ObjType* ptr); }; /*! \brief Base class of all object reference */ @@ -416,10 +487,54 @@ class ObjectRef { ObjectRef() = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { + return data_ != other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { + return data_.get() < other.data_.get(); + } + /*! \return whether the expression is null */ + bool defined() const { + return data_ != nullptr; + } /*! \return the internal object pointer */ - inline const Object* get() const; + const Object* get() const { + return data_.get(); + } /*! \return the internal node pointer */ - inline const Object* operator->() const; + const Object* operator->() const { + return get(); + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_.unique(); + } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -434,25 +549,81 @@ class ObjectRef { template inline const ObjectType* as() const; - /*! \brief type indicate the container type */ + /*! \brief type indicate the container type. */ using ContainerType = Object; protected: /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { + return data_.get(); + } + /*! + * \brief Internal helper function downcast a ref without check. + * \note Only used for internal dev purpoes. + * \tparam T The target reference type. + * \return The casted result. + */ + template + static T DowncastNoCheck(ObjectRef ref) { + return T(std::move(ref.data_)); + } + /*! + * \brief Internal helper function get data_ as ObjectPtr of ObjectType. + * \note only used for internal dev purpsoes. + * \tparam ObjectType The corresponding object type. + * \return the corresponding type. + */ + template + static ObjectPtr GetDataPtr(const ObjectRef& ref) { + return ObjectPtr(ref.data_.data_); + } // friend classes. + friend struct ObjectHash; friend class TVMRetValue; friend class TVMArgsSetter; + template + friend SubRef Downcast(BaseRef ref); }; + +/*! \brief ObjectRef hash functor */ +struct ObjectHash { + size_t operator()(const ObjectRef& a) const { + return operator()(a.data_); + } + + template + size_t operator()(const ObjectPtr& a) const { + return std::hash()(a.get()); + } +}; + + +/*! \brief ObjectRef equal functor */ +struct ObjectEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { + return a.same_as(b); + } + + template + size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { + return a == b; + } +}; + + /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static const uint32_t type_index() { \ - if (_type_index != TypeIndex::kDynamic) return _type_index; \ + static const uint32_t RuntimeTypeIndex() { \ + if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return _type_index; \ + } \ return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ @@ -551,11 +722,11 @@ inline bool Object::IsInstance() const { if (TargetType::_type_final) { // if the target type is a final type // then we only need to check the equivalence. - return self->type_index_ == TargetType::type_index(); + return self->type_index_ == TargetType::RuntimeTypeIndex(); } else { // if target type is a non-leaf type // Check if type index falls into the range of reserved slots. - uint32_t begin = TargetType::type_index(); + uint32_t begin = TargetType::RuntimeTypeIndex(); // The condition will be optimized by constant-folding. if (TargetType::_type_child_slots != 0) { uint32_t end = begin + TargetType::_type_child_slots; @@ -565,22 +736,15 @@ inline bool Object::IsInstance() const { } if (!TargetType::_type_child_slots_can_overflow) return false; // Invariance: parent index is always smaller than the child. - if (self->type_index_ < TargetType::type_index()) return false; + if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; // The rare slower-path, check type hierachy. - return self->DerivedFrom(TargetType::type_index()); + return self->DerivedFrom(TargetType::RuntimeTypeIndex()); } } else { return false; } } -inline const Object* ObjectRef::get() const { - return data_.data_; -} - -inline const Object* ObjectRef::operator->() const { - return get(); -} template inline const ObjectType* ObjectRef::as() const { @@ -591,7 +755,27 @@ inline const ObjectType* ObjectRef::as() const { return nullptr; } } + +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.data_)); +} + } // namespace runtime + +template +using NodePtr = runtime::ObjectPtr; + } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 2bfa3323e4f1f..649a5058a9a5e 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -40,7 +40,6 @@ #include "module.h" #include "ndarray.h" #include "object.h" -#include "node_base.h" // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY @@ -52,6 +51,8 @@ namespace tvm { class Integer; class DataType; class Expr; +class Node; +class NodeRef; namespace runtime { @@ -490,9 +491,12 @@ class TVMPODValue_ { return NDArray(static_cast(value_.v_handle)); } operator ObjectRef() const { - if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kNull) { + return ObjectRef(ObjectPtr(nullptr)); + } TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); + return ObjectRef( + ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); @@ -512,9 +516,14 @@ class TVMPODValue_ { CHECK_LT(type_code_, kExtEnd); return static_cast(value_.v_handle)[0]; } + template::value>::type> + inline bool IsObjectRef() const; int type_code() const { return type_code_; } + /*! * \brief return handle as specific pointer type. * \tparam T the data type. @@ -567,6 +576,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -616,15 +626,9 @@ class TVMArgValue : public TVMPODValue_ { typename = typename std::enable_if< std::is_class::value>::type> inline operator T() const; - template::value>::type> - inline bool IsNodeType() const; inline operator tvm::DataType() const; inline operator tvm::Expr() const; inline operator tvm::Integer() const; - // get internal node ptr, if it is node - inline NodePtr& node_sptr(); }; /*! @@ -663,6 +667,8 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } @@ -760,11 +766,19 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(ObjectRef other) { - this->Clear(); - type_code_ = kObjectHandle; - // move the handle out - value_.v_handle = other.data_.data_; - other.data_.data_ = nullptr; + return operator=(std::move(other.data_)); + } + template + TVMRetValue& operator=(ObjectPtr other) { + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = kObjectHandle; + // move the handle out + value_.v_handle = other.data_; + other.data_ = nullptr; + } else { + SwitchToPOD(kNull); + } return *this; } TVMRetValue& operator=(PackedFunc f) { @@ -814,7 +828,7 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kNodeHandle && + CHECK(type_code_ != kObjectHandle && type_code_ != kFuncHandle && type_code_ != kModuleHandle && type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; @@ -827,8 +841,6 @@ class TVMRetValue : public TVMPODValue_ { inline operator T() const; template inline TNodeRef AsNodeRef() const; - inline TVMRetValue& operator=(const NodeRef& other); - inline TVMRetValue& operator=(const NodePtr& other); // type related inline operator tvm::DataType() const; inline TVMRetValue& operator=(const tvm::DataType& other); @@ -857,11 +869,6 @@ class TVMRetValue : public TVMPODValue_ { *this = other.operator NDArray(); break; } - case kNodeHandle: { - SwitchToClass >( - kNodeHandle, *other.template ptr >()); - break; - } case kObjectHandle: { *this = other.operator ObjectRef(); break; @@ -908,7 +915,6 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; case kModuleHandle: delete ptr(); break; - case kNodeHandle: delete ptr >(); break; case kNDArrayContainer: { static_cast(value_.v_handle)->DecRef(); break; @@ -939,7 +945,6 @@ inline const char* TypeCode2Str(int type_code) { case kBytes: return "bytes"; case kHandle: return "handle"; case kNull: return "NULL"; - case kNodeHandle: return "NodeHandle"; case kArrayHandle: return "ArrayHandle"; case kTVMType: return "TVMType"; case kTVMContext: return "TVMContext"; @@ -1057,8 +1062,6 @@ inline PackedFunc::FType PackedFunc::body() const { return body_; } - - // internal namespace namespace detail { @@ -1163,8 +1166,12 @@ class TVMArgsSetter { type_codes_[i] = kNDArrayContainer; } void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectHandle; + if (value.defined()) { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kObjectHandle; + } else { + type_codes_[i] = kNull; + } } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { @@ -1181,8 +1188,6 @@ class TVMArgsSetter { typename = typename std::enable_if< extension_type_info::code != 0>::type> inline void operator()(size_t i, const T& value) const; - // NodeRef related extenstions: in tvm/packed_func_ext.h - inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const tvm::DataType& t) const; private: diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index af3e929ac3fa2..36265667e5b63 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -56,7 +56,7 @@ enum AttachType : int { class Stage : public NodeRef { public: Stage() {} - explicit Stage(NodePtr n) : NodeRef(n) {} + explicit Stage(ObjectPtr n) : NodeRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -280,7 +280,7 @@ class Stage : public NodeRef { class Schedule : public NodeRef { public: Schedule() {} - explicit Schedule(NodePtr n) : NodeRef(n) {} + explicit Schedule(ObjectPtr n) : NodeRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -403,7 +403,7 @@ class Schedule : public NodeRef { class IterVarRelation : public NodeRef { public: IterVarRelation() {} - explicit IterVarRelation(NodePtr n) : NodeRef(n) {} + explicit IterVarRelation(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -417,7 +417,7 @@ class IterVarRelation : public NodeRef { class IterVarAttr : public NodeRef { public: IterVarAttr() {} - explicit IterVarAttr(NodePtr n) : NodeRef(n) {} + explicit IterVarAttr(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -745,25 +745,25 @@ class SingletonNode : public IterVarRelationNode { // implementations inline const StageNode* Stage::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline StageNode* Stage::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline ScheduleNode* Schedule::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline const IterVarAttrNode* IterVarAttr::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index f37cc7bed7d1c..6471c9c69a627 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -50,7 +50,7 @@ class Tensor : public NodeRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(NodePtr n) : NodeRef(n) {} + explicit Tensor(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -141,7 +141,7 @@ class Operation : public ir::FunctionRef { public: /*! \brief default constructor */ Operation() {} - explicit Operation(NodePtr n) : FunctionRef(n) {} + explicit Operation(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -189,7 +189,7 @@ class TensorNode : public Node { // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline size_t Tensor::ndim() const { @@ -250,19 +250,17 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::Operation> { - std::size_t operator()(const ::tvm::Operation& k) const { - return k.hash(); - } +struct hash<::tvm::Operation> : public ::tvm::NodeHash { }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { + ::tvm::NodeHash hasher; if (k.defined() && k->op.defined()) { - return k->op.hash(); + return hasher(k->op); } else{ - return k.hash(); + return hasher(k); } } }; diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index b5ca6eb4358b6..152a27f6e2a9a 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -112,7 +112,7 @@ class TensorIntrinNode : public Node { }; inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } // Internal node container of tensor intrinsic calling. @@ -170,7 +170,7 @@ class TensorIntrinCallNode : public Node { }; inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h index fa8b69f9b70ac..9555c0e7b3eaa 100644 --- a/nnvm/include/nnvm/compiler/util.h +++ b/nnvm/include/nnvm/compiler/util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -56,7 +56,7 @@ inline tvm::Array ShapeToArray(TShape shape) { * \return An Array of Expr, where each element is a constant int32 */ inline tvm::Array ShapeToIntArray(TShape shape) { - return tvm::Array(ShapeToArray(shape).node_); + return tvm::Downcast >(ShapeToArray(shape)); } } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 542455969b8b4..5ce78d1d58d66 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -388,6 +388,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") *rv = ret; }); +TVM_REGISTER_NODE_TYPE(GraphFuncNode); +TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const GraphFuncNode *op, IRPrinter *p) { p->stream << "GraphFunc(name=" << op->func_name diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index 35287f5a9358e..e8d33cb4be7ed 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -92,7 +92,7 @@ class GraphCacheEntry : public ::tvm::NodeRef { GraphCacheEntry() {} explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} GraphCacheEntryNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } using ContainerType = GraphCacheEntryNode; }; diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 3a847de83d9f8..7b324ba100adb 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index bbcc62a99ad8c..45f1451663e62 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -115,7 +115,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") const Array& out_info) -> Array { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); - if ((*ret.ptr<::tvm::NodePtr >())->derived_from()) { + if (ret.IsObjectRef()) { return {ret.operator Tensor()}; } else { return ret; diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 5496a4c674f69..c48ae0061f9e8 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1237,7 +1237,7 @@ Array GetIntArray(Array arr) { CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Array(arr.node_); + return Downcast >(arr); } NNVM_REGISTER_OP(slice_like) diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 22fb6c335dcca..2f0b5babda4dc 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement +# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import """Function configuration API.""" from __future__ import absolute_import @@ -32,9 +32,8 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .node import NodeBase +from .object import ObjectBase, _set_class_node from . import object as _object -from . import node as _node FunctionHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p @@ -108,9 +107,9 @@ def _make_tvm_args(args, temp_args): values = (TVMValue * num_args)() type_codes = (ctypes.c_int * num_args)() for i, arg in enumerate(args): - if isinstance(arg, NodeBase): + if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None type_codes[i] = TypeCode.NULL @@ -148,7 +147,7 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle @@ -164,9 +163,6 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = arg.handle type_codes[i] = TypeCode.FUNC_HANDLE temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): - values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -226,7 +222,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE) + assert ret_tcode.value == TypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -247,7 +243,6 @@ def _handle_return_func(x): return _CLASS_FUNCTION(handle, False) # setup return handle for function type -_node.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py deleted file mode 100644 index 39fe0ef35525b..0000000000000 --- a/python/tvm/_ffi/_ctypes/node.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, protected-access -# pylint: disable=no-member, missing-docstring, not-callable -from __future__ import absolute_import - -import ctypes -from ..base import _LIB, check_call, c_str -from ..node_generic import _set_class_node_base -from .types import TVMValue, TypeCode -from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func - -NodeHandle = ctypes.c_void_p -__init_by_constructor__ = None - -"""Maps node type to its constructor""" -NODE_TYPE = {} - -def _register_node(index, cls): - """register node class""" - NODE_TYPE[index] = cls - -def _return_node(x): - """Return node function""" - handle = x.v_handle - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - tindex = ctypes.c_int() - check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex))) - cls = NODE_TYPE.get(tindex.value, NodeBase) - # Avoid calling __init__ of cls, instead directly call __new__ - # This allows child class to implement their own __init__ - node = cls.__new__(cls) - node.handle = handle - return node - - -RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node -C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( - _return_node, TypeCode.NODE_HANDLE) - - -class NodeBase(object): - __slots__ = ["handle"] - # pylint: disable=no-member - def __del__(self): - if _LIB is not None: - check_call(_LIB.TVMNodeFree(self.handle)) - - def __getattr__(self, name): - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - ret_success = ctypes.c_int() - check_call(_LIB.TVMNodeGetAttr( - self.handle, c_str(name), - ctypes.byref(ret_val), - ctypes.byref(ret_type_code), - ctypes.byref(ret_success))) - if not ret_success.value: - raise AttributeError( - "'%s' object has no attribute '%s'" % (str(type(self)), name)) - return RETURN_SWITCH[ret_type_code.value](ret_val) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # assign handle first to avoid error raising - self.handle = None - handle = __init_by_constructor__(fconstructor, args) - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - self.handle = handle - -_set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 5ddceb166677b..c3ae56822198d 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -21,6 +21,7 @@ import ctypes from ..base import _LIB, check_call from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func +from ..node_generic import _set_class_node_base ObjectHandle = ctypes.c_void_p @@ -29,6 +30,13 @@ """Maps object type to its constructor""" OBJECT_TYPE = {} +_CLASS_NODE = None + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class + + def _register_object(index, cls): """register object class""" OBJECT_TYPE[index] = cls @@ -40,7 +48,7 @@ def _return_object(x): handle = ObjectHandle(handle) tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) - cls = OBJECT_TYPE.get(tindex.value, ObjectBase) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) @@ -83,3 +91,6 @@ def __init_handle_by_constructor__(self, fconstructor, *args): if not isinstance(handle, ObjectHandle): handle = ObjectHandle(handle) self.handle = handle + + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 76fa96376b476..3c24c98b0505f 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -31,13 +31,12 @@ cdef enum TVMTypeCode: kTVMType = 5 kTVMContext = 6 kArrayHandle = 7 - kNodeHandle = 8 + kObjectHandle = 8 kModuleHandle = 9 kFuncHandle = 10 kStr = 11 kBytes = 12 kNDArrayContainer = 13 - kObjectHandle = 14 kExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -134,18 +133,6 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) -cdef extern from "tvm/c_dsl_api.h": - int TVMNodeFree(NodeHandle handle) - int TVMNodeTypeKey2Index(const char* type_key, - int* out_index) - int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index) - int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success) - cdef inline py_str(const char* x): if PY_MAJOR_VERSION < 3: return x diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index a9349338fc6a6..cbf9d5859046a 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -17,7 +17,7 @@ include "./base.pxi" include "./object.pxi" -include "./node.pxi" +# include "./node.pxi" include "./function.pxi" include "./ndarray.pxi" diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index ceacf74071704..a2360427b6c7a 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -41,10 +41,9 @@ cdef int tvm_callback(TVMValue* args, for i in range(num_args): value = args[i] tcode = type_codes[i] - if (tcode == kNodeHandle or + if (tcode == kObjectHandle or tcode == kFuncHandle or tcode == kModuleHandle or - tcode == kObjectHandle or tcode > kExtBegin): CALL(TVMCbArgToReturn(&value, tcode)) @@ -98,9 +97,9 @@ cdef inline int make_arg(object arg, list temp_args) except -1: """Pack arguments into c args tvm call accept""" cdef unsigned long long ptr - if isinstance(arg, NodeBase): - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle + if isinstance(arg, ObjectBase): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, NDArrayBase): value[0].v_handle = (arg).chandle tcode[0] = (kNDArrayContainer if @@ -152,12 +151,9 @@ cdef inline int make_arg(object arg, temp_args.append(tstr) elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle - temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): value[0].v_handle = (arg).chandle tcode[0] = kObjectHandle + temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kModuleHandle @@ -188,9 +184,7 @@ cdef inline bytearray make_ret_bytes(void* chandle): cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" - if tcode == kNodeHandle: - return make_ret_node(value.v_handle) - elif tcode == kObjectHandle: + if tcode == kObjectHandle: return make_ret_object(value.v_handle) elif tcode == kNull: return None @@ -314,6 +308,7 @@ cdef class FunctionBase: _CLASS_FUNCTION = None _CLASS_MODULE = None _CLASS_OBJECT = None +_CLASS_NODE = None def _set_class_module(module_class): """Initialize the module.""" @@ -327,3 +322,7 @@ def _set_class_function(func_class): def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi deleted file mode 100644 index 5e0c366e5600e..0000000000000 --- a/python/tvm/_ffi/_cython/node.pxi +++ /dev/null @@ -1,110 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from ... import _api_internal -from ..base import string_types -from ..node_generic import _set_class_node_base - -"""Maps node type to its constructor""" -NODE_TYPE = [] - -def _register_node(int index, object cls): - """register node class""" - while len(NODE_TYPE) <= index: - NODE_TYPE.append(None) - NODE_TYPE[index] = cls - - -cdef inline object make_ret_node(void* chandle): - global NODE_TYPE - cdef int tindex - cdef list node_type - cdef object cls - node_type = NODE_TYPE - CALL(TVMNodeGetTypeIndex(chandle, &tindex)) - if tindex < len(node_type): - cls = node_type[tindex] - if cls is not None: - obj = cls.__new__(cls) - else: - obj = NodeBase.__new__(NodeBase) - else: - obj = NodeBase.__new__(NodeBase) - (obj).chandle = chandle - return obj - - -cdef class NodeBase: - cdef void* chandle - - cdef _set_handle(self, handle): - cdef unsigned long long ptr - if handle is None: - self.chandle = NULL - else: - ptr = handle.value - self.chandle = (ptr) - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes_handle(self.chandle) - - def __set__(self, value): - self._set_handle(value) - - def __dealloc__(self): - CALL(TVMNodeFree(self.chandle)) - - def __getattr__(self, name): - cdef TVMValue ret_val - cdef int ret_type_code, ret_succ - CALL(TVMNodeGetAttr(self.chandle, c_str(name), - &ret_val, &ret_type_code, &ret_succ)) - if ret_succ == 0: - raise AttributeError( - "'%s' object has no attribute '%s'" % (type(self), name)) - return make_ret(ret_val, ret_type_code) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # avoid error raised during construction. - self.chandle = NULL - cdef void* chandle - ConstructorCall( - (fconstructor).chandle, - kNodeHandle, args, &chandle) - self.chandle = chandle - -_set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 90be6a9c5b741..9561eab94ea2f 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -16,6 +16,8 @@ # under the License. """Maps object type to its constructor""" +from ..node_generic import _set_class_node_base + OBJECT_TYPE = [] def _register_object(int index, object cls): @@ -27,6 +29,7 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE + global _CLASS_NODE cdef unsigned tindex cdef list object_type cdef object cls @@ -39,9 +42,11 @@ cdef inline object make_ret_object(void* chandle): if cls is not None: obj = cls.__new__(cls) else: - obj = ObjectBase.__new__(ObjectBase) + # default use node base class + # TODO(tqchen) change to object after Node unifies with Object + obj = _CLASS_NODE.__new__(_CLASS_NODE) else: - obj = ObjectBase.__new__(ObjectBase) + obj = _CLASS_NODE.__new__(_CLASS_NODE) (obj).chandle = chandle return obj @@ -94,3 +99,6 @@ cdef class ObjectBase: (fconstructor).chandle, kObjectHandle, args, &chandle) self.chandle = chandle + + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index baca89d628b8e..c6c151af90539 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -21,21 +21,8 @@ import ctypes import sys from .. import _api_internal +from .object import Object, register_object, _set_class_node from .node_generic import NodeGeneric, convert_to_node, const -from .base import _LIB, check_call, c_str, py_str, _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _register_node, NodeBase as _NodeBase - else: - from ._cy2.core import _register_node, NodeBase as _NodeBase -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.node import _register_node, NodeBase as _NodeBase def _new_object(cls): @@ -43,20 +30,22 @@ def _new_object(cls): return cls.__new__(cls) -class NodeBase(_NodeBase): +class NodeBase(Object): """NodeBase is the base class of all TVM language AST object.""" def __repr__(self): return _api_internal._format_str(self) def __dir__(self): - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - check_call(_LIB.TVMNodeListAttrNames( - self.handle, ctypes.byref(size), ctypes.byref(plist))) - names = [] - for i in range(size.value): - names.append(py_str(plist[i])) - return names + fnames = _api_internal._NodeListAttrNames(self) + size = fnames(-1) + return [fnames(i) for i in range(size)] + + def __getattr__(self, name): + try: + return _api_internal._NodeGetAttr(self, name) + except AttributeError: + raise AttributeError( + "%s has no attribute %s" % (str(type(self)), name)) def __hash__(self): return _api_internal._raw_ptr(self) @@ -95,24 +84,6 @@ def same_as(self, other): return self.__hash__() == other.__hash__() -def register_node(type_key=None): - """register node type - - Parameters - ---------- - type_key : str or cls - The type key of the node - """ - node_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - tindex = ctypes.c_int() - ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)) - if ret == 0: - _register_node(tindex.value, cls) - return cls - - if isinstance(type_key, str): - return register - return register(type_key) +# pylint: disable=invalid-name +register_node = register_object +_set_class_node(NodeBase) diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py index be8b086a50f96..002fd27af0fd4 100644 --- a/python/tvm/_ffi/object.py +++ b/python/tvm/_ffi/object.py @@ -20,25 +20,25 @@ import sys import ctypes -from .base import _FFI_MODE, check_call, _LIB, c_str +from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError try: - # pylint: disable=wrong-import-position + # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object + from ._cy3.core import _set_class_object, _set_class_node from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object else: - from ._cy2.core import _set_class_object + from ._cy2.core import _set_class_object, _set_class_node from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_object + # pylint: disable=wrong-import-position,unused-import + from ._ctypes.function import _set_class_object, _set_class_node from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object @@ -75,8 +75,15 @@ def register(cls): tindex = cls._type_index else: tidx = ctypes.c_uint() - check_call(_LIB.TVMObjectTypeKey2Index( - c_str(object_name), ctypes.byref(tidx))) + if not _RUNTIME_ONLY: + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + else: + # directly skip unknown objects during runtime. + ret = _LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx)) + if ret != 0: + return cls tindex = tidx.value _register_object(tindex, cls) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 00e19459df76e..2dbb67dfbf739 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -36,13 +36,12 @@ class TypeCode(object): TVM_TYPE = 5 TVM_CONTEXT = 6 ARRAY_HANDLE = 7 - NODE_HANDLE = 8 + OBJECT_HANDLE = 8 MODULE_HANDLE = 9 FUNC_HANDLE = 10 STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 - OBJECT_HANDLE = 14 EXT_BEGIN = 15 diff --git a/python/tvm/error.py b/python/tvm/error.py index b5a7ed2374b74..a6d4f701d2a62 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -49,6 +49,7 @@ def __init__(self, msg): register_error("ValueError", ValueError) register_error("TypeError", TypeError) +register_error("AttributeError", AttributeError) @register_error diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index b36715249f0a8..ded5d0d13bd71 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -62,6 +62,10 @@ def compile(mod, target=None, target_host=None, params=None): compiler._compile(mod, target, target_host) return vm.Executable(compiler._get_exec()) +def enabled(): + """Whether vm profiler is enabled.""" + return hasattr(_vm, "_VMCompilerProfiler") + class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" def __init__(self): diff --git a/python/tvm/relay/debug.py b/python/tvm/relay/debug.py index ee30f25d88c16..8887a7eb3c7ca 100644 --- a/python/tvm/relay/debug.py +++ b/python/tvm/relay/debug.py @@ -17,12 +17,8 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import -from .base import NodeBase, register_relay_node from ..api import register_func -@register_relay_node -class InterpreterState(NodeBase): - pass # pylint: disable=unused-argument def _debugger_init(expr, stack): diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index f31f02b1eaf4d..c57e2afaa8ebf 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -117,8 +117,7 @@ TVM_REGISTER_API("arith._CreateAnalyzer") }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - auto& sptr = args[1].node_sptr(); - if (sptr->is_type()) { + if (args[1].IsObjectRef()) { self->Bind(args[0], args[1].operator Range()); } else { self->Bind(args[0], args[1].operator Expr()); diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 28ebb4d65005f..c25c35f636e66 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ namespace tvm { TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); + CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; os << args[0].operator NodeRef(); *ret = os.str(); @@ -38,9 +38,8 @@ TVM_REGISTER_API("_format_str") TVM_REGISTER_API("_raw_ptr") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); - *ret = reinterpret_cast( - args[0].node_sptr().get()); + CHECK(args[0].type_code() == kObjectHandle); + *ret = reinterpret_cast(args[0].value().v_handle); }); TVM_REGISTER_API("_save_json") diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 73e26719cf152..f2ca67e6e2f9c 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,7 @@ namespace codegen { TVM_REGISTER_API("codegen._Build") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); } else { *ret = Build(args[0], args[1]); diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index b8ee1441fe120..9312c55323025 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of API functions related to IR build * \file api_ir.cc */ diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index aa0ce47b4a37b..f3d6c5f6ab626 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -57,25 +57,26 @@ TVM_REGISTER_API("_str") TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector > data; + std::vector data; for (int i = 0; i < args.size(); ++i) { if (args[i].type_code() != kNull) { - data.push_back(args[i].node_sptr()); + data.push_back(args[i].operator ObjectRef()); } else { - data.push_back(NodePtr(nullptr)); + data.push_back(ObjectRef(nullptr)); } } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = runtime::ObjectRef(node); }); TVM_REGISTER_API("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; *ret = n->data[static_cast(i)]; @@ -83,10 +84,11 @@ TVM_REGISTER_API("_ArrayGetItem") TVM_REGISTER_API("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); *ret = static_cast( - static_cast(sptr.get())->data.size()); + static_cast(ptr)->data.size()); }); TVM_REGISTER_API("_Map") @@ -98,10 +100,10 @@ TVM_REGISTER_API("_Map") for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].node_sptr())); + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -110,12 +112,12 @@ TVM_REGISTER_API("_Map") // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kNodeHandle) + CHECK(args[i].type_code() == kObjectHandle) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].node_sptr(), - args[i + 1].node_sptr())); + data.emplace(std::make_pair(args[i].operator ObjectRef(), + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -125,31 +127,33 @@ TVM_REGISTER_API("_Map") TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - CHECK(args[1].type_code() == kNodeHandle); - auto* n = static_cast(sptr.get()); - auto it = n->data.find(args[1].node_sptr()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + CHECK(args[1].type_code() == kObjectHandle); + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); auto it = n->data.find(args[1].operator std::string()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; @@ -159,16 +163,17 @@ TVM_REGISTER_API("_MapGetItem") TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); - CHECK(args[1].type_code() == kNodeHandle); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + CHECK_EQ(args[0].type_code(), kObjectHandle); *ret = static_cast( - n->data.count(args[1].node_sptr())); + n->data.count(args[1].operator ObjectRef())); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast( n->data.count(args[1].operator std::string())); } @@ -176,9 +181,11 @@ TVM_REGISTER_API("_MapCount") TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); @@ -186,10 +193,10 @@ TVM_REGISTER_API("_MapItems") } *ret = rkvs; } else { - auto* n = static_cast(sptr.get()); + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { - rkvs->data.push_back(ir::StringImm::make(kv.first).node_); + rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = rkvs; @@ -426,7 +433,7 @@ TVM_REGISTER_API("_ScheduleCacheRead") TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsNodeType()) { + if (args[1].IsObjectRef()) { *ret = args[0].operator Schedule() .cache_write(args[1].operator Tensor(), args[2]); } else { diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index d2352496c2b44..dd0415afd9eb9 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -35,7 +35,7 @@ namespace ir { TVM_REGISTER_API("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = Simplify(args[0].operator Stmt(), args[1]); } else { @@ -52,7 +52,7 @@ TVM_REGISTER_API("ir_pass.Simplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); } else { @@ -69,7 +69,7 @@ TVM_REGISTER_API("ir_pass.CanonicalSimplify") TVM_REGISTER_API("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); } else { *ret = Substitute(args[0].operator Expr(), args[1].operator Map()); @@ -78,7 +78,7 @@ TVM_REGISTER_API("ir_pass.Substitute") TVM_REGISTER_API("ir_pass.Equal") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 177360bf2ebbf..cf0e0f3c6b7a2 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Implementation of API functions related to schedule pass. * \file api_schedule.cc */ diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 89e999f73edb2..64805c9e8aa00 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -18,36 +18,18 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of DSL API * \file dsl_api.cc */ -#include #include -#include #include #include +#include #include #include -#include -#include "../runtime/dsl_api.h" namespace tvm { namespace runtime { -/*! \brief entry to to easily hold returning information */ -struct TVMAPIThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; - /*! \brief result holder for retruning string */ - std::string ret_str; -}; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMAPIThreadLocalStore; - -using TVMAPINode = NodePtr; struct APIAttrGetter : public AttrVisitor { std::string skey; @@ -138,93 +120,71 @@ struct APIAttrDir : public AttrVisitor { } }; -class DSLAPIImpl : public DSLAPI { - public: - void NodeFree(NodeHandle handle) const final { - delete static_cast(handle); - } - void NodeTypeKey2Index(const char* type_key, - int* out_index) const final { - *out_index = static_cast(Node::TypeKey2Index(type_key)); - } - void NodeGetTypeIndex(NodeHandle handle, - int* out_index) const final { - *out_index = static_cast( - (*static_cast(handle))->type_index()); - } - void NodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* ret_val, - int* ret_type_code, - int* ret_success) const final { - TVMRetValue rv; +struct NodeAPI { + static void GetAttr(TVMArgs args, TVMRetValue* ret) { + NodeRef ref = args[0]; + Node* tnode = const_cast(ref.get()); APIAttrGetter getter; - TVMAPINode* tnode = static_cast(handle); - getter.skey = key; - getter.ret = &rv; + getter.skey = args[1].operator std::string(); + getter.ret = ret; + + bool success; if (getter.skey == "type_key") { - ret_val->v_str = (*tnode)->type_key(); - *ret_type_code = kStr; - *ret_success = 1; - return; - } else if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&getter); - *ret_success = getter.found_ref_object || rv.type_code() != kNull; + *ret = tnode->GetTypeKey(); + success = true; + } else if (!tnode->IsInstance()) { + tnode->VisitAttrs(&getter); + success = getter.found_ref_object || ret->type_code() != kNull; } else { // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); - auto it = dnode->dict.find(key); + DictAttrsNode* dnode = static_cast(tnode); + auto it = dnode->dict.find(getter.skey); if (it != dnode->dict.end()) { - *ret_success = 1; - rv = (*it).second; + success = true; + *ret = (*it).second; } else { - *ret_success = 0; + success = false; } } - if (*ret_success) { - if (rv.type_code() == kStr || - rv.type_code() == kTVMType) { - TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); - e->ret_str = rv.operator std::string(); - *ret_type_code = kStr; - ret_val->v_str = e->ret_str.c_str(); - } else { - rv.MoveToCHost(ret_val, ret_type_code); - } + if (!success) { + LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey() + << " object has no attributed " << getter.skey; } } - void NodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array) const final { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - ret->ret_vec_str.clear(); - TVMAPINode* tnode = static_cast(handle); + + static void ListAttrNames(TVMArgs args, TVMRetValue* ret) { + NodeRef ref = args[0]; + Node* tnode = const_cast(ref.get()); + auto names = std::make_shared >(); APIAttrDir dir; - dir.names = &(ret->ret_vec_str); + dir.names = names.get(); - if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&dir); + if (!tnode->IsInstance()) { + tnode->VisitAttrs(&dir); } else { // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); + DictAttrsNode* dnode = static_cast(tnode); for (const auto& kv : dnode->dict) { - ret->ret_vec_str.push_back(kv.first); + names->push_back(kv.first); } } - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); + + *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); } }; -TVM_REGISTER_GLOBAL("dsl_api.singleton") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static DSLAPIImpl impl; - void* ptr = &impl; - *rv = ptr; - }); +TVM_REGISTER_GLOBAL("_NodeGetAttr") +.set_body(NodeAPI::GetAttr); + +TVM_REGISTER_GLOBAL("_NodeListAttrNames") +.set_body(NodeAPI::ListAttrNames); + } // namespace runtime } // namespace tvm diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index acd964935c256..98e25742592d7 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -36,9 +36,7 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const VarExpr& v, const Expr& expr) { - Var var(v.node_); - +void Analyzer::Bind(const VarExpr& var, const Expr& expr) { Expr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); @@ -49,9 +47,8 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) { this->canonical_simplify.Update(var, new_expr); } -void Analyzer::Bind(const VarExpr& v, const Range& range) { +void Analyzer::Bind(const VarExpr& var, const Range& range) { CHECK(range.defined()); - Var var(v.node_); if (is_one(range->extent)) { this->Bind(var, range->min); } else { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d80e4969d5c2e..02e8079c9c7b5 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -629,7 +629,7 @@ Mutate_(const Mul* op, const Expr& self) { } if (const auto* bconst = b.as()) { if (a.as()) { - SumExpr ret(std::move(a.node_)); + SumExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); return std::move(ret); } else { @@ -931,7 +931,7 @@ Mutate_(const Mod* op, const Expr& self) { int64_t new_base = psum->base % cval; if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); } @@ -992,7 +992,7 @@ Mutate_(const FloorMod* op, const Expr& self) { // Simplify the offset constant if necessary. // floormod(x - 5, 3) => floormod(x + 1, 3) int64_t new_base = floormod(psum->base, cval); - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(std::move(a)); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); } else { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index d5c012d302dcb..168486ee0018a 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -39,7 +39,7 @@ ConstIntBound::ConstIntBound( auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - node_ = std::move(node); + data_ = std::move(node); } inline void PrintBoundValue(std::ostream& os, int64_t val) { diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 3c5f12a7379e4..7da020efc42ad 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -176,7 +176,7 @@ bool DetectClipBound( if (const Variable* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { - var = Var(n.node_); + var = Downcast(n); flag = 1; } else if (flag == 1) { if (!var.same_as(n)) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 0e24714daf1f8..313b34ded034d 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -40,7 +40,7 @@ IntervalSet::IntervalSet(Expr min_value, Expr max_value) { auto node = make_node(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); - node_ = std::move(node); + data_ = std::move(node); } IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { @@ -506,7 +506,7 @@ class IntervalSetEvaluator : } IntervalSet VisitExprDefault_(const Node* op) final { - DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); } diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 04e166ae52c0e..cda9d585ace1b 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -87,7 +87,7 @@ Stmt IRMutatorWithAnalyzer:: Mutate_(const AttrStmt* op, const Stmt& s) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index 71eea50e4c726..918f2e89501fe 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -47,7 +47,7 @@ class IRVisitorWithAnalyzer final : public IRVisitor { void Visit_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 08454dd0ef5ac..9e363e7cf99a4 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -41,7 +41,7 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { node->coeff = coeff; node->base = base; // finish construction. - node_ = std::move(node); + data_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3f1c32243a232..66340e9c90219 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -34,6 +34,7 @@ namespace tvm { TVM_REGISTER_NODE_TYPE(TargetNode); +TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const TargetNode *op, IRPrinter *p) { @@ -51,9 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto target = Target(make_node()); - auto t = static_cast(target.node_.get()); - + auto t = make_node(); t->target_name = target_name; std::string libs_flag = "-libs="; @@ -137,7 +136,7 @@ Target CreateTarget(const std::string& target_name, return target::stackvm(); } - return target; + return Target(t); } TVM_REGISTER_API("_TargetCreate") @@ -674,7 +673,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); struct GenericFunc::Manager { - std::unordered_map > fmap; + std::unordered_map fmap; // mutex std::mutex mutex; @@ -694,10 +693,11 @@ GenericFunc GenericFunc::Get(const std::string& name) { if (it == m->fmap.end()) { auto f = make_node(); f->name_ = name; - m->fmap[name] = f; - return GenericFunc(f); + auto gf = GenericFunc(f); + m->fmap[name] = gf; + return gf; } else { - return GenericFunc(it->second); + return it->second; } } @@ -707,12 +707,12 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) auto it = m->fmap.find(name); CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name; func->name_ = name; - m->fmap[name] = func.node_; + m->fmap[name] = func; } GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { - auto node = static_cast(node_.get()); + bool allow_override) { + auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) << "Generic function already registered for " << node->name_; @@ -736,7 +736,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, } void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { - auto node = static_cast(node_.get()); + auto node = static_cast(get()); auto target = Target::Current(true); PackedFunc func; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index ecf62ab0cfac4..ab203f2aa28a0 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -806,7 +806,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { void CodeGenC::VisitStmt_(const AttrStmt* op) { if (op->attr_key == ir::attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { BindThreadIndex(iv); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index d009290bb2fe5..de54e242ff401 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1173,7 +1173,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54616adc214ef..778b6b1a7811a 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == ir::attr::realize_scope) { - auto v = FunctionRef(op->node.node_); + auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); } else { @@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { std::string CodeGenHybrid::GetVarID(const Variable *v) { if (binds_.count(v)) return binds_[v]; - auto key = std::make_pair(v->GetNodePtr().get(), 0); + auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 498838fc908f0..866756996f8d9 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file codegen_hybrid.h * \brief Common utilities to generated C style code. */ diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 995dfb392e872..b9391e4895b9e 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -44,17 +44,17 @@ class AttrFunctor; #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ // A functor for common attribute information. template -class AttrFunctor { +class AttrFunctor { private: - using TSelf = AttrFunctor; - using FType = tvm::IRFunctor; + using TSelf = AttrFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -65,7 +65,7 @@ class AttrFunctor { * \param args Additional arguments. * \return The result of the call */ - virtual R VisitAttr(const NodeRef& n, Args... args) { + virtual R VisitAttr(const ObjectRef& n, Args... args) { static FType vtable = InitVTable(); if (vtable.can_dispatch(n)) { return vtable(n, this, std::forward(args)...); @@ -73,7 +73,7 @@ class AttrFunctor { return VisitAttrDefault_(n.get(), std::forward(args)...); } } - virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; + virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -143,60 +143,60 @@ class AttrFunctor { }; class AttrsEqualHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Check if lhs equals rhs * \param lhs The left operand. * \param rhs The right operand. */ - bool Equal(const NodeRef& lhs, const NodeRef& rhs); + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs); protected: - bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final; - bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final; + bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::IntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::EQ* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::NE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final; }; class AttrsHashHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Get hash value of node * \param node The node to be hashed. */ - size_t Hash(const NodeRef& node) { + size_t Hash(const ObjectRef& node) { if (!node.defined()) return 0; return this->VisitAttr(node); } protected: - size_t VisitAttrDefault_(const Node* lhs) final; + size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImm* lhs) final; size_t VisitAttr_(const ir::UIntImm* lhs) final; size_t VisitAttr_(const ir::FloatImm* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index c5b14ac577ec4..a299e17996e08 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -40,7 +40,7 @@ void DictAttrsNode::InitByPackedArgs( for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; - if (val.type_code() == kNodeHandle) { + if (val.type_code() == kObjectHandle) { dict.Set(key, val.operator NodeRef()); } else if (val.type_code() == kStr) { dict.Set(key, Expr(val.operator std::string())); @@ -72,14 +72,14 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); using namespace ir; // Equal handler. -bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) { +bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; return this->VisitAttr(lhs, rhs); } -bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) { - if (lhs->derived_from()) { +bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) { + if (lhs->IsInstance()) { AttrsEqual equal; equal.handler_ = this; return static_cast(lhs)->ContentEqual( @@ -88,58 +88,58 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) return lhs == other.get(); } -bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; + if (!Equal(lhs->data[i], rhs->data[i])) return false; } } return true; } -bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); if (it == rhs->data.end()) return false; - if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false; + if (!Equal(kv.second, it->second)) return false; } } return true; } #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ - bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \ + bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \ if (const auto* rhs = other.as()) { \ if (!Equal(lhs->a, rhs->a)) return false; \ if (!Equal(lhs->b, rhs->b)) return false; \ @@ -167,7 +167,7 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(NE); TVM_DEFINE_ATTRS_BINOP_EQUAL(And); TVM_DEFINE_ATTRS_BINOP_EQUAL(Or); -bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return Equal(lhs->a, rhs->a); } else { @@ -175,7 +175,7 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (lhs->type != rhs->type) return false; return Equal(lhs->value, rhs->value); @@ -184,7 +184,7 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->name == rhs->name && @@ -196,7 +196,7 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) { if (const auto* rhs = other.as