Skip to content

Commit f287bab

Browse files
committed
[REFACTOR][FFI] Upgrade Web Runtime to new FFI
This PR refactors the web runtime to the new FFI protocol. Tested through rpc tests and local tests.
1 parent 7275cf0 commit f287bab

File tree

22 files changed

+604
-848
lines changed

22 files changed

+604
-848
lines changed

ffi/include/tvm/ffi/c_api.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,10 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType*
579579
* \return 0 when success, nonzero when failure happens
580580
* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree.
581581
The content of string can be accessed via TVMFFIObjectGetByteArrayPtr.
582+
583+
* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues.
582584
*/
583-
TVM_FFI_DLL int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out);
585+
TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out);
584586

585587
//------------------------------------------------------------
586588
// Section: Backend noexcept functions for internal use

ffi/include/tvm/ffi/dtype.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ inline DLDataType StringToDLDataType(const String& str) {
121121

122122
inline String DLDataTypeToString(DLDataType dtype) {
123123
TVMFFIObjectHandle out;
124-
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(dtype, &out));
124+
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
125125
return String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out)));
126126
}
127127

ffi/include/tvm/ffi/error.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1
5252
#endif
5353

54+
#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW
55+
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0
56+
#endif
57+
5458
namespace tvm {
5559
namespace ffi {
5660

@@ -212,8 +216,10 @@ class ErrorBuilder {
212216
*
213217
* \endcode
214218
*/
215-
#define TVM_FFI_THROW(ErrorKind) \
216-
::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, false).stream()
219+
#define TVM_FFI_THROW(ErrorKind) \
220+
::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, \
221+
TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \
222+
.stream()
217223

218224
/*!
219225
* \brief Explicitly log error in stderr and then throw the error.

ffi/src/ffi/dtype.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) {
320320
TVM_FFI_SAFE_CALL_END();
321321
}
322322

323-
int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) {
323+
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) {
324324
TVM_FFI_SAFE_CALL_BEGIN();
325-
tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(dtype));
325+
tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype));
326326
*out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str));
327327
TVM_FFI_SAFE_CALL_END();
328328
}

python/tvm/ffi/cython/base.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ cdef extern from "tvm/ffi/c_api.h":
150150
int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil
151151
int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil
152152
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
153-
int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) nogil
153+
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) nogil
154154
const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) nogil;
155155
int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t require_alignment,
156156
int32_t require_contiguous, TVMFFIObjectHandle* out) nogil

python/tvm/ffi/cython/dtype.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ cdef class DataType:
9494
def __str__(self):
9595
cdef TVMFFIObjectHandle dtype_str
9696
cdef TVMFFIByteArray* bytes
97-
CHECK_CALL(TVMFFIDataTypeToString(self.cdtype, &dtype_str))
97+
CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str))
9898
bytes = TVMFFIBytesGetByteArrayPtr(dtype_str)
9999
res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
100100
CHECK_CALL(TVMFFIObjectFree(dtype_str))

web/.eslintignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
dist
22
debug
3+
tvmjs_runtime_wasi.js
4+
lib

web/apps/node/example.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
3131
tvmjs.instantiate(wasmSource, tvmjs.createPolyfillWASI())
3232
.then((tvm) => {
3333
tvm.beginScope();
34-
const log_info = tvm.getGlobalFunc("testing.log_info_str");
34+
const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str");
3535
log_info("hello world");
3636
// List all the global functions from the runtime.
3737
console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames());

web/emcc/tvmjs_support.cc

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,11 @@
2828
#define TVM_LOG_STACK_TRACE 0
2929
#define TVM_LOG_DEBUG 0
3030
#define TVM_LOG_CUSTOMIZE 1
31+
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
3132
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
3233

33-
#include <tvm/runtime/c_runtime_api.h>
34+
#include <tvm/ffi/function.h>
3435
#include <tvm/runtime/device_api.h>
35-
#include <tvm/runtime/packed_func.h>
36-
#include <tvm/runtime/registry.h>
3736

3837
#include "../../src/runtime/rpc/rpc_local_session.h"
3938

@@ -59,27 +58,33 @@ TVM_DLL void TVMWasmFreeSpace(void* data);
5958
* \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
6059
3A * \return 0 if success.
6160
*/
62-
TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out);
61+
TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out);
62+
63+
/*!
64+
* \brief Get the last error message.
65+
* \return The last error message.
66+
*/
67+
TVM_DLL const char* TVMFFIWasmGetLastError();
6368

6469
// --- APIs to be implemented by the frontend. ---
70+
6571
/*!
66-
* \brief Wasm frontend packed function caller.
72+
* \brief Wasm frontend new ffi call function caller.
6773
*
74+
* \param self The pointer to the ffi::Function.
6875
* \param args The arguments
69-
* \param type_codes The type codes of the arguments
7076
* \param num_args Number of arguments.
71-
* \param ret The return value handle.
72-
* \param resource_handle The handle additional resource handle from front-end.
77+
* \param result The return value handle.
7378
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
7479
*/
75-
extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
76-
void* resource_handle);
77-
80+
extern int TVMFFIWasmSafeCall(void* self, const TVMFFIAny* args, int32_t num_args,
81+
TVMFFIAny* result);
7882
/*!
79-
* \brief Wasm frontend resource finalizer.
80-
* \param resource_handle The pointer to the external resource.
83+
* \brief Delete ffi::Function.
84+
* \param self The pointer to the ffi::Function.
8185
*/
82-
extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
86+
extern void TVMFFIWasmFunctionDeleter(void* self);
87+
8388
} // extern "C"
8489

