diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index b4f59526a900..4df2daffeb61 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -143,6 +143,18 @@ typedef enum { kTVMFFIMap = 72, /*! \brief Runtime dynamic loaded module object. */ kTVMFFIModule = 73, + /*! + * \brief Opaque python object. + * + * This is a special type index to indicate we are storing an opaque PyObject. + * Such object may interact with callback functions that are registered to support + * python-related operations. + * + * We only translate the objects that we do not recognize into this type index. + * + * \sa TVMFFIObjectCreateOpaque + */ + kTVMFFIOpaquePyObject = 74, kTVMFFIStaticObjectEnd, // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) /*! \brief Start of type indices that are allocated at runtime. */ @@ -344,11 +356,19 @@ typedef struct { TVMFFISafeCallType safe_call; } TVMFFIFunctionCell; +/*! + * \brief Object cell for opaque object following header. + */ +typedef struct { + /*! \brief The handle of the opaque object, for python it is PyObject* */ + void* handle; +} TVMFFIOpaqueObjectCell; + //------------------------------------------------------------ // Section: Basic object API //------------------------------------------------------------ /*! - * \brief Increas the strong reference count of an object handle + * \brief Increase the strong reference count of an object handle * \param obj The object handle. * \note Internally we increase the reference counter of the object. * \return 0 when success, nonzero when failure happens @@ -362,6 +382,33 @@ TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj); */ TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); +/*! + * \brief Create an Opaque object by passing in handle, type_index and deleter. + * + * The opaque object's lifetime is managed as an Object, so it can be retained + * and released like other objects. + * When the opaque object is kTVMFFIOpaquePyObject, it can be converted back to + * the python type when returned or passed as arguments to a python function. + * + * We can support ffi::Function that interacts with these objects, + * most likely callback registered from python. + * + * For language bindings, we only convert types that we do not recognize into this type. + * On the C++ side, the most common way to represent such OpaqueObject is to simply + * use ffi::ObjectRef or ffi::Any. + * + * \param handle The resource handle of the opaque object. + * \param type_index The type index of the object. + * \param deleter deleter to recycle + * \param out The output of the opaque object. + * \return 0 when success, nonzero when failure happens + * + * \note The caller must ensure the type_index is a valid opaque object type index. + * \sa kTVMFFIOpaquePyObject + */ +TVM_FFI_DLL int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, + void (*deleter)(void* handle), TVMFFIObjectHandle* out); + /*! * \brief Convert type key to type index. * \param type_key The key of the type. @@ -374,82 +421,73 @@ TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* o // Section: Basic function calling API for function implementation //----------------------------------------------------------------------- /*! - * \brief Create a FFIFunc by passing in callbacks from C callback. - * - * The registered function then can be pulled by the backend by the name. - * + * \brief Create a FFIFunc by passing in callbacks from a C callback. + * The registered function can then be retrieved by the backend using its name. * \param self The resource handle of the C callback. - * \param safe_call The C callback implementation - * \param deleter deleter to recycle + * \param safe_call The C callback implementation. + * \param deleter The deleter to recycle. * \param out The output of the function. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self), TVMFFIObjectHandle* out); /*! - * \brief Get a global function registered in system. - * + * \brief Get a global function registered in the system. * \param name The name of the function. - * \param out the result function pointer, NULL if it does not exist. - * \return 0 when success, nonzero when failure happens + * \param out The result function pointer, NULL if it does not exist. + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); /*! - * \brief Convert a AnyView to an owned Any. + * \brief Convert an AnyView to an owned Any. * \param any The AnyView to convert. - * \param out The output Any, must be an empty object - * \return 0 when success, nonzero when failure happens + * \param out The output Any, must be an empty object. + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); /*! * \brief Call a FFIFunc by passing in arguments. - * * \param func The resource handle of the C callback. * \param args The input arguments to the call. * \param num_args The number of input arguments. * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result); /*! - * \brief Move the last error from the environment to result. - * + * \brief Move the last error from the environment to the result. * \param result The result error. - * * \note This function clears the error stored in the TLS. */ TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result); /*! - * \brief Set raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * + * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. * \param error The error object handle */ TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); /*! - * \brief Set raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. - * + * \brief Set a raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. * \param kind The kind of the error. * \param message The error message. - * \note This is a convenient method for C API side to set error directly from string. + * \note This is a convenient method for the C API side to set an error directly from a string. */ TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message); /*! * \brief Create an initial error object. - * * \param kind The kind of the error. * \param message The error message. * \param traceback The traceback of the error. * \return The created error object handle. - * \note This function is different from other functions as it is used in error handling loop. - * So we do not follow normal error handling patterns via returning error code. + * \note This function is different from other functions as it is used in the error handling loop. + * So we do not follow normal error handling patterns via returning an error code. */ TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, const TVMFFIByteArray* message, @@ -461,29 +499,29 @@ TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, /*! * \brief Produce a managed NDArray from a DLPack tensor. * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_alignment The minimum alignment required of the data + byte_offset. * \param require_contiguous Boolean flag indicating if we need to check for contiguity. * \param out The output NDArray handle. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t require_alignment, int32_t require_contiguous, TVMFFIObjectHandle* out); /*! - * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \brief Produce a DLManagedTensor from the array that shares data memory with the array. * \param from The source array. * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); /*! * \brief Produce a managed NDArray from a DLPack tensor. * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_alignment The minimum alignment required of the data + byte_offset. * \param require_contiguous Boolean flag indicating if we need to check for contiguity. * \param out The output NDArray handle. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, int32_t require_alignment, @@ -491,10 +529,10 @@ TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, TVMFFIObjectHandle* out); /*! - * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \brief Produce a DLManagedTensor from the array that shares data memory with the array. * \param from The source array. * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, DLManagedTensorVersioned** out); @@ -508,7 +546,7 @@ TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, * \brief Convert a string to a DLDataType. * \param str The string to convert. * \param out The output DLDataType. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); @@ -516,7 +554,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* * \brief Convert a DLDataType to a string. * \param dtype The DLDataType to convert. * \param out The output string. -* \return 0 when success, nonzero when failure happens +* \return 0 on success, nonzero on failure. * \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef. The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. @@ -530,25 +568,25 @@ TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out); // The reflec //------------------------------------------------------------ /*! - * \brief Getter that can take address of a field and set the result. + * \brief Getter that can take the address of a field and set the result. * \param field The raw address of the field. * \param result Stores the result. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result); /*! - * \brief Getter that can take address of a field and set to value. + * \brief Getter that can take the address of a field and set it to a value. * \param field The raw address of the field. * \param value The value to set. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value); /*! - * \brief Function that create a new instance of the type. + * \brief Function that creates a new instance of the type. * \param result The new object handle - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ typedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle* result); @@ -808,68 +846,55 @@ typedef struct TVMFFITypeInfo { /*! * \brief Register the function to runtime's global table. - * - * The registered function then can be pulled by the backend by the name. - * + * The registered function can then be retrieved by the backend using its name. * \param name The name of the function. * \param f The function to be registered. - * \param allow_override Whether allow override already registered function. - * \return 0 when success, nonzero when failure happens + * \param allow_override Whether to allow overriding an already registered function. + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, int allow_override); /*! * \brief Register the function to runtime's global table with method info. - * - * This is same as TVMFFIFunctionSetGlobal but with method info that can provide extra + * This is the same as TVMFFIFunctionSetGlobal but with method info that can provide extra * metadata used in the runtime. - * * \param method_info The method info to be registered. - * \param override Whether allow override already registered function. - * \return 0 when success, nonzero when failure happens + * \param override Whether to allow overriding an already registered function. + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* method_info, int allow_override); /*! * \brief Register type field information for runtime reflection. - * \param type_index The type index - * \param info The field info to be registered. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info); /*! * \brief Register type method information for runtime reflection. - * \param type_index The type index - * \param info The method info to be registered. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info); /*! * \brief Register type creator information for runtime reflection. - * \param type_index The type index - * \param metadata The extra information to be registered. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata* metadata); /*! * \brief Register extra type attributes that can be looked up during runtime. - * \param type_index The type index - * \param attr_value The attribute value to be registered. - * \return 0 when success, nonzero when failure happens + * \return 0 on success, nonzero on failure. */ TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray* attr_name, const TVMFFIAny* attr_value); /*! * \brief Get the type attribute column by name. - * \param attr_name The name of the attribute. * \return The pointer to the type attribute column. - * \return NULL if the attribute was not registered in the system + * \return NULL if the attribute was not registered in the system. */ TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name); @@ -890,22 +915,19 @@ TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByte * or we should stop at the ffi boundary when detected * \return The traceback string * - * \note filename/func can be nullptr, then these info are skipped, they are useful - * for cases when debug symbols is not available. + * \note filename/func can be nullptr, then this info is skipped, they are useful + * for cases when debug symbols are not available. */ TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func, int cross_ffi_boundary); /*! * \brief Initialize the type info during runtime. - * - * When the function is first time called for a type, - * it will register the type to the type table in the runtime. - * - * If the static_tindex is non-negative, the function will - * allocate a runtime type index. - * Otherwise, we will populate the type table and return the static index. - * + * When the function is first called for a type, + * it will register the type to the type table in the runtime. + * If the static_tindex is non-negative, the function will + * allocate a runtime type index. + * Otherwise, we will populate the type table and return the static index. * \param type_key The type key. * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index * \param num_child_slots Number of slots reserved for its children. @@ -923,10 +945,7 @@ TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, /*! * \brief Get dynamic type info by type index. - * - * \param type_index The type index - * \param result The output type information - * \return The type info + * \return The type info. */ TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); @@ -974,7 +993,7 @@ inline TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) { /*! * \brief Get the data pointer of a ErrorInfo from an Error object. * \param obj The object handle. - * \return The data pointer. + * \return The cell pointer. */ inline TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); @@ -983,16 +1002,26 @@ inline TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { /*! * \brief Get the data pointer of a function cell from a function object. * \param obj The object handle. - * \return The data pointer. + * \return The cell pointer. */ inline TVMFFIFunctionCell* TVMFFIFunctionGetCellPtr(TVMFFIObjectHandle obj) { return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); } +/*! + * \brief Get the data pointer of a opaque object cell from a opaque object. + * \param obj The object handle. + * \return The cell pointer. + */ +inline TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + + sizeof(TVMFFIObject)); +} + /*! * \brief Get the data pointer of a shape array from a shape object. * \param obj The object handle. - * \return The data pointer. + * \return The cell pointer. */ inline TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) { return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); diff --git a/ffi/python/tvm_ffi/convert.py b/ffi/python/tvm_ffi/convert.py index 5b25ddae259b..94c82991101b 100644 --- a/ffi/python/tvm_ffi/convert.py +++ b/ffi/python/tvm_ffi/convert.py @@ -56,13 +56,13 @@ def convert(value: Any) -> Any: return None elif hasattr(value, "__dlpack__"): return core.from_dlpack( - value, - required_alignment=core.__dlpack_auto_import_required_alignment__, + value, required_alignment=core.__dlpack_auto_import_required_alignment__ ) elif isinstance(value, Exception): return core._convert_to_ffi_error(value) else: - raise TypeError(f"don't know how to convert type {type(value)} to object") + # in this case, it is an opaque python object + return core._convert_to_opaque_object(value) core._set_func_convert_to_object(convert) diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index 4a47efd773d9..4acf5f0a1717 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -53,6 +53,7 @@ cdef extern from "tvm/ffi/c_api.h": kTVMFFIArray = 71 kTVMFFIMap = 72 kTVMFFIModule = 73 + kTVMFFIOpaquePyObject = 74 ctypedef void* TVMFFIObjectHandle @@ -111,6 +112,9 @@ cdef extern from "tvm/ffi/c_api.h": const char* data size_t size + ctypedef struct TVMFFIOpaqueObjectCell: + void* handle + ctypedef struct TVMFFIShapeCell: const int64_t* data size_t size @@ -172,6 +176,8 @@ cdef extern from "tvm/ffi/c_api.h": const TVMFFITypeMetadata* metadata int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil + int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, + void (*deleter)(void*), TVMFFIObjectHandle* out) nogil int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) nogil @@ -203,6 +209,7 @@ cdef extern from "tvm/ffi/c_api.h": TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil + TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) nogil TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 064473e134c4..fc273b5cee0f 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -46,6 +46,8 @@ cdef inline object make_ret(TVMFFIAny result): if type_index == kTVMFFINDArray: # specially handle NDArray as it needs a special dltensor field return make_ndarray_from_any(result) + elif type_index == kTVMFFIOpaquePyObject: + return make_ret_opaque_object(result) elif type_index >= kTVMFFIStaticObjectBegin: return make_ret_object(result) elif type_index == kTVMFFINone: @@ -182,7 +184,10 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args, out[i].v_ptr = (arg).chandle temp_args.append(arg) else: - raise TypeError("Unsupported argument type: %s" % type(arg)) + arg = _convert_to_opaque_object(arg) + out[i].type_index = kTVMFFIOpaquePyObject + out[i].v_ptr = (arg).chandle + temp_args.append(arg) cdef inline int FuncCall3(void* chandle, @@ -431,9 +436,9 @@ def _get_global_func(name, allow_missing): # handle callbacks -cdef void tvm_ffi_callback_deleter(void* fhandle) noexcept with gil: - local_pyfunc = (fhandle) - Py_DECREF(local_pyfunc) +cdef void tvm_ffi_pyobject_deleter(void* fhandle) noexcept with gil: + local_pyobject = (fhandle) + Py_DECREF(local_pyobject) cdef int tvm_ffi_callback(void* context, @@ -468,12 +473,27 @@ def _convert_to_ffi_func(object pyfunc): CHECK_CALL(TVMFFIFunctionCreate( (pyfunc), tvm_ffi_callback, - tvm_ffi_callback_deleter, + tvm_ffi_pyobject_deleter, &chandle)) ret = Function.__new__(Function) (ret).chandle = chandle return ret + +def _convert_to_opaque_object(object pyobject): + """Convert a python object to TVM FFI opaque object""" + cdef TVMFFIObjectHandle chandle + Py_INCREF(pyobject) + CHECK_CALL(TVMFFIObjectCreateOpaque( + (pyobject), + kTVMFFIOpaquePyObject, + tvm_ffi_pyobject_deleter, + &chandle)) + ret = OpaquePyObject.__new__(OpaquePyObject) + (ret).chandle = chandle + return ret + + _STR_CONSTRUCTOR = _get_global_func("ffi.String", False) _BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False) _OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) diff --git a/ffi/python/tvm_ffi/cython/object.pxi b/ffi/python/tvm_ffi/cython/object.pxi index 1203f0c68289..fda7f56b23be 100644 --- a/ffi/python/tvm_ffi/cython/object.pxi +++ b/ffi/python/tvm_ffi/cython/object.pxi @@ -194,6 +194,17 @@ cdef class Object: (other).chandle = NULL +cdef class OpaquePyObject(Object): + """Opaque PyObject container""" + def pyobject(self): + """Get the underlying python object""" + cdef object obj + cdef PyObject* py_handle + py_handle = (TVMFFIOpaqueObjectGetCellPtr(self.chandle).handle) + obj = py_handle + return obj + + class PyNativeObject: """Base class of all TVM objects that also subclass python's builtin types.""" __slots__ = [] @@ -252,6 +263,12 @@ cdef inline str _type_index_to_key(int32_t tindex): return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size)) +cdef inline object make_ret_opaque_object(TVMFFIAny result): + obj = OpaquePyObject.__new__(OpaquePyObject) + (obj).chandle = result.v_obj + return obj.pyobject() + + cdef inline object make_ret_object(TVMFFIAny result): global OBJECT_TYPE cdef int32_t tindex diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index f96636fd4994..9f554e3356f9 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -385,6 +386,29 @@ class TypeTable { Map type_attr_name_to_column_index_; }; +/** + * \brief Opaque implementation + */ +class OpaqueObjectImpl : public Object, public TVMFFIOpaqueObjectCell { + public: + OpaqueObjectImpl(void* handle, void (*deleter)(void* handle)) : deleter_(deleter) { + this->handle = handle; + } + + void SetTypeIndex(int32_t type_index) { + details::ObjectUnsafe::GetHeader(this)->type_index = type_index; + } + + ~OpaqueObjectImpl() { + if (deleter_ != nullptr) { + deleter_(handle); + } + } + + private: + void (*deleter_)(void* handle); +}; + } // namespace ffi } // namespace tvm @@ -400,6 +424,22 @@ int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) { TVM_FFI_SAFE_CALL_END(); } +int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, void (*deleter)(void* handle), + TVMFFIObjectHandle* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + if (type_index != kTVMFFIOpaquePyObject) { + TVM_FFI_THROW(RuntimeError) << "Only kTVMFFIOpaquePyObject is supported for now"; + } + // create initial opaque object + tvm::ffi::ObjectPtr p = + tvm::ffi::make_object(handle, deleter); + // need to set the type index after creation, because the set to RuntimeTypeIndex() + // happens after the constructor is called + p->SetTypeIndex(type_index); + *out = tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(p)); + TVM_FFI_SAFE_CALL_END(); +} + int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) { TVM_FFI_SAFE_CALL_BEGIN(); out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key); diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc index f6bedcb6f371..1d7de990f01a 100644 --- a/ffi/tests/cpp/test_object.cc +++ b/ffi/tests/cpp/test_object.cc @@ -222,4 +222,29 @@ TEST(Object, WeakObjectPtrAssignment) { EXPECT_EQ(lock3->value, 777); } +TEST(Object, OpaqueObject) { + thread_local int deleter_trigger_counter = 0; + struct DummyOpaqueObject { + int value; + DummyOpaqueObject(int value) : value(value) {} + + static void Deleter(void* handle) { + deleter_trigger_counter++; + delete static_cast(handle); + } + }; + TVMFFIObjectHandle handle = nullptr; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIObjectCreateOpaque(new DummyOpaqueObject(10), kTVMFFIOpaquePyObject, + DummyOpaqueObject::Deleter, &handle)); + ObjectPtr a = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + EXPECT_EQ(a->type_index(), kTVMFFIOpaquePyObject); + EXPECT_EQ(static_cast(TVMFFIOpaqueObjectGetCellPtr(a.get())->handle)->value, + 10); + EXPECT_EQ(a.use_count(), 1); + EXPECT_EQ(deleter_trigger_counter, 0); + a.reset(); + EXPECT_EQ(deleter_trigger_counter, 1); +} + } // namespace diff --git a/ffi/tests/python/test_container.py b/ffi/tests/python/test_container.py index 657adbef663e..9f2fb09df216 100644 --- a/ffi/tests/python/test_container.py +++ b/ffi/tests/python/test_container.py @@ -66,6 +66,28 @@ def test_int_map(): assert tuple(amap.values()) == (2, 3) +def test_array_map_of_opaque_object(): + class MyObject: + def __init__(self, value): + self.value = value + + a = tvm_ffi.convert([MyObject("hello"), MyObject(1)]) + assert isinstance(a, tvm_ffi.Array) + assert len(a) == 2 + assert isinstance(a[0], MyObject) + assert a[0].value == "hello" + assert isinstance(a[1], MyObject) + assert a[1].value == 1 + + y = tvm_ffi.convert({"a": MyObject(1), "b": MyObject("hello")}) + assert isinstance(y, tvm_ffi.Map) + assert len(y) == 2 + assert isinstance(y["a"], MyObject) + assert y["a"].value == 1 + assert isinstance(y["b"], MyObject) + assert y["b"].value == "hello" + + def test_str_map(): data = [] for i in reversed(range(10)): diff --git a/ffi/tests/python/test_function.py b/ffi/tests/python/test_function.py index cb81f47c7d58..4b0db45b4bd3 100644 --- a/ffi/tests/python/test_function.py +++ b/ffi/tests/python/test_function.py @@ -17,6 +17,7 @@ import gc import ctypes +import sys import numpy as np import tvm_ffi @@ -161,3 +162,27 @@ def check1(): check0() check1() + + +def test_echo_with_opaque_object(): + class MyObject: + def __init__(self, value): + self.value = value + + fecho = tvm_ffi.get_global_func("testing.echo") + x = MyObject("hello") + assert sys.getrefcount(x) == 2 + y = fecho(x) + assert isinstance(y, MyObject) + assert y is x + assert sys.getrefcount(x) == 3 + + def py_callback(z): + """python callback with opaque object""" + assert z is x + return z + + fcallback = tvm_ffi.convert(py_callback) + z = fcallback(x) + assert z is x + assert sys.getrefcount(x) == 4 diff --git a/ffi/tests/python/test_object.py b/ffi/tests/python/test_object.py index 63867b9de155..1b07de8e9d69 100644 --- a/ffi/tests/python/test_object.py +++ b/ffi/tests/python/test_object.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import pytest +import sys import tvm_ffi @@ -68,3 +69,23 @@ def test_derived_object(): obj0.v_i64 = 21 assert obj0.v_i64 == 21 + + +class MyObject: + def __init__(self, value): + self.value = value + + +def test_opaque_object(): + obj0 = MyObject("hello") + assert sys.getrefcount(obj0) == 2 + obj0_converted = tvm_ffi.convert(obj0) + assert sys.getrefcount(obj0) == 3 + assert isinstance(obj0_converted, tvm_ffi.core.OpaquePyObject) + obj0_cpy = obj0_converted.pyobject() + assert obj0_cpy is obj0 + assert sys.getrefcount(obj0) == 4 + obj0_converted = None + assert sys.getrefcount(obj0) == 3 + obj0_cpy = None + assert sys.getrefcount(obj0) == 2