diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index feed44498917..94fc6422891f 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -41,10 +41,10 @@ #include "../ffi/src/ffi/extra/library_module_dynamic_lib.cc" #include "../ffi/src/ffi/extra/library_module_system_lib.cc" #include "../ffi/src/ffi/extra/module.cc" +#include "../ffi/src/ffi/extra/testing.cc" #include "../ffi/src/ffi/function.cc" #include "../ffi/src/ffi/ndarray.cc" #include "../ffi/src/ffi/object.cc" -#include "../ffi/src/ffi/testing.cc" #include "../ffi/src/ffi/traceback.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/device_api.cc" diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index b5823b76a7c9..466571c2889f 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -57,7 +57,6 @@ set(tvm_ffi_objs_sources "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/function.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/ndarray.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" ) @@ -74,6 +73,8 @@ if (TVM_FFI_USE_EXTRA_CXX_API) "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_c_api.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/testing.cc" ) endif() diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index c8d46d455227..39b7de69fa75 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -300,6 +300,185 @@ typedef struct { TVMFFISafeCallType safe_call; } TVMFFIFunctionCell; +//------------------------------------------------------------ +// Section: Basic object API +//------------------------------------------------------------ +/*! + * \brief Free an object handle by decreasing reference + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); + +//----------------------------------------------------------------------- +// Section: Basic function calling API for function implementation +//----------------------------------------------------------------------- +/*! + * \brief Create a FFIFunc by passing in callbacks from C callback. + * + * The registered function then can be pulled by the backend by the name. + * + * \param self The resource handle of the C callback. + * \param safe_call The C callback implementation + * \param deleter deleter to recycle + * \param out The output of the function. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, + void (*deleter)(void* self), TVMFFIObjectHandle* out); + +/*! + * \brief Get a global function registered in system. + * + * \param name The name of the function. + * \param out the result function pointer, NULL if it does not exist. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); + +/*! + * \brief Convert a AnyView to an owned Any. + * \param any The AnyView to convert. + * \param out The output Any, must be an empty object + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); + +/*! + * \brief Call a FFIFunc by passing in arguments. + * + * \param func The resource handle of the C callback. + * \param args The input arguments to the call. + * \param num_args The number of input arguments. + * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result); + +/*! + * \brief Move the last error from the environment to result. + * + * \param result The result error. + * + * \note This function clears the error stored in the TLS. + */ +TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result); + +/*! + * \brief Set raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. + * + * \param error The error object handle + */ +TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); + +/*! + * \brief Set raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. + * + * \param kind The kind of the error. + * \param message The error message. + * \note This is a convenient method for C API side to set error directly from string. + */ +TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message); + +/*! + * \brief Create an initial error object. + * + * \param kind The kind of the error. + * \param message The error message. + * \param traceback The traceback of the error. + * \return The created error object handle. + * \note This function is different from other functions as it is used in error handling loop. + * So we do not follow normal error handling patterns via returning error code. + */ +TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, + const TVMFFIByteArray* message, + const TVMFFIByteArray* traceback); + +//------------------------------------------------------------ +// Section: DLPack support APIs +//------------------------------------------------------------ +/*! + * \brief Produce a managed NDArray from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output NDArray handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t require_alignment, + int32_t require_contiguous, TVMFFIObjectHandle* out); + +/*! + * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); + +/*! + * \brief Produce a managed NDArray from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output NDArray handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, + int32_t require_alignment, + int32_t require_contiguous, + TVMFFIObjectHandle* out); + +/*! + * \brief Produce a DLMangedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, + DLManagedTensorVersioned** out); + +//--------------------------------------------------------------- +// Section: dtype string support APIs. +// These APIs are used to simplify the dtype printings during FFI +//--------------------------------------------------------------- + +/*! + * \brief Convert a string to a DLDataType. + * \param str The string to convert. + * \param out The output DLDataType. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); + +/*! +* \brief Convert a DLDataType to a string. +* \param dtype The DLDataType to convert. +* \param out The output string. +* \return 0 when success, nonzero when failure happens +* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. +The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. + +* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. +*/ +TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out); + +//------------------------------------------------------------ +// Section: Type reflection support APIs +// +// The reflec +//------------------------------------------------------------ /*! * \brief Getter that can take address of a field and set the result. * \param field The raw address of the field. @@ -577,63 +756,6 @@ typedef struct TVMFFITypeInfo { const TVMFFITypeMetadata* metadata; } TVMFFITypeInfo; -//------------------------------------------------------------ -// Section: User APIs to interact with the FFI -//------------------------------------------------------------ -/*! - * \brief Free an object handle by decreasing reference - * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); - -//----------------------------------------------------------------------- -// Section: Function calling APIs and support API for func implementation -//----------------------------------------------------------------------- -/*! - * \brief Create a FFIFunc by passing in callbacks from C callback. - * - * The registered function then can be pulled by the backend by the name. - * - * \param self The resource handle of the C callback. - * \param safe_call The C callback implementation - * \param deleter deleter to recycle - * \param out The output of the function. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionCreate(void* self, TVMFFISafeCallType safe_call, - void (*deleter)(void* self), TVMFFIObjectHandle* out); - -/*! - * \brief Convert a AnyView to an owned Any. - * \param any The AnyView to convert. - * \param out The output Any, must be an empty object - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); - -/*! - * \brief Call a FFIFunc by passing in arguments. - * - * \param func The resource handle of the C callback. - * \param args The input arguments to the call. - * \param num_args The number of input arguments. - * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, - TVMFFIAny* result); - /*! * \brief Register the function to runtime's global table. * @@ -660,72 +782,6 @@ TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjec TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo* method_info, int override); -/*! - * \brief Get a global function. - * - * \param name The name of the function. - * \param out the result function pointer, NULL if it does not exist. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); - -/*! - * \brief Move the last error from the environment to result. - * - * \param result The result error. - * - * \note This function clears the error stored in the TLS. - */ -TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle* result); - -/*! - * \brief Set raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * - * \param error The error object handle - */ -TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); - -/*! - * \brief Set raised error in TLS, which can be fetched by TVMFFIMoveFromRaised. - * - * \param kind The kind of the error. - * \param message The error message. - * \note This is a convenient method for C API side to set error directly from string. - */ -TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char* kind, const char* message); - -/*! - * \brief Create an initial error object. - * - * \param kind The kind of the error. - * \param message The error message. - * \param traceback The traceback of the error. - * \return The created error object handle. - * \note This function is different from other functions as it is used in error handling loop. - * So we do not follow normal error handling patterns via returning error code. - */ -TVM_FFI_DLL TVMFFIObjectHandle TVMFFIErrorCreate(const TVMFFIByteArray* kind, - const TVMFFIByteArray* message, - const TVMFFIByteArray* traceback); - -/*! - * \brief Check if there are any signals raised in the surrounding env. - * \return 0 when success, nonzero when failure happens - * \note Under python this function redirects to PyErr_CheckSignals - */ -TVM_FFI_DLL int TVMFFIEnvCheckSignals(); - -/*! - * \brief Register a symbol into the from the surrounding env. - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol); - -//------------------------------------------------------------ -// Section: Type reflection support APIs -//------------------------------------------------------------ /*! * \brief Register type field information for runtime reflection. * \param type_index The type index @@ -767,75 +823,6 @@ TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray */ TVM_FFI_DLL const TVMFFITypeAttrColumn* TVMFFIGetTypeAttrColumn(const TVMFFIByteArray* attr_name); -//------------------------------------------------------------ -// Section: DLPack support APIs -//------------------------------------------------------------ -/*! - * \brief Produce a managed NDArray from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output NDArray handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayFromDLPack(DLManagedTensor* from, int32_t require_alignment, - int32_t require_contiguous, TVMFFIObjectHandle* out); - -/*! - * \brief Produce a DLMangedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayToDLPack(TVMFFIObjectHandle from, DLManagedTensor** out); - -/*! - * \brief Produce a managed NDArray from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output NDArray handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayFromDLPackVersioned(DLManagedTensorVersioned* from, - int32_t require_alignment, - int32_t require_contiguous, - TVMFFIObjectHandle* out); - -/*! - * \brief Produce a DLMangedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFINDArrayToDLPackVersioned(TVMFFIObjectHandle from, - DLManagedTensorVersioned** out); - -//--------------------------------------------------------------- -// Section: dtype string support APIs. -// These APIs are used to simplify the dtype printings during FFI -//--------------------------------------------------------------- - -/*! - * \brief Convert a string to a DLDataType. - * \param str The string to convert. - * \param out The output DLDataType. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); - -/*! - * \brief Convert a DLDataType to a string. - * \param dtype The DLDataType to convert. - * \param out The output string. - * \return 0 when success, nonzero when failure happens - * \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. - The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. - - * \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. - */ -TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out); - //------------------------------------------------------------ // Section: Backend noexcept functions for internal use // diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h index 1211ab0eeb1b..17cb3af6d0eb 100644 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -59,8 +59,23 @@ TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, */ TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id); +/*! + * \brief Check if there are any signals raised in the surrounding env. + * \return 0 when success, nonzero when failure happens + * \note Under python this function redirects to PyErr_CheckSignals + */ +TVM_FFI_DLL int TVMFFIEnvCheckSignals(); + +/*! + * \brief Register a symbol into the from the surrounding env such as python + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const char* name, void* symbol); + // ---------------------------------------------------------------------------- -// Module symbol management +// Module symbol management in callee side // ---------------------------------------------------------------------------- /*! * \brief FFI function to lookup a function from a module's imports. @@ -73,8 +88,8 @@ TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, in * \note The returned function is a weak reference that is cached/owned by the module. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_FFI_DLL int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, - TVMFFIObjectHandle* out); +TVM_FFI_DLL int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, + TVMFFIObjectHandle* out); /* * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. @@ -86,7 +101,7 @@ TVM_FFI_DLL int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const * \param symbol The symbol to register. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol); +TVM_FFI_DLL int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol); /*! * \brief Register a symbol that will be initialized when a system library is loaded. @@ -95,7 +110,7 @@ TVM_FFI_DLL int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol); * \param symbol The symbol to register. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* symbol); +TVM_FFI_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* symbol); #ifdef __cplusplus } // extern "C" diff --git a/ffi/src/ffi/extra/env_c_api.cc b/ffi/src/ffi/extra/env_c_api.cc new file mode 100644 index 000000000000..121cc9a3ccde --- /dev/null +++ b/ffi/src/ffi/extra/env_c_api.cc @@ -0,0 +1,148 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/env_c_api.cc + * \brief Environment C API implementation. + */ +#include +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Execution environment specific API registry. + * + * This registry stores C API function pointers about + * execution environment(e.g. python) specific API function that + * we need for specific low-level handling(e.g. signal checking). + * + * We only stores the C API function when absolutely necessary (e.g. when signal handler + * cannot trap back into python). Always consider use the Function FFI when possible + * in other cases. + */ +class EnvCAPIRegistry { + public: + /*! + * \brief Callback to check if signals have been sent to the process and + * if so invoke the registered signal handler in the frontend environment. + * + * When running FFI in another language (Python), the signal handler + * may not be immediately executed, but instead the signal is marked + * in the interpreter state (to ensure non-blocking of the signal handler). + * + * \return 0 if no error happens, -1 if error happens. + */ + typedef int (*F_PyErr_CheckSignals)(); + + /*! \brief Callback to increment/decrement the python ref count */ + typedef void (*F_Py_IncDefRef)(void*); + + /*! + * \brief PyErr_CheckSignal function + */ + F_PyErr_CheckSignals pyerr_check_signals = nullptr; + + /*! + \brief PyGILState_Ensure function + */ + void* (*py_gil_state_ensure)() = nullptr; + + /*! + \brief PyGILState_Release function + */ + void (*py_gil_state_release)(void*) = nullptr; + + static EnvCAPIRegistry* Global() { + static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); + return inst; + } + + // register environment(e.g. python) specific api functions + void Register(const String& symbol_name, void* fptr) { + if (symbol_name == "PyErr_CheckSignals") { + Update(symbol_name, &pyerr_check_signals, fptr); + } else if (symbol_name == "PyGILState_Ensure") { + Update(symbol_name, &py_gil_state_ensure, fptr); + } else if (symbol_name == "PyGILState_Release") { + Update(symbol_name, &py_gil_state_release, fptr); + } else { + TVM_FFI_THROW(ValueError) << "Unknown env API " + symbol_name; + } + } + + int EnvCheckSignals() { + // check python signal to see if there are exception raised + if (pyerr_check_signals != nullptr) { + // The C++ env comes without gil, so we need to grab gil here + WithGIL context(this); + if ((*pyerr_check_signals)() != 0) { + // The error will let FFI know that the frontend environment + // already set an error. + return -1; + } + } + return 0; + } + + private: + // update the internal API table + template + void Update(const String& symbol_name, FType* target, void* ptr) { + FType ptr_casted = reinterpret_cast(ptr); + target[0] = ptr_casted; + } + + struct WithGIL { + explicit WithGIL(EnvCAPIRegistry* self) : self(self) { + TVM_FFI_ICHECK(self->py_gil_state_ensure); + TVM_FFI_ICHECK(self->py_gil_state_release); + gil_state = self->py_gil_state_ensure(); + } + ~WithGIL() { + if (self && gil_state) { + self->py_gil_state_release(gil_state); + } + } + WithGIL(const WithGIL&) = delete; + WithGIL(WithGIL&&) = delete; + WithGIL& operator=(const WithGIL&) = delete; + WithGIL& operator=(WithGIL&&) = delete; + + EnvCAPIRegistry* self = nullptr; + void* gil_state = nullptr; + }; +}; +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvCheckSignals() { return tvm::ffi::EnvCAPIRegistry::Global()->EnvCheckSignals(); } + +/*! + * \brief Register a symbol into the from the surrounding env. + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::String s_name(name); + tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/src/ffi/extra/library_module.cc b/ffi/src/ffi/extra/library_module.cc index 34286d6d0eb2..71c6da6f7cc4 100644 --- a/ffi/src/ffi/extra/library_module.cc +++ b/ffi/src/ffi/extra/library_module.cc @@ -191,7 +191,7 @@ Module CreateLibraryModule(ObjectPtr lib) { } // namespace ffi } // namespace tvm -int TVMFFIEnvRegisterContextSymbol(const char* name, void* symbol) { +int TVMFFIEnvModRegisterContextSymbol(const char* name, void* symbol) { TVM_FFI_SAFE_CALL_BEGIN(); tvm::ffi::String s_name(name); tvm::ffi::ContextSymbolRegistry::Global()->Register(s_name, symbol); diff --git a/ffi/src/ffi/extra/library_module_system_lib.cc b/ffi/src/ffi/extra/library_module_system_lib.cc index 64b95a122d56..cdc932cba292 100644 --- a/ffi/src/ffi/extra/library_module_system_lib.cc +++ b/ffi/src/ffi/extra/library_module_system_lib.cc @@ -123,7 +123,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } // namespace ffi } // namespace tvm -int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr) { +int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr) { tvm::ffi::SystemLibSymbolRegistry::Global()->RegisterSymbol(name, ptr); return 0; } diff --git a/ffi/src/ffi/extra/module.cc b/ffi/src/ffi/extra/module.cc index a7f6d4460079..d8ec77f98c97 100644 --- a/ffi/src/ffi/extra/module.cc +++ b/ffi/src/ffi/extra/module.cc @@ -130,8 +130,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ } // namespace ffi } // namespace tvm -int TVMFFIEnvLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, - TVMFFIObjectHandle* out) { +int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char* func_name, + TVMFFIObjectHandle* out) { TVM_FFI_SAFE_CALL_BEGIN(); *out = tvm::ffi::ModuleObj::InternalUnsafe::GetFunctionFromImports( reinterpret_cast(library_ctx), func_name); diff --git a/ffi/src/ffi/extra/module_internal.h b/ffi/src/ffi/extra/module_internal.h index f43d3a3d2c42..472d531f4b51 100644 --- a/ffi/src/ffi/extra/module_internal.h +++ b/ffi/src/ffi/extra/module_internal.h @@ -57,7 +57,7 @@ struct ModuleObj::InternalUnsafe { static Array* GetImports(ModuleObj* module) { return &(module->imports_); } static void* GetFunctionFromImports(ModuleObj* module, const char* name) { - // backend implementation for TVMFFIEnvLookupFromImports + // backend implementation for TVMFFIEnvModLookupFromImports static std::mutex mutex_; std::lock_guard lock(mutex_); String s_name(name); diff --git a/ffi/src/ffi/testing.cc b/ffi/src/ffi/extra/testing.cc similarity index 98% rename from ffi/src/ffi/testing.cc rename to ffi/src/ffi/extra/testing.cc index 8fddf54ebd63..3d27d5ccb6a4 100644 --- a/ffi/src/ffi/testing.cc +++ b/ffi/src/ffi/extra/testing.cc @@ -19,6 +19,7 @@ // This file is used for testing the FFI API. #include #include +#include #include #include diff --git a/ffi/src/ffi/function.cc b/ffi/src/ffi/function.cc index a2d3fb9db353..8db03bf28eb0 100644 --- a/ffi/src/ffi/function.cc +++ b/ffi/src/ffi/function.cc @@ -137,110 +137,6 @@ class GlobalFunctionTable { private: Map table_; }; - -/*! - * \brief Execution environment specific API registry. - * - * This registry stores C API function pointers about - * execution environment(e.g. python) specific API function that - * we need for specific low-level handling(e.g. signal checking). - * - * We only stores the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). Always consider use the Function FFI when possible - * in other cases. - */ -class EnvCAPIRegistry { - public: - /*! - * \brief Callback to check if signals have been sent to the process and - * if so invoke the registered signal handler in the frontend environment. - * - * When running FFI in another language (Python), the signal handler - * may not be immediately executed, but instead the signal is marked - * in the interpreter state (to ensure non-blocking of the signal handler). - * - * \return 0 if no error happens, -1 if error happens. - */ - typedef int (*F_PyErr_CheckSignals)(); - - /*! \brief Callback to increment/decrement the python ref count */ - typedef void (*F_Py_IncDefRef)(void*); - - /*! - * \brief PyErr_CheckSignal function - */ - F_PyErr_CheckSignals pyerr_check_signals = nullptr; - - /*! - \brief PyGILState_Ensure function - */ - void* (*py_gil_state_ensure)() = nullptr; - - /*! - \brief PyGILState_Release function - */ - void (*py_gil_state_release)(void*) = nullptr; - - static EnvCAPIRegistry* Global() { - static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); - return inst; - } - - // register environment(e.g. python) specific api functions - void Register(const String& symbol_name, void* fptr) { - if (symbol_name == "PyErr_CheckSignals") { - Update(symbol_name, &pyerr_check_signals, fptr); - } else if (symbol_name == "PyGILState_Ensure") { - Update(symbol_name, &py_gil_state_ensure, fptr); - } else if (symbol_name == "PyGILState_Release") { - Update(symbol_name, &py_gil_state_release, fptr); - } else { - TVM_FFI_THROW(ValueError) << "Unknown env API " + symbol_name; - } - } - - int EnvCheckSignals() { - // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr) { - // The C++ env comes without gil, so we need to grab gil here - WithGIL context(this); - if ((*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - return -1; - } - } - return 0; - } - - private: - // update the internal API table - template - void Update(const String& symbol_name, FType* target, void* ptr) { - FType ptr_casted = reinterpret_cast(ptr); - target[0] = ptr_casted; - } - - struct WithGIL { - explicit WithGIL(EnvCAPIRegistry* self) : self(self) { - TVM_FFI_ICHECK(self->py_gil_state_ensure); - TVM_FFI_ICHECK(self->py_gil_state_release); - gil_state = self->py_gil_state_ensure(); - } - ~WithGIL() { - if (self && gil_state) { - self->py_gil_state_release(gil_state); - } - } - WithGIL(const WithGIL&) = delete; - WithGIL(WithGIL&&) = delete; - WithGIL& operator=(const WithGIL&) = delete; - WithGIL& operator=(WithGIL&&) = delete; - - EnvCAPIRegistry* self = nullptr; - void* gil_state = nullptr; - }; -}; } // namespace ffi } // namespace tvm @@ -296,21 +192,6 @@ int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_arg return reinterpret_cast(func)->safe_call(func, args, num_args, result); } -int TVMFFIEnvCheckSignals() { return tvm::ffi::EnvCAPIRegistry::Global()->EnvCheckSignals(); } - -/*! - * \brief Register a symbol into the from the surrounding env. - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -int TVMFFIEnvRegisterCAPI(const TVMFFIByteArray* name, void* symbol) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String s_name(name->data, name->size); - tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol); - TVM_FFI_SAFE_CALL_END(); -} - TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index e44fe465bc96..4e6c2f53641a 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -54,7 +54,7 @@ TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, * \param ptr The symbol address. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr); +TVM_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr); /*! * \brief Backend function to allocate temporal workspace. diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 24c729095989..e61eaf322db2 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -183,7 +183,7 @@ cdef extern from "tvm/ffi/c_api.h": void TVMFFIErrorSetRaised(TVMFFIObjectHandle error) nogil TVMFFIObjectHandle TVMFFIErrorCreate(TVMFFIByteArray* kind, TVMFFIByteArray* message, TVMFFIByteArray* traceback) nogil - int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil + int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil @@ -208,6 +208,7 @@ cdef extern from "tvm/ffi/c_api.h": cdef extern from "tvm/ffi/extra/c_env_api.h": ctypedef void* TVMFFIStreamHandle + int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, @@ -279,11 +280,8 @@ cdef _init_env_api(): # Initialize env api for signal handling # Also registers the gil state release and ensure as PyErr_CheckSignals # function is called with gil released and we need to regrab the gil - cdef ByteArrayArg pyerr_check_signals_arg = ByteArrayArg(c_str("PyErr_CheckSignals")) - cdef ByteArrayArg pygilstate_ensure_arg = ByteArrayArg(c_str("PyGILState_Ensure")) - cdef ByteArrayArg pygilstate_release_arg = ByteArrayArg(c_str("PyGILState_Release")) - CHECK_CALL(TVMFFIEnvRegisterCAPI(pyerr_check_signals_arg.cptr(), PyErr_CheckSignals)) - CHECK_CALL(TVMFFIEnvRegisterCAPI(pygilstate_ensure_arg.cptr(), PyGILState_Ensure)) - CHECK_CALL(TVMFFIEnvRegisterCAPI(pygilstate_release_arg.cptr(), PyGILState_Release)) + CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyErr_CheckSignals"), PyErr_CheckSignals)) + CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Ensure"), PyGILState_Ensure)) + CHECK_CALL(TVMFFIEnvRegisterCAPI(c_str("PyGILState_Release"), PyGILState_Release)) _init_env_api() diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index 3ab232e95997..4148cc6c88e1 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -43,7 +43,6 @@ def load_torch_get_current_cuda_stream(): def fallback_get_current_cuda_stream(device_id): """Fallback with python api""" return torch.cuda.current_stream(device_id).cuda_stream - return fallback_get_current_cuda_stream try: result = cpp_extension.load_inline( name="get_current_cuda_stream", diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index 31006069a26b..28dc313ba3e6 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -242,7 +242,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ using namespace tvm::runtime; int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) { - return TVMFFIEnvLookupFromImports(mod_node, func_name, func); + return TVMFFIEnvModLookupFromImports(mod_node, func_name, func); } void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 12b58da2df2a..16c617ce3fcb 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -70,7 +70,7 @@ bool RuntimeEnabled(const String& target_str) { #define TVM_INIT_CONTEXT_FUNC(FuncName) \ TVM_FFI_CHECK_SAFE_CALL( \ - TVMFFIEnvRegisterContextSymbol("__" #FuncName, reinterpret_cast(FuncName))) + TVMFFIEnvModRegisterContextSymbol("__" #FuncName, reinterpret_cast(FuncName))) TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; diff --git a/src/support/errno_handling.h b/src/support/errno_handling.h index 748e390e78db..b2b59ed16df5 100644 --- a/src/support/errno_handling.h +++ b/src/support/errno_handling.h @@ -24,6 +24,7 @@ #ifndef TVM_SUPPORT_ERRNO_HANDLING_H_ #define TVM_SUPPORT_ERRNO_HANDLING_H_ #include +#include #include "ssize.h" diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 96075450183c..bd45ce32e053 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -311,10 +311,10 @@ std::string PackImportsToC(const ffi::Module& mod, bool system_lib, } os << "\n};\n"; if (system_lib) { - os << "extern int TVMFFIEnvRegisterSystemLibSymbol(const char*, void*);\n"; + os << "extern int TVMFFIEnvModRegisterSystemLibSymbol(const char*, void*);\n"; os << "static int " << mdev_blob_name << "_reg_ = " - << "TVMFFIEnvRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)" << mdev_blob_name - << ");\n"; + << "TVMFFIEnvModRegisterSystemLibSymbol(\"" << mdev_blob_name << "\", (void*)" + << mdev_blob_name << ");\n"; } os << "#ifdef __cplusplus\n" << "}\n" diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index fc2acfddfb81..17056c331295 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -151,11 +151,11 @@ std::unique_ptr CodeGenBlob(const std::string& data, bool system_l llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, llvm::Twine("__cxx_global_var_init"), module.get()); - // Create TVMFFIEnvRegisterSystemLibSymbol function + // Create TVMFFIEnvModRegisterSystemLibSymbol function llvm::Function* tvm_backend_fn = llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), llvm::GlobalValue::ExternalLinkage, - llvm::Twine("TVMFFIEnvRegisterSystemLibSymbol"), module.get()); + llvm::Twine("TVMFFIEnvModRegisterSystemLibSymbol"), module.get()); // Set necessary fn sections auto get_static_init_section_specifier = [&triple]() -> std::string { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index eebbd5b64fd4..5ce8b1ec6584 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -146,10 +146,10 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, if (system_lib_prefix_.has_value() && !target_c_runtime) { // We will need this in environment for backward registration. // Defined in include/tvm/runtime/c_backend_api.h: - // int TVMFFIEnvRegisterSystemLibSymbol(const char* name, void* ptr); + // int TVMFFIEnvModRegisterSystemLibSymbol(const char* name, void* ptr); f_tvm_register_system_symbol_ = llvm::Function::Create( llvm::FunctionType::get(t_int_, {llvmGetPointerTo(t_char_, 0), t_void_p_}, false), - llvm::Function::ExternalLinkage, "TVMFFIEnvRegisterSystemLibSymbol", module_.get()); + llvm::Function::ExternalLinkage, "TVMFFIEnvModRegisterSystemLibSymbol", module_.get()); } else { f_tvm_register_system_symbol_ = nullptr; } diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 6e2664a93bff..31f494322684 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -54,10 +54,10 @@ #include "ffi/src/ffi/extra/library_module.cc" #include "ffi/src/ffi/extra/library_module_system_lib.cc" #include "ffi/src/ffi/extra/module.cc" +#include "ffi/src/ffi/extra/testing.cc" #include "ffi/src/ffi/function.cc" #include "ffi/src/ffi/ndarray.cc" #include "ffi/src/ffi/object.cc" -#include "ffi/src/ffi/testing.cc" #include "ffi/src/ffi/traceback.cc" #include "src/runtime/memory/memory_manager.cc" #include "src/runtime/nvtx.cc"