Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Refactor object python FFI to new protocol. #4128

Merged
merged 2 commits into from
Oct 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ typedef enum {
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kObjectCell = 14U,
kObjectHandle = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down Expand Up @@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
TVMStreamHandle dst);

/*!
* \brief Get the tag from an object.
* \brief Get the type_index from an object.
*
* \param obj The object handle.
* \param tag The tag of object.
* \param out_tindex the output type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);

/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_tindex the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);

/*!
* \brief Free the object.
*
* \param obj The object handle.
* \note Internally we decrease the reference counter of the object.
* The object will be freed when every reference to the object are removed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);

#ifdef __cplusplus
} // TVM_EXTERN_C
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class Object {
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class TVMObjectCAPI;
};

/*!
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class TVMPODValue_ {
}
operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
Expand Down Expand Up @@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
}
TVMRetValue& operator=(ObjectRef other) {
this->Clear();
type_code_ = kObjectCell;
type_code_ = kObjectHandle;
// move the handle out
value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
Expand Down Expand Up @@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObjectCell: {
case kObjectHandle: {
*this = other.operator ObjectRef();
break;
}
Expand Down Expand Up @@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObjectCell: {
case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
Expand Down Expand Up @@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObjectCell: return "ObjectCell";
case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
Expand Down Expand Up @@ -1164,7 +1164,7 @@ class TVMArgsSetter {
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell;
type_codes_[i] = kObjectHandle;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
from .node import NodeBase
from . import object as _object
from . import node as _node

FunctionHandle = ctypes.c_void_p
Expand Down Expand Up @@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args):
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_CELL
type_codes[i] = TypeCode.OBJECT_HANDLE
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
Expand Down Expand Up @@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args):
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE
assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
handle = ret_val.v_handle
return handle

Expand All @@ -247,6 +248,7 @@ def _handle_return_func(x):

# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
Expand Down
85 changes: 85 additions & 0 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import

import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func


ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None

"""Maps object type to its constructor"""
OBJECT_TYPE = {}

def _register_object(index, cls):
"""register object class"""
OBJECT_TYPE[index] = cls


def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tindex = ctypes.c_uint()
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
obj.handle = handle
return obj

RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)


class ObjectBase(object):
"""Base object for all object types"""
__slots__ = ["handle"]

def __del__(self):
if _LIB is not None:
check_call(_LIB.TVMObjectFree(self.handle))

def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
52 changes: 0 additions & 52 deletions python/tvm/_ffi/_ctypes/vmobj.py

This file was deleted.

6 changes: 4 additions & 2 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObjectCell = 14
kObjectHandle = 14
kExtBegin = 15

cdef extern from "tvm/runtime/c_runtime_api.h":
Expand Down Expand Up @@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
int TVMArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out)
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
int TVMGetObjectTag(ObjectHandle obj, int* tag)
int TVMObjectFree(ObjectHandle obj)
int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)


cdef extern from "tvm/c_dsl_api.h":
int TVMNodeFree(NodeHandle handle)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/_ffi/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.

include "./base.pxi"
include "./object.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
include "./vmobj.pxi"

12 changes: 6 additions & 6 deletions python/tvm/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObjectCell or
tcode == kObjectHandle or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))

Expand Down Expand Up @@ -155,12 +155,12 @@ cdef inline int make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
temp_args.append(arg)
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObjectCell
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
Expand Down Expand Up @@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value."""
if tcode == kNodeHandle:
return make_ret_node(value.v_handle)
elif tcode == kObjectHandle:
return make_ret_object(value.v_handle)
elif tcode == kNull:
return None
elif tcode == kInt:
Expand All @@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObjectCell:
return make_ret_object(value.v_handle)
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))

Expand Down
Loading