Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,10 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType*
* \return 0 when success, nonzero when failure happens
* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree.
The content of string can be accessed via TVMFFIObjectGetByteArrayPtr.

* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues.
*/
TVM_FFI_DLL int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out);
TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out);

//------------------------------------------------------------
// Section: Backend noexcept functions for internal use
Expand Down
2 changes: 1 addition & 1 deletion ffi/include/tvm/ffi/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ inline DLDataType StringToDLDataType(const String& str) {

inline String DLDataTypeToString(DLDataType dtype) {
TVMFFIObjectHandle out;
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(dtype, &out));
TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out));
return String(details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(out)));
}

Expand Down
10 changes: 8 additions & 2 deletions ffi/include/tvm/ffi/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1
#endif

#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0
#endif

namespace tvm {
namespace ffi {

Expand Down Expand Up @@ -212,8 +216,10 @@ class ErrorBuilder {
*
* \endcode
*/
#define TVM_FFI_THROW(ErrorKind) \
::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, false).stream()
#define TVM_FFI_THROW(ErrorKind) \
::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, \
TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \
.stream()

/*!
* \brief Explicitly log error in stderr and then throw the error.
Expand Down
4 changes: 2 additions & 2 deletions ffi/src/ffi/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) {
TVM_FFI_SAFE_CALL_END();
}

int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) {
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(dtype));
tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype));
*out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str));
TVM_FFI_SAFE_CALL_END();
}
2 changes: 1 addition & 1 deletion python/tvm/ffi/cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ cdef extern from "tvm/ffi/c_api.h":
int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil
int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil
int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil
int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) nogil
int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) nogil
const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) nogil;
int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t require_alignment,
int32_t require_contiguous, TVMFFIObjectHandle* out) nogil
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ffi/cython/dtype.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ cdef class DataType:
def __str__(self):
cdef TVMFFIObjectHandle dtype_str
cdef TVMFFIByteArray* bytes
CHECK_CALL(TVMFFIDataTypeToString(self.cdtype, &dtype_str))
CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str))
bytes = TVMFFIBytesGetByteArrayPtr(dtype_str)
res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size))
CHECK_CALL(TVMFFIObjectFree(dtype_str))
Expand Down
2 changes: 2 additions & 0 deletions web/.eslintignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
dist
debug
tvmjs_runtime_wasi.js
lib
2 changes: 1 addition & 1 deletion web/apps/node/example.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm"));
tvmjs.instantiate(wasmSource, tvmjs.createPolyfillWASI())
.then((tvm) => {
tvm.beginScope();
const log_info = tvm.getGlobalFunc("testing.log_info_str");
const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str");
log_info("hello world");
// List all the global functions from the runtime.
console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames());
Expand Down
46 changes: 28 additions & 18 deletions web/emcc/tvmjs_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
#define TVM_LOG_STACK_TRACE 0
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ffi/function.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "../../src/runtime/rpc/rpc_local_session.h"

Expand All @@ -59,27 +58,33 @@ TVM_DLL void TVMWasmFreeSpace(void* data);
* \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer
3A * \return 0 if success.
*/
TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out);
TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out);

/*!
* \brief Get the last error message.
* \return The last error message.
*/
TVM_DLL const char* TVMFFIWasmGetLastError();

// --- APIs to be implemented by the frontend. ---

/*!
* \brief Wasm frontend packed function caller.
* \brief Wasm frontend new ffi call function caller.
*
* \param self The pointer to the ffi::Function.
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resource handle from front-end.
* \param result The return value handle.
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
*/
extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret,
void* resource_handle);

extern int TVMFFIWasmSafeCall(void* self, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* result);
/*!
* \brief Wasm frontend resource finalizer.
* \param resource_handle The pointer to the external resource.
* \brief Delete ffi::Function.
* \param self The pointer to the ffi::Function.
*/
extern void TVMWasmPackedCFuncFinalizer(void* resource_handle);
extern void TVMFFIWasmFunctionDeleter(void* self);

} // extern "C"

void* TVMWasmAllocSpace(int size) {
Expand All @@ -89,9 +94,14 @@ void* TVMWasmAllocSpace(int size) {

void TVMWasmFreeSpace(void* arr) { delete[] static_cast<int64_t*>(arr); }

int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) {
return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer,
out);
int TVMFFIWasmFunctionCreate(void* self, TVMFunctionHandle* out) {
return TVMFFIFunctionCreate(self, TVMFFIWasmSafeCall, TVMFFIWasmFunctionDeleter, out);
}

const char* TVMFFIWasmGetLastError() {
static thread_local std::string last_error;
last_error = ::tvm::ffi::details::MoveFromSafeCallRaised().what();
return last_error.c_str();
}

namespace tvm {
Expand Down Expand Up @@ -291,7 +301,7 @@ class AsyncLocalSession : public LocalSession {
}
};

TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
TVM_FFI_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() {
return CreateRPCSessionModule(std::make_shared<AsyncLocalSession>());
});

