diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 1d495d9c5e96..131f2e73e08a 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -579,8 +579,10 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* * \return 0 when success, nonzero when failure happens * \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. + + * \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. */ -TVM_FFI_DLL int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out); +TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out); //------------------------------------------------------------ // Section: Backend noexcept functions for internal use diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h index 99eb227ee1af..a1a6b58afa28 100644 --- a/ffi/include/tvm/ffi/dtype.h +++ b/ffi/include/tvm/ffi/dtype.h @@ -121,7 +121,7 @@ inline DLDataType StringToDLDataType(const String& str) { inline String DLDataTypeToString(DLDataType dtype) { TVMFFIObjectHandle out; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(dtype, &out)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); return String(details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(out))); } diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index 239a0e500b73..de754bd6ea77 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -51,6 +51,10 @@ #define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 #endif +#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0 +#endif + namespace tvm { namespace ffi { @@ -212,8 +216,10 @@ class ErrorBuilder { * * \endcode */ -#define TVM_FFI_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, false).stream() +#define TVM_FFI_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, \ + TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \ + .stream() /*! * \brief Explicitly log error in stderr and then throw the error. diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc index 7661ab4b97b1..cb0bd4959735 100644 --- a/ffi/src/ffi/dtype.cc +++ b/ffi/src/ffi/dtype.cc @@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) { +int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) { TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(dtype)); + tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype)); *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str)); TVM_FFI_SAFE_CALL_END(); } diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 8fe23cd23b29..8b9c1f3d947b 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -150,7 +150,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil - int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) nogil + int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) nogil const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) nogil; int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t require_alignment, int32_t require_contiguous, TVMFFIObjectHandle* out) nogil diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi index 30f9f274b4af..80ec5d9364b1 100644 --- a/python/tvm/ffi/cython/dtype.pxi +++ b/python/tvm/ffi/cython/dtype.pxi @@ -94,7 +94,7 @@ cdef class DataType: def __str__(self): cdef TVMFFIObjectHandle dtype_str cdef TVMFFIByteArray* bytes - CHECK_CALL(TVMFFIDataTypeToString(self.cdtype, &dtype_str)) + CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str)) bytes = TVMFFIBytesGetByteArrayPtr(dtype_str) res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) CHECK_CALL(TVMFFIObjectFree(dtype_str)) diff --git a/web/.eslintignore b/web/.eslintignore index f71ee79871c4..1549e07c251e 100644 --- a/web/.eslintignore +++ b/web/.eslintignore @@ -1,2 +1,4 @@ dist debug +tvmjs_runtime_wasi.js +lib diff --git a/web/apps/node/example.js b/web/apps/node/example.js index 580bbf57ab80..62c9157c7c29 100644 --- a/web/apps/node/example.js +++ b/web/apps/node/example.js @@ -31,7 +31,7 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); tvmjs.instantiate(wasmSource, tvmjs.createPolyfillWASI()) .then((tvm) => { tvm.beginScope(); - const log_info = tvm.getGlobalFunc("testing.log_info_str"); + const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str"); log_info("hello world"); // List all the global functions from the runtime. console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index e50e6c37d34c..1e35a1137fb7 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -28,12 +28,11 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY -#include +#include #include -#include -#include #include "../../src/runtime/rpc/rpc_local_session.h" @@ -59,27 +58,33 @@ TVM_DLL void TVMWasmFreeSpace(void* data); * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer 3A * \return 0 if success. */ -TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out); +TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out); + +/*! + * \brief Get the last error message. + * \return The last error message. + */ +TVM_DLL const char* TVMFFIWasmGetLastError(); // --- APIs to be implemented by the frontend. --- + /*! - * \brief Wasm frontend packed function caller. + * \brief Wasm frontend new ffi call function caller. * + * \param self The pointer to the ffi::Function. * \param args The arguments - * \param type_codes The type codes of the arguments * \param num_args Number of arguments. - * \param ret The return value handle. - * \param resource_handle The handle additional resource handle from front-end. + * \param result The return value handle. * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. */ -extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, - void* resource_handle); - +extern int TVMFFIWasmSafeCall(void* self, const TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result); /*! - * \brief Wasm frontend resource finalizer. - * \param resource_handle The pointer to the external resource. + * \brief Delete ffi::Function. + * \param self The pointer to the ffi::Function. */ -extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +extern void TVMFFIWasmFunctionDeleter(void* self); + } // extern "C" void* TVMWasmAllocSpace(int size) { @@ -89,9 +94,14 @@ void* TVMWasmAllocSpace(int size) { void TVMWasmFreeSpace(void* arr) { delete[] static_cast(arr); } -int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) { - return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer, - out); +int TVMFFIWasmFunctionCreate(void* self, TVMFunctionHandle* out) { + return TVMFFIFunctionCreate(self, TVMFFIWasmSafeCall, TVMFFIWasmFunctionDeleter, out); +} + +const char* TVMFFIWasmGetLastError() { + static thread_local std::string last_error; + last_error = ::tvm::ffi::details::MoveFromSafeCallRaised().what(); + return last_error.c_str(); } namespace tvm { @@ -291,7 +301,7 @@ class AsyncLocalSession : public LocalSession { } }; -TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); }); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index b208daed51d3..728e1c648c28 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -27,9 +27,9 @@ #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 #define TVM_FFI_USE_LIBBACKTRACE 0 +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY -#include #include #include "src/runtime/c_runtime_api.cc" @@ -107,45 +107,24 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = args[0]; -}); - -TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - (args[0].cast()).CallPacked(args.Slice(1), ret); -}); - -TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = args[0].cast(); -}); - -TVM_REGISTER_GLOBAL("testing.log_info_str") +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.call") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - LOG(INFO) << args[0].cast(); + (args[0].cast()).CallPacked(args.Slice(1), ret); }); -TVM_REGISTER_GLOBAL("testing.log_fatal_str") +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.log_info_str") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - LOG(FATAL) << args[0].cast(); + LOG(INFO) << args[0].cast(); }); -TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.add_one").set_body_typed([](int x) { return x + 1; }); -TVM_REGISTER_GLOBAL("testing.wrap_callback") +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.wrap_callback") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); *ret = ffi::TypedFunction([pf]() { pf(); }); }); -// internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.object_use_count") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto obj = args[0].cast(); - // subtract the current one because we always copy - // and get another value. - *ret = (obj.use_count() - 1); - }); - void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) { if (format == "f32-to-bf16" && dtype == "float32") { std::vector buffer(bytes.length() / 2); @@ -167,10 +146,10 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, } } -TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); +TVM_FFI_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); // Concatenate n TVMArrays -TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat") +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { std::vector data; for (int i = 0; i < args.size(); ++i) { @@ -220,7 +199,7 @@ NDArray ConcatEmbeddings(const std::vector& embeddings) { } // Concatenate n NDArrays -TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { std::vector embeddings; for (int i = 0; i < args.size(); ++i) { @@ -230,5 +209,19 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") *ret = result; }); +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyFromBytes") + .set_body_typed([](NDArray nd, TVMFFIByteArray* bytes) { + nd.CopyFromBytes(bytes->data, bytes->size); + }); + +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyToBytes") + .set_body_typed([](NDArray nd) -> ffi::Bytes { + size_t size = GetDataSize(*(nd.operator->())); + std::string bytes; + bytes.resize(size); + nd.CopyToBytes(bytes.data(), size); + return ffi::Bytes(bytes); + }); + } // namespace runtime } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 3d74d77f14ce..00b1db266a0b 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -26,13 +26,11 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY -#include -#include +#include #include -#include -#include #include #include @@ -152,7 +150,10 @@ typedef dmlc::ThreadLocalStore WebGPUThreadStore; WebGPUThreadEntry::WebGPUThreadEntry() : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) {} -WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); } +WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { + static thread_local WebGPUThreadEntry inst = WebGPUThreadEntry(); + return &inst; +} class WebGPUModuleNode final : public runtime::ModuleNode { public: @@ -241,12 +242,13 @@ Module WebGPUModuleLoadBinary(void* strm) { } // for now webgpu is hosted via a vulkan module. -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); -TVM_REGISTER_GLOBAL("device_api.webgpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = WebGPUDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.webgpu") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = WebGPUDeviceAPI::Global(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm diff --git a/web/package.json b/web/package.json index 583232d20951..b4fc25e12fcf 100644 --- a/web/package.json +++ b/web/package.json @@ -45,5 +45,9 @@ "typedoc-plugin-missing-exports": "2.0.0", "typescript": "^4.9.5", "ws": "^7.2.5" + }, + "dependencies": { + "audit": "^0.0.6", + "fix": "^0.0.6" } } diff --git a/web/src/asyncify.ts b/web/src/asyncify.ts index 703dbbf80a10..6074a559e00d 100644 --- a/web/src/asyncify.ts +++ b/web/src/asyncify.ts @@ -70,6 +70,15 @@ export class AsyncifyHandler { return this.exports.asyncify_stop_rewind !== undefined; } + /** + * Get the current asynctify state + * + * @returns The current asynctify state + */ + isNormalStackState(): boolean { + return this.state == AsyncifyStateKind.None; + } + /** * Get the current asynctify state * diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index c4941f07d57a..c9a5e263d5f2 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -27,231 +27,165 @@ export type Pointer = number; /** A pointer offset, need to add a base address to get a valid ptr. */ export type PtrOffset = number; -// -- TVM runtime C API -- /** - * const char *TVMGetLastError(); - */ -export type FTVMGetLastError = () => Pointer; - -/** - * void TVMAPISetLastError(const char* msg); - */ -export type FTVMAPISetLastError = (msg: Pointer) => void; - -/** - * int TVMModGetFunction(TVMModuleHandle mod, - * const char* func_name, - * int query_imports, - * TVMFunctionHandle *out); - */ -export type FTVMModGetFunction = ( - mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; -/** - * int TVMModImport(TVMModuleHandle mod, - * TVMModuleHandle dep); - */ -export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; - -/** - * int TVMModFree(TVMModuleHandle mod); - */ -export type FTVMModFree = (mod: Pointer) => number; - -/** - * int TVMFuncFree(TVMFunctionHandle func); - */ -export type FTVMFuncFree = (func: Pointer) => number; - -/** - * int TVMFuncCall(TVMFunctionHandle func, - * TVMValue* arg_values, - * int* type_codes, - * int num_args, - * TVMValue* ret_val, - * int* ret_type_code); - */ -export type FTVMFuncCall = ( - func: Pointer, argValues: Pointer, typeCode: Pointer, - nargs: number, retValue: Pointer, retCode: Pointer) => number; - -/** - * int TVMCFuncSetReturn(TVMRetValueHandle ret, - * TVMValue* value, - * int* type_code, - * int num_ret); - */ -export type FTVMCFuncSetReturn = ( - ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; - -/** - * int TVMCbArgToReturn(TVMValue* value, int* code); - */ -export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; - -/** - * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + * Size of common data types. */ -export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + TVMFFIAny = 8 * 2, + DLDataType = I32, + DLDevice = I32 + I32, + ObjectHeader = 8 * 2, +} +//---------------The new TVM FFI--------------- /** - * int TVMFuncRegisterGlobal( - * const char* name, TVMFunctionHandle f, int override); - */ -export type FTVMFuncRegisterGlobal = ( - name: Pointer, f: Pointer, override: number) => number; + * Type Index in new TVM FFI. + * + * We are keeping the same style as C API here. + */ +export const enum TypeIndex { + kTVMFFINone = 0, + /*! \brief POD int value */ + kTVMFFIInt = 1, + /*! \brief POD bool value */ + kTVMFFIBool = 2, + /*! \brief POD float value */ + kTVMFFIFloat = 3, + /*! \brief Opaque pointer object */ + kTVMFFIOpaquePtr = 4, + /*! \brief DLDataType */ + kTVMFFIDataType = 5, + /*! \brief DLDevice */ + kTVMFFIDevice = 6, + /*! \brief DLTensor* */ + kTVMFFIDLTensorPtr = 7, + /*! \brief const char**/ + kTVMFFIRawStr = 8, + /*! \brief TVMFFIByteArray* */ + kTVMFFIByteArrayPtr = 9, + /*! \brief R-value reference to ObjectRef */ + kTVMFFIObjectRValueRef = 10, + /*! \brief Start of statically defined objects. */ + kTVMFFIStaticObjectBegin = 64, + /*! + * \brief Object, all objects starts with TVMFFIObject as its header. + * \note We will also add other fields + */ + kTVMFFIObject = 64, + /*! + * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIStr = 65, + /*! + * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIBytes = 66, + /*! \brief Error object. */ + kTVMFFIError = 67, + /*! \brief Function object. */ + kTVMFFIFunction = 68, + /*! \brief Array object. */ + kTVMFFIArray = 69, + /*! \brief Map object. */ + kTVMFFIMap = 70, + /*! + * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } + */ + kTVMFFIShape = 71, + /*! + * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... } + */ + kTVMFFINDArray = 72, + /*! \brief Runtime module object. */ + kTVMFFIModule = 73, +} -/** - *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); - */ -export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; +// -- TVM Wasm Auxiliary C API -- -/** - * int TVMArrayAlloc(const tvm_index_t* shape, - * int ndim, - * int dtype_code, - * int dtype_bits, - * int dtype_lanes, - * int device_type, - * int device_id, - * TVMArrayHandle* out); - */ -export type FTVMArrayAlloc = ( - shape: Pointer, ndim: number, - dtypeCode: number, dtypeBits: number, - dtypeLanes: number, deviceType: number, deviceId: number, - out: Pointer) => number; +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; -/** - * int TVMArrayFree(TVMArrayHandle handle); - */ -export type FTVMArrayFree = (handle: Pointer) => number; +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; -/** - * int TVMArrayCopyFromBytes(TVMArrayHandle handle, - * void* data, - * size_t nbytes); - */ -export type FTVMArrayCopyFromBytes = ( - handle: Pointer, data: Pointer, nbytes: number) => number; +/** const char* TVMFFIWasmGetLastError(); */ +export type FTVMFFIWasmGetLastError = () => Pointer; /** - * int TVMArrayCopyToBytes(TVMArrayHandle handle, - * void* data, - * size_t nbytes); + * int TVMFFIWasmSafeCallType(void* self, const TVMFFIAny* args, + * int32_t num_args, TVMFFIAny* result); */ -export type FTVMArrayCopyToBytes = ( - handle: Pointer, data: Pointer, nbytes: number) => number; +export type FTVMFFIWasmSafeCallType = ( + self: Pointer, args: Pointer, num_args: number, + result: Pointer) => number; /** - * int TVMArrayCopyFromTo(TVMArrayHandle from, - * TVMArrayHandle to, - * TVMStreamHandle stream); + * int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out); */ -export type FTVMArrayCopyFromTo = ( - from: Pointer, to: Pointer, stream: Pointer) => number; +export type FTVMFFIWasmFunctionCreate = ( + resource_handle: Pointer, out: Pointer) => number; /** - * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + * void TVMFFIWasmFunctionDeleter(void* self); */ -export type FTVMSynchronize = ( - deviceType: number, deviceId: number, stream: Pointer) => number; +export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void; /** - * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, - * int* type_codes, - * int num_args, - * TVMValue* out_ret_value, - * int* out_ret_tcode); + * int TVMFFIObjectFree(TVMFFIObjectHandle obj); */ -export type FTVMBackendPackedCFunc = ( - argValues: Pointer, argCodes: Pointer, nargs: number, - outValue: Pointer, outCode: Pointer) => number; - +export type FTVMFFIObjectFree = (obj: Pointer) => number; /** - * int TVMObjectFree(TVMObjectHandle obj); + * int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); */ -export type FTVMObjectFree = (obj: Pointer) => number; +export type FTVMFFITypeKeyToIndex = (type_key: Pointer, out_tindex: Pointer) => number; /** - * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + * int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); */ -export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number; +export type FTVMFFIAnyViewToOwnedAny = (any_view: Pointer, out: Pointer) => number; /** - * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + * void TVMFFIErrorSetRaisedByCStr(const char* kind, const char* message); */ -export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number; +export type FTVMFFIErrorSetRaisedByCStr = (kind: Pointer, message: Pointer) => void; /** - * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + * int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, + * int override); */ -export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number; - -// -- TVM Wasm Auxiliary C API -- - -/** void* TVMWasmAllocSpace(int size); */ -export type FTVMWasmAllocSpace = (size: number) => Pointer; - -/** void TVMWasmFreeSpace(void* data); */ -export type FTVMWasmFreeSpace = (ptr: Pointer) => void; +export type FTVMFFIFunctionSetGlobal = (name: Pointer, f: Pointer, override: number) => number; /** - * int TVMWasmPackedCFunc(TVMValue* args, - * int* type_codes, - * int num_args, - * TVMRetValueHandle ret, - * void* resource_handle); + * int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); */ -export type FTVMWasmPackedCFunc = ( - args: Pointer, typeCodes: Pointer, nargs: number, - ret: Pointer, resourceHandle: Pointer) => number; +export type FTVMFFIFunctionGetGlobal = (name: Pointer, out: Pointer) => number; /** - * int TVMWasmFuncCreateFromCFunc(void* resource_handle, - * TVMFunctionHandle *out); + * int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + * TVMFFIAny* result); */ -export type FTVMWasmFuncCreateFromCFunc = ( - resource: Pointer, out: Pointer) => number; +export type FTVMFFIFunctionCall = (func: Pointer, args: Pointer, num_args: number, + result: Pointer) => number; /** - * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + * int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); */ -export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; +export type FTVMFFIDataTypeFromString = (str: Pointer, out: Pointer) => number; /** - * Size of common data types. + * int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out); */ -export const enum SizeOf { - U8 = 1, - U16 = 2, - I32 = 4, - I64 = 8, - F32 = 4, - F64 = 8, - TVMValue = 8, - DLDataType = I32, - DLDevice = I32 + I32, -} +export type FTVMFFIDataTypeToString = (dtype: Pointer, out: Pointer) => number; /** - * Argument Type code in TVM FFI. + * TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); */ -export const enum ArgTypeCode { - Int = 0, - UInt = 1, - Float = 2, - TVMOpaqueHandle = 3, - Null = 4, - TVMDataType = 5, - DLDevice = 6, - TVMDLTensorHandle = 7, - TVMObjectHandle = 8, - TVMModuleHandle = 9, - TVMPackedFuncHandle = 10, - TVMStr = 11, - TVMBytes = 12, - TVMNDArrayHandle = 13, - TVMObjectRValueRefArg = 14, - TVMArgBool = 15, -} +export type FTVMFFIGetTypeInfo = (type_index: number) => Pointer; diff --git a/web/src/environment.ts b/web/src/environment.ts index 42a873f1284e..01e19a1c18f4 100644 --- a/web/src/environment.ts +++ b/web/src/environment.ts @@ -75,7 +75,7 @@ export class Environment implements LibraryProvider { * We maintain a separate table so that we can have un-limited amount * of functions that do not maps to the address space. */ - packedCFuncTable: Array = [ + packedCFuncTable: Array = [ undefined, ]; /** @@ -115,28 +115,27 @@ export class Environment implements LibraryProvider { // eslint-disable-next-line @typescript-eslint/no-unused-vars "emscripten_notify_memory_growth": (index: number): void => {} }; - const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + const wasmSafeCall: ctypes.FTVMFFIWasmSafeCallType = ( + self: Pointer, args: Pointer, - typeCodes: Pointer, - nargs: number, - ret: Pointer, - resourceHandle: Pointer + num_args: number, + result: Pointer ): number => { - const cfunc = this.packedCFuncTable[resourceHandle]; + const cfunc = this.packedCFuncTable[self]; assert(cfunc !== undefined); - return cfunc(args, typeCodes, nargs, ret, resourceHandle); + return cfunc(self, args, num_args, result); }; - const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( - resourceHandle: Pointer + const wasmFunctionDeleter: ctypes.FTVMFFIWasmFunctionDeleter = ( + self: Pointer ): void => { - this.packedCFuncTable[resourceHandle] = undefined; - this.packedCFuncTableFreeId.push(resourceHandle); + this.packedCFuncTable[self] = undefined; + this.packedCFuncTableFreeId.push(self); }; const newEnv = { - TVMWasmPackedCFunc: wasmPackedCFunc, - TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "TVMFFIWasmSafeCall": wasmSafeCall, + "TVMFFIWasmFunctionDeleter": wasmFunctionDeleter, "__console_log": (msg: string): void => { this.logger(msg); } diff --git a/web/src/memory.ts b/web/src/memory.ts index b0d4ff3bf194..850f3bd37195 100644 --- a/web/src/memory.ts +++ b/web/src/memory.ts @@ -137,16 +137,6 @@ export class Memory { result.set(this.viewU8.slice(ptr, ptr + numBytes)); return result; } - /** - * Load TVMByteArray from ptr. - * - * @param ptr The address of the header. - */ - loadTVMBytes(ptr: Pointer): Uint8Array { - const data = this.loadPointer(ptr); - const length = this.loadUSize(ptr + this.sizeofPtr()); - return this.loadRawBytes(data, length); - } /** * Load null-terminated C-string from ptr. * @param ptr The head address @@ -178,7 +168,56 @@ export class Memory { } this.viewU8.set(bytes, ptr); } - + // the following functions are related to TVM FFI + /** + * Load the object type index from the object handle. + * @param objectHandle The handle of the object. + * @returns The object type index. + */ + loadObjectTypeIndex(objectHandle: Pointer): number { + return this.loadI32(objectHandle); + } + /** + * Load the type key from the type info pointer. + * @param typeInfoPtr The pointer to the type info. + * @returns The type key. + */ + loadTypeInfoTypeKey(typeInfoPtr: Pointer): string { + const typeKeyPtr = typeInfoPtr + 2 * SizeOf.I32; + return this.loadByteArrayAsString(typeKeyPtr); + } + /** + * Load bytearray as string from ptr. + * @param byteArrayPtr The head address of the bytearray. + */ + loadByteArrayAsString(byteArrayPtr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const ptr = this.loadPointer(byteArrayPtr); + const length = this.loadUSize(byteArrayPtr + this.sizeofPtr()); + // NOTE: the views are still valid for read. + const ret = []; + for (let i = 0; i < length; i++) { + ret.push(String.fromCharCode(this.viewU8[ptr + i])); + } + return ret.join(""); + } + /** + * Load bytearray as bytes from ptr. + * @param byteArrayPtr The head address of the bytearray. + */ + loadByteArrayAsBytes(byteArrayPtr: Pointer): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const ptr = this.loadPointer(byteArrayPtr); + const length = this.loadUSize(byteArrayPtr + this.sizeofPtr()); + const result = new Uint8Array(length); + result.set(this.viewU8.slice(ptr, ptr + length)); + return result; +} + // private functions /** * Update memory view after the memory growth. */ @@ -365,6 +404,21 @@ export class CachedCallStack implements Disposable { this.viewU8.set(bytes, offset); } + /** + * Allocate a byte array for a string and return the offset of the byte array. + * @param data The string to allocate. + * @returns The offset of the byte array. + */ + allocByteArrayForString(data: string): PtrOffset { + const dataUint8: Uint8Array = StringToUint8Array(data); + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(dataUint8.length); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + this.storeRawBytes(dataOffset, dataUint8); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + return headerOffset; + } /** * Allocate then set C-String pointer to the offset. * This function will call into allocBytes to allocate necessary data. diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index 46848a6dec1c..1e3af6f6438e 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -17,7 +17,7 @@ * under the License. */ -import { SizeOf, ArgTypeCode } from "./ctypes"; +import { SizeOf, TypeIndex } from "./ctypes"; import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; import { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu"; import * as compact from "./compact"; @@ -228,21 +228,16 @@ export class RPCServer { // eslint-disable-next-line @typescript-eslint/no-unused-vars const ver = Uint8ArrayToString(reader.readByteArray()); const nargs = reader.readU32(); - const tcodes = []; const args = []; for (let i = 0; i < nargs; ++i) { - tcodes.push(reader.readU32()); - } - - for (let i = 0; i < nargs; ++i) { - const tcode = tcodes[i]; - if (tcode === ArgTypeCode.TVMStr) { + const typeIndex = reader.readU32(); + if (typeIndex === TypeIndex.kTVMFFIRawStr) { const str = Uint8ArrayToString(reader.readByteArray()); args.push(str); - } else if (tcode === ArgTypeCode.TVMBytes) { + } else if (typeIndex === TypeIndex.kTVMFFIByteArrayPtr) { args.push(reader.readByteArray()); } else { - throw new Error("cannot support type code " + tcode); + throw new Error("cannot support type index " + typeIndex); } } this.onInitServer(args, header, body); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 5c47c0e7a52f..47902086f588 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -20,7 +20,7 @@ /** * TVM JS Wasm Runtime library. */ -import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; +import { Pointer, PtrOffset, SizeOf, TypeIndex } from "./ctypes"; import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./support"; @@ -90,8 +90,8 @@ class FFILibrary implements Disposable { checkCall(code: number): void { if (code != 0) { const msgPtr = (this.exports - .TVMGetLastError as ctypes.FTVMGetLastError)(); - throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + .TVMFFIWasmGetLastError as ctypes.FTVMFFIWasmGetLastError)(); + throw new Error(this.memory.loadCString(msgPtr)); } } @@ -153,6 +153,13 @@ class FFILibrary implements Disposable { * Manages extra runtime context for the runtime. */ class RuntimeContext implements Disposable { + functionListGlobalNamesFunctor: PackedFunc; + moduleGetFunction: PackedFunc; + moduleImport: PackedFunc; + ndarrayEmpty: PackedFunc; + ndarrayCopyFromTo: PackedFunc; + ndarrayCopyFromJSBytes: PackedFunc; + ndarrayCopyToJSBytes: PackedFunc; arrayGetItem: PackedFunc; arrayGetSize: PackedFunc; arrayMake: PackedFunc; @@ -173,10 +180,21 @@ class RuntimeContext implements Disposable { applyPresenceAndFrequencyPenalty: PackedFunc; applySoftmaxWithTemperature: PackedFunc; concatEmbeddings: PackedFunc | undefined; - + bool: PackedFunc; private autoDisposeScope: Array> = []; - constructor(getGlobalFunc: (name: string) => PackedFunc) { + constructor( + getGlobalFunc: (name: string) => PackedFunc + ) { + this.functionListGlobalNamesFunctor = getGlobalFunc( + "ffi.FunctionListGlobalNamesFunctor" + ); + this.moduleGetFunction = getGlobalFunc("runtime.ModuleGetFunction"); + this.moduleImport = getGlobalFunc("runtime.ModuleImport"); + this.ndarrayEmpty = getGlobalFunc("runtime.TVMArrayAllocWithScope"); + this.ndarrayCopyFromTo = getGlobalFunc("runtime.TVMArrayCopyFromTo"); + this.ndarrayCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyFromBytes"); + this.ndarrayCopyToJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyToBytes"); this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem"); this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); this.arrayMake = getGlobalFunc("runtime.Array"); @@ -189,18 +207,14 @@ class RuntimeContext implements Disposable { this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name"); - this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); + this.makeShapeTuple = getGlobalFunc("ffi.Shape"); this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob"); this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); - try { - this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); - } catch { - // TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility. - } + this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); } dispose(): void { @@ -306,35 +320,6 @@ export class Scalar { } } -/** - * Cell holds the PackedFunc object. - */ -class PackedFuncCell implements Disposable { - private handle: Pointer; - private lib: FFILibrary; - - constructor(handle: Pointer, lib: FFILibrary) { - this.handle = handle; - this.lib = lib; - } - - dispose(): void { - if (this.handle != 0) { - this.lib.checkCall( - (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) - ); - this.handle = 0; - } - } - - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("PackedFunc has already been disposed"); - } - return this.handle; - } -} - const DeviceEnumToStr: Record = { 1: "cpu", 2: "cuda", @@ -392,7 +377,7 @@ export class DLDevice { toString(): string { return ( - DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + DeviceEnumToStr[this.deviceType] + ":" + this.deviceId.toString() ); } } @@ -444,12 +429,78 @@ export class DLDataType { } } +/** + * Generic object base + */ +export class TVMObject implements Disposable { + protected handle: Pointer; + protected lib: FFILibrary; + protected ctx: RuntimeContext; + + constructor( + handle: Pointer, + lib: FFILibrary, + ctx: RuntimeContext + ) { + this.handle = handle; + this.lib = lib; + this.ctx = ctx; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull = true): Pointer { + if (requireNotNull && this.handle === 0) { + throw Error("Object has already been disposed"); + } + return this.handle; + } + + /** get the type index of the object */ + typeIndex(): number { + if (this.handle === 0) { + throw Error("The current Object has already been disposed"); + } + return this.lib.memory.loadObjectTypeIndex(this.handle); + } + + /** get the type key of the object */ + typeKey(): string { + const type_index = this.typeIndex(); + const typeInfoPtr = (this.lib.exports.TVMFFIGetTypeInfo as ctypes.FTVMFFIGetTypeInfo)( + type_index + ); + return this.lib.memory.loadTypeInfoTypeKey(typeInfoPtr); + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell extends TVMObject { + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) { + super(handle, lib, ctx); + } +} + /** * n-dimnesional array. */ -export class NDArray implements Disposable { - /** Internal array handle. */ - private handle: Pointer; + +export class NDArray extends TVMObject { /** Number of dimensions. */ ndim: number; /** Data type of the array. */ @@ -463,16 +514,14 @@ export class NDArray implements Disposable { private byteOffset: number; private dltensor: Pointer; private dataPtr: Pointer; - private lib: FFILibrary; - private ctx: RuntimeContext; private dlDataType: DLDataType; - constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx: RuntimeContext) { - this.handle = handle; + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext, isView: boolean) { + // if the array is a view, we need to create a new object with a null handle + // so dispose won't trigger memory free + const objectHandle = isView ? 0 : handle; + super(objectHandle, lib, ctx); this.isView = isView; - this.lib = lib; - this.ctx = ctx; - if (this.isView) { this.dltensor = handle; } else { @@ -535,20 +584,6 @@ export class NDArray implements Disposable { /*relative_byte_offset=*/ new Scalar(0, "int"), ); } - - /** - * Get handle of ndarray, check it is not null. - * - * @param requireNotNull require handle is not null. - * @returns The handle. - */ - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("NDArray has already been disposed"); - } - return this.handle; - } - /** * Get dataPtr of NDarray * @@ -561,14 +596,6 @@ export class NDArray implements Disposable { return this.dataPtr; } - dispose(): void { - if (this.handle != 0 && !this.isView) { - this.lib.checkCall( - (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) - ); - this.handle = 0; - } - } /** * Copy data from another NDArray or javascript array. * The number of elements must match. @@ -581,13 +608,7 @@ export class NDArray implements Disposable { Int32Array | Int8Array | Uint8Array | Uint8ClampedArray ): this { if (data instanceof NDArray) { - this.lib.checkCall( - (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( - data.getHandle(), - this.getHandle(), - 0 - ) - ); + this.ctx.ndarrayCopyFromTo(data, this); return this; } else { const size = this.shape.reduce((a, b) => { @@ -639,21 +660,7 @@ export class NDArray implements Disposable { if (nbytes != data.length) { throw new Error("Expect the data's length equals nbytes=" + nbytes); } - - const stack = this.lib.getOrAllocCallStack(); - - const tempOffset = stack.allocRawBytes(nbytes); - const tempPtr = stack.ptrFromOffset(tempOffset); - this.lib.memory.storeRawBytes(tempPtr, data); - this.lib.checkCall( - (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( - this.getHandle(), - tempPtr, - nbytes - ) - ); - - this.lib.recycleCallStack(stack); + this.ctx.ndarrayCopyFromJSBytes(this, data); return this; } /** @@ -664,26 +671,7 @@ export class NDArray implements Disposable { if (this.device.deviceType != DeviceStrToEnum.cpu) { throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); } - const size = this.shape.reduce((a, b) => { - return a * b; - }, 1); - - const nbytes = this.dlDataType.numStorageBytes() * size; - const stack = this.lib.getOrAllocCallStack(); - - const tempOffset = stack.allocRawBytes(nbytes); - const tempPtr = stack.ptrFromOffset(tempOffset); - this.lib.checkCall( - (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( - this.getHandle(), - tempPtr, - nbytes - ) - ); - const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); - - this.lib.recycleCallStack(stack); - return ret; + return this.ctx.ndarrayCopyToJSBytes(this) as Uint8Array; } /** @@ -709,52 +697,22 @@ export class NDArray implements Disposable { } private getDLTensorFromArrayHandle(handle: Pointer): Pointer { - // Note: this depends on the NDArray C ABI. - // keep this function in case of ABI change. - return handle; + return handle + SizeOf.ObjectHeader; } } + /** * Runtime Module. */ -export class Module implements Disposable { - private handle: Pointer; - private lib: FFILibrary; - private makePackedFunc: (ptr: Pointer) => PackedFunc; - +export class Module extends TVMObject { constructor( handle: Pointer, lib: FFILibrary, - makePackedFunc: (ptr: Pointer) => PackedFunc + ctx: RuntimeContext, ) { - this.handle = handle; - this.lib = lib; - this.makePackedFunc = makePackedFunc; - } - - dispose(): void { - if (this.handle != 0) { - this.lib.checkCall( - (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) - ); - this.handle = 0; - } - } - - /** - * Get handle of module, check it is not null. - * - * @param requireNotNull require handle is not null. - * @returns The handle. - */ - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("Module has already been disposed"); - } - return this.handle; + super(handle, lib, ctx); } - /** * Get a function in the module. * @param name The name of the function. @@ -762,33 +720,7 @@ export class Module implements Disposable { * @returns The result function. */ getFunction(name: string, queryImports = true): PackedFunc { - if (this.handle === 0) { - throw Error("Module has already been disposed"); - } - const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); - - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - - stack.commitToWasmMemory(outOffset); - - this.lib.checkCall( - (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( - this.getHandle(), - stack.ptrFromOffset(nameOffset), - queryImports ? 1 : 0, - outPtr - ) - ); - const handle = this.lib.memory.loadPointer(outPtr); - this.lib.recycleCallStack(stack); - if (handle === 0) { - throw Error("Cannot find function " + name); - } - const ret = this.makePackedFunc(handle); - return ret; + return this.ctx.moduleGetFunction(this, name, queryImports) as PackedFunc; } /** @@ -796,100 +728,16 @@ export class Module implements Disposable { * @param mod The module to be imported. */ importModule(mod: Module): void { - this.lib.checkCall( - (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( - this.getHandle(), - mod.getHandle() - ) - ); + this.ctx.moduleImport(this, mod); } } -/** - * Generic object base - */ -export class TVMObject implements Disposable { - private handle: Pointer; - private lib: FFILibrary; - protected ctx: RuntimeContext; - - constructor( - handle: Pointer, - lib: FFILibrary, - ctx: RuntimeContext - ) { - this.handle = handle; - this.lib = lib; - this.ctx = ctx; - } - - dispose(): void { - if (this.handle != 0) { - this.lib.checkCall( - (this.lib.exports.TVMObjectFree as ctypes.FTVMObjectFree)(this.handle) - ); - this.handle = 0; - } - } - - /** - * Get handle of module, check it is not null. - * - * @param requireNotNull require handle is not null. - * @returns The handle. - */ - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("Module has already been disposed"); - } - return this.handle; - } - - /** get the type index of the object */ - typeIndex(): number { - if (this.handle === 0) { - throw Error("The current Object has already been disposed"); - } - const stack = this.lib.getOrAllocCallStack(); - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - - this.lib.checkCall( - (this.lib.exports.TVMObjectGetTypeIndex as ctypes.FTVMObjectGetTypeIndex)( - this.getHandle(), - outPtr - ) - ); - const result = this.lib.memory.loadU32(outPtr); - this.lib.recycleCallStack(stack); - return result; - } - - /** get the type key of the object */ - typeKey(): string { - const type_index = this.typeIndex(); - const stack = this.lib.getOrAllocCallStack(); - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - this.lib.checkCall( - (this.lib.exports.TVMObjectTypeIndex2Key as ctypes.FTVMObjectTypeIndex2Key)( - type_index, - outPtr - ) - ); - const result = this.lib.memory.loadCString( - this.lib.memory.loadPointer(outPtr) - ); - this.lib.recycleCallStack(stack); - return result; - } -} /** Objectconstructor */ type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; /** All possible object types. */ -type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc; +type TVMObjectBase = TVMObject | PackedFunc; /** Runtime array object. */ export class TVMArray extends TVMObject { @@ -1212,38 +1060,16 @@ export class Instance implements Disposable { * @returns The name list. */ listGlobalFuncNames(): Array { - const stack = this.lib.getOrAllocCallStack(); - - const outSizeOffset = stack.allocPtrArray(2); - - const outSizePtr = stack.ptrFromOffset(outSizeOffset); - const outArrayPtr = stack.ptrFromOffset( - outSizeOffset + this.lib.sizeofPtr() - ); - - this.lib.checkCall( - (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( - outSizePtr, - outArrayPtr - ) - ); - - const size = this.memory.loadI32(outSizePtr); - const array = this.memory.loadPointer(outArrayPtr); - const names: Array = []; - - for (let i = 0; i < size; ++i) { - names.push( - this.memory.loadCString( - this.memory.loadPointer(array + this.lib.sizeofPtr() * i) - ) - ); - } - - this.lib.recycleCallStack(stack); - return names; + return this.withNewScope(() => { + const functor = this.ctx.functionListGlobalNamesFunctor(); + const numNames = functor(new Scalar(-1, "int")) as number; + const names = new Array(numNames); + for (let i = 0; i < numNames; i++) { + names[i] = functor(new Scalar(i, "int")) as string; + } + return names; + }); } - /** * Register function to be global function in tvm runtime. * @param name The name of the function. @@ -1262,12 +1088,10 @@ export class Instance implements Disposable { const ioverride = override ? 1 : 0; const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const nameOffset = stack.allocByteArrayForString(name); stack.commitToWasmMemory(); - this.lib.checkCall( - (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + (this.lib.exports.TVMFFIFunctionSetGlobal as ctypes.FTVMFFIFunctionSetGlobal)( stack.ptrFromOffset(nameOffset), packedFunc._tvmPackedCell.getHandle(), ioverride @@ -1289,15 +1113,14 @@ export class Instance implements Disposable { private getGlobalFuncInternal(name: string, autoAttachToScope = true): PackedFunc { const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const nameOffset = stack.allocByteArrayForString(name); const outOffset = stack.allocPtrArray(1); const outPtr = stack.ptrFromOffset(outOffset); stack.commitToWasmMemory(outOffset); this.lib.checkCall( - (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + (this.exports.TVMFFIFunctionGetGlobal as ctypes.FTVMFFIFunctionGetGlobal)( stack.ptrFromOffset(nameOffset), outPtr ) @@ -1335,7 +1158,7 @@ export class Instance implements Disposable { private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { if (this.isPackedFunc(func)) return func as PackedFunc; - const ret = this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + const ret = this.createPackedFuncFromSafeCallType(this.wrapJSFuncAsSafeCallType(func)); if (autoAttachToScope) return this.ctx.attachToCurrentScope(ret); return ret; } @@ -1603,52 +1426,6 @@ export class Instance implements Disposable { } } - /** - * Convert dtype to {@link DLDataType} - * - * @param dtype The input dtype string or DLDataType. - * @returns The converted result. - */ - toDLDataType(dtype: string | DLDataType): DLDataType { - if (dtype instanceof DLDataType) return dtype; - if (typeof dtype === "string") { - let pattern = dtype; - let code, - bits = 32, - lanes = 1; - if (pattern.substring(0, 5) === "float") { - pattern = pattern.substring(5, pattern.length); - code = DLDataTypeCode.Float; - } else if (pattern.substring(0, 3) === "int") { - pattern = pattern.substring(3, pattern.length); - code = DLDataTypeCode.Int; - } else if (pattern.substring(0, 4) === "uint") { - pattern = pattern.substring(4, pattern.length); - code = DLDataTypeCode.UInt; - } else if (pattern.substring(0, 6) === "handle") { - pattern = pattern.substring(5, pattern.length); - code = DLDataTypeCode.OpaqueHandle; - bits = 64; - } else { - throw new Error("Unknown dtype " + dtype); - } - - const arr = pattern.split("x"); - if (arr.length >= 1) { - const parsed = parseInt(arr[0]); - if (parsed + "" === arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new DLDataType(code, bits, lanes); - } else { - throw new Error("Unknown dtype " + dtype); - } - } - /** * Create a new {@link Scalar} that can be passed to a PackedFunc. * @param value The number value. @@ -1698,36 +1475,8 @@ export class Instance implements Disposable { dtype: string | DLDataType = "float32", dev: DLDevice = this.device("cpu", 0) ): NDArray { - dtype = this.toDLDataType(dtype); shape = typeof shape === "number" ? [shape] : shape; - - const stack = this.lib.getOrAllocCallStack(); - const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); - for (let i = 0; i < shape.length; ++i) { - stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); - } - - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - stack.commitToWasmMemory(outOffset); - - this.lib.checkCall( - (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( - stack.ptrFromOffset(shapeOffset), - shape.length, - dtype.code, - dtype.bits, - dtype.lanes, - dev.deviceType, - dev.deviceId, - outPtr - ) - ); - const ret = this.ctx.attachToCurrentScope( - new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx) - ); - this.lib.recycleCallStack(stack); - return ret; + return this.ctx.ndarrayEmpty(this.makeShapeTuple(shape), dtype, dev, null); } /** @@ -1936,15 +1685,13 @@ export class Instance implements Disposable { typeKey: string ): number { const stack = this.lib.getOrAllocCallStack(); - const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1); - stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey)); + const typeKeyOffset = stack.allocByteArrayForString(typeKey); const outOffset = stack.allocPtrArray(1); const outPtr = stack.ptrFromOffset(outOffset); stack.commitToWasmMemory(outOffset); - this.lib.checkCall( - (this.lib.exports.TVMObjectTypeKey2Index as ctypes.FTVMObjectTypeKey2Index)( + (this.lib.exports.TVMFFITypeKeyToIndex as ctypes.FTVMFFITypeKeyToIndex)( stack.ptrFromOffset(typeKeyOffset), outPtr ) @@ -2153,6 +1900,10 @@ export class Instance implements Disposable { (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMArray(handle, lib, ctx); }); + this.registerObjectConstructor("runtime.Module", + (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { + return new Module(handle, lib, ctx); + }); } /** Register global packed functions needed by the backend to the env. */ @@ -2224,8 +1975,8 @@ export class Instance implements Disposable { this.registerAsyncServerFunc("testing.asyncAddOne", addOne); } - private createPackedFuncFromCFunc( - func: ctypes.FTVMWasmPackedCFunc + private createPackedFuncFromSafeCallType( + func: ctypes.FTVMFFIWasmSafeCallType ): PackedFunc { let findex = this.env.packedCFuncTable.length; if (this.env.packedCFuncTableFreeId.length != 0) { @@ -2240,7 +1991,7 @@ export class Instance implements Disposable { const outPtr = stack.ptrFromOffset(outOffset); this.lib.checkCall( (this.exports - .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + .TVMFFIWasmFunctionCreate as ctypes.FTVMFFIWasmFunctionCreate)( findex, outPtr ) @@ -2256,20 +2007,19 @@ export class Instance implements Disposable { * * @parma stack The call stack * @param args The input arguments. - * @param argsValue The offset of argsValue. - * @param argsCode The offset of argsCode. + * @param packedArgs The offset of packedArgs. */ setPackedArguments( stack: CachedCallStack, args: Array, - argsValue: PtrOffset, - argsCode: PtrOffset + packedArgs: PtrOffset, ): void { for (let i = 0; i < args.length; ++i) { let val = args[i]; const tp = typeof val; - const valueOffset = argsValue + i * SizeOf.TVMValue; - const codeOffset = argsCode + i * SizeOf.I32; + const argOffset = packedArgs + i * SizeOf.TVMFFIAny; + const argTypeIndexOffset = argOffset; + const argValueOffset = argOffset + SizeOf.I32 * 2; // Convert string[] to a TVMArray of, hence treated as a TVMObject if (val instanceof Array && val.every(e => typeof e === "string")) { @@ -2278,97 +2028,100 @@ export class Instance implements Disposable { val = this.makeTVMArray(tvmStringArray); } + // clear off the extra padding valuesbefore ptr storage + stack.storeI32(argTypeIndexOffset + SizeOf.I32, 0); + stack.storeI32(argValueOffset + SizeOf.I32, 0); if (val instanceof NDArray) { if (!val.isView) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINDArray); + stack.storePtr(argValueOffset, val.getHandle()); } else { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDLTensorPtr); + stack.storePtr(argValueOffset, val.getHandle()); } } else if (val instanceof Scalar) { if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { - stack.storeI64(valueOffset, val.value); - stack.storeI32(codeOffset, ArgTypeCode.Int); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIInt); + stack.storeI64(argValueOffset, val.value); } else if (val.dtype.startsWith("float")) { - stack.storeF64(valueOffset, val.value); - stack.storeI32(codeOffset, ArgTypeCode.Float); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat); + stack.storeF64(argValueOffset, val.value); } else { assert(val.dtype === "handle", "Expect handle"); - stack.storePtr(valueOffset, val.value); - stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIOpaquePtr); + stack.storePtr(argValueOffset, val.value); } } else if (val instanceof DLDevice) { - stack.storeI32(valueOffset, val.deviceType); - stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); - stack.storeI32(codeOffset, ArgTypeCode.DLDevice); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDevice); + stack.storeI32(argValueOffset, val.deviceType); + stack.storeI32(argValueOffset + SizeOf.I32, val.deviceId); + } else if (tp === "boolean") { + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIBool); + stack.storeI64(argValueOffset, val ? 1 : 0); } else if (tp === "number") { - stack.storeF64(valueOffset, val); - stack.storeI32(codeOffset, ArgTypeCode.Float); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat); + stack.storeF64(argValueOffset, val); // eslint-disable-next-line no-prototype-builtins } else if (tp === "function" && val.hasOwnProperty("_tvmPackedCell")) { - stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle()); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction); } else if (val === null || val === undefined) { - stack.storePtr(valueOffset, 0); - stack.storeI32(codeOffset, ArgTypeCode.Null); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINone); + stack.storePtr(argValueOffset, 0); } else if (tp === "string") { - stack.allocThenSetArgString(valueOffset, val); - stack.storeI32(codeOffset, ArgTypeCode.TVMStr); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIRawStr); + stack.allocThenSetArgString(argValueOffset, val); } else if (val instanceof Uint8Array) { - stack.allocThenSetArgBytes(valueOffset, val); - stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIByteArrayPtr); + stack.allocThenSetArgBytes(argValueOffset, val); } else if (val instanceof Function) { val = this.toPackedFuncInternal(val, false); stack.tempArgs.push(val); - stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction); + stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle()); } else if (val instanceof Module) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIModule); + stack.storePtr(argValueOffset, val.getHandle()); } else if (val instanceof TVMObject) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMObjectHandle); + stack.storeI32(argTypeIndexOffset, val.typeIndex()); + stack.storePtr(argValueOffset, val.getHandle()); } else { - throw new Error("Unsupported argument type " + tp); + throw new Error("Unsupported argument type " + tp + " value=`" + val.toString() + "`"); } } } - private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + private wrapJSFuncAsSafeCallType(func: Function): ctypes.FTVMFFIWasmSafeCallType { const lib = this.lib; return ( - argValues: Pointer, - argCodes: Pointer, - nargs: number, - ret: Pointer, // eslint-disable-next-line @typescript-eslint/no-unused-vars - _handle: Pointer + self: Pointer, + packedArgs: Pointer, + numArgs: number, + ret: Pointer ): number => { const jsArgs = []; // use scope to track js values. this.ctx.beginScope(); - for (let i = 0; i < nargs; ++i) { - const valuePtr = argValues + i * SizeOf.TVMValue; - const codePtr = argCodes + i * SizeOf.I32; - let tcode = lib.memory.loadI32(codePtr); - - if ( - tcode === ArgTypeCode.TVMObjectHandle || - tcode === ArgTypeCode.TVMObjectRValueRefArg || - tcode === ArgTypeCode.TVMPackedFuncHandle || - tcode === ArgTypeCode.TVMNDArrayHandle || - tcode === ArgTypeCode.TVMModuleHandle - ) { + for (let i = 0; i < numArgs; ++i) { + const argPtr = packedArgs + i * SizeOf.TVMFFIAny; + const typeIndex = lib.memory.loadI32(argPtr); + + if (typeIndex >= TypeIndex.kTVMFFIRawStr) { + // NOTE: the following code have limitations in asyncify mode. + // The reason is that the TVMFFIAnyViewToOwnedAny will simply + // get skipped during the rewinding process, causing memory failure + if (!this.asyncifyHandler.isNormalStackState()) { + throw Error("Cannot handle str/object argument callback in asyncify mode"); + } lib.checkCall( - (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( - valuePtr, - codePtr + (lib.exports.TVMFFIAnyViewToOwnedAny as ctypes.FTVMFFIAnyViewToOwnedAny)( + argPtr, + argPtr ) ); } - tcode = lib.memory.loadI32(codePtr); - jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); + jsArgs.push(this.retValueToJS(argPtr, true)); } let rv: any; @@ -2378,12 +2131,16 @@ export class Instance implements Disposable { // error handling // store error via SetLastError this.ctx.endScope(); - const errMsg = "JSCallbackError: " + error.message; + const errKind = "JSCallbackError" + const errMsg = error.message; const stack = lib.getOrAllocCallStack(); + const errKindOffset = stack.allocRawBytes(errKind.length + 1); + stack.storeRawBytes(errKindOffset, StringToUint8Array(errKind)); const errMsgOffset = stack.allocRawBytes(errMsg.length + 1); stack.storeRawBytes(errMsgOffset, StringToUint8Array(errMsg)); stack.commitToWasmMemory(); - (this.lib.exports.TVMAPISetLastError as ctypes.FTVMAPISetLastError)( + (this.lib.exports.FTVMFFIErrorSetRaisedByCStr as ctypes.FTVMFFIErrorSetRaisedByCStr)( + stack.ptrFromOffset(errKindOffset), stack.ptrFromOffset(errMsgOffset) ); this.lib.recycleCallStack(stack); @@ -2395,18 +2152,14 @@ export class Instance implements Disposable { this.ctx.endScope(); if (rv !== undefined && rv !== null) { const stack = lib.getOrAllocCallStack(); - const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); - const codeOffset = stack.allocRawBytes(SizeOf.I32); - this.setPackedArguments(stack, [rv], valueOffset, codeOffset); - const valuePtr = stack.ptrFromOffset(valueOffset); - const codePtr = stack.ptrFromOffset(codeOffset); + const argOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); + this.setPackedArguments(stack, [rv], argOffset); stack.commitToWasmMemory(); + const argPtr = stack.ptrFromOffset(argOffset); lib.checkCall( - (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( - ret, - valuePtr, - codePtr, - 1 + (lib.exports.TVMFFIAnyViewToOwnedAny as ctypes.FTVMFFIAnyViewToOwnedAny)( + argPtr, + ret ) ); lib.recycleCallStack(stack); @@ -2416,38 +2169,25 @@ export class Instance implements Disposable { } private makePackedFunc(handle: Pointer): PackedFunc { - const cell = new PackedFuncCell(handle, this.lib); - + const cell = new PackedFuncCell(handle, this.lib, this.ctx); const packedFunc = (...args: any): any => { const stack = this.lib.getOrAllocCallStack(); - - const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); - const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); - - this.setPackedArguments(stack, args, valueOffset, tcodeOffset); - - const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); - const rcodeOffset = stack.allocRawBytes(SizeOf.I32); - const rvaluePtr = stack.ptrFromOffset(rvalueOffset); - const rcodePtr = stack.ptrFromOffset(rcodeOffset); - - // pre-store the rcode to be null, in case caller unwind - // and not have chance to reset this rcode. - stack.storeI32(rcodeOffset, ArgTypeCode.Null); + const argsOffset = stack.allocRawBytes(SizeOf.TVMFFIAny * args.length); + this.setPackedArguments(stack, args, argsOffset); + const retOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); + // pre-store the result to be null + stack.storeI32(retOffset, TypeIndex.kTVMFFINone); stack.commitToWasmMemory(); - this.lib.checkCall( - (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + (this.exports.TVMFFIFunctionCall as ctypes.FTVMFFIFunctionCall)( cell.getHandle(), - stack.ptrFromOffset(valueOffset), - stack.ptrFromOffset(tcodeOffset), + stack.ptrFromOffset(argsOffset), args.length, - rvaluePtr, - rcodePtr + stack.ptrFromOffset(retOffset) ) ); - const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); + const ret = this.retValueToJS(stack.ptrFromOffset(retOffset), false); this.lib.recycleCallStack(stack); return ret; }; @@ -2463,78 +2203,91 @@ export class Instance implements Disposable { /** * Creaye return value of the packed func. The value us auto-tracked for dispose. - * @param rvaluePtr The location of rvalue - * @param tcode The type code. + * @param resultAnyPtr The location of rvalue * @param callbackArg Whether it is being used in callbackArg. * @returns The JS value. */ - private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { - switch (tcode) { - case ArgTypeCode.Int: - case ArgTypeCode.UInt: - case ArgTypeCode.TVMArgBool: - return this.memory.loadI64(rvaluePtr); - case ArgTypeCode.Float: - return this.memory.loadF64(rvaluePtr); - case ArgTypeCode.TVMOpaqueHandle: { - return this.memory.loadPointer(rvaluePtr); + private retValueToJS(resultAnyPtr: Pointer, callbackArg: boolean): any { + const typeIndex = this.memory.loadI32(resultAnyPtr); + const valuePtr = resultAnyPtr + SizeOf.I32 * 2; + switch (typeIndex) { + case TypeIndex.kTVMFFINone: return undefined; + case TypeIndex.kTVMFFIBool: + return this.memory.loadI64(valuePtr) != 0; + case TypeIndex.kTVMFFIInt: + return this.memory.loadI64(valuePtr); + case TypeIndex.kTVMFFIFloat: + return this.memory.loadF64(valuePtr); + case TypeIndex.kTVMFFIOpaquePtr: { + return this.memory.loadPointer(valuePtr); } - case ArgTypeCode.TVMNDArrayHandle: { + case TypeIndex.kTVMFFINDArray: { return this.ctx.attachToCurrentScope( - new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, this.ctx) + new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false) ); } - case ArgTypeCode.TVMDLTensorHandle: { + case TypeIndex.kTVMFFIDLTensorPtr: { assert(callbackArg); // no need to attach as we are only looking at view - return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib, this.ctx); + return new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx, true); } - case ArgTypeCode.TVMPackedFuncHandle: { + case TypeIndex.kTVMFFIFunction: { return this.ctx.attachToCurrentScope( - this.makePackedFunc(this.memory.loadPointer(rvaluePtr)) + this.makePackedFunc(this.memory.loadPointer(valuePtr)) ); } - case ArgTypeCode.TVMModuleHandle: { - return this.ctx.attachToCurrentScope( - new Module( - this.memory.loadPointer(rvaluePtr), - this.lib, - (ptr: Pointer) => { - return this.ctx.attachToCurrentScope(this.makePackedFunc(ptr)); - } - ) + case TypeIndex.kTVMFFIDevice: { + const deviceType = this.memory.loadI32(valuePtr); + const deviceId = this.memory.loadI32(valuePtr + SizeOf.I32); + return this.device(deviceType, deviceId); + } + case TypeIndex.kTVMFFIDataType: { + // simply return dtype as tring to keep things simple + this.lib.checkCall( + (this.lib.exports.TVMFFIDataTypeToString as ctypes.FTVMFFIDataTypeToString)(valuePtr, valuePtr) + ); + const strObjPtr = this.memory.loadPointer(valuePtr); + const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + ); + return result; + } + case TypeIndex.kTVMFFIStr: { + const strObjPtr = this.memory.loadPointer(valuePtr); + const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) ); + return result; } - case ArgTypeCode.TVMObjectHandle: { - const obj = new TVMObject( - this.memory.loadPointer(rvaluePtr), - this.lib, - this.ctx + case TypeIndex.kTVMFFIBytes: { + const bytesObjPtr = this.memory.loadPointer(valuePtr); + const result = this.memory.loadByteArrayAsBytes(bytesObjPtr + SizeOf.ObjectHeader); + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(bytesObjPtr) ); - const func = this.objFactory.get(obj.typeIndex()) - if (func != undefined) { - return this.ctx.attachToCurrentScope( - func(obj.getHandle(), this.lib, this.ctx) + return result; + } + default: { + if (typeIndex >= TypeIndex.kTVMFFIStaticObjectBegin) { + const obj = new TVMObject( + this.memory.loadPointer(valuePtr), + this.lib, + this.ctx ); + const func = this.objFactory.get(obj.typeIndex()) + if (func != undefined) { + return this.ctx.attachToCurrentScope( + func(obj.getHandle(), this.lib, this.ctx) + ); + } else { + return this.ctx.attachToCurrentScope(obj); + } } else { - return this.ctx.attachToCurrentScope(obj); + throw new Error("Unsupported return type code=" + typeIndex); } } - case ArgTypeCode.Null: return undefined; - case ArgTypeCode.DLDevice: { - const deviceType = this.memory.loadI32(rvaluePtr); - const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); - return this.device(deviceType, deviceId); - } - case ArgTypeCode.TVMStr: { - const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); - return ret; - } - case ArgTypeCode.TVMBytes: { - return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); - } - default: - throw new Error("Unsupported return type code=" + tcode); } } } diff --git a/web/tests/node/test_ndarray.js b/web/tests/node/test_ndarray.js index 8d369216d2d8..495d05070147 100644 --- a/web/tests/node/test_ndarray.js +++ b/web/tests/node/test_ndarray.js @@ -38,7 +38,7 @@ function testArrayCopy(dtype, arrayType) { let data = [1, 2, 3, 4, 5, 6]; let a = tvm.empty([2, 3], dtype).copyFrom(data); - assert(a.device.toString() == "cpu(0)"); + assert(a.device.toString() == "cpu:0"); assert(a.shape[0] == 2 && a.shape[1] == 3); let ret = a.toArray(); diff --git a/web/tests/node/test_object.js b/web/tests/node/test_object.js index 2423ef4ceb46..3db3bd9c8431 100644 --- a/web/tests/node/test_object.js +++ b/web/tests/node/test_object.js @@ -42,10 +42,5 @@ test("object", () => { let t1 = b.get(1); assert(t1.getHandle() == t.getHandle()); - - let ret_string = tvm.getGlobalFunc("testing.ret_string"); - let s1 = ret_string("hello"); - assert(s1 == "hello"); - ret_string.dispose(); }); }); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index e1d070f0e473..e2b6c7b7c9b3 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -37,7 +37,7 @@ let tvm = new tvmjs.Instance( test("GetGlobal", () => { tvm.beginScope(); let flist = tvm.listGlobalFuncNames(); - let faddOne = tvm.getGlobalFunc("testing.add_one"); + let faddOne = tvm.getGlobalFunc("tvmjs.testing.add_one"); let fecho = tvm.getGlobalFunc("testing.echo"); assert(faddOne(tvm.scalar(1, "int")) == 2); @@ -146,31 +146,6 @@ test("ExceptionPassing", () => { tvm.endScope(); }); - -test("AsyncifyFunc", async () => { - if (!tvm.asyncifyEnabled()) { - console.log("Skip asyncify tests as it is not enabled.."); - return; - } - tvm.beginScope(); - tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) { - await new Promise(resolve => setTimeout(resolve, 10)); - return x; - }); - let fecho = tvm.wrapAsyncifyPackedFunc( - tvm.getGlobalFunc("async_sleep_echo") - ); - let fcall = tvm.wrapAsyncifyPackedFunc( - tvm.getGlobalFunc("testing.call") - ); - assert((await fecho(1)) == 1); - assert((await fecho(2)) == 2); - assert((await fcall(fecho, 2) == 2)); - tvm.endScope(); - assert(fecho._tvmPackedCell.getHandle(false) == 0); - assert(fcall._tvmPackedCell.getHandle(false) == 0); -}); - test("NDArrayCbArg", () => { tvm.beginScope(); let use_count = tvm.getGlobalFunc("testing.object_use_count"); @@ -204,8 +179,32 @@ test("NDArrayCbArg", () => { test("Logging", () => { tvm.beginScope(); - const log_info = tvm.getGlobalFunc("testing.log_info_str"); + const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str"); log_info("helow world") log_info.dispose(); tvm.endScope(); }); + +test("AsyncifyFunc", async () => { + if (!tvm.asyncifyEnabled()) { + console.log("Skip asyncify tests as it is not enabled.."); + return; + } + tvm.beginScope(); + tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) { + await new Promise(resolve => setTimeout(resolve, 10)); + return x; + }); + let fecho = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("async_sleep_echo") + ); + let fcall = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("tvmjs.testing.call") + ); + assert((await fecho(1)) == 1); + assert((await fecho(2)) == 2); + assert((await fcall(fecho, 2) == 2)); + tvm.endScope(); + assert(fecho._tvmPackedCell.getHandle(false) == 0); + assert(fcall._tvmPackedCell.getHandle(false) == 0); +}); diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index e831afd9d3f8..8925da00a489 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -35,7 +35,6 @@ def test_rpc(): return # generate the wasm library target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") - runtime = Runtime("cpp", {"system-lib": True}) n = te.var("n") A = te.placeholder((n,), name="A")