From 6069489c038a5a3d1ead125656ea4ac6e7bf8b73 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 23 Dec 2019 20:24:24 -0800 Subject: [PATCH] [REFACTOR][RUNTIME] Move NDArray to Object System. Previously NDArray has its own object reference counting mechanism. This PR migrates NDArray to the unified object protocol. The calling convention of NDArray remained intact. That means NDArray still has its own type_code and its handle is still DLTensor compatible. In order to do so, this PR added a few minimum runtime type detection in TVMArgValue and RetValue only when the corresponding type is a base type(ObjectRef) that could also refer to NDArray. This means that even if we return a base reference object ObjectRef which refers to the NDArray. The type_code will still be translated correctly as kNDArrayContainer. If we assign a non-base type(say Expr) that we know is not compatible with NDArray during compile time, no runtime type detection will be performed. This PR also adopts the object protocol for NDArray sub-classing and removed the legacy NDArray subclass protocol. Examples in apps/extension are now updated to reflect that. Making NDArray as an Object brings all the benefits of the object system. For example, we can now use the Array container to store NDArrays. --- apps/extension/python/tvm_ext/__init__.py | 15 +- apps/extension/src/tvm_ext.cc | 71 ++-- apps/extension/tests/test_ext.py | 15 +- include/tvm/node/container.h | 4 +- include/tvm/packed_func_ext.h | 98 +---- include/tvm/runtime/container.h | 1 + include/tvm/runtime/ndarray.h | 283 ++++++--------- include/tvm/runtime/object.h | 44 ++- include/tvm/runtime/packed_func.h | 342 ++++++++++-------- include/tvm/runtime/registry.h | 15 +- python/setup.py | 1 + python/tvm/_ffi/_ctypes/ndarray.py | 23 +- python/tvm/_ffi/_ctypes/object.py | 8 +- python/tvm/_ffi/_cython/base.pxi | 13 +- python/tvm/_ffi/_cython/ndarray.pxi | 46 ++- python/tvm/_ffi/_cython/object.pxi | 17 +- python/tvm/_ffi/function.py | 8 + python/tvm/_ffi/ndarray.py | 19 +- python/tvm/_ffi/node_generic.py | 12 +- python/tvm/_ffi/runtime_ctypes.py | 9 - python/tvm/ndarray.py | 3 + src/api/api_lang.cc | 18 +- src/runtime/ndarray.cc | 101 ++++-- src/runtime/rpc/rpc_module.cc | 7 +- src/runtime/rpc/rpc_session.cc | 7 +- src/runtime/vm/memory_manager.cc | 27 +- src/runtime/vm/memory_manager.h | 2 +- tests/cpp/packed_func_test.cc | 64 ++++ tests/python/unittest/test_lang_container.py | 10 + .../unittest/test_runtime_module_export.py | 4 +- tests/scripts/task_python_integration.sh | 3 +- 31 files changed, 675 insertions(+), 615 deletions(-) diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 7404a717f7788..31b149eb4913f 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -51,26 +51,23 @@ def __getitem__(self, idx): nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") -nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info") +nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info") +@tvm.register_object("tvm_ext.NDSubClass") class NDSubClass(tvm.nd.NDArrayBase): """Example for subclassing TVM's NDArray infrastructure. By inheriting TMV's NDArray, external libraries could leverage TVM's FFI without any modification. """ - # Should be consistent with the type-trait set in the backend - _array_type_code = 1 @staticmethod - def create(addtional_info): - return nd_create(addtional_info) + def create(additional_info): + return nd_create(additional_info) @property - def addtional_info(self): - return nd_get_addtional_info(self) + def additional_info(self): + return nd_get_additional_info(self) def __add__(self, other): return nd_add_two(self, other) - -tvm.register_extension(NDSubClass, NDSubClass) diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 788c28da18d34..a68c8e8c53249 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -29,19 +29,6 @@ #include #include -namespace tvm_ext { -class NDSubClass; -} // namespace tvm_ext - -namespace tvm { -namespace runtime { -template<> -struct array_type_info { - static const int code = 1; -}; -} // namespace tvm -} // namespace runtime - using namespace tvm; using namespace tvm::runtime; @@ -65,41 +52,45 @@ class NDSubClass : public tvm::runtime::NDArray { public: class SubContainer : public NDArray::Container { public: - SubContainer(int addtional_info) : - addtional_info_(addtional_info) { - array_type_code_ = array_type_info::code; - } - static bool Is(NDArray::Container *container) { - SubContainer *c = static_cast(container); - return c->array_type_code_ == array_type_info::code; + SubContainer(int additional_info) : + additional_info_(additional_info) { + type_index_ = SubContainer::RuntimeTypeIndex(); } - int addtional_info_{0}; + int additional_info_{0}; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "tvm_ext.NDSubClass"; + TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container); }; - NDSubClass(NDArray::Container *container) { - if (container == nullptr) { - data_ = nullptr; - return; - } - CHECK(SubContainer::Is(container)); - container->IncRef(); - data_ = container; + + static void SubContainerDeleter(Object* obj) { + auto* ptr = static_cast(obj); + delete ptr; } - ~NDSubClass() { - this->reset(); + + NDSubClass() {} + explicit NDSubClass(ObjectPtr n) : NDArray(n) {} + explicit NDSubClass(int additional_info) { + SubContainer* ptr = new SubContainer(additional_info); + ptr->SetDeleter(SubContainerDeleter); + data_ = GetObjectPtr(ptr); } + NDSubClass AddWith(const NDSubClass &other) const { - SubContainer *a = static_cast(data_); - SubContainer *b = static_cast(other.data_); + SubContainer *a = static_cast(get_mutable()); + SubContainer *b = static_cast(other.get_mutable()); CHECK(a != nullptr && b != nullptr); - return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_)); + return NDSubClass(a->additional_info_ + b->additional_info_); } int get_additional_info() const { - SubContainer *self = static_cast(data_); + SubContainer *self = static_cast(get_mutable()); CHECK(self != nullptr); - return self->addtional_info_; + return self->additional_info_; } + using ContainerType = SubContainer; }; +TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer); /*! * \brief Introduce additional extension data structures @@ -166,8 +157,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") TVM_REGISTER_GLOBAL("tvm_ext.nd_create") .set_body([](TVMArgs args, TVMRetValue *rv) { - int addtional_info = args[0]; - *rv = NDSubClass(new NDSubClass::SubContainer(addtional_info)); + int additional_info = args[0]; + *rv = NDSubClass(additional_info); + CHECK_EQ(rv->type_code(), kNDArrayContainer); + }); TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") @@ -177,7 +170,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") *rv = a.AddWith(b); }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info") +TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") .set_body([](TVMArgs args, TVMRetValue *rv) { NDSubClass a = args[0]; *rv = a.get_additional_info(); diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index e481e82fefb37..a5e7e0f694561 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -87,16 +87,17 @@ def check_llvm(): def test_nd_subclass(): - a = tvm_ext.NDSubClass.create(addtional_info=3) - b = tvm_ext.NDSubClass.create(addtional_info=5) + a = tvm_ext.NDSubClass.create(additional_info=3) + b = tvm_ext.NDSubClass.create(additional_info=5) + assert isinstance(a, tvm_ext.NDSubClass) c = a + b d = a + a e = b + b - assert(a.addtional_info == 3) - assert(b.addtional_info == 5) - assert(c.addtional_info == 8) - assert(d.addtional_info == 6) - assert(e.addtional_info == 10) + assert(a.additional_info == 3) + assert(b.additional_info == 5) + assert(c.additional_info == 8) + assert(d.additional_info == 6) + assert(e.additional_info == 10) if __name__ == "__main__": diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 41b47d3a679e0..1a276ae695fc2 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,14 +23,14 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ +#include + #include #include #include #include #include #include -#include "node.h" -#include "memory.h" namespace tvm { diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 93b7ac33f155b..c9f7a580621f6 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -25,7 +25,6 @@ #ifndef TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_ -#include #include #include #include @@ -43,22 +42,7 @@ using runtime::TVMRetValue; using runtime::PackedFunc; namespace runtime { -/*! - * \brief Runtime type checker for node type. - * \tparam T the type to be checked. - */ -template -struct ObjectTypeChecker { - static bool Check(const Object* ptr) { - using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return true; - return ptr->IsInstance(); - } - static void PrintName(std::ostream& os) { // NOLINT(*) - using ContainerType = typename T::ContainerType; - os << ContainerType::_type_key; - } -}; + template struct ObjectTypeChecker > { @@ -73,10 +57,8 @@ struct ObjectTypeChecker > { } return true; } - static void PrintName(std::ostream& os) { // NOLINT(*) - os << "List["; - ObjectTypeChecker::PrintName(os); - os << "]"; + static std::string TypeName() { + return "List[" + ObjectTypeChecker::TypeName() + "]"; } }; @@ -91,11 +73,9 @@ struct ObjectTypeChecker > { } return true; } - static void PrintName(std::ostream& os) { // NOLINT(*) - os << "Map[str"; - os << ','; - ObjectTypeChecker::PrintName(os); - os << ']'; + static std::string TypeName() { + return "Map[str, " + + ObjectTypeChecker::TypeName()+ ']'; } }; @@ -111,39 +91,16 @@ struct ObjectTypeChecker > { } return true; } - static void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "Map["; - ObjectTypeChecker::PrintName(os); - os << ','; - ObjectTypeChecker::PrintName(os); - os << ']'; + static std::string TypeName() { + return "Map[" + + ObjectTypeChecker::TypeName() + + ", " + + ObjectTypeChecker::TypeName()+ ']'; } }; -template -inline std::string ObjectTypeName() { - std::ostringstream os; - ObjectTypeChecker::PrintName(os); - return os.str(); -} - // extensions for tvm arg value - -template -inline TObjectRef TVMArgValue::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); - if (type_code_ == kNull) return TObjectRef(NodePtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() - << " but get " << ptr->GetTypeKey(); - return TObjectRef(ObjectPtr(ptr)); -} - -inline TVMArgValue::operator tvm::Expr() const { +inline TVMPODValue_::operator tvm::Expr() const { if (type_code_ == kNull) return Expr(); if (type_code_ == kDLInt) { CHECK_LE(value_.v_int64, std::numeric_limits::max()); @@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const { return Tensor(ObjectPtr(ptr))(); } CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return Expr(ObjectPtr(ptr)); } -inline TVMArgValue::operator tvm::Integer() const { +inline TVMPODValue_::operator tvm::Integer() const { if (type_code_ == kNull) return Integer(); if (type_code_ == kDLInt) { CHECK_LE(value_.v_int64, std::numeric_limits::max()); @@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const { TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); Object* ptr = static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return Integer(ObjectPtr(ptr)); } - -template -inline bool TVMPODValue_::IsObjectRef() const { - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - Object* ptr = static_cast(value_.v_handle); - return ObjectTypeChecker::Check(ptr); -} - -// extensions for TVMRetValue -template -inline TObjectRef TVMRetValue::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); - if (type_code_ == kNull) return TObjectRef(); - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - - Object* ptr = static_cast(value_.v_handle); - - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() - << " but get " << ptr->GetTypeKey(); - return TObjectRef(ObjectPtr(ptr)); -} - } // namespace runtime } // namespace tvm #endif // TVM_PACKED_FUNC_EXT_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index dbe827812fc3d..4dc07f4a3a045 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -23,6 +23,7 @@ */ #ifndef TVM_RUNTIME_CONTAINER_H_ #define TVM_RUNTIME_CONTAINER_H_ + #include #include #include diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 9932951798423..1bc49fab6fccb 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_NDARRAY_H_ #define TVM_RUNTIME_NDARRAY_H_ +#include +#include +#include + #include #include #include -#include "c_runtime_api.h" -#include "serializer.h" namespace tvm { namespace runtime { @@ -37,72 +39,23 @@ namespace runtime { * \brief Managed NDArray. * The array is backed by reference counted blocks. */ -class NDArray { +class NDArray : public ObjectRef { public: - // internal container type + /*! \brief ContainerBase used to back the TVMArrayHandle */ + class ContainerBase; + /*! \brief NDArray internal container type */ class Container; + /*! \brief Container type for Object system. */ + using ContainerType = Container; /*! \brief default constructor */ NDArray() {} /*! - * \brief cosntruct a NDArray that refers to data - * \param data The data this NDArray refers to - */ - explicit inline NDArray(Container* data); - /*! - * \brief copy constructor. - * - * It does not make a copy, but the reference count of the input NDArray is incremented - * - * \param other NDArray that shares internal data with the input NDArray. - */ - inline NDArray(const NDArray& other); // NOLINT(*) - /*! - * \brief move constructor - * \param other The value to be moved - */ - NDArray(NDArray&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! \brief destructor */ - ~NDArray() { - this->reset(); - } - /*! - * \brief Swap this array with another NDArray - * \param other The other NDArray + * \brief constructor. + * \param data ObjectPtr to the data container. */ - void swap(NDArray& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \brief copy assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NDArray& operator=(const NDArray& other) { // NOLINT(*) - // copy-and-swap idiom - NDArray(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NDArray& operator=(NDArray&& other) { // NOLINT(*) - // copy-and-swap idiom - NDArray(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! \return If NDArray is defined */ - bool defined() const { - return data_ != nullptr; - } - /*! \return If both NDArray reference the same container */ - bool same_as(const NDArray& other) const { - return data_ == other.data_; - } + explicit NDArray(ObjectPtr data) + : ObjectRef(data) {} + /*! \brief reset the content of NDArray to be nullptr */ inline void reset(); /*! @@ -191,36 +144,40 @@ class NDArray { DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; - // internal namespace struct Internal; + protected: - /*! \brief Internal Data content */ - Container* data_{nullptr}; - // enable internal functions - friend struct Internal; friend class TVMPODValue_; - friend class TVMArgValue; friend class TVMRetValue; friend class TVMArgsSetter; -}; - -/*! - * \brief The type trait indicates subclass of TVM's NDArray. - * For irrelavant classes, code = -1. - * For TVM NDArray itself, code = 0. - * All subclasses of NDArray should override code > 0. - */ -template -struct array_type_info { - /*! \brief the value of the traits */ - static const int code = -1; -}; - -// Overrides the type trait for tvm's NDArray. -template<> -struct array_type_info { - static const int code = 0; + /*! + * \brief Get mutable internal container pointer. + * \return a mutable container pointer. + */ + inline Container* get_mutable() const; + // Helper functions for FFI handling. + /*! + * \brief Construct NDArray's Data field from array handle in FFI. + * \param handle The array handle. + * \return The constructed NDArray. + * + * \note We keep a special calling convention for NDArray by passing + * ContainerBase pointer in FFI. + * As a result, the argument is compatible to DLTensor*. + */ + inline static ObjectPtr FFIDataFromHandle(TVMArrayHandle handle); + /*! + * \brief DecRef resource managed by an FFI array handle. + * \param handle The array handle. + */ + inline static void FFIDecRef(TVMArrayHandle handle); + /*! + * \brief Get FFI Array handle from ndarray. + * \param nd The object with ndarray type. + * \return The result array handle. + */ + inline static TVMArrayHandle FFIGetHandle(const ObjectRef& nd); }; /*! @@ -231,19 +188,14 @@ struct array_type_info { inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); /*! - * \brief Reference counted Container object used to back NDArray. + * \brief The container base structure + * contains all the fields except for the Object header. * - * This object is DLTensor compatible: - * the pointer to the NDArrayContainer can be directly - * interpreted as a DLTensor* - * - * \note do not use this function directly, use NDArray. + * \note We explicitly declare this structure in order to pass + * PackedFunc argument using ContainerBase*. */ -class NDArray::Container { +class NDArray::ContainerBase { public: - // NOTE: the first part of this structure is the same as - // DLManagedTensor, note that, however, the deleter - // is only called when the reference counter goes to 0 /*! * \brief The corresponding dl_tensor field. * \note it is important that the first field is DLTensor @@ -259,42 +211,27 @@ class NDArray::Container { * (e.g. reference to original memory when creating views). */ void* manager_ctx{nullptr}; - /*! - * \brief Customized deleter - * - * \note The customized deleter is helpful to enable - * different ways of memory allocator that are not - * currently defined by the system. - */ - void (*deleter)(Container* self) = nullptr; protected: - friend class NDArray; - friend class TVMPODValue_; - friend class TVMArgValue; - friend class TVMRetValue; - friend class RPCWrappedFunc; - /*! - * \brief Type flag used to indicate subclass. - * Default value 0 means normal NDArray::Conatainer. - * - * We can extend a more specialized NDArray::Container - * and use the array_type_code_ to indicate - * the specific array subclass. - */ - int32_t array_type_code_{0}; - /*! \brief The internal reference counter */ - std::atomic ref_counter_{0}; - /*! * \brief The shape container, * can be used used for shape data. */ std::vector shape_; +}; +/*! + * \brief Object container class taht backs NDArray. + * \note do not use this function directly, use NDArray. + */ +class NDArray::Container : + public Object, + public NDArray::ContainerBase { public: /*! \brief default constructor */ Container() { + // Initialize the type index. + type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = nullptr; dl_tensor.ndim = 0; dl_tensor.shape = nullptr; @@ -306,6 +243,8 @@ class NDArray::Container { std::vector shape, DLDataType dtype, DLContext ctx) { + // Initialize the type index. + type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; shape_ = std::move(shape); dl_tensor.ndim = static_cast(shape_.size()); @@ -315,49 +254,38 @@ class NDArray::Container { dl_tensor.byte_offset = 0; dl_tensor.ctx = ctx; } - - /*! \brief developer function, increases reference counter */ - void IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); - } - /*! \brief developer function, decrease reference counter */ - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter != nullptr) { - (*this->deleter)(this); - } - } + /*! + * \brief Set the deleter field. + * \param deleter The deleter. + */ + void SetDeleter(FDeleter deleter) { + deleter_ = deleter; } -}; -// implementations of inline functions -// the usages of functions are documented in place. -inline NDArray::NDArray(Container* data) - : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } -} + // Expose DecRef and IncRef as public function + // NOTE: they are only for developer purposes only. + using Object::DecRef; + using Object::IncRef; -inline NDArray::NDArray(const NDArray& other) - : data_(other.data_) { - if (data_ != nullptr) { - data_->IncRef(); - } -} + // Information for object protocol. + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_child_slots = 0; + static constexpr const uint32_t _type_child_slots_can_overflow = true; + static constexpr const char* _type_key = "NDArray"; + TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object); -inline void NDArray::reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } -} + protected: + friend class RPCWrappedFunc; + friend class NDArray; +}; + +// implementations of inline functions +// the usages of functions are documented in place.a -/*! \brief return the size of data the DLTensor hold, in term of number of bytes +/*! + * \brief return the size of data the DLTensor hold, in term of number of bytes * * \param arr the input DLTensor - * * \return number of bytes of data in the DLTensor. */ inline size_t GetDataSize(const DLTensor& arr) { @@ -371,24 +299,24 @@ inline size_t GetDataSize(const DLTensor& arr) { inline void NDArray::CopyFrom(DLTensor* other) { CHECK(data_ != nullptr); - CopyFromTo(other, &(data_->dl_tensor)); + CopyFromTo(other, &(get_mutable()->dl_tensor)); } inline void NDArray::CopyFrom(const NDArray& other) { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); - CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor)); + CopyFromTo(&(other.get_mutable()->dl_tensor), &(get_mutable()->dl_tensor)); } inline void NDArray::CopyTo(DLTensor* other) const { CHECK(data_ != nullptr); - CopyFromTo(&(data_->dl_tensor), other); + CopyFromTo(&(get_mutable()->dl_tensor), other); } inline void NDArray::CopyTo(const NDArray& other) const { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); - CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); + CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor)); } inline NDArray NDArray::CopyTo(const DLContext& ctx) const { @@ -401,12 +329,39 @@ inline NDArray NDArray::CopyTo(const DLContext& ctx) const { } inline int NDArray::use_count() const { - if (data_ == nullptr) return 0; - return data_->ref_counter_.load(std::memory_order_relaxed); + return data_.use_count(); } inline const DLTensor* NDArray::operator->() const { - return &(data_->dl_tensor); + return &(get_mutable()->dl_tensor); +} + +inline NDArray::Container* NDArray::get_mutable() const { + return static_cast(data_.get()); +} + +inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { + return GetObjectPtr(static_cast( + reinterpret_cast(handle))); +} + +inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { + // NOTE: it is necessary to cast to container then to base + // so that the FFI handle uses the ContainerBase address. + return reinterpret_cast( + static_cast( + static_cast( + const_cast(nd.get())))); +} + +inline void NDArray::FFIDecRef(TVMArrayHandle handle) { + static_cast( + reinterpret_cast(handle))->DecRef(); +} + +inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) { + return static_cast( + reinterpret_cast(handle)); } /*! \brief Magic number for NDArray file */ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 20e6b5a0fb632..96215daf4a7ac 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -24,10 +24,11 @@ #define TVM_RUNTIME_OBJECT_H_ #include +#include #include #include #include -#include "c_runtime_api.h" + /*! * \brief Whether or not use atomic reference counter. @@ -580,6 +581,14 @@ class ObjectRef { static T DowncastNoCheck(ObjectRef ref) { return T(std::move(ref.data_)); } + /*! + * \brief Clear the object ref data field without DecRef + * after we successfully moved the field. + * \param ref The reference data. + */ + static void FFIClearAfterMove(ObjectRef* ref) { + ref->data_.data_ = nullptr; + } /*! * \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \note only used for internal dev purpose. @@ -648,7 +657,7 @@ struct ObjectEqual { return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ + static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ TypeName::_type_key, \ TypeName::_type_index, \ ParentType::_GetOrAllocRuntimeTypeIndex(), \ @@ -668,6 +677,19 @@ struct ObjectEqual { TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ +/*! \brief helper macro to supress unused warning */ +#if defined(__GNUC__) +#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_ATTRIBUTE_UNUSED +#endif + +#define TVM_STR_CONCAT_(__x, __y) __x##__y +#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) + +#define TVM_OBJECT_REG_VAR_DEF \ + static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid + /*! * \brief Helper macro to register the object type to runtime. * Makes sure that the runtime type table is correctly populated. @@ -675,7 +697,7 @@ struct ObjectEqual { * Use this macro in the cc file for each terminal class. */ #define TVM_REGISTER_OBJECT_TYPE(TypeName) \ - static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ + TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() @@ -691,14 +713,14 @@ struct ObjectEqual { using ContainerType = ObjectName; #define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ - TypeName() {} \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - ObjectName* operator->() { \ - return static_cast(data_.get()); \ - } \ - operator bool() const { return data_ != nullptr; } \ + TypeName() {} \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + ObjectName* operator->() { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; // Implementations details below diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 27dcb4130b4c2..8b9acbe7a44a4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -387,20 +387,22 @@ inline std::string TVMType2String(TVMType t); #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ + /*! - * \brief Type traits to mark if a class is tvm extension type. - * - * To enable extension type in C++ must be registered via marco. - * TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits. - * - * Extension class can be passed and returned via PackedFunc in all tvm runtime. - * Internally extension class is stored as T*. - * - * \tparam T the typename + * \brief Type traits for runtime type check during FFI conversion. + * \tparam T the type to be checked. */ template -struct extension_type_info { - static const int code = 0; +struct ObjectTypeChecker { + static bool Check(const Object* ptr) { + using ContainerType = typename T::ContainerType; + if (ptr == nullptr) return true; + return ptr->IsInstance(); + } + static std::string TypeName() { + using ContainerType = typename T::ContainerType; + return ContainerType::_type_key; + } }; /*! @@ -449,24 +451,17 @@ class TVMPODValue_ { return static_cast(value_.v_handle); } else { if (type_code_ == kNull) return nullptr; - LOG(FATAL) << "Expected " + LOG(FATAL) << "Expect " << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_); return nullptr; } } operator NDArray() const { - if (type_code_ == kNull) return NDArray(); + if (type_code_ == kNull) return NDArray(ObjectPtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); - return NDArray(static_cast(value_.v_handle)); - } - operator ObjectRef() const { - if (type_code_ == kNull) { - return ObjectRef(ObjectPtr(nullptr)); - } - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - return ObjectRef( - ObjectPtr(static_cast(value_.v_handle))); + return NDArray(NDArray::FFIDataFromHandle( + static_cast(value_.v_handle))); } operator Module() const { if (type_code_ == kNull) { @@ -480,23 +475,9 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } - template::value>::type> - TNDArray AsNDArray() const { - if (type_code_ == kNull) return TNDArray(nullptr); - auto *container = static_cast(value_.v_handle); - CHECK_EQ(container->array_type_code_, array_type_info::code); - return TNDArray(container); - } - template::value>::type> - inline bool IsObjectRef() const; int type_code() const { return type_code_; } - /*! * \brief return handle as specific pointer type. * \tparam T the data type. @@ -506,6 +487,16 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } + // ObjectRef handling + template::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; + // ObjectRef Specializations + inline operator tvm::Expr() const; + inline operator tvm::Integer() const; protected: friend class TVMArgsSetter; @@ -548,9 +539,11 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; - using TVMPODValue_::operator ObjectRef; using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; + using TVMPODValue_::AsObjectRef; + using TVMPODValue_::operator tvm::Expr; + using TVMPODValue_::operator tvm::Integer; // conversion operator. operator std::string() const { @@ -577,6 +570,9 @@ class TVMArgValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMType); return value_.v_type; } + operator DataType() const { + return DataType(operator DLDataType()); + } operator PackedFunc() const { if (type_code_ == kNull) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); @@ -589,16 +585,10 @@ class TVMArgValue : public TVMPODValue_ { const TVMValue& value() const { return value_; } - // Deferred extension handler. - template - inline TObjectRef AsObjectRef() const; template::value>::type> + std::is_class::value>::type> inline operator T() const; - inline operator DataType() const; - inline operator tvm::Expr() const; - inline operator tvm::Integer() const; }; /*! @@ -636,9 +626,11 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; - using TVMPODValue_::operator ObjectRef; using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; + using TVMPODValue_::AsObjectRef; + using TVMPODValue_::operator tvm::Expr; + using TVMPODValue_::operator tvm::Integer; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); @@ -660,6 +652,9 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMType); return value_.v_type; } + operator DataType() const { + return DataType(operator DLDataType()); + } operator PackedFunc() const { if (type_code_ == kNull) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); @@ -712,6 +707,9 @@ class TVMRetValue : public TVMPODValue_ { value_.v_type = t; return *this; } + TVMRetValue& operator=(const DataType& other) { + return operator=(other.operator DLDataType()); + } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; @@ -726,24 +724,20 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(NDArray other) { - this->Clear(); - type_code_ = kNDArrayContainer; - value_.v_handle = other.data_; - other.data_ = nullptr; + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = kNDArrayContainer; + value_.v_handle = NDArray::FFIGetHandle(other); + ObjectRef::FFIClearAfterMove(&other); + } else { + SwitchToPOD(kNull); + } return *this; } - TVMRetValue& operator=(ObjectRef other) { - return operator=(std::move(other.data_)); - } TVMRetValue& operator=(Module m) { SwitchToObject(kModuleHandle, std::move(m.data_)); return *this; } - template - TVMRetValue& operator=(ObjectPtr other) { - SwitchToObject(kObjectHandle, std::move(other)); - return *this; - } TVMRetValue& operator=(PackedFunc f) { this->SwitchToClass(kFuncHandle, f); return *this; @@ -760,14 +754,6 @@ class TVMRetValue : public TVMPODValue_ { this->Assign(other); return *this; } - template::code != 0>::type> - TVMRetValue& operator=(const T& other) { - this->SwitchToClass( - extension_type_info::code, other); - return *this; - } /*! * \brief Move the value back to front-end via C API. * This marks the current container as null. @@ -793,16 +779,15 @@ class TVMRetValue : public TVMPODValue_ { type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; return value_; } - // ObjectRef related extenstions: in tvm/packed_func_ext.h + // ObjectRef handling + template::value>::type> + inline TVMRetValue& operator=(TObjectRef other); template::value>::type> inline operator T() const; - template - inline TObjectRef AsObjectRef() const; - // type related - inline operator DataType() const; - inline TVMRetValue& operator=(const DataType& other); private: template @@ -829,7 +814,10 @@ class TVMRetValue : public TVMPODValue_ { break; } case kObjectHandle: { - *this = other.operator ObjectRef(); + // Avoid operator ObjectRef as we already know it is not NDArray/Module + SwitchToObject( + kObjectHandle, GetObjectPtr( + static_cast(other.value_.v_handle))); break; } default: { @@ -873,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; case kNDArrayContainer: { - static_cast(value_.v_handle)->DecRef(); + NDArray::FFIDecRef(static_cast(value_.v_handle)); break; } case kModuleHandle: { @@ -905,7 +893,7 @@ inline const char* TypeCode2Str(int type_code) { case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; - case kObjectHandle: return "ObjectCell"; + case kObjectHandle: return "Object"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } @@ -929,6 +917,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) return os; } +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) + return os << dtype.operator DLDataType(); +} + #endif inline std::string TVMType2String(TVMType t) { @@ -996,10 +988,6 @@ inline TVMType String2TVMType(std::string s) { return t; } -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) - return os << dtype.operator DLDataType(); -} - inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) << "not enough argument passed, " @@ -1092,50 +1080,31 @@ class TVMArgsSetter { values_[i].v_type = value; type_codes_[i] = kTVMType; } + void operator()(size_t i, DataType dtype) const { + operator()(i, dtype.operator DLDataType()); + } void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kStr; } - // setters for container type - // They must be reference(instead of const ref) - // to make sure they are alive in the tuple(instead of getting converted) - void operator()(size_t i, const std::string& value) const { // NOLINT(*) + // setters for container types + void operator()(size_t i, const std::string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kStr; } - void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*) + void operator()(size_t i, const TVMByteArray& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kBytes; } - void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) + void operator()(size_t i, const PackedFunc& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kFuncHandle; } template - void operator()(size_t i, const TypedPackedFunc& value) const { // NOLINT(*) + void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } - void operator()(size_t i, const Module& value) const { // NOLINT(*) - if (value.defined()) { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kModuleHandle; - } else { - type_codes_[i] = kNull; - } - } - void operator()(size_t i, const NDArray& value) const { // NOLINT(*) - values_[i].v_handle = value.data_; - type_codes_[i] = kNDArrayContainer; - } - void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) - if (value.defined()) { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectHandle; - } else { - type_codes_[i] = kNull; - } - } - void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) + void operator()(size_t i, const TVMRetValue& value) const { if (value.type_code() == kStr) { values_[i].v_str = value.ptr()->c_str(); type_codes_[i] = kStr; @@ -1145,12 +1114,11 @@ class TVMArgsSetter { type_codes_[i] = value.type_code(); } } - // extension - template::code != 0>::type> - inline void operator()(size_t i, const T& value) const; - inline void operator()(size_t i, const DataType& t) const; + std::is_base_of::value>::type> + inline void operator()(size_t i, const TObjectRef& value) const; private: /*! \brief The values fields */ @@ -1262,57 +1230,129 @@ inline R TypedPackedFunc::operator()(Args... args) const { ::run(packed_, std::forward(args)...); } -// extension and node type handling -namespace detail { -template -struct TVMValueCast { - static T Apply(const TSrc* self) { - static_assert(!is_nd, "The default case accepts only non-extensions"); - return self->template AsObjectRef(); - } -}; - -template -struct TVMValueCast { - static T Apply(const TSrc* self) { - return self->template AsNDArray(); +// ObjectRef related conversion handling +// Object can have three possible type codes: +// kNDArrayContainer, kModuleHandle, kObjectHandle +// +// We use type traits to eliminate un-necessary checks. +template +inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { + if (value.defined()) { + Object* ptr = value.data_.data_; + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + values_[i].v_handle = NDArray::FFIGetHandle(value); + type_codes_[i] = kNDArrayContainer; + } else if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + values_[i].v_handle = ptr; + type_codes_[i] = kModuleHandle; + } else { + values_[i].v_handle = ptr; + type_codes_[i] = kObjectHandle; + } + } else { + type_codes_[i] = kNull; } -}; - -} // namespace detail - -template -inline TVMArgValue::operator T() const { - return detail:: - TVMValueCast::code > 0)> - ::Apply(this); } -template -inline TVMRetValue::operator T() const { - return detail:: - TVMValueCast::code > 0)> - ::Apply(this); +template +inline bool TVMPODValue_::IsObjectRef() const { + using ContainerType = typename TObjectRef::ContainerType; + // NOTE: the following code can be optimized by constant folding. + if (std::is_base_of::value) { + return type_code_ == kNDArrayContainer && + NDArray::FFIDataFromHandle( + static_cast(value_.v_handle))->IsInstance(); + } + if (std::is_base_of::value) { + return type_code_ == kModuleHandle && + static_cast(value_.v_handle)->IsInstance(); + } + return + (std::is_base_of::value && type_code_ == kNDArrayContainer) || + (std::is_base_of::value && type_code_ == kModuleHandle) || + (type_code_ == kObjectHandle && + ObjectTypeChecker::Check(static_cast(value_.v_handle))); } -// PackedFunc support -inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { - return this->operator=(t.operator DLDataType()); +template +inline TObjectRef TVMPODValue_::AsObjectRef() const { + static_assert( + std::is_base_of::value, + "Conversion only works for ObjectRef"); + using ContainerType = typename TObjectRef::ContainerType; + if (type_code_ == kNull) return TObjectRef(ObjectPtr(nullptr)); + // NOTE: the following code can be optimized by constant folding. + if (std::is_base_of::value && type_code_ == kNDArrayContainer) { + // Casting to a sub-class of NDArray + ObjectPtr data = NDArray::FFIDataFromHandle( + static_cast(value_.v_handle)); + CHECK(data->IsInstance()) + << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); + return TObjectRef(data); + } + if (std::is_base_of::value && type_code_ == kModuleHandle) { + // Casting to a sub-class of Module + ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); + CHECK(data->IsInstance()) + << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); + return TObjectRef(data); + } + if (type_code_ == kObjectHandle) { + // normal object type check. + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return TObjectRef(GetObjectPtr(ptr)); + } else if (std::is_base_of::value && + type_code_ == kNDArrayContainer) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = NDArray::FFIDataFromHandle( + static_cast(value_.v_handle)); + return TObjectRef(data); + } else if (std::is_base_of::value && + type_code_ == kModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } else { + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); + } } -inline TVMRetValue::operator DataType() const { - return DataType(operator DLDataType()); +template +inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { + const Object* ptr = other.get(); + if (ptr != nullptr) { + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(NDArray(std::move(other.data_))); + } + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(Module(std::move(other.data_))); + } + SwitchToObject(kObjectHandle, std::move(other.data_)); + } else { + SwitchToPOD(kNull); + } + return *this; } -inline TVMArgValue::operator DataType() const { - return DataType(operator DLDataType()); +template +inline TVMArgValue::operator T() const { + return AsObjectRef(); } -inline void TVMArgsSetter::operator()( - size_t i, const DataType& t) const { - this->operator()(i, t.operator DLDataType()); +template +inline TVMRetValue::operator T() const { + return AsObjectRef(); } inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 3500a7e4e3982..e51b806ea81ff 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -43,9 +43,9 @@ #ifndef TVM_RUNTIME_REGISTRY_H_ #define TVM_RUNTIME_REGISTRY_H_ +#include #include #include -#include "packed_func.h" namespace tvm { namespace runtime { @@ -283,22 +283,9 @@ class Registry { friend struct Manager; }; -/*! \brief helper macro to supress unused warning */ -#if defined(__GNUC__) -#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define TVM_ATTRIBUTE_UNUSED -#endif - -#define TVM_STR_CONCAT_(__x, __y) __x##__y -#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) - #define TVM_FUNC_REG_VAR_DEF \ static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM -#define TVM_TYPE_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT - /*! * \brief Register a function globally. * \code diff --git a/python/setup.py b/python/setup.py index f8b580cd75529..bc53060f95cf6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,6 +96,7 @@ def config_cython(): "../3rdparty/dmlc-core/include", "../3rdparty/dlpack/include", ], + extra_compile_args=["-std=c++11"], library_dirs=library_dirs, libraries=libraries, language="c++")) diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index af59de6eee1d7..c572947c8d19a 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -20,7 +20,7 @@ import ctypes from ..base import _LIB, check_call, c_str -from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle +from ..runtime_ctypes import TVMArrayHandle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle @@ -110,12 +110,17 @@ def to_dlpack(self): def _make_array(handle, is_view, is_container): global _TVM_ND_CLS handle = ctypes.cast(handle, TVMArrayHandle) - fcreate = _CLASS_NDARRAY - if is_container and _TVM_ND_CLS: - array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value - if array_type_info > 0: - fcreate = _TVM_ND_CLS[array_type_info] - return fcreate(handle, is_view) + if is_container: + tindex = ctypes.c_uint() + check_call(_LIB.TVMArrayGetTypeIndex(handle, ctypes.byref(tindex))) + cls = _TVM_ND_CLS.get(tindex.value, _CLASS_NDARRAY) + else: + cls = _CLASS_NDARRAY + + ret = cls.__new__(cls) + ret.handle = handle + ret.is_view = is_view + return ret _TVM_COMPATS = () @@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate): _TVM_ND_CLS = {} -def _reg_ndarray(cls, fcreate): +def _register_ndarray(index, cls): global _TVM_ND_CLS - _TVM_ND_CLS[cls._array_type_code] = fcreate + _TVM_ND_CLS[index] = cls _CLASS_NDARRAY = None diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index c3ae56822198d..b8b8aefea131c 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -21,7 +21,7 @@ import ctypes from ..base import _LIB, check_call from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func -from ..node_generic import _set_class_node_base +from .ndarray import _register_ndarray, NDArrayBase ObjectHandle = ctypes.c_void_p @@ -39,6 +39,9 @@ def _set_class_node(node_class): def _register_object(index, cls): """register object class""" + if issubclass(cls, NDArrayBase): + _register_ndarray(index, cls) + return OBJECT_TYPE[index] = cls @@ -91,6 +94,3 @@ def __init_handle_by_constructor__(self, fconstructor, *args): if not isinstance(handle, ObjectHandle): handle = ObjectHandle(handle) self.handle = handle - - -_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 4b7b2c88ffa50..7ccb6279fed00 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -19,7 +19,7 @@ from ..base import get_last_ffi_error from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule -from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t +from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t import ctypes cdef enum TVMTypeCode: @@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* ObjectHandle +ctypedef struct TVMObject: + uint32_t type_index_ + int32_t ref_counter_ + void (*deleter_)(TVMObject* self) -ctypedef struct TVMNDArrayContainer: - DLTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensor* self) - int32_t array_type_info - -ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle ctypedef int (*TVMPackedCFunc)( TVMValue* args, diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi index 5682ae619a46e..9fd3aa43841f0 100644 --- a/python/tvm/_ffi/_cython/ndarray.pxi +++ b/python/tvm/_ffi/_cython/ndarray.pxi @@ -100,17 +100,34 @@ cdef class NDArrayBase: return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) +# Import limited object-related function from C++ side to improve the speed +# NOTE: can only use POD-C compatible object in FFI. +cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime": + cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle) + + cdef c_make_array(void* chandle, is_view, is_container): global _TVM_ND_CLS - cdef int32_t array_type_info - fcreate = _CLASS_NDARRAY - if is_container and len(_TVM_ND_CLS) > 0: - array_type_info = (chandle).array_type_info - if array_type_info > 0: - fcreate = _TVM_ND_CLS[array_type_info] - ret = fcreate(None, is_view) - (ret).chandle = chandle - return ret + + if is_container: + tindex = ( + TVMArrayHandleToObjectHandle(chandle)).type_index_ + if tindex < len(_TVM_ND_CLS): + cls = _TVM_ND_CLS[tindex] + if cls is not None: + ret = cls.__new__(cls) + else: + ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) + else: + ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) + (ret).chandle = chandle + (ret).c_is_view = is_view + return ret + else: + ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) + (ret).chandle = chandle + (ret).c_is_view = is_view + return ret cdef _TVM_COMPATS = () @@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate): if fcreate: _TVM_EXT_RET[cls._tvm_tcode] = fcreate -cdef _TVM_ND_CLS = {} +cdef list _TVM_ND_CLS = [] -def _reg_ndarray(cls, fcreate): +cdef _register_ndarray(int index, object cls): + """register object class""" global _TVM_ND_CLS - _TVM_ND_CLS[cls._array_type_code] = fcreate + while len(_TVM_ND_CLS) <= index: + _TVM_ND_CLS.append(None) + + _TVM_ND_CLS[index] = cls + def _make_array(handle, is_view, is_container): cdef unsigned long long ptr diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 9561eab94ea2f..6d20723fd1881 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -16,12 +16,15 @@ # under the License. """Maps object type to its constructor""" -from ..node_generic import _set_class_node_base - -OBJECT_TYPE = [] +cdef list OBJECT_TYPE = [] def _register_object(int index, object cls): """register object class""" + if issubclass(cls, NDArrayBase): + _register_ndarray(index, cls) + return + + global OBJECT_TYPE while len(OBJECT_TYPE) <= index: OBJECT_TYPE.append(None) OBJECT_TYPE[index] = cls @@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE global _CLASS_NODE cdef unsigned tindex - cdef list object_type cdef object cls cdef object handle object_type = OBJECT_TYPE handle = ctypes_handle(chandle) CALL(TVMObjectGetTypeIndex(chandle, &tindex)) - if tindex < len(object_type): - cls = object_type[tindex] + if tindex < len(OBJECT_TYPE): + cls = OBJECT_TYPE[tindex] if cls is not None: obj = cls.__new__(cls) else: @@ -99,6 +101,3 @@ cdef class ObjectBase: (fconstructor).chandle, kObjectHandle, args, &chandle) self.chandle = chandle - - -_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index ed2f7e1f62d60..23d95ebbf66bb 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,6 +22,7 @@ import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE +from .node_generic import _set_class_objects IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError @@ -32,15 +33,21 @@ if sys.version_info >= (3, 0): from ._cy3.core import _set_class_function, _set_class_module from ._cy3.core import FunctionBase as _FunctionBase + from ._cy3.core import NDArrayBase as _NDArrayBase + from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import convert_to_tvm_func else: from ._cy2.core import _set_class_function, _set_class_module from ._cy2.core import FunctionBase as _FunctionBase + from ._cy2.core import NDArrayBase as _NDArrayBase + from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import convert_to_tvm_func except IMPORT_EXCEPT: # pylint: disable=wrong-import-position from ._ctypes.function import _set_class_function, _set_class_module from ._ctypes.function import FunctionBase as _FunctionBase + from ._ctypes.ndarray import NDArrayBase as _NDArrayBase + from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.function import convert_to_tvm_func FunctionHandle = ctypes.c_void_p @@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix): setattr(target_module, ff.__name__, ff) _set_class_function(Function) +_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase)) diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 1773d916722b5..650f01dd5409b 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -35,16 +35,16 @@ if sys.version_info >= (3, 0): from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy3.core import NDArrayBase as _NDArrayBase - from ._cy3.core import _reg_extension, _reg_ndarray + from ._cy3.core import _reg_extension else: from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy2.core import NDArrayBase as _NDArrayBase - from ._cy2.core import _reg_extension, _reg_ndarray + from ._cy2.core import _reg_extension except IMPORT_EXCEPT: # pylint: disable=wrong-import-position from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import NDArrayBase as _NDArrayBase - from ._ctypes.ndarray import _reg_extension, _reg_ndarray + from ._ctypes.ndarray import _reg_extension def context(dev_type, dev_id=0): @@ -348,13 +348,8 @@ def __init__(self): def _tvm_handle(self): return self.handle.value """ - if issubclass(cls, _NDArrayBase): - assert fcreate is not None - assert hasattr(cls, "_array_type_code") - _reg_ndarray(cls, fcreate) - else: - assert hasattr(cls, "_tvm_tcode") - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") - _reg_extension(cls, fcreate) + assert hasattr(cls, "_tvm_tcode") + if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: + raise ValueError("Cannot register create when extension tcode is same as buildin") + _reg_extension(cls, fcreate) return cls diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py index e89812685eb22..8ee7fc5f2b5bc 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/node_generic.py @@ -23,11 +23,11 @@ from .base import string_types # Node base class -_CLASS_NODE_BASE = None +_CLASS_OBJECTS = None -def _set_class_node_base(cls): - global _CLASS_NODE_BASE - _CLASS_NODE_BASE = cls +def _set_class_objects(cls): + global _CLASS_OBJECTS + _CLASS_OBJECTS = cls def _scalar_type_inference(value): @@ -67,7 +67,7 @@ def convert_to_node(value): node : Node The corresponding node value. """ - if isinstance(value, _CLASS_NODE_BASE): + if isinstance(value, _CLASS_OBJECTS): return value if isinstance(value, bool): return const(value, 'uint1x1') @@ -81,7 +81,7 @@ def convert_to_node(value): if isinstance(value, dict): vlist = [] for item in value.items(): - if (not isinstance(item[0], _CLASS_NODE_BASE) and + if (not isinstance(item[0], _CLASS_OBJECTS) and not isinstance(item[0], string_types)): raise ValueError("key of map must already been a container type") vlist.append(item[0]) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 2dbb67dfbf739..a7947dbc38a2f 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure): ("byte_offset", ctypes.c_uint64)] TVMArrayHandle = ctypes.POINTER(TVMArray) - -class TVMNDArrayContainer(ctypes.Structure): - """TVM NDArray::Container""" - _fields_ = [("dl_tensor", TVMArray), - ("manager_ctx", ctypes.c_void_p), - ("deleter", ctypes.c_void_p), - ("array_type_info", ctypes.c_int32)] - -TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer) diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 2a7a532e660eb..b7fe780f46298 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -27,7 +27,10 @@ from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import register_extension +from ._ffi.object import register_object + +@register_object class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 9cb797fa45e44..8a74fe5cdb7df 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array") } auto node = make_node(); node->data = std::move(data); - *ret = runtime::ObjectRef(node); + *ret = Array(node); }); TVM_REGISTER_API("_ArrayGetItem") @@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map") for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kObjectHandle) + CHECK(args[i + 1].IsObjectRef()) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = Map(node); } else { // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kObjectHandle) - << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kObjectHandle) + CHECK(args[i].IsObjectRef()) + << "key of str map need to be object"; + CHECK(args[i + 1].IsObjectRef()) << "value of map to be NodeRef"; data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = Map(node); } }); @@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems") rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); } - *ret = rkvs; + *ret = Array(rkvs); } else { auto* n = static_cast(ptr); auto rkvs = make_node(); @@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems") rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } - *ret = rkvs; + *ret = Array(rkvs); } }); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 9d2d53e03eb85..3078a0ae07739 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -27,8 +27,12 @@ #include #include "runtime_base.h" +extern "C" { // deleter for arrays used by DLPack exporter -extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor); +void NDArrayDLPackDeleter(DLManagedTensor* tensor); +// helper function to get NDArray's type index, only used by ctypes. +TVM_DLL int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex); +} namespace tvm { namespace runtime { @@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) { struct NDArray::Internal { // Default deleter for the container - static void DefaultDeleter(NDArray::Container* ptr) { - using tvm::runtime::NDArray; + static void DefaultDeleter(Object* ptr_obj) { + auto* ptr = static_cast(ptr_obj); if (ptr->manager_ctx != nullptr) { static_cast(ptr->manager_ctx)->DecRef(); } else if (ptr->dl_tensor.data != nullptr) { @@ -68,7 +72,8 @@ struct NDArray::Internal { // that are not allocated inside of TVM. // This enables us to create NDArray from memory allocated by other // frameworks that are DLPack compatible - static void DLPackDeleter(NDArray::Container* ptr) { + static void DLPackDeleter(Object* ptr_obj) { + auto* ptr = static_cast(ptr_obj); DLManagedTensor* tensor = static_cast(ptr->manager_ctx); if (tensor->deleter != nullptr) { (*tensor->deleter)(tensor); @@ -81,12 +86,13 @@ struct NDArray::Internal { DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); - // critical zone + + // critical zone: construct header NDArray::Container* data = new NDArray::Container(); - data->deleter = DefaultDeleter; - NDArray ret(data); - ret.data_ = data; + data->SetDeleter(DefaultDeleter); + // RAII now in effect + NDArray ret(GetObjectPtr(data)); // setup shape data->shape_ = std::move(shape); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); @@ -98,13 +104,21 @@ struct NDArray::Internal { return ret; } // Implementation of API function - static DLTensor* MoveAsDLTensor(NDArray arr) { - DLTensor* tensor = const_cast(arr.operator->()); - CHECK(reinterpret_cast(arr.data_) == tensor); - arr.data_ = nullptr; - return tensor; + static DLTensor* MoveToFFIHandle(NDArray arr) { + DLTensor* handle = NDArray::FFIGetHandle(arr); + ObjectRef::FFIClearAfterMove(&arr); + return handle; + } + static void FFIDecRef(TVMArrayHandle tensor) { + NDArray::FFIDecRef(tensor); } // Container to DLManagedTensor + static DLManagedTensor* ToDLPack(TVMArrayHandle handle) { + auto* from = static_cast( + reinterpret_cast(handle)); + return ToDLPack(from); + } + static DLManagedTensor* ToDLPack(NDArray::Container* from) { CHECK(from != nullptr); DLManagedTensor* ret = new DLManagedTensor(); @@ -114,29 +128,33 @@ struct NDArray::Internal { ret->deleter = NDArrayDLPackDeleter; return ret; } + // Delete dlpack object. + static void NDArrayDLPackDeleter(DLManagedTensor* tensor) { + static_cast(tensor->manager_ctx)->DecRef(); + delete tensor; + } }; -NDArray NDArray::CreateView(std::vector shape, - DLDataType dtype) { +NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { CHECK(data_ != nullptr); - CHECK(data_->dl_tensor.strides == nullptr) + CHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; - NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx); - ret.data_->dl_tensor.byte_offset = - this->data_->dl_tensor.byte_offset; - size_t curr_size = GetDataSize(this->data_->dl_tensor); - size_t view_size = GetDataSize(ret.data_->dl_tensor); + NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx); + ret.get_mutable()->dl_tensor.byte_offset = + this->get_mutable()->dl_tensor.byte_offset; + size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); + size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor); CHECK_LE(view_size, curr_size) << "Tries to create a view that has bigger memory than current one"; // increase ref count - this->data_->IncRef(); - ret.data_->manager_ctx = this->data_; - ret.data_->dl_tensor.data = this->data_->dl_tensor.data; + get_mutable()->IncRef(); + ret.get_mutable()->manager_ctx = get_mutable(); + ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data; return ret; } DLManagedTensor* NDArray::ToDLPack() const { - return Internal::ToDLPack(data_); + return Internal::ToDLPack(get_mutable()); } NDArray NDArray::Empty(std::vector shape, @@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector shape, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content - size_t size = GetDataSize(ret.data_->dl_tensor); - size_t alignment = GetDataAlignment(ret.data_->dl_tensor); - ret.data_->dl_tensor.data = + size_t size = GetDataSize(ret.get_mutable()->dl_tensor); + size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); + ret.get_mutable()->dl_tensor.data = DeviceAPI::Get(ret->ctx)->AllocDataSpace( ret->ctx, size, alignment, ret->dtype); return ret; @@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector shape, NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); - data->deleter = Internal::DLPackDeleter; + // construct header + data->SetDeleter(Internal::DLPackDeleter); + // fill up content. data->manager_ctx = tensor; data->dl_tensor = tensor->dl_tensor; - return NDArray(data); + return NDArray(GetObjectPtr(data)); } void NDArray::CopyFromTo(DLTensor* from, @@ -184,17 +204,24 @@ void NDArray::CopyFromTo(DLTensor* from, } std::vector NDArray::Shape() const { - return data_->shape_; + return get_mutable()->shape_; } +TVM_REGISTER_OBJECT_TYPE(NDArray::Container); + } // namespace runtime } // namespace tvm using namespace tvm::runtime; void NDArrayDLPackDeleter(DLManagedTensor* tensor) { - static_cast(tensor->manager_ctx)->DecRef(); - delete tensor; + NDArray::Internal::NDArrayDLPackDeleter(tensor); +} + +int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { + API_BEGIN(); + *out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index(); + API_END(); } int TVMArrayAlloc(const tvm_index_t* shape, @@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape, DLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; - *out = NDArray::Internal::MoveAsDLTensor( + *out = NDArray::Internal::MoveToFFIHandle( NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx)); API_END(); } int TVMArrayFree(TVMArrayHandle handle) { API_BEGIN(); - reinterpret_cast(handle)->DecRef(); + NDArray::Internal::FFIDecRef(handle); API_END(); } @@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { API_BEGIN(); - *out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from)); + *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); API_END(); } int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { API_BEGIN(); - *out = NDArray::Internal::ToDLPack(reinterpret_cast(from)); + *out = NDArray::Internal::ToDLPack(from); API_END(); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 1042a4f68e5e1..881788a5292c6 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -59,7 +59,8 @@ class RPCWrappedFunc { const TVMArgValue& arg); // deleter of RPC remote array - static void RemoteNDArrayDeleter(NDArray::Container* ptr) { + static void RemoteNDArrayDeleter(Object* obj) { + auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); delete space; @@ -71,12 +72,12 @@ class RPCWrappedFunc { void* nd_handle) { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; - data->deleter = RemoteNDArrayDeleter; + data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); space->sess = sess; space->data = tensor->data; data->dl_tensor.data = space; - NDArray ret(data); + NDArray ret(GetObjectPtr(data)); // RAII now in effect data->shape_ = std::vector( tensor->shape, tensor->shape + tensor->ndim); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 16b0e7f695298..77d39754b0953 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream { TVMValue ret_value_pack[2]; int ret_tcode_pack[2]; rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - - NDArray::Container* nd = static_cast(ret_value_pack[0].v_handle); - ret_value_pack[1].v_handle = nd; + ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; ret_tcode_pack[1] = kHandle; SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); } else { @@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { void* handle = args[0]; - static_cast(handle)->DecRef(); + static_cast( + reinterpret_cast(handle))->DecRef(); } void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index ff2bbe8eaf11d..3e6140ed38304 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -31,7 +31,8 @@ namespace tvm { namespace runtime { namespace vm { -static void BufferDeleter(NDArray::Container* ptr) { +static void BufferDeleter(Object* obj) { + auto* ptr = static_cast(obj); CHECK(ptr->manager_ctx != nullptr); Buffer* buffer = reinterpret_cast(ptr->manager_ctx); MemoryManager::Global()->GetAllocator(buffer->ctx)-> @@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) { delete ptr; } -void StorageObj::Deleter(NDArray::Container* ptr) { +void StorageObj::Deleter(Object* obj) { + auto* ptr = static_cast(obj); // When invoking AllocNDArray we don't own the underlying allocation // and should not delete the buffer, but instead let it be reclaimed // by the storage object's destructor. @@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa // TODO(@jroesch): generalize later to non-overlapping allocations. CHECK_EQ(offset, 0u); VerifyDataType(dtype); + + // crtical zone: allocate header, cannot throw NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx); - container->deleter = StorageObj::Deleter; + + container->SetDeleter(StorageObj::Deleter); size_t needed_size = GetDataSize(container->dl_tensor); - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; this->IncRef(); container->manager_ctx = reinterpret_cast(this); container->dl_tensor.data = this->buffer.data; - return NDArray(container); + NDArray ret(GetObjectPtr(container)); + + // RAII in effect, now run the check. + // TODO(@jroesch): generalize later to non-overlapping allocations. + CHECK(needed_size == this->buffer.size) + << "size mistmatch required " << needed_size << " found " << this->buffer.size; + + return ret; } MemoryManager* MemoryManager::Global() { @@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx); - container->deleter = BufferDeleter; + container->SetDeleter(BufferDeleter); size_t size = GetDataSize(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor); Buffer *buffer = new Buffer; *buffer = this->Alloc(size, alignment, dtype); container->manager_ctx = reinterpret_cast(buffer); container->dl_tensor.data = buffer->data; - return NDArray(container); + return NDArray(GetObjectPtr(container)); } } // namespace vm diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index 78c8fb36bf703..292fb55e59950 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -120,7 +120,7 @@ class StorageObj : public Object { DLDataType dtype); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ - static void Deleter(NDArray::Container* ptr); + static void Deleter(Object* ptr); ~StorageObj() { auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index f6f3b8f90e37d..a9b9b83ee250c 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include TEST(PackedFunc, Basic) { @@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) { CHECK_EQ(f1(3), 4); } + +TEST(PackedFunc, ObjectConversion) { + using namespace tvm; + using namespace tvm::runtime; + TVMRetValue rv; + auto x = NDArray::Empty( + {}, String2TVMType("float32"), + TVMContext{kDLCPU, 0}); + // assign null + rv = ObjectRef(); + CHECK_EQ(rv.type_code(), kNull); + + // Can assign NDArray to ret type + rv = x; + CHECK_EQ(rv.type_code(), kNDArrayContainer); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(x); + CHECK_EQ(rv.type_code(), kNDArrayContainer); + // Check convert back + CHECK(rv.operator NDArray().same_as(x)); + CHECK(rv.operator ObjectRef().same_as(x)); + CHECK(!rv.IsObjectRef()); + + auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args[0].type_code(), kNDArrayContainer); + CHECK(args[0].operator NDArray().same_as(x)); + CHECK(args[0].operator ObjectRef().same_as(x)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(args[1].operator Array().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); + pf1(x, ObjectRef()); + pf1(ObjectRef(x), NDArray()); + + // testcases for modules + auto* pf = tvm::runtime::Registry::Get("module.source_module_create"); + CHECK(pf != nullptr); + Module m = (*pf)("", "xyz"); + rv = m; + CHECK_EQ(rv.type_code(), kModuleHandle); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(m); + CHECK_EQ(rv.type_code(), kModuleHandle); + // Check convert back + CHECK(rv.operator Module().same_as(m)); + CHECK(rv.operator ObjectRef().same_as(m)); + CHECK(!rv.IsObjectRef()); + + auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args[0].type_code(), kModuleHandle); + CHECK(args[0].operator Module().same_as(m)); + CHECK(args[0].operator ObjectRef().same_as(m)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); + pf2(m, ObjectRef()); + pf2(ObjectRef(m), Module()); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index 206e143029cfd..92edbee9072fb 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import numpy as np def test_array(): a = tvm.convert([1,2,3]) @@ -71,6 +72,14 @@ def test_in_container(): assert tvm.make.StringImm('a') in arr assert 'd' not in arr +def test_ndarray_container(): + x = tvm.nd.array([1,2,3]) + arr = tvm.convert([x, x]) + assert arr[0].same_as(x) + assert arr[1].same_as(x) + assert isinstance(arr[0], tvm.nd.NDArray) + + if __name__ == "__main__": test_str_map() test_array() @@ -78,3 +87,4 @@ def test_in_container(): test_array_save_load_json() test_map_save_load_json() test_in_container() + test_ndarray_container() diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index b676cf2d5244b..951ea97bf252a 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -32,7 +32,7 @@ def gen_engine_header(): #include class Engine { }; - + #endif ''' header_file = header_file_dir_path.relpath("gcc_engine.h") @@ -45,7 +45,7 @@ def generate_engine_module(): #include #include #include "gcc_engine.h" - + extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5, float* gcc_input6, float* gcc_input7, float* out) { Engine engine; diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index ebbcf2106617b..143f7f2f98e56 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -35,7 +35,8 @@ rm -rf lib make cd ../.. -python3 -m pytest -v apps/extension/tests +TVM_FFI=cython python3 -m pytest -v apps/extension/tests +TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests TVM_FFI=ctypes python3 -m pytest -v tests/python/integration TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib