Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays #2613

Merged
merged 9 commits into from
Feb 21, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/3rdparty/dlpack/include\
-I${TVM_ROOT}/3rdparty/HalideIR/src

PKG_LDFLAGS =-L${TVM_ROOT}/lib
PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)

ifeq ($(UNAME_S), Darwin)
Expand Down
29 changes: 28 additions & 1 deletion apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, handle):
def __del__(self):
# You can also call your own customized
# deleter if you can free it via your own FFI.
tvm.nd.free_extension_handle(self.handle, 17)
tvm.nd.free_extension_handle(self.handle, self.__class__._tvm_tcode)

@property
def _tvm_handle(self):
Expand All @@ -42,3 +42,30 @@ def __getitem__(self, idx):

# Register IntVec extension on python side.
tvm.register_extension(IntVec, IntVec)


nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_tracing = tvm.get_global_func("tvm_ext.nd_get_tracing")

class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.

By inheriting TMV's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_index = 1

@staticmethod
def create(is_tracing):
return nd_create(is_tracing)

@property
junrushao marked this conversation as resolved.
Show resolved Hide resolved
def is_tracing(self):
return bool(nd_get_tracing(self))

def __add__(self, other):
return nd_add_two(self, other)

tvm.register_extension(NDSubClass, NDSubClass)
68 changes: 68 additions & 0 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>

namespace tvm_ext {
using IntVector = std::vector<int>;
class NDSubClass;
} // namespace tvm_ext

namespace tvm {
Expand All @@ -19,12 +22,57 @@ template<>
struct extension_class_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_index<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
} // namespace runtime

using namespace tvm;
using namespace tvm::runtime;

namespace tvm_ext {
class NDSubClass : public tvm::runtime::NDArray {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
public:
class SubContainer : public NDArray::Container {
public:
SubContainer(bool is_tracing) {
array_type_index_ = array_type_index<NDSubClass>::code;
junrushao marked this conversation as resolved.
Show resolved Hide resolved
is_tracing_ = is_tracing;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
junrushao marked this conversation as resolved.
Show resolved Hide resolved
return c->array_type_index_ == array_type_index<NDSubClass>::code;
}
bool is_tracing_{false};
junrushao marked this conversation as resolved.
Show resolved Hide resolved
};
NDSubClass(NDArray::Container *container) {
if (container == nullptr) {
data_ = nullptr;
return;
}
CHECK(SubContainer::Is(container));
container->IncRef();
data_ = container;
}
~NDSubClass() {
this->reset();
}
NDSubClass addWith(const NDSubClass &other) const {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
SubContainer *a = static_cast<SubContainer*>(data_);
SubContainer *b = static_cast<SubContainer*>(other.data_);
CHECK(a != nullptr && b != nullptr);
return NDSubClass(new SubContainer(a->is_tracing_ || b->is_tracing_));
}
bool get_tracing() const {
SubContainer *self = static_cast<SubContainer*>(data_);
CHECK(self != nullptr);
return self->is_tracing_;
}
};
} // namespace tvm_ext

namespace tvm_ext {

TVM_REGISTER_EXT_TYPE(IntVector);
Expand Down Expand Up @@ -64,6 +112,26 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = (*tvm::runtime::Registry::Get("device_api.cpu"))();
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
bool is_tracing = args[0];
*rv = NDSubClass(new NDSubClass::SubContainer(is_tracing));
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
NDSubClass b = args[1];
*rv = a.addWith(b);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_get_tracing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
NDSubClass a = args[0];
*rv = (bool)(a.get_tracing());
});

} // namespace tvm_ext

// External function exposed to runtime.
Expand Down
16 changes: 16 additions & 0 deletions apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_sym_add():
c = tvm_ext.sym_add(a, b)
assert c.a == a and c.b == b


def test_ext_vec():
ivec = tvm_ext.ivec_create(1, 2, 3)
assert(isinstance(ivec, tvm_ext.IntVec))
Expand All @@ -44,6 +45,7 @@ def ivec_cb(v2):

tvm.convert(ivec_cb)(ivec)


def test_extract_ext():
fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
assert fdict["mul"](3, 4) == 12
Expand All @@ -68,7 +70,21 @@ def check_llvm():
check_llvm()


