Skip to content

Commit

Permalink
[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and wweic committed Mar 12, 2019
1 parent 9e08e45 commit 399a6ce
Show file tree
Hide file tree
Showing 18 changed files with 280 additions and 54 deletions.
2 changes: 1 addition & 1 deletion apps/extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/3rdparty/dlpack/include\
-I${TVM_ROOT}/3rdparty/HalideIR/src

PKG_LDFLAGS =-L${TVM_ROOT}/lib
PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)

ifeq ($(UNAME_S), Darwin)
Expand Down
29 changes: 28 additions & 1 deletion apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, handle):
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)

@property
def _tvm_handle(self):
Expand All @@ -42,3 +42,30 @@ def __getitem__(self, idx):

# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)


nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")

class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_code = 1

@staticmethod
def create(addtional_info):
return nd_create(addtional_info)

@property
def addtional_info(self):
return nd_get_addtional_info(self)

def __add__(self, other):
return nd_add_two(self, other)

tvm.register_extension(NDSubClass, NDSubClass)
85 changes: 84 additions & 1 deletion apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,87 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>

namespace tvm_ext {
using IntVector = std::vector<int>;
class NDSubClass;
} // namespace tvm_ext

namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime

using namespace tvm;
using namespace tvm::runtime;

namespace tvm_ext {
/*!
* \brief A subclass of TVM's NDArray.
*
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_code` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/
class NDSubClass : public tvm::runtime::NDArray {
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_code_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_code_ == array_type_info<NDSubClass>::code;
}
int addtional_info_{0};
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
}
~NDSubClass() {
this->reset();
}
NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
CHECK(self != nullptr);
return self->addtional_info_;
}
};
} // namespace tvm_ext

namespace tvm_ext {

TVM_REGISTER_EXT_TYPE(IntVector);
Expand Down Expand Up @@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.AddWith(b);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
});

} // namespace tvm_ext

// External function exposed to runtime.
Expand Down
16 changes: 16 additions & 0 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b


def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
Expand All @@ -44,6 +45,7 @@ def ivec_cb(v2):

tvm.convert(ivec_cb)(ivec)


def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12
Expand All @@ -68,7 +70,21 @@ def check_llvm():
check_llvm()


def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)


if __name__ == "__main__":
test_nd_subclass()
test_extern_call()
test_ext_dev()
test_ext_vec()
Expand Down
29 changes: 26 additions & 3 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,30 @@ class NDArray {
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class TVMArgsSetter;
};

/*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
struct array_type_info {
/*! \brief the value of the traits */
static const int code = -1;
};

// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_info<NDArray> {
static const int code = 0;
};

/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
Expand All @@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note: do not use this function directly, use NDArray.
* \note do not use this function directly, use NDArray.
*/
class NDArray::Container {
public:
Expand Down Expand Up @@ -228,16 +248,19 @@ class NDArray::Container {

protected:
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
/*!
* \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_index_ to indicate
* and use the array_type_code_ to indicate
* the specific array subclass.
*/
uint32_t array_type_index_{0};
int32_t array_type_code_{0};
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
Expand Down
48 changes: 35 additions & 13 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t);
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
struct extension_type_info {
static const int code = 0;
};

Expand Down Expand Up @@ -455,6 +455,15 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_code_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
Expand Down Expand Up @@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ {
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
Expand Down Expand Up @@ -727,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ {
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
extension_type_info<T>::code, other);
return *this;
}
/*!
Expand Down Expand Up @@ -1094,7 +1103,7 @@ class TVMArgsSetter {
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
Expand Down Expand Up @@ -1212,40 +1221,53 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {

// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
struct TVMValueCast<T, TSrc, true, false> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
}
};

} // namespace detail

template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMArgValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}

Expand All @@ -1262,9 +1284,9 @@ struct ExtTypeInfo {

template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
const int code = extension_type_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
"require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
Expand Down
Loading

0 comments on commit 399a6ce

Please sign in to comment.