Skip to content

Commit

Permalink
[REFACTOR][RUNTIME] Move NDArray to Object System.
Browse files Browse the repository at this point in the history
Previously NDArray has its own object reference counting mechanism.
This PR migrates NDArray to the unified object protocol.

The calling convention of NDArray remained intact.
That means NDArray still has its own type_code and
its handle is still DLTensor compatible.

In order to do so, this PR added a few minimum runtime type
detection in TVMArgValue and RetValue only when the corresponding
type is a base type(ObjectRef) that could also refer to NDArray.

This means that even if we return a base reference object ObjectRef
which refers to the NDArray. The type_code will still be translated
correctly as kNDArrayContainer.
If we assign a non-base type(say Expr) that we know is not compatible
with NDArray during compile time, no runtime type detection will be performed.

This PR also adopts the object protocol for NDArray sub-classing and
removed the legacy NDArray subclass protocol.
Examples in apps/extension are now updated to reflect that.

Making NDArray as an Object brings all the benefits of the object system.
For example, we can now use the Array container to store NDArrays.
  • Loading branch information
tqchen committed Dec 25, 2019
1 parent e91cc5a commit 6069489
Show file tree
Hide file tree
Showing 31 changed files with 675 additions and 615 deletions.
15 changes: 6 additions & 9 deletions apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,23 @@ def __getitem__(self, idx):

nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")

@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1

@staticmethod
def create(addtional_info):
return nd_create(addtional_info)
def create(additional_info):
return nd_create(additional_info)

@property
def addtional_info(self):
return nd_get_addtional_info(self)
def additional_info(self):
return nd_get_additional_info(self)

def __add__(self, other):
return nd_add_two(self, other)

tvm.register_extension(NDSubClass, NDSubClass)
71 changes: 32 additions & 39 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,6 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>

namespace tvm_ext {
class NDSubClass;
} // namespace tvm_ext

namespace tvm {
namespace runtime {
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime

using namespace tvm;
using namespace tvm::runtime;

Expand All @@ -65,41 +52,45 @@ class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
SubContainer(int additional_info) :
additional_info_(additional_info) {
type_index_ = SubContainer::RuntimeTypeIndex();
}
int addtional_info_{0};
int additional_info_{0};

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "tvm_ext.NDSubClass";
TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;

static void SubContainerDeleter(Object* obj) {
auto* ptr = static_cast<SubContainer*>(obj);
delete ptr;
}
~NDSubClass() {
this->reset();

NDSubClass() {}
explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
explicit NDSubClass(int additional_info) {
SubContainer* ptr = new SubContainer(additional_info);
ptr->SetDeleter(SubContainerDeleter);
data_ = GetObjectPtr<Object>(ptr);
}

NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
SubContainer *a = static_cast<SubContainer*>(get_mutable());
SubContainer *b = static_cast<SubContainer*>(other.get_mutable());
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
return NDSubClass(a->additional_info_ + b->additional_info_);
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
SubContainer *self = static_cast<SubContainer*>(get_mutable());
CHECK(self != nullptr);
return self->addtional_info_;
return self->additional_info_;
}
using ContainerType = SubContainer;
};

TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);

/*!
* \brief Introduce additional extension data structures
Expand Down Expand Up @@ -166,8 +157,10 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")

TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);

});

TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
Expand All @@ -177,7 +170,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
*rv = a.AddWith(b);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
Expand Down
15 changes: 8 additions & 7 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ def check_llvm():


def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
a = tvm_ext.NDSubClass.create(additional_info=3)
b = tvm_ext.NDSubClass.create(additional_info=5)
assert isinstance(a, tvm_ext.NDSubClass)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)
assert(a.additional_info == 3)
assert(b.additional_info == 5)
assert(c.additional_info == 8)
assert(d.additional_info == 6)
assert(e.additional_info == 10)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_

#include <tvm/node/node.h>

#include <type_traits>
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <utility>
#include <string>
#include "node.h"
#include "memory.h"

namespace tvm {

Expand Down
98 changes: 15 additions & 83 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
#include <limits>
Expand All @@ -43,22 +42,7 @@ using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return true;
return ptr->IsInstance<ContainerType>();
}
static void PrintName(std::ostream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};


template<typename T>
struct ObjectTypeChecker<Array<T> > {
Expand All @@ -73,10 +57,8 @@ struct ObjectTypeChecker<Array<T> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "List[";
ObjectTypeChecker<T>::PrintName(os);
os << "]";
static std::string TypeName() {
return "List[" + ObjectTypeChecker<T>::TypeName() + "]";
}
};

Expand All @@ -91,11 +73,9 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
return true;
}
static void PrintName(std::ostream& os) { // NOLINT(*)
os << "Map[str";
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[str, " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};

Expand All @@ -111,39 +91,16 @@ struct ObjectTypeChecker<Map<K, V> > {
}
return true;
}
static void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "Map[";
ObjectTypeChecker<K>::PrintName(os);
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << ']';
static std::string TypeName() {
return "Map[" +
ObjectTypeChecker<K>::TypeName() +
", " +
ObjectTypeChecker<V>::TypeName()+ ']';
}
};

template<typename T>
inline std::string ObjectTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}

// extensions for tvm arg value

template<typename TObjectRef>
inline TObjectRef TVMArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Node>(ptr));
}

inline TVMArgValue::operator tvm::Expr() const {
inline TVMPODValue_::operator tvm::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
Expand All @@ -164,12 +121,12 @@ inline TVMArgValue::operator tvm::Expr() const {
return Tensor(ObjectPtr<Node>(ptr))();
}
CHECK(ObjectTypeChecker<Expr>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Expr(ObjectPtr<Node>(ptr));
}

inline TVMArgValue::operator tvm::Integer() const {
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
Expand All @@ -179,35 +136,10 @@ inline TVMArgValue::operator tvm::Integer() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expected type " << ObjectTypeName<Expr>()
<< "Expect type " << ObjectTypeChecker<Expr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Node>(ptr));
}

template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return ObjectTypeChecker<TObjectRef>::Check(ptr);
}

// extensions for TVMRetValue
template<typename TObjectRef>
inline TObjectRef TVMRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);

Object* ptr = static_cast<Object*>(value_.v_handle);

CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expected type " << ObjectTypeName<TObjectRef>()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(ObjectPtr<Object>(ptr));
}

} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
1 change: 1 addition & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_

#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
Expand Down
Loading

0 comments on commit 6069489

Please sign in to comment.