def test_nd_subclass():
a = tvm_ext.NDSubClass.create(is_tracing=False)
b = tvm_ext.NDSubClass.create(is_tracing=True)
c = a + b
d = a + a
e = b + b
assert(a.is_tracing == False)
assert(b.is_tracing == True)
assert(c.is_tracing == True)
assert(d.is_tracing == False)
assert(e.is_tracing == True)


if __name__ == "__main__":
test_nd_subclass()
test_extern_call()
test_ext_dev()
test_ext_vec()
Expand Down
27 changes: 25 additions & 2 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,30 @@ class NDArray {
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class TVMArgsSetter;
};

/*!
* \brief The type trait indicates subclass of TVM's NDArray.
* For irrelavant classes, code = -1.
* For TVM NDArray itself, code = 0.
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
junrushao marked this conversation as resolved.
Show resolved Hide resolved
struct array_type_index {
/*! \brief the value of the traits */
static const int code = -1;
};

// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_index<NDArray> {
static const int code = 0;
};

/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
Expand All @@ -196,7 +216,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note: do not use this function directly, use NDArray.
* \note do not use this function directly, use NDArray.
*/
class NDArray::Container {
public:
Expand Down Expand Up @@ -228,6 +248,9 @@ class NDArray::Container {

protected:
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
/*!
* \brief Type flag used to indicate subclass.
Expand All @@ -237,7 +260,7 @@ class NDArray::Container {
* and use the array_type_index_ to indicate
* the specific array subclass.
*/
uint32_t array_type_index_{0};
int32_t array_type_index_{0};
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
Expand Down
32 changes: 27 additions & 5 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,15 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TNDArray,
typename = typename std::enable_if<
std::is_base_of<NDArray, TNDArray>::value>::type>
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_index_, array_type_index<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
Expand Down Expand Up @@ -561,7 +570,7 @@ class TVMArgValue : public TVMPODValue_ {
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
Expand Down Expand Up @@ -1212,32 +1221,45 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {

// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
template<typename T, typename TSrc, bool is_ext, bool is_nd>
struct TVMValueCast {
static T Apply(const TSrc* self) {
static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions");
return self->template AsNodeRef<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
struct TVMValueCast<T, TSrc, true, false> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};

template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, false, true> {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
static T Apply(const TSrc* self) {
return self->template AsNDArray<T>();
}
};

} // namespace detail

template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMArgValue,
(extension_class_info<T>::code != 0),
(array_type_index<T>::code > 0)>
junrushao marked this conversation as resolved.
Show resolved Hide resolved
::Apply(this);
}

template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
TVMValueCast<T, TVMRetValue,
(extension_class_info<T>::code != 0),
(array_type_index<T>::code > 0)>
::Apply(this);
}

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,13 @@ def _handle_return_func(x):
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)

_CLASS_MODULE = None
_CLASS_FUNCTION = None
Expand Down
19 changes: 15 additions & 4 deletions python/tvm/_ffi/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
from ..runtime_ctypes import TVMArrayHandle, TVMNDArrayContainerHandle
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle


Expand All @@ -28,7 +28,7 @@ def _from_dlpack(dltensor):
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
return _make_array(handle, False)
return _make_array(handle, False, False)
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")


Expand Down Expand Up @@ -77,9 +77,15 @@ def to_dlpack(self):
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)


def _make_array(handle, is_view):
def _make_array(handle, is_view, is_container):
global _TVM_ND_CLS
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
fcreate = _CLASS_NDARRAY
if is_container and _TVM_ND_CLS:
array_type_index = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_index.value
if array_type_index > 0:
fcreate = _TVM_ND_CLS[array_type_index]
return fcreate(handle, is_view)

_TVM_COMPATS = ()

Expand All @@ -91,6 +97,11 @@ def _reg_extension(cls, fcreate):
RETURN_SWITCH[cls._tvm_tcode] = fret
C_TO_PY_ARG_SWITCH[cls._tvm_tcode] = _wrap_arg_func(fret, cls._tvm_tcode)

_TVM_ND_CLS = {}

def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_index] = fcreate

_CLASS_NDARRAY = None

Expand Down
Loading