8590
void* TVMWasmAllocSpace(int size) {
@@ -89,9 +94,14 @@ void* TVMWasmAllocSpace(int size) {
8994

9095
void TVMWasmFreeSpace(void* arr) { delete[] static_cast<int64_t*>(arr); }
9196

92-
int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) {
93-
return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer,
94-
out);
97+
int TVMFFIWasmFunctionCreate(void* self, TVMFunctionHandle* out) {
98+
return TVMFFIFunctionCreate(self, TVMFFIWasmSafeCall, TVMFFIWasmFunctionDeleter, out);
99+
}
100+
101+
const char* TVMFFIWasmGetLastError() {
102+
static thread_local std::string last_error;
103+
last_error = ::tvm::ffi::details::MoveFromSafeCallRaised().what();
104+
return last_error.c_str();
95105
}
96106

97107
namespace tvm {
@@ -291,7 +301,7 @@ class AsyncLocalSession : public LocalSession {
291301
}
292302
};
293303

294-
TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
304+
TVM_FFI_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
295305
return CreateRPCSessionModule(std::make_shared<AsyncLocalSession>());
296306
});
297307

web/emcc/wasm_runtime.cc

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
#define TVM_LOG_DEBUG 0
2828
#define TVM_LOG_CUSTOMIZE 1
2929
#define TVM_FFI_USE_LIBBACKTRACE 0
30+
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
3031
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
3132

32-
#include <tvm/runtime/c_runtime_api.h>
3333
#include <tvm/runtime/logging.h>
3434

3535
#include "src/runtime/c_runtime_api.cc"
@@ -107,45 +107,24 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s
107107

108108
} // namespace detail
109109

110-
TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
111-
*ret = args[0];
112-
});
113-
114-
TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
115-
(args[0].cast<ffi::Function>()).CallPacked(args.Slice(1), ret);
116-
});
117-
118-
TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
119-
*ret = args[0].cast<String>();
120-
});
121-
122-
TVM_REGISTER_GLOBAL("testing.log_info_str")
110+
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.call")
123111
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
124-
LOG(INFO) << args[0].cast<String>();
112+
(args[0].cast<ffi::Function>()).CallPacked(args.Slice(1), ret);
125113
});
126114

127-
TVM_REGISTER_GLOBAL("testing.log_fatal_str")
115+
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.log_info_str")
128116
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
129-
LOG(FATAL) << args[0].cast<String>();
117+
LOG(INFO) << args[0].cast<String>();
130118
});
131119

132-
TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; });
120+
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.add_one").set_body_typed([](int x) { return x + 1; });
133121

134-
TVM_REGISTER_GLOBAL("testing.wrap_callback")
122+
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.wrap_callback")
135123
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
136124
ffi::Function pf = args[0].cast<ffi::Function>();
137125
*ret = ffi::TypedFunction<void()>([pf]() { pf(); });
138126
});
139127

140-
// internal function used for debug and testing purposes
141-
TVM_REGISTER_GLOBAL("testing.object_use_count")
142-
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
143-
auto obj = args[0].cast<ffi::ObjectRef>();
144-
// subtract the current one because we always copy
145-
// and get another value.
146-
*ret = (obj.use_count() - 1);
147-
});
148-
149128
void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) {
150129
if (format == "f32-to-bf16" && dtype == "float32") {
151130
std::vector<uint16_t> buffer(bytes.length() / 2);
@@ -167,10 +146,10 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format,
167146
}
168147
}
169148

170-
TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
149+
TVM_FFI_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
171150

172151
// Concatenate n TVMArrays
173-
TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat")
152+
TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat")
174153
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
175154
std::vector<Any> data;
176155
for (int i = 0; i < args.size(); ++i) {
@@ -220,7 +199,7 @@ NDArray ConcatEmbeddings(const std::vector<NDArray>& embeddings) {
220199
}
221200

222201
// Concatenate n NDArrays
223-
TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
202+
TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
224203
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
225204
std::vector<NDArray> embeddings;
226205
for (int i = 0; i < args.size(); ++i) {
@@ -230,5 +209,19 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
230209
*ret = result;
231210
});
232211

212+
TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyFromBytes")
213+
.set_body_typed([](NDArray nd, TVMFFIByteArray* bytes) {
214+
nd.CopyFromBytes(bytes->data, bytes->size);
215+
});
216+
217+
TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyToBytes")
218+
.set_body_typed([](NDArray nd) -> ffi::Bytes {
219+
size_t size = GetDataSize(*(nd.operator->()));
220+
std::string bytes;
221+
bytes.resize(size);
222+
nd.CopyToBytes(bytes.data(), size);
223+
return ffi::Bytes(bytes);
224+
});
225+
233226
} // namespace runtime
234227
} // namespace tvm

0 commit comments

Comments
 (0)