Expand Down
55 changes: 24 additions & 31 deletions web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
#define TVM_FFI_USE_LIBBACKTRACE 0
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>

#include "src/runtime/c_runtime_api.cc"
Expand Down Expand Up @@ -107,45 +107,24 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s

} // namespace detail

TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
*ret = args[0];
});

TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
(args[0].cast<ffi::Function>()).CallPacked(args.Slice(1), ret);
});

TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
*ret = args[0].cast<String>();
});

TVM_REGISTER_GLOBAL("testing.log_info_str")
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.call")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
LOG(INFO) << args[0].cast<String>();
(args[0].cast<ffi::Function>()).CallPacked(args.Slice(1), ret);
});

TVM_REGISTER_GLOBAL("testing.log_fatal_str")
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.log_info_str")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
LOG(FATAL) << args[0].cast<String>();
LOG(INFO) << args[0].cast<String>();
});

TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; });
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.add_one").set_body_typed([](int x) { return x + 1; });

TVM_REGISTER_GLOBAL("testing.wrap_callback")
TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.wrap_callback")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
ffi::Function pf = args[0].cast<ffi::Function>();
*ret = ffi::TypedFunction<void()>([pf]() { pf(); });
});

// internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("testing.object_use_count")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
auto obj = args[0].cast<ffi::ObjectRef>();
// subtract the current one because we always copy
// and get another value.
*ret = (obj.use_count() - 1);
});

void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) {
if (format == "f32-to-bf16" && dtype == "float32") {
std::vector<uint16_t> buffer(bytes.length() / 2);
Expand All @@ -167,10 +146,10 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format,
}
}

TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);
TVM_FFI_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);

// Concatenate n TVMArrays
TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat")
TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
std::vector<Any> data;
for (int i = 0; i < args.size(); ++i) {
Expand Down Expand Up @@ -220,7 +199,7 @@ NDArray ConcatEmbeddings(const std::vector<NDArray>& embeddings) {
}

// Concatenate n NDArrays
TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
std::vector<NDArray> embeddings;
for (int i = 0; i < args.size(); ++i) {
Expand All @@ -230,5 +209,19 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings")
*ret = result;
});

TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyFromBytes")
.set_body_typed([](NDArray nd, TVMFFIByteArray* bytes) {
nd.CopyFromBytes(bytes->data, bytes->size);
});

TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyToBytes")
.set_body_typed([](NDArray nd) -> ffi::Bytes {
size_t size = GetDataSize(*(nd.operator->()));
std::string bytes;
bytes.resize(size);
nd.CopyToBytes(bytes.data(), size);
return ffi::Bytes(bytes);
});

} // namespace runtime
} // namespace tvm
22 changes: 12 additions & 10 deletions web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@
#define TVM_LOG_STACK_TRACE 0
#define TVM_LOG_DEBUG 0
#define TVM_LOG_CUSTOMIZE 1
#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1
#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>

#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/ffi/function.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <iostream>
#include <string>
Expand Down Expand Up @@ -152,7 +150,10 @@ typedef dmlc::ThreadLocalStore<WebGPUThreadEntry> WebGPUThreadStore;
WebGPUThreadEntry::WebGPUThreadEntry()
: pool(static_cast<DLDeviceType>(kDLWebGPU), WebGPUDeviceAPI::Global()) {}

WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); }
WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() {
static thread_local WebGPUThreadEntry inst = WebGPUThreadEntry();
return &inst;
}

class WebGPUModuleNode final : public runtime::ModuleNode {
public:
Expand Down Expand Up @@ -241,12 +242,13 @@ Module WebGPUModuleLoadBinary(void* strm) {
}

// for now webgpu is hosted via a vulkan module.
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary);
TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary);

TVM_REGISTER_GLOBAL("device_api.webgpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
DeviceAPI* ptr = WebGPUDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
TVM_FFI_REGISTER_GLOBAL("device_api.webgpu")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) {
DeviceAPI* ptr = WebGPUDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

} // namespace runtime
} // namespace tvm
4 changes: 4 additions & 0 deletions web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,9 @@
"typedoc-plugin-missing-exports": "2.0.0",
"typescript": "^4.9.5",
"ws": "^7.2.5"
},
"dependencies": {
"audit": "^0.0.6",
"fix": "^0.0.6"
}
}
9 changes: 9 additions & 0 deletions web/src/asyncify.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ export class AsyncifyHandler {
return this.exports.asyncify_stop_rewind !== undefined;
}

/**
* Get the current asynctify state
*
* @returns The current asynctify state
*/
isNormalStackState(): boolean {
return this.state == AsyncifyStateKind.None;
}

/**
* Get the current asynctify state
*
Expand Down
Loading
Loading