Skip to content

Commit

Permalink
[PY][FFI] runtime.String to subclass str
Browse files Browse the repository at this point in the history
To make runtime.String to work as naturally as possible in the python side,
we make it sub-class the python's str object. Note that however, we cannot
sub-class Object at the same time due to python's type layout constraint(
cannot subclass from multiple classes with slots).

We introduce a PyNativeObject class to handle this kind of object sub-classing.
  • Loading branch information
tqchen committed Apr 24, 2020
1 parent cf5c63b commit 4c5ac4a
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 89 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ if(MSVC)
endif()
else(MSVC)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
message("Build in Debug mode")
message(STATUS "Build in Debug mode")
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _return_object(x):
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.handle = handle
return cls.__from_tvm_object__(cls, obj)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand All @@ -64,6 +68,33 @@ def _return_object(x):
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
# pylint: disable=assigning-non-slot
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj



class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .object import ObjectBase, _set_class_object
from .object import ObjectBase, PyNativeObject, _set_class_object
from . import object as _object

PackedFuncHandle = ctypes.c_void_p
Expand Down Expand Up @@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, PyNativeObject):
values[i].v_handle = arg.__tvm_object__.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/_ffi/_cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
object_type = OBJECT_TYPE
handle = ctypes_handle(chandle)
CALL(TVMObjectGetTypeIndex(chandle, &tindex))

if tindex < len(OBJECT_TYPE):
cls = OBJECT_TYPE[tindex]
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<ObjectBase>obj).chandle = chandle
return cls.__from_tvm_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)

(<ObjectBase>obj).chandle = chandle
return obj


class PyNativeObject:
"""Base class of all TVM objects that also subclass python's builtin types."""
__slots__ = []

def __init_tvm_object_by_constructor__(self, fconstructor, *args):
"""Initialize the internal tvm_object by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return object is directly set into the object
"""
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
obj.__init_handle_by_constructor__(fconstructor, *args)
self.__tvm_object__ = obj


cdef class ObjectBase:
cdef void* chandle

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, PyNativeObject):
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
tcode[0] = kTVMObjectHandle
elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr)
Expand Down
83 changes: 23 additions & 60 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
from tvm.runtime import _ffi_api
from .object import Object, PyNativeObject
from .object_generic import ObjectTypes
from . import _ffi_api


def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
Expand Down Expand Up @@ -112,64 +113,26 @@ def tuple_object(fields=None):


@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
class String(str, PyNativeObject):
"""TVM runtime.String object, represented as a python str.
Parameters
----------
string : str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
content : str
The content string used to construct the object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_ffi_api.String, string)

def __str__(self):
return _ffi_api.GetStdString(self)

def __len__(self):
return _ffi_api.GetStringSize(self)

def __hash__(self):
return _ffi_api.StringHash(self)

def __eq__(self, other):
if isinstance(other, string_types):
return self.__str__() == other

if not isinstance(other, String):
return False

return _ffi_api.CompareString(self, other) == 0

def __ne__(self, other):
return not self.__eq__(other)

def __gt__(self, other):
return _ffi_api.CompareString(self, other) > 0

def __lt__(self, other):
return _ffi_api.CompareString(self, other) < 0

def __getitem__(self, key):
return self.__str__()[key]

def startswith(self, string):
"""Check if the runtime string starts with a given string
Parameters
----------
string : str
The provided string
Returns
-------
ret : boolean
Return true if the runtime string starts with the given string,
otherwise, false.
"""
return self.__str__().startswith(string)
__slots__ = ["__tvm_object__"]

def __new__(cls, content):
"""Construct from string content."""
val = str.__new__(cls, content)
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
return val

@staticmethod
def __from_tvm_object__(cls, obj):
"""Construct from a given tvm object."""
content = _ffi_api.GetFFIString(obj)
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
14 changes: 6 additions & 8 deletions python/tvm/runtime/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
if _FFI_MODE == "ctypes":
raise ImportError()
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
from tvm._ffi._cy3.core import ObjectBase
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
except (RuntimeError, ImportError):
# pylint: disable=wrong-import-position,unused-import
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
from tvm._ffi._ctypes.object import ObjectBase
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject


def _new_object(cls):
Expand All @@ -41,6 +41,7 @@ def _new_object(cls):

class Object(ObjectBase):
"""Base class for all tvm's runtime objects."""
__slots__ = []
def __repr__(self):
return _ffi_node_api.AsRepr(self)

Expand Down Expand Up @@ -78,13 +79,10 @@ def __getstate__(self):
def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
handle = state['handle']
self.handle = None
if handle is not None:
json_str = handle
other = _ffi_node_api.LoadJSON(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None
self.__init_handle_by_constructor__(
_ffi_node_api.LoadJSON, handle)

def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm._ffi.runtime_ctypes import ObjectRValueRef

from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
from .module import Module
Expand All @@ -34,7 +34,7 @@ def asobject(self):
raise NotImplementedError()


ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)


def convert_to_object(value):
Expand Down
19 changes: 2 additions & 17 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file src/runtime/container.cc
* \brief Implementations of common plain old data (POD) containers.
* \brief Implementations of common containers.
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
Expand Down Expand Up @@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String")
return String(std::move(str));
});

TVM_REGISTER_GLOBAL("runtime.GetStringSize")
.set_body_typed([](String str) {
return static_cast<int64_t>(str.size());
});

TVM_REGISTER_GLOBAL("runtime.GetStdString")
TVM_REGISTER_GLOBAL("runtime.GetFFIString")
.set_body_typed([](String str) {
return std::string(str);
});

TVM_REGISTER_GLOBAL("runtime.CompareString")
.set_body_typed([](String lhs, String rhs) {
return lhs.compare(rhs);
});

TVM_REGISTER_GLOBAL("runtime.StringHash")
.set_body_typed([](String str) {
return static_cast<int64_t>(std::hash<String>()(str));
});

TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
Expand Down
5 changes: 5 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});

TVM_REGISTER_GLOBAL("testing.echo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0];
});

TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_runtime_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import tvm
import pickle
from tvm import te
from tvm import nd, relay
from tvm.runtime import container as _container
Expand Down Expand Up @@ -56,6 +57,30 @@ def test_tuple_object():
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))


def test_string():
s = tvm.runtime.String("xyz")

assert isinstance(s, tvm.runtime.String)
assert isinstance(s, str)
assert s.startswith("xy")
assert s + "1" == "xyz1"
y = tvm.testing.echo(s)
assert isinstance(y, tvm.runtime.String)
assert s.__tvm_object__.same_as(y.__tvm_object__)
assert s == y

x = tvm.ir.load_json(tvm.ir.save_json(y))
assert isinstance(x, tvm.runtime.String)
assert x == y

# test pickle
z = pickle.loads(pickle.dumps(s))
assert isinstance(z, tvm.runtime.String)
assert s == z


if __name__ == "__main__":
test_string()

test_adt_constructor()
test_tuple_object()

0 comments on commit 4c5ac4a

Please sign in to comment.