Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays #2613

Merged
merged 9 commits into from
Feb 21, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/3rdparty/dlpack/include\
-I${TVM_ROOT}/3rdparty/HalideIR/src

PKG_LDFLAGS =-L${TVM_ROOT}/lib
PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)

ifeq ($(UNAME_S), Darwin)
Expand Down
29 changes: 28 additions & 1 deletion apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, handle):
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)

@property
def _tvm_handle(self):
Expand All @@ -42,3 +42,30 @@ def __getitem__(self, idx):

# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)


nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_addtional_info = tvm.get_global_func("tvm_ext.nd_get_addtional_info")

class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.

By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_info = 1

@staticmethod
def create(addtional_info):
return nd_create(addtional_info)

@property
junrushao marked this conversation as resolved.
Show resolved Hide resolved
def addtional_info(self):
return nd_get_addtional_info(self)

def __add__(self, other):
return nd_add_two(self, other)

tvm.register_extension(NDSubClass, NDSubClass)
85 changes: 84 additions & 1 deletion apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,87 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>

namespace tvm_ext {
using IntVector = std::vector<int>;
class NDSubClass;
} // namespace tvm_ext

namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime

using namespace tvm;
using namespace tvm::runtime;

namespace tvm_ext {
/*!
* \brief A subclass of TVM's NDArray.
*
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_info` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/
class NDSubClass : public tvm::runtime::NDArray {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_info_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
junrushao marked this conversation as resolved.
Show resolved Hide resolved
return c->array_type_info_ == array_type_info<NDSubClass>::code;
}
int addtional_info_{0};
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
}
~NDSubClass() {
this->reset();
}
NDSubClass AddWith(const NDSubClass &other) const {
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->addtional_info_ + b->addtional_info_));
}
int get_additional_info() const {
SubContainer *self = static_cast<SubContainer*>(data_);
CHECK(self != nullptr);
return self->addtional_info_;
}
};
} // namespace tvm_ext

namespace tvm_ext {

TVM_REGISTER_EXT_TYPE(IntVector);
Expand Down Expand Up @@ -64,6 +127,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int addtional_info = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(addtional_info));
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.AddWith(b);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_get_addtional_info")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = a.get_additional_info();
});

} // namespace tvm_ext

// External function exposed to runtime.
Expand Down
16 changes: 16 additions & 0 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b


def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
Expand All @@ -44,6 +45,7 @@ def ivec_cb(v2):

tvm.convert(ivec_cb)(ivec)


def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12
Expand All @@ -68,7 +70,21 @@ def check_llvm():
check_llvm()


def test_nd_subclass():
a = tvm_ext.NDSubClass.create(addtional_info=3)
b = tvm_ext.NDSubClass.create(addtional_info=5)
c = a + b
d = a + a
e = b + b
assert(a.addtional_info == 3)
assert(b.addtional_info == 5)
assert(c.addtional_info == 8)
assert(d.addtional_info == 6)
assert(e.addtional_info == 10)


if __name__ == "__main__":
test_nd_subclass()
test_extern_call()
test_ext_dev()
test_ext_vec()
Expand Down
29 changes: 26 additions & 3 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,30 @@ class NDArray {
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class TVMArgsSetter;
};

/*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
junrushao marked this conversation as resolved.
Show resolved Hide resolved
struct array_type_info {
/*! \brief the value of the traits */
static const int code = -1;
};

// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_info<NDArray> {
static const int code = 0;
};

/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
Expand All @@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note: do not use this function directly, use NDArray.
* \note do not use this function directly, use NDArray.
*/
class NDArray::Container {
public:
Expand Down Expand Up @@ -228,16 +248,19 @@ class NDArray::Container {

protected:
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
/*!
* \brief Type flag used to indicate subclass.
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_index_ to indicate
* and use the array_type_info_ to indicate
* the specific array subclass.
*/
uint32_t array_type_index_{0};
int32_t array_type_info_{0};
junrushao marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
Expand Down
48 changes: 35 additions & 13 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t);
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
struct extension_type_info {
static const int code = 0;
};

Expand Down Expand Up @@ -455,6 +455,15 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_info_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
Expand Down Expand Up @@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ {
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
Expand Down Expand Up @@ -727,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ {
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
extension_type_info<T>::code, other);
return *this;
}
/*!
Expand Down Expand Up @@ -1094,7 +1103,7 @@ class TVMArgsSetter {
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
Expand Down Expand Up @@ -1212,40 +1221,53 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {

// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
struct TVMValueCast<T, TSrc, true, false> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
}
};

} // namespace detail

template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMArgValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}

Expand All @@ -1262,9 +1284,9 @@ struct ExtTypeInfo {

template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
const int code = extension_type_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
"require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
Expand Down
Loading