diff --git a/source/extensions/common/wasm/wasm.cc b/source/extensions/common/wasm/wasm.cc index c45cd7d6f215..aa0a475390c6 100644 --- a/source/extensions/common/wasm/wasm.cc +++ b/source/extensions/common/wasm/wasm.cc @@ -953,7 +953,7 @@ Wasm::Wasm(absl::string_view vm, absl::string_view id, absl::string_view initial } void Wasm::registerCallbacks() { -#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), "envoy", #_fn, &_fn##Handler); +#define _REGISTER(_fn) wasm_vm_->registerCallback("envoy", #_fn, &_fn##Handler); if (is_emscripten_) { _REGISTER(getTotalMemory); _REGISTER(_emscripten_get_heap_size); @@ -962,8 +962,7 @@ void Wasm::registerCallbacks() { #undef _REGISTER // Calls with the "_proxy_" prefix. -#define _REGISTER_PROXY(_fn) \ - registerCallback(wasm_vm_.get(), "envoy", "_proxy_" #_fn, &_fn##Handler); +#define _REGISTER_PROXY(_fn) wasm_vm_->registerCallback("envoy", "_proxy_" #_fn, &_fn##Handler); _REGISTER_PROXY(log); _REGISTER_PROXY(getRequestStreamInfoProtocol); @@ -1018,19 +1017,19 @@ void Wasm::registerCallbacks() { void Wasm::establishEnvironment() { if (is_emscripten_) { wasm_vm_->makeModule("global"); - emscripten_NaN_ = makeGlobal(wasm_vm_.get(), "global", "NaN", std::nan("0")); + emscripten_NaN_ = wasm_vm_->makeGlobal("global", "NaN", std::nan("0")); emscripten_Infinity_ = - makeGlobal(wasm_vm_.get(), "global", "Infinity", std::numeric_limits::infinity()); + wasm_vm_->makeGlobal("global", "Infinity", std::numeric_limits::infinity()); } } void Wasm::getFunctions() { -#define _GET(_fn) getFunction(wasm_vm_.get(), "_" #_fn, &_fn##_); +#define _GET(_fn) wasm_vm_->getFunction("_" #_fn, &_fn##_); _GET(malloc); _GET(free); #undef _GET -#define _GET_PROXY(_fn) getFunction(wasm_vm_.get(), "_proxy_" #_fn, &_fn##_); +#define _GET_PROXY(_fn) wasm_vm_->getFunction("_proxy_" #_fn, &_fn##_); _GET_PROXY(onStart); _GET_PROXY(onConfigure); _GET_PROXY(onTick); diff --git a/source/extensions/common/wasm/wasm.h b/source/extensions/common/wasm/wasm.h index 1bba65d0ba84..ed01d79c90fd 100644 --- a/source/extensions/common/wasm/wasm.h +++ b/source/extensions/common/wasm/wasm.h @@ -31,18 +31,33 @@ class WasmVm; using Pairs = std::vector>; using PairsWithStringValues = std::vector>; +// 1st arg is always a pointer to Context (Context*). using WasmCall0Void = std::function; using WasmCall1Void = std::function; -using WasmCall1Int = std::function; using WasmCall2Void = std::function; - -using WasmContextCall0Void = std::function; -using WasmContextCall7Void = std::function; - -using WasmContextCall0Int = std::function; -using WasmContextCall2Int = - std::function; +using WasmCall8Void = std::function; +using WasmCall1Int = std::function; +using WasmCall3Int = std::function; + +// 1st arg is always a context_id (uint32_t). +using WasmContextCall0Void = WasmCall1Void; +using WasmContextCall7Void = WasmCall8Void; +using WasmContextCall0Int = WasmCall1Int; +using WasmContextCall2Int = WasmCall3Int; + +// 1st arg is always a pointer to raw_context (void*). +using WasmCallback0Void = void (*)(void*); +using WasmCallback1Void = void (*)(void*, uint32_t); +using WasmCallback2Void = void (*)(void*, uint32_t, uint32_t); +using WasmCallback3Void = void (*)(void*, uint32_t, uint32_t, uint32_t); +using WasmCallback4Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t); +using WasmCallback5Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); +using WasmCallback0Int = uint32_t (*)(void*); +using WasmCallback3Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t); +using WasmCallback5Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); +using WasmCallback9Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t); // A context which will be the target of callbacks for a particular session // e.g. a handler of a stream. @@ -435,6 +450,40 @@ class WasmVm : public Logger::Loggable { // Get the contents of the user section with the given name or "" if it does not exist and // optionally a presence indicator. virtual absl::string_view getUserSection(absl::string_view name, bool* present = nullptr) PURE; + + // Get typed function exported by the WASM module. + virtual void getFunction(absl::string_view functionName, WasmCall0Void* f) PURE; + virtual void getFunction(absl::string_view functionName, WasmCall1Void* f) PURE; + virtual void getFunction(absl::string_view functionName, WasmCall2Void* f) PURE; + virtual void getFunction(absl::string_view functionName, WasmCall8Void* f) PURE; + virtual void getFunction(absl::string_view functionName, WasmCall1Int* f) PURE; + virtual void getFunction(absl::string_view functionName, WasmCall3Int* f) PURE; + + // Register typed callbacks exported by the host environment. + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback0Void f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback1Void f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback2Void f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback3Void f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback4Void f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback5Void f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback0Int f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback3Int f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback5Int f) PURE; + virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName, + WasmCallback9Int f) PURE; + + // Register typed value exported by the host environment. + virtual std::unique_ptr> + makeGlobal(absl::string_view moduleName, absl::string_view name, double initialValue) PURE; }; // Create a new low-level WASM VM of the give type (e.g. "envoy.wasm.vm.wavm"). @@ -466,46 +515,6 @@ class WasmVmException : public EnvoyException { inline Context::Context(Wasm* wasm) : wasm_(wasm), id_(wasm->allocContextId()) {} -// Forward declarations for VM implemenations. -template -void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, - R (*)(Args...)); -template -void getFunctionWavm(WasmVm* vm, absl::string_view functionName, - std::function*); - -template -std::unique_ptr> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName, - absl::string_view name, T initialValue); - -template -void registerCallback(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, - R (*f)(Args...)) { - if (vm->vm() == WasmVmNames::get().Wavm) { - registerCallbackWavm(vm, moduleName, functionName, f); - } else { - throw WasmVmException("unsupported wasm vm"); - } -} - -template void getFunction(WasmVm* vm, absl::string_view functionName, F* function) { - if (vm->vm() == WasmVmNames::get().Wavm) { - getFunctionWavm(vm, functionName, function); - } else { - throw WasmVmException("unsupported wasm vm"); - } -} - -template -std::unique_ptr> makeGlobal(WasmVm* vm, absl::string_view moduleName, - absl::string_view name, T initialValue) { - if (vm->vm() == WasmVmNames::get().Wavm) { - return makeGlobalWavm(vm, moduleName, name, initialValue); - } else { - throw WasmVmException("unsupported wasm vm"); - } -} - inline void* Wasm::allocMemory(uint32_t size, uint32_t* address) { uint32_t a = malloc_(generalContext(), size); *address = a; diff --git a/source/extensions/common/wasm/wavm/wavm.cc b/source/extensions/common/wasm/wavm/wavm.cc index 703cac28aeb5..6a733ce1113a 100644 --- a/source/extensions/common/wasm/wavm/wavm.cc +++ b/source/extensions/common/wasm/wavm/wavm.cc @@ -54,6 +54,17 @@ namespace Wasm { extern thread_local Envoy::Extensions::Common::Wasm::Context* current_context_; +// Forward declarations. +template +void getFunctionWavm(WasmVm* vm, absl::string_view functionName, + std::function* function); +template +void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + R (*)(Args...)); +template +std::unique_ptr> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view name, T initialValue); + namespace Wavm { struct Wavm; @@ -221,6 +232,40 @@ struct Wavm : public WasmVm { void getInstantiatedGlobals(); +#define _GET_FUNCTION(_type) \ + void getFunction(absl::string_view functionName, _type* f) override { \ + getFunctionWavm(this, functionName, f); \ + }; + _GET_FUNCTION(WasmCall0Void); + _GET_FUNCTION(WasmCall1Void); + _GET_FUNCTION(WasmCall2Void); + _GET_FUNCTION(WasmCall8Void); + _GET_FUNCTION(WasmCall1Int); + _GET_FUNCTION(WasmCall3Int); +#undef _GET_FUNCTION + +#define _REGISTER_CALLBACK(_type) \ + void registerCallback(absl::string_view moduleName, absl::string_view functionName, \ + _type f) override { \ + registerCallbackWavm(this, moduleName, functionName, f); \ + }; + _REGISTER_CALLBACK(WasmCallback0Void); + _REGISTER_CALLBACK(WasmCallback1Void); + _REGISTER_CALLBACK(WasmCallback2Void); + _REGISTER_CALLBACK(WasmCallback3Void); + _REGISTER_CALLBACK(WasmCallback4Void); + _REGISTER_CALLBACK(WasmCallback5Void); + _REGISTER_CALLBACK(WasmCallback0Int); + _REGISTER_CALLBACK(WasmCallback3Int); + _REGISTER_CALLBACK(WasmCallback5Int); + _REGISTER_CALLBACK(WasmCallback9Int); +#undef _REGISTER_CALLBACK + + std::unique_ptr> makeGlobal(absl::string_view moduleName, absl::string_view name, + double initialValue) override { + return makeGlobalWavm(this, moduleName, name, initialValue); + }; + bool hasInstantiatedModule_ = false; IR::Module irModule_; WAVM::Runtime::ModuleRef module_ = nullptr;