diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 7404a717f7788..31b149eb4913f 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -51,26 +51,23 @@ def __getitem__(self, idx): nd_create = tvm.get_global_func("tvm_ext.nd_create") nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two") -nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info") +nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info") +@tvm.register_object("tvm_ext.NDSubClass") class NDSubClass(tvm.nd.NDArrayBase): """Example for subclassing TVM's NDArray infrastructure. By inheriting TMV's NDArray, external libraries could leverage TVM's FFI without any modification. """ - # Should be consistent with the type-trait set in the backend - _array_type_code = 1 @staticmethod - def create(addtional_info): - return nd_create(addtional_info) + def create(additional_info): + return nd_create(additional_info) @property - def addtional_info(self): - return nd_get_addtional_info(self) + def additional_info(self): + return nd_get_additional_info(self) def __add__(self, other): return nd_add_two(self, other) - -tvm.register_extension(NDSubClass, NDSubClass) diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index 788c28da18d34..a68c8e8c53249 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -29,19 +29,6 @@ #include #include -namespace tvm_ext { -class NDSubClass; -} // namespace tvm_ext - -namespace tvm { -namespace runtime { -template<> -struct array_type_info { - static const int code = 1; -}; -} // namespace tvm -} // namespace runtime - using namespace tvm; using namespace tvm::runtime; @@ -65,41 +52,45 @@ class NDSubClass : public tvm::runtime::NDArray { public: class SubContainer : public NDArray::Container { public: - SubContainer(int addtional_info) : - addtional_info_(addtional_info) { - array_type_code_ = array_type_info::code; - } - static bool Is(NDArray::Container *container) { - SubContainer *c = static_cast(container); - return c->array_type_code_ == array_type_info::code; + SubContainer(int additional_info) : + additional_info_(additional_info) { + type_index_ = SubContainer::RuntimeTypeIndex(); } - int addtional_info_{0}; + int additional_info_{0}; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "tvm_ext.NDSubClass"; + TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container); }; - NDSubClass(NDArray::Container *container) { - if (container == nullptr) { - data_ = nullptr; - return; - } - CHECK(SubContainer::Is(container)); - container->IncRef(); - data_ = container; + + static void SubContainerDeleter(Object* obj) { + auto* ptr = static_cast(obj); + delete ptr; } - ~NDSubClass() { - this->reset(); + + NDSubClass() {} + explicit NDSubClass(ObjectPtr n) : NDArray(n) {} + explicit NDSubClass(int additional_info) { + SubContainer* ptr = new SubContainer(additional_info); + ptr->SetDeleter(SubContainerDeleter); + data_ = GetObjectPtr(ptr); } + NDSubClass AddWith(const NDSubClass &other) const { - SubContainer *a = static_cast(data_); - SubContainer *b = static_cast(other.data_); + SubContainer *a = static_cast(get_mutable()); + SubContainer *b = static_cast(other.get_mutable()); CHECK(a != nullptr && b != nullptr); - return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_)); + return NDSubClass(a->additional_info_ + b->additional_info_); } int get_additional_info() const { - SubContainer *self = static_cast(data_); + SubContainer *self = static_cast(get_mutable()); CHECK(self != nullptr); - return self->addtional_info_; + return self->additional_info_; } + using ContainerType = SubContainer; }; +TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer); /*! * \brief Introduce additional extension data structures @@ -166,8 +157,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") TVM_REGISTER_GLOBAL("tvm_ext.nd_create") .set_body([](TVMArgs args, TVMRetValue *rv) { - int addtional_info = args[0]; - *rv = NDSubClass(new NDSubClass::SubContainer(addtional_info)); + int additional_info = args[0]; + *rv = NDSubClass(additional_info); + CHECK_EQ(rv->type_code(), kNDArrayContainer); + }); TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") @@ -177,7 +170,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") *rv = a.AddWith(b); }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info") +TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") .set_body([](TVMArgs args, TVMRetValue *rv) { NDSubClass a = args[0]; *rv = a.get_additional_info(); diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index e481e82fefb37..a5e7e0f694561 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -87,16 +87,17 @@ def check_llvm(): def test_nd_subclass(): - a = tvm_ext.NDSubClass.create(addtional_info=3) - b = tvm_ext.NDSubClass.create(addtional_info=5) + a = tvm_ext.NDSubClass.create(additional_info=3) + b = tvm_ext.NDSubClass.create(additional_info=5) + assert isinstance(a, tvm_ext.NDSubClass) c = a + b d = a + a e = b + b - assert(a.addtional_info == 3) - assert(b.addtional_info == 5) - assert(c.addtional_info == 8) - assert(d.addtional_info == 6) - assert(e.addtional_info == 10) + assert(a.additional_info == 3) + assert(b.additional_info == 5) + assert(c.additional_info == 8) + assert(d.additional_info == 6) + assert(e.additional_info == 10) if __name__ == "__main__": diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 41b47d3a679e0..1a276ae695fc2 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,14 +23,14 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ +#include + #include #include #include #include #include #include -#include "node.h" -#include "memory.h" namespace tvm { diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 93b7ac33f155b..c9f7a580621f6 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -25,7 +25,6 @@ #ifndef TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_ -#include #include #include #include @@ -43,22 +42,7 @@ using runtime::TVMRetValue; using runtime::PackedFunc; namespace runtime { -/*! - * \brief Runtime type checker for node type. - * \tparam T the type to be checked. - */ -template -struct ObjectTypeChecker { - static bool Check(const Object* ptr) { - using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return true; - return ptr->IsInstance(); - } - static void PrintName(std::ostream& os) { // NOLINT(*) - using ContainerType = typename T::ContainerType; - os << ContainerType::_type_key; - } -}; + template struct ObjectTypeChecker > { @@ -73,10 +57,8 @@ struct ObjectTypeChecker > { } return true; } - static void PrintName(std::ostream& os) { // NOLINT(*) - os << "List["; - ObjectTypeChecker::PrintName(os); - os << "]"; + static std::string TypeName() { + return "List[" + ObjectTypeChecker::TypeName() + "]"; } }; @@ -91,11 +73,9 @@ struct ObjectTypeChecker > { } return true; } - static void PrintName(std::ostream& os) { // NOLINT(*) - os << "Map[str"; - os << ','; - ObjectTypeChecker::PrintName(os); - os << ']'; + static std::string TypeName() { + return "Map[str, " + + ObjectTypeChecker::TypeName()+ ']'; } }; @@ -111,39 +91,16 @@ struct ObjectTypeChecker > { } return true; } - static void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "Map["; - ObjectTypeChecker::PrintName(os); - os << ','; - ObjectTypeChecker::PrintName(os); - os << ']'; + static std::string TypeName() { + return "Map[" + + ObjectTypeChecker::TypeName() + + ", " + + ObjectTypeChecker::TypeName()+ ']'; } }; -template -inline std::string ObjectTypeName() { - std::ostringstream os; - ObjectTypeChecker::PrintName(os); - return os.str(); -} - // extensions for tvm arg value - -template -inline TObjectRef TVMArgValue::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); - if (type_code_ == kNull) return TObjectRef(NodePtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() - << " but get " << ptr->GetTypeKey(); - return TObjectRef(ObjectPtr(ptr)); -} - -inline TVMArgValue::operator tvm::Expr() const { +inline TVMPODValue_::operator tvm::Expr() const { if (type_code_ == kNull) return Expr(); if (type_code_ == kDLInt) { CHECK_LE(value_.v_int64, std::numeric_limits::max()); @@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const { return Tensor(ObjectPtr(ptr))(); } CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return Expr(ObjectPtr(ptr)); } -inline TVMArgValue::operator tvm::Integer() const { +inline TVMPODValue_::operator tvm::Integer() const { if (type_code_ == kNull) return Integer(); if (type_code_ == kDLInt) { CHECK_LE(value_.v_int64, std::numeric_limits::max()); @@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const { TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); Object* ptr = static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); return Integer(ObjectPtr(ptr)); } - -template -inline bool TVMPODValue_::IsObjectRef() const { - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - Object* ptr = static_cast(value_.v_handle); - return ObjectTypeChecker::Check(ptr); -} - -// extensions for TVMRetValue -template -inline TObjectRef TVMRetValue::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); - if (type_code_ == kNull) return TObjectRef(); - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - - Object* ptr = static_cast(value_.v_handle); - - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() - << " but get " << ptr->GetTypeKey(); - return TObjectRef(ObjectPtr(ptr)); -} - } // namespace runtime } // namespace tvm #endif // TVM_PACKED_FUNC_EXT_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index dbe827812fc3d..4dc07f4a3a045 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -23,6 +23,7 @@ */ #ifndef TVM_RUNTIME_CONTAINER_H_ #define TVM_RUNTIME_CONTAINER_H_ + #include #include #include diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 9932951798423..1bc49fab6fccb 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_NDARRAY_H_ #define TVM_RUNTIME_NDARRAY_H_ +#include +#include +#include + #include #include #include -#include "c_runtime_api.h" -#include "serializer.h" namespace tvm { namespace runtime { @@ -37,72 +39,23 @@ namespace runtime { * \brief Managed NDArray. * The array is backed by reference counted blocks. */ -class NDArray { +class NDArray : public ObjectRef { public: - // internal container type + /*! \brief ContainerBase used to back the TVMArrayHandle */ + class ContainerBase; + /*! \brief NDArray internal container type */ class Container; + /*! \brief Container type for Object system. */ + using ContainerType = Container; /*! \brief default constructor */ NDArray() {} /*! - * \brief cosntruct a NDArray that refers to data - * \param data The data this NDArray refers to - */ - explicit inline NDArray(Container* data); - /*! - * \brief copy constructor. - * - * It does not make a copy, but the reference count of the input NDArray is incremented - * - * \param other NDArray that shares internal data with the input NDArray. - */ - inline NDArray(const NDArray& other); // NOLINT(*) - /*! - * \brief move constructor - * \param other The value to be moved - */ - NDArray(NDArray&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! \brief destructor */ - ~NDArray() { - this->reset(); - } - /*! - * \brief Swap this array with another NDArray - * \param other The other NDArray + * \brief constructor. + * \param data ObjectPtr to the data container. */ - void swap(NDArray& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \brief copy assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NDArray& operator=(const NDArray& other) { // NOLINT(*) - // copy-and-swap idiom - NDArray(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NDArray& operator=(NDArray&& other) { // NOLINT(*) - // copy-and-swap idiom - NDArray(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! \return If NDArray is defined */ - bool defined() const { - return data_ != nullptr; - } - /*! \return If both NDArray reference the same container */ - bool same_as(const NDArray& other) const { - return data_ == other.data_; - } + explicit NDArray(ObjectPtr data) + : ObjectRef(data) {} + /*! \brief reset the content of NDArray to be nullptr */ inline void reset(); /*! @@ -191,36 +144,40 @@ class NDArray { DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; - // internal namespace struct Internal; + protected: - /*! \brief Internal Data content */ - Container* data_{nullptr}; - // enable internal functions - friend struct Internal; friend class TVMPODValue_; - friend class TVMArgValue; friend class TVMRetValue; friend class TVMArgsSetter; -}; - -/*! - * \brief The type trait indicates subclass of TVM's NDArray. - * For irrelavant classes, code = -1. - * For TVM NDArray itself, code = 0. - * All subclasses of NDArray should override code > 0. - */ -template -struct array_type_info { - /*! \brief the value of the traits */ - static const int code = -1; -}; - -// Overrides the type trait for tvm's NDArray. -template<> -struct array_type_info { - static const int code = 0; + /*! + * \brief Get mutable internal container pointer. + * \return a mutable container pointer. + */ + inline Container* get_mutable() const; + // Helper functions for FFI handling. + /*! + * \brief Construct NDArray's Data field from array handle in FFI. + * \param handle The array handle. + * \return The constructed NDArray. + * + * \note We keep a special calling convention for NDArray by passing + * ContainerBase pointer in FFI. + * As a result, the argument is compatible to DLTensor*. + */ + inline static ObjectPtr FFIDataFromHandle(TVMArrayHandle handle); + /*! + * \brief DecRef resource managed by an FFI array handle. + * \param handle The array handle. + */ + inline static void FFIDecRef(TVMArrayHandle handle); + /*! + * \brief Get FFI Array handle from ndarray. + * \param nd The object with ndarray type. + * \return The result array handle. + */ + inline static TVMArrayHandle FFIGetHandle(const ObjectRef& nd); }; /*! @@ -231,19 +188,14 @@ struct array_type_info { inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); /*! - * \brief Reference counted Container object used to back NDArray. + * \brief The container base structure + * contains all the fields except for the Object header. * - * This object is DLTensor compatible: - * the pointer to the NDArrayContainer can be directly - * interpreted as a DLTensor* - * - * \note do not use this function directly, use NDArray. + * \note We explicitly declare this structure in order to pass + * PackedFunc argument using ContainerBase*. */ -class NDArray::Container { +class NDArray::ContainerBase { public: - // NOTE: the first part of this structure is the same as - // DLManagedTensor, note that, however, the deleter - // is only called when the reference counter goes to 0 /*! * \brief The corresponding dl_tensor field. * \note it is important that the first field is DLTensor @@ -259,42 +211,27 @@ class NDArray::Container { * (e.g. reference to original memory when creating views). */ void* manager_ctx{nullptr}; - /*! - * \brief Customized deleter - * - * \note The customized deleter is helpful to enable - * different ways of memory allocator that are not - * currently defined by the system. - */ - void (*deleter)(Container* self) = nullptr; protected: - friend class NDArray; - friend class TVMPODValue_; - friend class TVMArgValue; - friend class TVMRetValue; - friend class RPCWrappedFunc; - /*! - * \brief Type flag used to indicate subclass. - * Default value 0 means normal NDArray::Conatainer. - * - * We can extend a more specialized NDArray::Container - * and use the array_type_code_ to indicate - * the specific array subclass. - */ - int32_t array_type_code_{0}; - /*! \brief The internal reference counter */ - std::atomic ref_counter_{0}; - /*! * \brief The shape container, * can be used used for shape data. */ std::vector shape_; +}; +/*! + * \brief Object container class taht backs NDArray. + * \note do not use this function directly, use NDArray. + */ +class NDArray::Container : + public Object, + public NDArray::ContainerBase { public: /*! \brief default constructor */ Container() { + // Initialize the type index. + type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = nullptr; dl_tensor.ndim = 0; dl_tensor.shape = nullptr; @@ -306,6 +243,8 @@ class NDArray::Container { std::vector shape, DLDataType dtype, DLContext ctx) { + // Initialize the type index. + type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; shape_ = std::move(shape); dl_tensor.ndim = static_cast(shape_.size()); @@ -315,49 +254,38 @@ class NDArray::Container { dl_tensor.byte_offset = 0; dl_tensor.ctx = ctx; } - - /*! \brief developer function, increases reference counter */ - void IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); - } - /*! \brief developer function, decrease reference counter */ - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter != nullptr) { - (*this->deleter)(this); - } - } + /*! + * \brief Set the deleter field. + * \param deleter The deleter. + */ + void SetDeleter(FDeleter deleter) { + deleter_ = deleter; } -}; -// implementations of inline functions -// the usages of functions are documented in place. -inline NDArray::NDArray(Container* data) - : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } -} + // Expose DecRef and IncRef as public function + // NOTE: they are only for developer purposes only. + using Object::DecRef; + using Object::IncRef; -inline NDArray::NDArray(const NDArray& other) - : data_(other.data_) { - if (data_ != nullptr) { - data_->IncRef(); - } -} + // Information for object protocol. + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_child_slots = 0; + static constexpr const uint32_t _type_child_slots_can_overflow = true; + static constexpr const char* _type_key = "NDArray"; + TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object); -inline void NDArray::reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } -} + protected: + friend class RPCWrappedFunc; + friend class NDArray; +}; + +// implementations of inline functions +// the usages of functions are documented in place.a -/*! \brief return the size of data the DLTensor hold, in term of number of bytes +/*! + * \brief return the size of data the DLTensor hold, in term of number of bytes * * \param arr the input DLTensor - * * \return number of bytes of data in the DLTensor. */ inline size_t GetDataSize(const DLTensor& arr) { @@ -371,24 +299,24 @@ inline size_t GetDataSize(const DLTensor& arr) { inline void NDArray::CopyFrom(DLTensor* other) { CHECK(data_ != nullptr); - CopyFromTo(other, &(data_->dl_tensor)); + CopyFromTo(other, &(get_mutable()->dl_tensor)); } inline void NDArray::CopyFrom(const NDArray& other) { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); - CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor)); + CopyFromTo(&(other.get_mutable()->dl_tensor), &(get_mutable()->dl_tensor)); } inline void NDArray::CopyTo(DLTensor* other) const { CHECK(data_ != nullptr); - CopyFromTo(&(data_->dl_tensor), other); + CopyFromTo(&(get_mutable()->dl_tensor), other); } inline void NDArray::CopyTo(const NDArray& other) const { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); - CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); + CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor)); } inline NDArray NDArray::CopyTo(const DLContext& ctx) const { @@ -401,12 +329,39 @@ inline NDArray NDArray::CopyTo(const DLContext& ctx) const { } inline int NDArray::use_count() const { - if (data_ == nullptr) return 0; - return data_->ref_counter_.load(std::memory_order_relaxed); + return data_.use_count(); } inline const DLTensor* NDArray::operator->() const { - return &(data_->dl_tensor); + return &(get_mutable()->dl_tensor); +} + +inline NDArray::Container* NDArray::get_mutable() const { + return static_cast(data_.get()); +} + +inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { + return GetObjectPtr(static_cast( + reinterpret_cast(handle))); +} + +inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { + // NOTE: it is necessary to cast to container then to base + // so that the FFI handle uses the ContainerBase address. + return reinterpret_cast( + static_cast( + static_cast( + const_cast(nd.get())))); +} + +inline void NDArray::FFIDecRef(TVMArrayHandle handle) { + static_cast( + reinterpret_cast(handle))->DecRef(); +} + +inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) { + return static_cast( + reinterpret_cast(handle)); } /*! \brief Magic number for NDArray file */ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 20e6b5a0fb632..96215daf4a7ac 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -24,10 +24,11 @@ #define TVM_RUNTIME_OBJECT_H_ #include +#include #include #include #include -#include "c_runtime_api.h" + /*! * \brief Whether or not use atomic reference counter. @@ -580,6 +581,14 @@ class ObjectRef { static T DowncastNoCheck(ObjectRef ref) { return T(std::move(ref.data_)); } + /*! + * \brief Clear the object ref data field without DecRef + * after we successfully moved the field. + * \param ref The reference data. + */ + static void FFIClearAfterMove(ObjectRef* ref) { + ref->data_.data_ = nullptr; + } /*! * \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \note only used for internal dev purpose. @@ -648,7 +657,7 @@ struct ObjectEqual { return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ + static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ TypeName::_type_key, \ TypeName::_type_index, \ ParentType::_GetOrAllocRuntimeTypeIndex(), \ @@ -668,6 +677,19 @@ struct ObjectEqual { TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ +/*! \brief helper macro to supress unused warning */ +#if defined(__GNUC__) +#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_ATTRIBUTE_UNUSED +#endif + +#define TVM_STR_CONCAT_(__x, __y) __x##__y +#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) + +#define TVM_OBJECT_REG_VAR_DEF \ + static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid + /*! * \brief Helper macro to register the object type to runtime. * Makes sure that the runtime type table is correctly populated. @@ -675,7 +697,7 @@ struct ObjectEqual { * Use this macro in the cc file for each terminal class. */ #define TVM_REGISTER_OBJECT_TYPE(TypeName) \ - static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ + TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() @@ -691,14 +713,14 @@ struct ObjectEqual { using ContainerType = ObjectName; #define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \ - TypeName() {} \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - ObjectName* operator->() { \ - return static_cast(data_.get()); \ - } \ - operator bool() const { return data_ != nullptr; } \ + TypeName() {} \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + ObjectName* operator->() { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return data_ != nullptr; } \ using ContainerType = ObjectName; // Implementations details below diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 27dcb4130b4c2..5650db6f909cd 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -387,20 +387,22 @@ inline std::string TVMType2String(TVMType t); #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ + /*! - * \brief Type traits to mark if a class is tvm extension type. - * - * To enable extension type in C++ must be registered via marco. - * TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits. - * - * Extension class can be passed and returned via PackedFunc in all tvm runtime. - * Internally extension class is stored as T*. - * - * \tparam T the typename + * \brief Type traits for runtime type check during FFI conversion. + * \tparam T the type to be checked. */ template -struct extension_type_info { - static const int code = 0; +struct ObjectTypeChecker { + static bool Check(const Object* ptr) { + using ContainerType = typename T::ContainerType; + if (ptr == nullptr) return true; + return ptr->IsInstance(); + } + static std::string TypeName() { + using ContainerType = typename T::ContainerType; + return ContainerType::_type_key; + } }; /*! @@ -449,24 +451,17 @@ class TVMPODValue_ { return static_cast(value_.v_handle); } else { if (type_code_ == kNull) return nullptr; - LOG(FATAL) << "Expected " + LOG(FATAL) << "Expect " << "DLTensor* or NDArray but get " << TypeCode2Str(type_code_); return nullptr; } } operator NDArray() const { - if (type_code_ == kNull) return NDArray(); + if (type_code_ == kNull) return NDArray(ObjectPtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); - return NDArray(static_cast(value_.v_handle)); - } - operator ObjectRef() const { - if (type_code_ == kNull) { - return ObjectRef(ObjectPtr(nullptr)); - } - TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - return ObjectRef( - ObjectPtr(static_cast(value_.v_handle))); + return NDArray(NDArray::FFIDataFromHandle( + static_cast(value_.v_handle))); } operator Module() const { if (type_code_ == kNull) { @@ -480,23 +475,9 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } - template::value>::type> - TNDArray AsNDArray() const { - if (type_code_ == kNull) return TNDArray(nullptr); - auto *container = static_cast(value_.v_handle); - CHECK_EQ(container->array_type_code_, array_type_info::code); - return TNDArray(container); - } - template::value>::type> - inline bool IsObjectRef() const; int type_code() const { return type_code_; } - /*! * \brief return handle as specific pointer type. * \tparam T the data type. @@ -506,6 +487,16 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } + // ObjectRef handling + template::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; + // ObjectRef Specializations + inline operator tvm::Expr() const; + inline operator tvm::Integer() const; protected: friend class TVMArgsSetter; @@ -548,9 +539,11 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; - using TVMPODValue_::operator ObjectRef; using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; + using TVMPODValue_::AsObjectRef; + using TVMPODValue_::operator tvm::Expr; + using TVMPODValue_::operator tvm::Integer; // conversion operator. operator std::string() const { @@ -577,6 +570,9 @@ class TVMArgValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMType); return value_.v_type; } + operator DataType() const { + return DataType(operator DLDataType()); + } operator PackedFunc() const { if (type_code_ == kNull) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); @@ -589,16 +585,10 @@ class TVMArgValue : public TVMPODValue_ { const TVMValue& value() const { return value_; } - // Deferred extension handler. - template - inline TObjectRef AsObjectRef() const; template::value>::type> + std::is_class::value>::type> inline operator T() const; - inline operator DataType() const; - inline operator tvm::Expr() const; - inline operator tvm::Integer() const; }; /*! @@ -636,9 +626,11 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; - using TVMPODValue_::operator ObjectRef; using TVMPODValue_::operator Module; using TVMPODValue_::IsObjectRef; + using TVMPODValue_::AsObjectRef; + using TVMPODValue_::operator tvm::Expr; + using TVMPODValue_::operator tvm::Integer; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); @@ -660,6 +652,9 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMType); return value_.v_type; } + operator DataType() const { + return DataType(operator DLDataType()); + } operator PackedFunc() const { if (type_code_ == kNull) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); @@ -712,6 +707,9 @@ class TVMRetValue : public TVMPODValue_ { value_.v_type = t; return *this; } + TVMRetValue& operator=(const DataType& other) { + return operator=(other.operator DLDataType()); + } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; @@ -726,24 +724,20 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(NDArray other) { - this->Clear(); - type_code_ = kNDArrayContainer; - value_.v_handle = other.data_; - other.data_ = nullptr; + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = kNDArrayContainer; + value_.v_handle = NDArray::FFIGetHandle(other); + ObjectRef::FFIClearAfterMove(&other); + } else { + SwitchToPOD(kNull); + } return *this; } - TVMRetValue& operator=(ObjectRef other) { - return operator=(std::move(other.data_)); - } TVMRetValue& operator=(Module m) { SwitchToObject(kModuleHandle, std::move(m.data_)); return *this; } - template - TVMRetValue& operator=(ObjectPtr other) { - SwitchToObject(kObjectHandle, std::move(other)); - return *this; - } TVMRetValue& operator=(PackedFunc f) { this->SwitchToClass(kFuncHandle, f); return *this; @@ -760,14 +754,6 @@ class TVMRetValue : public TVMPODValue_ { this->Assign(other); return *this; } - template::code != 0>::type> - TVMRetValue& operator=(const T& other) { - this->SwitchToClass( - extension_type_info::code, other); - return *this; - } /*! * \brief Move the value back to front-end via C API. * This marks the current container as null. @@ -793,16 +779,15 @@ class TVMRetValue : public TVMPODValue_ { type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; return value_; } - // ObjectRef related extenstions: in tvm/packed_func_ext.h + // ObjectRef handling + template::value>::type> + inline TVMRetValue& operator=(TObjectRef other); template::value>::type> inline operator T() const; - template - inline TObjectRef AsObjectRef() const; - // type related - inline operator DataType() const; - inline TVMRetValue& operator=(const DataType& other); private: template @@ -829,7 +814,10 @@ class TVMRetValue : public TVMPODValue_ { break; } case kObjectHandle: { - *this = other.operator ObjectRef(); + // Avoid operator ObjectRef as we already know it is not NDArray/Module + SwitchToObject( + kObjectHandle, GetObjectPtr( + static_cast(other.value_.v_handle))); break; } default: { @@ -873,7 +861,7 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; case kNDArrayContainer: { - static_cast(value_.v_handle)->DecRef(); + NDArray::FFIDecRef(static_cast(value_.v_handle)); break; } case kModuleHandle: { @@ -905,7 +893,7 @@ inline const char* TypeCode2Str(int type_code) { case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; - case kObjectHandle: return "ObjectCell"; + case kObjectHandle: return "Object"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } @@ -929,6 +917,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) return os; } +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) + return os << dtype.operator DLDataType(); +} + #endif inline std::string TVMType2String(TVMType t) { @@ -996,10 +988,6 @@ inline TVMType String2TVMType(std::string s) { return t; } -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) - return os << dtype.operator DLDataType(); -} - inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) << "not enough argument passed, " @@ -1092,50 +1080,31 @@ class TVMArgsSetter { values_[i].v_type = value; type_codes_[i] = kTVMType; } + void operator()(size_t i, DataType dtype) const { + operator()(i, dtype.operator DLDataType()); + } void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kStr; } - // setters for container type - // They must be reference(instead of const ref) - // to make sure they are alive in the tuple(instead of getting converted) - void operator()(size_t i, const std::string& value) const { // NOLINT(*) + // setters for container types + void operator()(size_t i, const std::string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kStr; } - void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*) + void operator()(size_t i, const TVMByteArray& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kBytes; } - void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) + void operator()(size_t i, const PackedFunc& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kFuncHandle; } template - void operator()(size_t i, const TypedPackedFunc& value) const { // NOLINT(*) + void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } - void operator()(size_t i, const Module& value) const { // NOLINT(*) - if (value.defined()) { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kModuleHandle; - } else { - type_codes_[i] = kNull; - } - } - void operator()(size_t i, const NDArray& value) const { // NOLINT(*) - values_[i].v_handle = value.data_; - type_codes_[i] = kNDArrayContainer; - } - void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) - if (value.defined()) { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectHandle; - } else { - type_codes_[i] = kNull; - } - } - void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) + void operator()(size_t i, const TVMRetValue& value) const { if (value.type_code() == kStr) { values_[i].v_str = value.ptr()->c_str(); type_codes_[i] = kStr; @@ -1145,12 +1114,11 @@ class TVMArgsSetter { type_codes_[i] = value.type_code(); } } - // extension - template::code != 0>::type> - inline void operator()(size_t i, const T& value) const; - inline void operator()(size_t i, const DataType& t) const; + std::is_base_of::value>::type> + inline void operator()(size_t i, const TObjectRef& value) const; private: /*! \brief The values fields */ @@ -1262,57 +1230,131 @@ inline R TypedPackedFunc::operator()(Args... args) const { ::run(packed_, std::forward(args)...); } -// extension and node type handling -namespace detail { -template -struct TVMValueCast { - static T Apply(const TSrc* self) { - static_assert(!is_nd, "The default case accepts only non-extensions"); - return self->template AsObjectRef(); - } -}; - -template -struct TVMValueCast { - static T Apply(const TSrc* self) { - return self->template AsNDArray(); +// ObjectRef related conversion handling +// Object can have three possible type codes: +// kNDArrayContainer, kModuleHandle, kObjectHandle +// +// We use type traits to eliminate un-necessary checks. +template +inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { + if (value.defined()) { + Object* ptr = value.data_.data_; + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + values_[i].v_handle = NDArray::FFIGetHandle(value); + type_codes_[i] = kNDArrayContainer; + } else if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + values_[i].v_handle = ptr; + type_codes_[i] = kModuleHandle; + } else { + values_[i].v_handle = ptr; + type_codes_[i] = kObjectHandle; + } + } else { + type_codes_[i] = kNull; } -}; - -} // namespace detail - -template -inline TVMArgValue::operator T() const { - return detail:: - TVMValueCast::code > 0)> - ::Apply(this); } -template -inline TVMRetValue::operator T() const { - return detail:: - TVMValueCast::code > 0)> - ::Apply(this); +template +inline bool TVMPODValue_::IsObjectRef() const { + using ContainerType = typename TObjectRef::ContainerType; + // NOTE: the following code can be optimized by constant folding. + if (std::is_base_of::value) { + return type_code_ == kNDArrayContainer && + TVMArrayHandleToObjectHandle( + static_cast(value_.v_handle))->IsInstance(); + } + if (std::is_base_of::value) { + return type_code_ == kModuleHandle && + static_cast(value_.v_handle)->IsInstance(); + } + return + (std::is_base_of::value && type_code_ == kNDArrayContainer) || + (std::is_base_of::value && type_code_ == kModuleHandle) || + (type_code_ == kObjectHandle && + ObjectTypeChecker::Check(static_cast(value_.v_handle))); } -// PackedFunc support -inline TVMRetValue& TVMRetValue::operator=(const DataType& t) { - return this->operator=(t.operator DLDataType()); +template +inline TObjectRef TVMPODValue_::AsObjectRef() const { + static_assert( + std::is_base_of::value, + "Conversion only works for ObjectRef"); + using ContainerType = typename TObjectRef::ContainerType; + if (type_code_ == kNull) return TObjectRef(ObjectPtr(nullptr)); + // NOTE: the following code can be optimized by constant folding. + if (std::is_base_of::value) { + // Casting to a sub-class of NDArray + TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); + ObjectPtr data = NDArray::FFIDataFromHandle( + static_cast(value_.v_handle)); + CHECK(data->IsInstance()) + << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); + return TObjectRef(data); + } + if (std::is_base_of::value) { + // Casting to a sub-class of Module + TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); + ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); + CHECK(data->IsInstance()) + << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); + return TObjectRef(data); + } + if (type_code_ == kObjectHandle) { + // normal object type check. + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expect " << ObjectTypeChecker::TypeName() + << " but get " << ptr->GetTypeKey(); + return TObjectRef(GetObjectPtr(ptr)); + } else if (std::is_base_of::value && + type_code_ == kNDArrayContainer) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = NDArray::FFIDataFromHandle( + static_cast(value_.v_handle)); + return TObjectRef(data); + } else if (std::is_base_of::value && + type_code_ == kModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } else { + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); + } } -inline TVMRetValue::operator DataType() const { - return DataType(operator DLDataType()); +template +inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { + const Object* ptr = other.get(); + if (ptr != nullptr) { + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(NDArray(std::move(other.data_))); + } + if (std::is_base_of::value || + (std::is_base_of::value && + ptr->IsInstance())) { + return operator=(Module(std::move(other.data_))); + } + SwitchToObject(kObjectHandle, std::move(other.data_)); + } else { + SwitchToPOD(kNull); + } + return *this; } -inline TVMArgValue::operator DataType() const { - return DataType(operator DLDataType()); +template +inline TVMArgValue::operator T() const { + return AsObjectRef(); } -inline void TVMArgsSetter::operator()( - size_t i, const DataType& t) const { - this->operator()(i, t.operator DLDataType()); +template +inline TVMRetValue::operator T() const { + return AsObjectRef(); } inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 3500a7e4e3982..e51b806ea81ff 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -43,9 +43,9 @@ #ifndef TVM_RUNTIME_REGISTRY_H_ #define TVM_RUNTIME_REGISTRY_H_ +#include #include #include -#include "packed_func.h" namespace tvm { namespace runtime { @@ -283,22 +283,9 @@ class Registry { friend struct Manager; }; -/*! \brief helper macro to supress unused warning */ -#if defined(__GNUC__) -#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define TVM_ATTRIBUTE_UNUSED -#endif - -#define TVM_STR_CONCAT_(__x, __y) __x##__y -#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) - #define TVM_FUNC_REG_VAR_DEF \ static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM -#define TVM_TYPE_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::ExtTypeVTable* __mk_ ## TVMT - /*! * \brief Register a function globally. * \code diff --git a/python/setup.py b/python/setup.py index f8b580cd75529..bc53060f95cf6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,6 +96,7 @@ def config_cython(): "../3rdparty/dmlc-core/include", "../3rdparty/dlpack/include", ], + extra_compile_args=["-std=c++11"], library_dirs=library_dirs, libraries=libraries, language="c++")) diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py index af59de6eee1d7..c572947c8d19a 100644 --- a/python/tvm/_ffi/_ctypes/ndarray.py +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -20,7 +20,7 @@ import ctypes from ..base import _LIB, check_call, c_str -from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle +from ..runtime_ctypes import TVMArrayHandle from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle @@ -110,12 +110,17 @@ def to_dlpack(self): def _make_array(handle, is_view, is_container): global _TVM_ND_CLS handle = ctypes.cast(handle, TVMArrayHandle) - fcreate = _CLASS_NDARRAY - if is_container and _TVM_ND_CLS: - array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value - if array_type_info > 0: - fcreate = _TVM_ND_CLS[array_type_info] - return fcreate(handle, is_view) + if is_container: + tindex = ctypes.c_uint() + check_call(_LIB.TVMArrayGetTypeIndex(handle, ctypes.byref(tindex))) + cls = _TVM_ND_CLS.get(tindex.value, _CLASS_NDARRAY) + else: + cls = _CLASS_NDARRAY + + ret = cls.__new__(cls) + ret.handle = handle + ret.is_view = is_view + return ret _TVM_COMPATS = () @@ -129,9 +134,9 @@ def _reg_extension(cls, fcreate): _TVM_ND_CLS = {} -def _reg_ndarray(cls, fcreate): +def _register_ndarray(index, cls): global _TVM_ND_CLS - _TVM_ND_CLS[cls._array_type_code] = fcreate + _TVM_ND_CLS[index] = cls _CLASS_NDARRAY = None diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index c3ae56822198d..b8b8aefea131c 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -21,7 +21,7 @@ import ctypes from ..base import _LIB, check_call from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func -from ..node_generic import _set_class_node_base +from .ndarray import _register_ndarray, NDArrayBase ObjectHandle = ctypes.c_void_p @@ -39,6 +39,9 @@ def _set_class_node(node_class): def _register_object(index, cls): """register object class""" + if issubclass(cls, NDArrayBase): + _register_ndarray(index, cls) + return OBJECT_TYPE[index] = cls @@ -91,6 +94,3 @@ def __init_handle_by_constructor__(self, fconstructor, *args): if not isinstance(handle, ObjectHandle): handle = ObjectHandle(handle) self.handle = handle - - -_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 4b7b2c88ffa50..7ccb6279fed00 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -19,7 +19,7 @@ from ..base import get_last_ffi_error from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule -from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t +from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t import ctypes cdef enum TVMTypeCode: @@ -78,14 +78,11 @@ ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* ObjectHandle +ctypedef struct TVMObject: + uint32_t type_index_ + int32_t ref_counter_ + void (*deleter_)(TVMObject* self) -ctypedef struct TVMNDArrayContainer: - DLTensor dl_tensor - void* manager_ctx - void (*deleter)(DLManagedTensor* self) - int32_t array_type_info - -ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle ctypedef int (*TVMPackedCFunc)( TVMValue* args, diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi index 5682ae619a46e..9fd3aa43841f0 100644 --- a/python/tvm/_ffi/_cython/ndarray.pxi +++ b/python/tvm/_ffi/_cython/ndarray.pxi @@ -100,17 +100,34 @@ cdef class NDArrayBase: return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) +# Import limited object-related function from C++ side to improve the speed +# NOTE: can only use POD-C compatible object in FFI. +cdef extern from "tvm/runtime/ndarray.h" namespace "tvm::runtime": + cdef void* TVMArrayHandleToObjectHandle(DLTensorHandle handle) + + cdef c_make_array(void* chandle, is_view, is_container): global _TVM_ND_CLS - cdef int32_t array_type_info - fcreate = _CLASS_NDARRAY - if is_container and len(_TVM_ND_CLS) > 0: - array_type_info = (chandle).array_type_info - if array_type_info > 0: - fcreate = _TVM_ND_CLS[array_type_info] - ret = fcreate(None, is_view) - (ret).chandle = chandle - return ret + + if is_container: + tindex = ( + TVMArrayHandleToObjectHandle(chandle)).type_index_ + if tindex < len(_TVM_ND_CLS): + cls = _TVM_ND_CLS[tindex] + if cls is not None: + ret = cls.__new__(cls) + else: + ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) + else: + ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) + (ret).chandle = chandle + (ret).c_is_view = is_view + return ret + else: + ret = _CLASS_NDARRAY.__new__(_CLASS_NDARRAY) + (ret).chandle = chandle + (ret).c_is_view = is_view + return ret cdef _TVM_COMPATS = () @@ -123,11 +140,16 @@ def _reg_extension(cls, fcreate): if fcreate: _TVM_EXT_RET[cls._tvm_tcode] = fcreate -cdef _TVM_ND_CLS = {} +cdef list _TVM_ND_CLS = [] -def _reg_ndarray(cls, fcreate): +cdef _register_ndarray(int index, object cls): + """register object class""" global _TVM_ND_CLS - _TVM_ND_CLS[cls._array_type_code] = fcreate + while len(_TVM_ND_CLS) <= index: + _TVM_ND_CLS.append(None) + + _TVM_ND_CLS[index] = cls + def _make_array(handle, is_view, is_container): cdef unsigned long long ptr diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 9561eab94ea2f..6d20723fd1881 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -16,12 +16,15 @@ # under the License. """Maps object type to its constructor""" -from ..node_generic import _set_class_node_base - -OBJECT_TYPE = [] +cdef list OBJECT_TYPE = [] def _register_object(int index, object cls): """register object class""" + if issubclass(cls, NDArrayBase): + _register_ndarray(index, cls) + return + + global OBJECT_TYPE while len(OBJECT_TYPE) <= index: OBJECT_TYPE.append(None) OBJECT_TYPE[index] = cls @@ -31,14 +34,13 @@ cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE global _CLASS_NODE cdef unsigned tindex - cdef list object_type cdef object cls cdef object handle object_type = OBJECT_TYPE handle = ctypes_handle(chandle) CALL(TVMObjectGetTypeIndex(chandle, &tindex)) - if tindex < len(object_type): - cls = object_type[tindex] + if tindex < len(OBJECT_TYPE): + cls = OBJECT_TYPE[tindex] if cls is not None: obj = cls.__new__(cls) else: @@ -99,6 +101,3 @@ cdef class ObjectBase: (fconstructor).chandle, kObjectHandle, args, &chandle) self.chandle = chandle - - -_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index ed2f7e1f62d60..23d95ebbf66bb 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,6 +22,7 @@ import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE +from .node_generic import _set_class_objects IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError @@ -32,15 +33,21 @@ if sys.version_info >= (3, 0): from ._cy3.core import _set_class_function, _set_class_module from ._cy3.core import FunctionBase as _FunctionBase + from ._cy3.core import NDArrayBase as _NDArrayBase + from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import convert_to_tvm_func else: from ._cy2.core import _set_class_function, _set_class_module from ._cy2.core import FunctionBase as _FunctionBase + from ._cy2.core import NDArrayBase as _NDArrayBase + from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import convert_to_tvm_func except IMPORT_EXCEPT: # pylint: disable=wrong-import-position from ._ctypes.function import _set_class_function, _set_class_module from ._ctypes.function import FunctionBase as _FunctionBase + from ._ctypes.ndarray import NDArrayBase as _NDArrayBase + from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.function import convert_to_tvm_func FunctionHandle = ctypes.c_void_p @@ -325,3 +332,4 @@ def _init_api_prefix(module_name, prefix): setattr(target_module, ff.__name__, ff) _set_class_function(Function) +_set_class_objects((_ObjectBase, _NDArrayBase, ModuleBase)) diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 1773d916722b5..650f01dd5409b 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -35,16 +35,16 @@ if sys.version_info >= (3, 0): from ._cy3.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy3.core import NDArrayBase as _NDArrayBase - from ._cy3.core import _reg_extension, _reg_ndarray + from ._cy3.core import _reg_extension else: from ._cy2.core import _set_class_ndarray, _make_array, _from_dlpack from ._cy2.core import NDArrayBase as _NDArrayBase - from ._cy2.core import _reg_extension, _reg_ndarray + from ._cy2.core import _reg_extension except IMPORT_EXCEPT: # pylint: disable=wrong-import-position from ._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack from ._ctypes.ndarray import NDArrayBase as _NDArrayBase - from ._ctypes.ndarray import _reg_extension, _reg_ndarray + from ._ctypes.ndarray import _reg_extension def context(dev_type, dev_id=0): @@ -348,13 +348,8 @@ def __init__(self): def _tvm_handle(self): return self.handle.value """ - if issubclass(cls, _NDArrayBase): - assert fcreate is not None - assert hasattr(cls, "_array_type_code") - _reg_ndarray(cls, fcreate) - else: - assert hasattr(cls, "_tvm_tcode") - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") - _reg_extension(cls, fcreate) + assert hasattr(cls, "_tvm_tcode") + if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: + raise ValueError("Cannot register create when extension tcode is same as buildin") + _reg_extension(cls, fcreate) return cls diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py index e89812685eb22..8ee7fc5f2b5bc 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/node_generic.py @@ -23,11 +23,11 @@ from .base import string_types # Node base class -_CLASS_NODE_BASE = None +_CLASS_OBJECTS = None -def _set_class_node_base(cls): - global _CLASS_NODE_BASE - _CLASS_NODE_BASE = cls +def _set_class_objects(cls): + global _CLASS_OBJECTS + _CLASS_OBJECTS = cls def _scalar_type_inference(value): @@ -67,7 +67,7 @@ def convert_to_node(value): node : Node The corresponding node value. """ - if isinstance(value, _CLASS_NODE_BASE): + if isinstance(value, _CLASS_OBJECTS): return value if isinstance(value, bool): return const(value, 'uint1x1') @@ -81,7 +81,7 @@ def convert_to_node(value): if isinstance(value, dict): vlist = [] for item in value.items(): - if (not isinstance(item[0], _CLASS_NODE_BASE) and + if (not isinstance(item[0], _CLASS_OBJECTS) and not isinstance(item[0], string_types)): raise ValueError("key of map must already been a container type") vlist.append(item[0]) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 2dbb67dfbf739..a7947dbc38a2f 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -271,12 +271,3 @@ class TVMArray(ctypes.Structure): ("byte_offset", ctypes.c_uint64)] TVMArrayHandle = ctypes.POINTER(TVMArray) - -class TVMNDArrayContainer(ctypes.Structure): - """TVM NDArray::Container""" - _fields_ = [("dl_tensor", TVMArray), - ("manager_ctx", ctypes.c_void_p), - ("deleter", ctypes.c_void_p), - ("array_type_info", ctypes.c_int32)] - -TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer) diff --git a/python/tvm/ndarray.py b/python/tvm/ndarray.py index 2a7a532e660eb..b7fe780f46298 100644 --- a/python/tvm/ndarray.py +++ b/python/tvm/ndarray.py @@ -27,7 +27,10 @@ from ._ffi.ndarray import context, empty, from_dlpack from ._ffi.ndarray import _set_class_ndarray from ._ffi.ndarray import register_extension +from ._ffi.object import register_object + +@register_object class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 9cb797fa45e44..8a74fe5cdb7df 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -67,7 +67,7 @@ TVM_REGISTER_API("_Array") } auto node = make_node(); node->data = std::move(data); - *ret = runtime::ObjectRef(node); + *ret = Array(node); }); TVM_REGISTER_API("_ArrayGetItem") @@ -100,28 +100,28 @@ TVM_REGISTER_API("_Map") for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kObjectHandle) + CHECK(args[i + 1].IsObjectRef()) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = Map(node); } else { // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kObjectHandle) - << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kObjectHandle) + CHECK(args[i].IsObjectRef()) + << "key of str map need to be object"; + CHECK(args[i + 1].IsObjectRef()) << "value of map to be NodeRef"; data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = Map(node); } }); @@ -191,7 +191,7 @@ TVM_REGISTER_API("_MapItems") rkvs->data.push_back(kv.first); rkvs->data.push_back(kv.second); } - *ret = rkvs; + *ret = Array(rkvs); } else { auto* n = static_cast(ptr); auto rkvs = make_node(); @@ -199,7 +199,7 @@ TVM_REGISTER_API("_MapItems") rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } - *ret = rkvs; + *ret = Array(rkvs); } }); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 9d2d53e03eb85..99bb994fde88b 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -27,8 +27,12 @@ #include #include "runtime_base.h" +extern "C" { // deleter for arrays used by DLPack exporter -extern "C" void NDArrayDLPackDeleter(DLManagedTensor* tensor); +void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor); +// helper function to get NDArray's type index, only used by ctypes. +TVM_DLL int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex); +} namespace tvm { namespace runtime { @@ -53,8 +57,8 @@ inline size_t GetDataAlignment(const DLTensor& arr) { struct NDArray::Internal { // Default deleter for the container - static void DefaultDeleter(NDArray::Container* ptr) { - using tvm::runtime::NDArray; + static void DefaultDeleter(Object* ptr_obj) { + auto* ptr = static_cast(ptr_obj); if (ptr->manager_ctx != nullptr) { static_cast(ptr->manager_ctx)->DecRef(); } else if (ptr->dl_tensor.data != nullptr) { @@ -68,7 +72,8 @@ struct NDArray::Internal { // that are not allocated inside of TVM. // This enables us to create NDArray from memory allocated by other // frameworks that are DLPack compatible - static void DLPackDeleter(NDArray::Container* ptr) { + static void DLPackDeleter(Object* ptr_obj) { + auto* ptr = static_cast(ptr_obj); DLManagedTensor* tensor = static_cast(ptr->manager_ctx); if (tensor->deleter != nullptr) { (*tensor->deleter)(tensor); @@ -81,12 +86,13 @@ struct NDArray::Internal { DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); - // critical zone + + // critical zone: construct header NDArray::Container* data = new NDArray::Container(); - data->deleter = DefaultDeleter; - NDArray ret(data); - ret.data_ = data; + data->SetDeleter(DefaultDeleter); + // RAII now in effect + NDArray ret(GetObjectPtr(data)); // setup shape data->shape_ = std::move(shape); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); @@ -98,13 +104,21 @@ struct NDArray::Internal { return ret; } // Implementation of API function - static DLTensor* MoveAsDLTensor(NDArray arr) { - DLTensor* tensor = const_cast(arr.operator->()); - CHECK(reinterpret_cast(arr.data_) == tensor); - arr.data_ = nullptr; - return tensor; + static DLTensor* MoveToFFIHandle(NDArray arr) { + DLTensor* handle = NDArray::FFIGetHandle(arr); + ObjectRef::FFIClearAfterMove(&arr); + return handle; + } + static void FFIDecRef(TVMArrayHandle tensor) { + NDArray::FFIDecRef(tensor); } // Container to DLManagedTensor + static DLManagedTensor* ToDLPack(TVMArrayHandle handle) { + auto* from = static_cast( + reinterpret_cast(handle)); + return ToDLPack(from); + } + static DLManagedTensor* ToDLPack(NDArray::Container* from) { CHECK(from != nullptr); DLManagedTensor* ret = new DLManagedTensor(); @@ -114,29 +128,33 @@ struct NDArray::Internal { ret->deleter = NDArrayDLPackDeleter; return ret; } + // Delete dlpack object. + static void NDArrayDLPackDeleter(DLManagedTensor* tensor) { + static_cast(tensor->manager_ctx)->DecRef(); + delete tensor; + } }; -NDArray NDArray::CreateView(std::vector shape, - DLDataType dtype) { +NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { CHECK(data_ != nullptr); - CHECK(data_->dl_tensor.strides == nullptr) + CHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; - NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx); - ret.data_->dl_tensor.byte_offset = - this->data_->dl_tensor.byte_offset; - size_t curr_size = GetDataSize(this->data_->dl_tensor); - size_t view_size = GetDataSize(ret.data_->dl_tensor); + NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx); + ret.get_mutable()->dl_tensor.byte_offset = + this->get_mutable()->dl_tensor.byte_offset; + size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); + size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor); CHECK_LE(view_size, curr_size) << "Tries to create a view that has bigger memory than current one"; // increase ref count - this->data_->IncRef(); - ret.data_->manager_ctx = this->data_; - ret.data_->dl_tensor.data = this->data_->dl_tensor.data; + get_mutable()->IncRef(); + ret.get_mutable()->manager_ctx = get_mutable(); + ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data; return ret; } DLManagedTensor* NDArray::ToDLPack() const { - return Internal::ToDLPack(data_); + return Internal::ToDLPack(get_mutable()); } NDArray NDArray::Empty(std::vector shape, @@ -144,9 +162,9 @@ NDArray NDArray::Empty(std::vector shape, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content - size_t size = GetDataSize(ret.data_->dl_tensor); - size_t alignment = GetDataAlignment(ret.data_->dl_tensor); - ret.data_->dl_tensor.data = + size_t size = GetDataSize(ret.get_mutable()->dl_tensor); + size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); + ret.get_mutable()->dl_tensor.data = DeviceAPI::Get(ret->ctx)->AllocDataSpace( ret->ctx, size, alignment, ret->dtype); return ret; @@ -154,10 +172,12 @@ NDArray NDArray::Empty(std::vector shape, NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); - data->deleter = Internal::DLPackDeleter; + // construct header + data->SetDeleter(Internal::DLPackDeleter); + // fill up content. data->manager_ctx = tensor; data->dl_tensor = tensor->dl_tensor; - return NDArray(data); + return NDArray(GetObjectPtr(data)); } void NDArray::CopyFromTo(DLTensor* from, @@ -184,17 +204,24 @@ void NDArray::CopyFromTo(DLTensor* from, } std::vector NDArray::Shape() const { - return data_->shape_; + return get_mutable()->shape_; } +TVM_REGISTER_OBJECT_TYPE(NDArray::Container); + } // namespace runtime } // namespace tvm using namespace tvm::runtime; -void NDArrayDLPackDeleter(DLManagedTensor* tensor) { - static_cast(tensor->manager_ctx)->DecRef(); - delete tensor; +void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor) { + NDArray::Internal::NDArrayDLPackDeleter(tensor); +} + +int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { + API_BEGIN(); + *out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index(); + API_END(); } int TVMArrayAlloc(const tvm_index_t* shape, @@ -213,14 +240,14 @@ int TVMArrayAlloc(const tvm_index_t* shape, DLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; - *out = NDArray::Internal::MoveAsDLTensor( + *out = NDArray::Internal::MoveToFFIHandle( NDArray::Empty(std::vector(shape, shape + ndim), dtype, ctx)); API_END(); } int TVMArrayFree(TVMArrayHandle handle) { API_BEGIN(); - reinterpret_cast(handle)->DecRef(); + NDArray::Internal::FFIDecRef(handle); API_END(); } @@ -235,14 +262,14 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { API_BEGIN(); - *out = NDArray::Internal::MoveAsDLTensor(NDArray::FromDLPack(from)); + *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); API_END(); } int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { API_BEGIN(); - *out = NDArray::Internal::ToDLPack(reinterpret_cast(from)); + *out = NDArray::Internal::ToDLPack(from); API_END(); } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 1042a4f68e5e1..881788a5292c6 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -59,7 +59,8 @@ class RPCWrappedFunc { const TVMArgValue& arg); // deleter of RPC remote array - static void RemoteNDArrayDeleter(NDArray::Container* ptr) { + static void RemoteNDArrayDeleter(Object* obj) { + auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); delete space; @@ -71,12 +72,12 @@ class RPCWrappedFunc { void* nd_handle) { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; - data->deleter = RemoteNDArrayDeleter; + data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); space->sess = sess; space->data = tensor->data; data->dl_tensor.data = space; - NDArray ret(data); + NDArray ret(GetObjectPtr(data)); // RAII now in effect data->shape_ = std::vector( tensor->shape, tensor->shape + tensor->ndim); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 16b0e7f695298..77d39754b0953 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -787,9 +787,7 @@ class RPCSession::EventHandler : public dmlc::Stream { TVMValue ret_value_pack[2]; int ret_tcode_pack[2]; rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - - NDArray::Container* nd = static_cast(ret_value_pack[0].v_handle); - ret_value_pack[1].v_handle = nd; + ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; ret_tcode_pack[1] = kHandle; SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); } else { @@ -1190,7 +1188,8 @@ void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { void* handle = args[0]; - static_cast(handle)->DecRef(); + static_cast( + reinterpret_cast(handle))->DecRef(); } void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index ff2bbe8eaf11d..3e6140ed38304 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -31,7 +31,8 @@ namespace tvm { namespace runtime { namespace vm { -static void BufferDeleter(NDArray::Container* ptr) { +static void BufferDeleter(Object* obj) { + auto* ptr = static_cast(obj); CHECK(ptr->manager_ctx != nullptr); Buffer* buffer = reinterpret_cast(ptr->manager_ctx); MemoryManager::Global()->GetAllocator(buffer->ctx)-> @@ -40,7 +41,8 @@ static void BufferDeleter(NDArray::Container* ptr) { delete ptr; } -void StorageObj::Deleter(NDArray::Container* ptr) { +void StorageObj::Deleter(Object* obj) { + auto* ptr = static_cast(obj); // When invoking AllocNDArray we don't own the underlying allocation // and should not delete the buffer, but instead let it be reclaimed // by the storage object's destructor. @@ -77,16 +79,23 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa // TODO(@jroesch): generalize later to non-overlapping allocations. CHECK_EQ(offset, 0u); VerifyDataType(dtype); + + // crtical zone: allocate header, cannot throw NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx); - container->deleter = StorageObj::Deleter; + + container->SetDeleter(StorageObj::Deleter); size_t needed_size = GetDataSize(container->dl_tensor); - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; this->IncRef(); container->manager_ctx = reinterpret_cast(this); container->dl_tensor.data = this->buffer.data; - return NDArray(container); + NDArray ret(GetObjectPtr(container)); + + // RAII in effect, now run the check. + // TODO(@jroesch): generalize later to non-overlapping allocations. + CHECK(needed_size == this->buffer.size) + << "size mistmatch required " << needed_size << " found " << this->buffer.size; + + return ret; } MemoryManager* MemoryManager::Global() { @@ -108,14 +117,14 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx); - container->deleter = BufferDeleter; + container->SetDeleter(BufferDeleter); size_t size = GetDataSize(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor); Buffer *buffer = new Buffer; *buffer = this->Alloc(size, alignment, dtype); container->manager_ctx = reinterpret_cast(buffer); container->dl_tensor.data = buffer->data; - return NDArray(container); + return NDArray(GetObjectPtr(container)); } } // namespace vm diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index 78c8fb36bf703..292fb55e59950 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -120,7 +120,7 @@ class StorageObj : public Object { DLDataType dtype); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ - static void Deleter(NDArray::Container* ptr); + static void Deleter(Object* ptr); ~StorageObj() { auto alloc = MemoryManager::Global()->GetAllocator(buffer.ctx); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index f6f3b8f90e37d..a9b9b83ee250c 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include TEST(PackedFunc, Basic) { @@ -178,6 +179,69 @@ TEST(TypedPackedFunc, HighOrder) { CHECK_EQ(f1(3), 4); } + +TEST(PackedFunc, ObjectConversion) { + using namespace tvm; + using namespace tvm::runtime; + TVMRetValue rv; + auto x = NDArray::Empty( + {}, String2TVMType("float32"), + TVMContext{kDLCPU, 0}); + // assign null + rv = ObjectRef(); + CHECK_EQ(rv.type_code(), kNull); + + // Can assign NDArray to ret type + rv = x; + CHECK_EQ(rv.type_code(), kNDArrayContainer); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(x); + CHECK_EQ(rv.type_code(), kNDArrayContainer); + // Check convert back + CHECK(rv.operator NDArray().same_as(x)); + CHECK(rv.operator ObjectRef().same_as(x)); + CHECK(!rv.IsObjectRef()); + + auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args[0].type_code(), kNDArrayContainer); + CHECK(args[0].operator NDArray().same_as(x)); + CHECK(args[0].operator ObjectRef().same_as(x)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(args[1].operator Array().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); + pf1(x, ObjectRef()); + pf1(ObjectRef(x), NDArray()); + + // testcases for modules + auto* pf = tvm::runtime::Registry::Get("module.source_module_create"); + CHECK(pf != nullptr); + Module m = (*pf)("", "xyz"); + rv = m; + CHECK_EQ(rv.type_code(), kModuleHandle); + // Even if we assign base type it still shows as NDArray + rv = ObjectRef(m); + CHECK_EQ(rv.type_code(), kModuleHandle); + // Check convert back + CHECK(rv.operator Module().same_as(m)); + CHECK(rv.operator ObjectRef().same_as(m)); + CHECK(!rv.IsObjectRef()); + + auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args[0].type_code(), kModuleHandle); + CHECK(args[0].operator Module().same_as(m)); + CHECK(args[0].operator ObjectRef().same_as(m)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); + pf2(m, ObjectRef()); + pf2(ObjectRef(m), Module()); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index 206e143029cfd..92edbee9072fb 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import numpy as np def test_array(): a = tvm.convert([1,2,3]) @@ -71,6 +72,14 @@ def test_in_container(): assert tvm.make.StringImm('a') in arr assert 'd' not in arr +def test_ndarray_container(): + x = tvm.nd.array([1,2,3]) + arr = tvm.convert([x, x]) + assert arr[0].same_as(x) + assert arr[1].same_as(x) + assert isinstance(arr[0], tvm.nd.NDArray) + + if __name__ == "__main__": test_str_map() test_array() @@ -78,3 +87,4 @@ def test_in_container(): test_array_save_load_json() test_map_save_load_json() test_in_container() + test_ndarray_container() diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index b676cf2d5244b..951ea97bf252a 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -32,7 +32,7 @@ def gen_engine_header(): #include class Engine { }; - + #endif ''' header_file = header_file_dir_path.relpath("gcc_engine.h") @@ -45,7 +45,7 @@ def generate_engine_module(): #include #include #include "gcc_engine.h" - + extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5, float* gcc_input6, float* gcc_input7, float* out) { Engine engine; diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index ebbcf2106617b..143f7f2f98e56 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -35,7 +35,8 @@ rm -rf lib make cd ../.. -python3 -m pytest -v apps/extension/tests +TVM_FFI=cython python3 -m pytest -v apps/extension/tests +TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests TVM_FFI=ctypes python3 -m pytest -v tests/python/integration TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib