Skip to content

Commit

Permalink
[RUNTIME] Refactor object python FFI to new protocol. (#4128)
Browse files Browse the repository at this point in the history
* [RUNTIME] Refactor object python FFI to new protocol.

This is a pre-req to bring the Node system under object protocol.
Most of the code reflects the current code in the Node system.

- Use new instead of init so subclass can define their own constructors
- Allow register via name, besides type idnex
- Introduce necessary runtime C API functions
- Refactored Tensor and Datatype to directly use constructor.

* address review comments
  • Loading branch information
tqchen authored Oct 16, 2019
1 parent e3fbdc8 commit 02c1e11
Show file tree
Hide file tree
Showing 23 changed files with 440 additions and 252 deletions.
26 changes: 22 additions & 4 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kObjectCell = 14U,
kObjectHandle = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down Expand Up @@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle dst);

/*!
* \brief Get the tag from an object.
* \brief Get the type_index from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \param out_tindex the output type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);

/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_tindex the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);

/*!
* \brief Free the object.
*
* \param obj The object handle.
* \note Internally we decrease the reference counter of the object.
* The object will be freed when every reference to the object are removed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

#ifdef __cplusplus
} // TVM_EXTERN_C
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class Object {
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class TVMObjectCAPI;
};

/*!
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class TVMPODValue_ {
}
operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
Expand Down Expand Up @@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
}
TVMRetValue& operator=(ObjectRef other) {
this->Clear();
type_code_ = kObjectCell;
type_code_ = kObjectHandle;
// move the handle out
value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
Expand Down Expand Up @@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObjectCell: {
case kObjectHandle: {
*this = other.operator ObjectRef();
break;
}
Expand Down Expand Up @@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObjectCell: {
case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
Expand Down Expand Up @@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObjectCell: return "ObjectCell";
case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
Expand Down Expand Up @@ -1164,7 +1164,7 @@ class TVMArgsSetter {
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell;
type_codes_[i] = kObjectHandle;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase
from . import object as _object
from . import node as _node

FunctionHandle = ctypes.c_void_p
Expand Down Expand Up @@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args):
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_CELL
type_codes[i] = TypeCode.OBJECT_HANDLE
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
Expand Down Expand Up @@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args):
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE
assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
handle = ret_val.v_handle
return handle

Expand All @@ -247,6 +248,7 @@ def _handle_return_func(x):

# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
Expand Down
85 changes: 85 additions & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import

import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func


ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None

"""Maps object type to its constructor"""
OBJECT_TYPE = {}

def _register_object(index, cls):
"""register object class"""
OBJECT_TYPE[index] = cls


def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
obj.handle = handle
return obj

RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)


class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]

def __del__(self):
if _LIB is not None:
check_call(_LIB.TVMObjectFree(self.handle))

def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
52 changes: 0 additions & 52 deletions python/tvm/_ffi/_ctypes/vmobj.py

This file was deleted.

6 changes: 4 additions & 2 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObjectCell = 14
kObjectHandle = 14
kExtBegin = 15

cdef extern from "tvm/runtime/c_runtime_api.h":
Expand Down Expand Up @@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
int TVMObjectFree(ObjectHandle obj)
int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)


cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/_ffi/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.

include "./base.pxi"
include "./object.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
include "./vmobj.pxi"

12 changes: 6 additions & 6 deletions python/tvm/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObjectCell or
tcode == kObjectHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))

Expand Down Expand Up @@ -155,12 +155,12 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
Expand Down Expand Up @@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kObjectHandle:
return make_ret_object(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
Expand All @@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObjectCell:
return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))

Expand Down
Loading

0 comments on commit 02c1e11

Please sign in to comment.