Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wasm: move implementation details into WasmVm abstraction layer. #47

Merged
merged 4 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions source/extensions/common/wasm/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ Wasm::Wasm(absl::string_view vm, absl::string_view id, absl::string_view initial
}

void Wasm::registerCallbacks() {
#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), "envoy", #_fn, &_fn##Handler);
#define _REGISTER(_fn) wasm_vm_->registerCallback("envoy", #_fn, &_fn##Handler);
if (is_emscripten_) {
_REGISTER(getTotalMemory);
_REGISTER(_emscripten_get_heap_size);
Expand All @@ -962,8 +962,7 @@ void Wasm::registerCallbacks() {
#undef _REGISTER

// Calls with the "_proxy_" prefix.
#define _REGISTER_PROXY(_fn) \
registerCallback(wasm_vm_.get(), "envoy", "_proxy_" #_fn, &_fn##Handler);
#define _REGISTER_PROXY(_fn) wasm_vm_->registerCallback("envoy", "_proxy_" #_fn, &_fn##Handler);
_REGISTER_PROXY(log);

_REGISTER_PROXY(getRequestStreamInfoProtocol);
Expand Down Expand Up @@ -1018,19 +1017,19 @@ void Wasm::registerCallbacks() {
void Wasm::establishEnvironment() {
if (is_emscripten_) {
wasm_vm_->makeModule("global");
emscripten_NaN_ = makeGlobal(wasm_vm_.get(), "global", "NaN", std::nan("0"));
emscripten_NaN_ = wasm_vm_->makeGlobal("global", "NaN", std::nan("0"));
emscripten_Infinity_ =
makeGlobal(wasm_vm_.get(), "global", "Infinity", std::numeric_limits<double>::infinity());
wasm_vm_->makeGlobal("global", "Infinity", std::numeric_limits<double>::infinity());
}
}

void Wasm::getFunctions() {
#define _GET(_fn) getFunction(wasm_vm_.get(), "_" #_fn, &_fn##_);
#define _GET(_fn) wasm_vm_->getFunction("_" #_fn, &_fn##_);
_GET(malloc);
_GET(free);
#undef _GET

#define _GET_PROXY(_fn) getFunction(wasm_vm_.get(), "_proxy_" #_fn, &_fn##_);
#define _GET_PROXY(_fn) wasm_vm_->getFunction("_proxy_" #_fn, &_fn##_);
_GET_PROXY(onStart);
_GET_PROXY(onConfigure);
_GET_PROXY(onTick);
Expand Down
107 changes: 58 additions & 49 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,33 @@ class WasmVm;
using Pairs = std::vector<std::pair<absl::string_view, absl::string_view>>;
using PairsWithStringValues = std::vector<std::pair<absl::string_view, std::string>>;

// 1st arg is always a pointer to Context (Context*).
using WasmCall0Void = std::function<void(Context*)>;
using WasmCall1Void = std::function<void(Context*, uint32_t)>;
using WasmCall1Int = std::function<uint32_t(Context*, uint32_t)>;
using WasmCall2Void = std::function<void(Context*, uint32_t, uint32_t)>;

using WasmContextCall0Void = std::function<void(Context*, uint32_t context_id)>;
using WasmContextCall7Void = std::function<void(Context*, uint32_t context_id, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t, uint32_t, uint32_t)>;

using WasmContextCall0Int = std::function<uint32_t(Context*, uint32_t context_id)>;
using WasmContextCall2Int =
std::function<uint32_t(Context*, uint32_t context_id, uint32_t, uint32_t)>;
using WasmCall8Void = std::function<void(Context*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t)>;
using WasmCall1Int = std::function<uint32_t(Context*, uint32_t)>;
using WasmCall3Int = std::function<uint32_t(Context*, uint32_t, uint32_t, uint32_t)>;

// 1st arg is always a context_id (uint32_t).
using WasmContextCall0Void = WasmCall1Void;
using WasmContextCall7Void = WasmCall8Void;
using WasmContextCall0Int = WasmCall1Int;
using WasmContextCall2Int = WasmCall3Int;

// 1st arg is always a pointer to raw_context (void*).
using WasmCallback0Void = void (*)(void*);
using WasmCallback1Void = void (*)(void*, uint32_t);
using WasmCallback2Void = void (*)(void*, uint32_t, uint32_t);
using WasmCallback3Void = void (*)(void*, uint32_t, uint32_t, uint32_t);
using WasmCallback4Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback5Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback0Int = uint32_t (*)(void*);
using WasmCallback3Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t);
using WasmCallback5Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback9Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t, uint32_t);

// A context which will be the target of callbacks for a particular session
// e.g. a handler of a stream.
Expand Down Expand Up @@ -435,6 +450,40 @@ class WasmVm : public Logger::Loggable<Logger::Id::wasm> {
// Get the contents of the user section with the given name or "" if it does not exist and
// optionally a presence indicator.
virtual absl::string_view getUserSection(absl::string_view name, bool* present = nullptr) PURE;

// Get typed function exported by the WASM module.
virtual void getFunction(absl::string_view functionName, WasmCall0Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall1Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall2Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall8Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall1Int* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall3Int* f) PURE;

// Register typed callbacks exported by the host environment.
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback1Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback2Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback4Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback9Int f) PURE;

// Register typed value exported by the host environment.
virtual std::unique_ptr<Global<double>>
makeGlobal(absl::string_view moduleName, absl::string_view name, double initialValue) PURE;
};

// Create a new low-level WASM VM of the give type (e.g. "envoy.wasm.vm.wavm").
Expand Down Expand Up @@ -466,46 +515,6 @@ class WasmVmException : public EnvoyException {

inline Context::Context(Wasm* wasm) : wasm_(wasm), id_(wasm->allocContextId()) {}

// Forward declarations for VM implemenations.
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>*);

template <typename T>
std::unique_ptr<Global<T>> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue);

template <typename R, typename... Args>
void registerCallback(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*f)(Args...)) {
if (vm->vm() == WasmVmNames::get().Wavm) {
registerCallbackWavm(vm, moduleName, functionName, f);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

template <typename F> void getFunction(WasmVm* vm, absl::string_view functionName, F* function) {
if (vm->vm() == WasmVmNames::get().Wavm) {
getFunctionWavm(vm, functionName, function);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

template <typename T>
std::unique_ptr<Global<T>> makeGlobal(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue) {
if (vm->vm() == WasmVmNames::get().Wavm) {
return makeGlobalWavm(vm, moduleName, name, initialValue);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

inline void* Wasm::allocMemory(uint32_t size, uint32_t* address) {
uint32_t a = malloc_(generalContext(), size);
*address = a;
Expand Down
76 changes: 76 additions & 0 deletions source/extensions/common/wasm/wavm/wavm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ namespace Wasm {

extern thread_local Envoy::Extensions::Common::Wasm::Context* current_context_;

// Forward declarations.
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>* function);
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename T>
std::unique_ptr<Global<T>> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue);

namespace Wavm {

struct Wavm;
Expand Down Expand Up @@ -221,6 +232,71 @@ struct Wavm : public WasmVm {

void getInstantiatedGlobals();

void getFunction(absl::string_view functionName, WasmCall0Void* f) override {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could wrap this repetition up in a macro?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

getFunctionWavm(this, functionName, f);
};
void getFunction(absl::string_view functionName, WasmCall1Void* f) override {
getFunctionWavm(this, functionName, f);
};
void getFunction(absl::string_view functionName, WasmCall2Void* f) override {
getFunctionWavm(this, functionName, f);
};
void getFunction(absl::string_view functionName, WasmCall8Void* f) override {
getFunctionWavm(this, functionName, f);
};
void getFunction(absl::string_view functionName, WasmCall1Int* f) override {
getFunctionWavm(this, functionName, f);
};
void getFunction(absl::string_view functionName, WasmCall3Int* f) override {
getFunctionWavm(this, functionName, f);
};

void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Void f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback1Void f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback2Void f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Void f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback4Void f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Void f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Int f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Int f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Int f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};
void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback9Int f) override {
registerCallbackWavm(this, moduleName, functionName, f);
};

std::unique_ptr<Global<double>> makeGlobal(absl::string_view moduleName, absl::string_view name,
double initialValue) override {
return makeGlobalWavm(this, moduleName, name, initialValue);
};

bool hasInstantiatedModule_ = false;
IR::Module irModule_;
WAVM::Runtime::ModuleRef module_ = nullptr;
Expand Down