diff --git a/WORKSPACE b/WORKSPACE index 390535bf..8fb4ce21 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,7 +5,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") git_repository( name = "proxy_wasm_cpp_sdk", - commit = "f44562520bca7bfeee77d6284a96d2900f2f13ac", + commit = "35163bbf32fccfbde7b95d909a392dc1dc562596", remote = "https://github.com/proxy-wasm/proxy-wasm-cpp-sdk", ) diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index 03a15bec..1a495d00 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -63,6 +63,46 @@ struct PluginBase { std::string log_prefix_; }; +struct BufferBase : public BufferInterface { + BufferBase() = default; + ~BufferBase() override = default; + + // BufferInterface + size_t size() const override { + if (owned_data_) { + return owned_data_size_; + } + return data_.size(); + } + WasmResult copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr, + uint64_t size_ptr) const override; + WasmResult copyFrom(size_t /* start */, size_t /* length */, string_view /* data */) override { + // Setting a string buffer not supported (no use case). + return WasmResult::BadArgument; + } + + virtual void clear() { + data_ = ""; + owned_data_ = nullptr; + } + BufferBase *set(string_view data) { + clear(); + data_ = data; + return this; + } + BufferBase *set(std::unique_ptr owned_data, uint32_t owned_data_size) { + clear(); + owned_data_ = std::move(owned_data); + owned_data_size_ = owned_data_size; + return this; + } + +protected: + string_view data_; + std::unique_ptr owned_data_; + uint32_t owned_data_size_; +}; + /** * ContextBase is the interface between the VM host and the VM. It has several uses: * @@ -94,7 +134,7 @@ class ContextBase : public RootInterface, ContextBase(); // Testing. ContextBase(WasmBase *wasm); // Vm Context. ContextBase(WasmBase *wasm, std::shared_ptr plugin); // Root Context. - ContextBase(WasmBase *wasm, uint32_t root_context_id, + ContextBase(WasmBase *wasm, uint32_t parent_context_id, std::shared_ptr plugin); // Stream context. virtual ~ContextBase(); @@ -103,8 +143,17 @@ class ContextBase : public RootInterface, // The VM Context used for calling "malloc" has an id_ == 0. bool isVmContext() const { return id_ == 0; } // Root Contexts have the VM Context as a parent. - bool isRootContext() const { return root_context_id_ == 0; } - ContextBase *root_context() const { return root_context_; } + bool isRootContext() const { return parent_context_id_ == 0; } + ContextBase *parent_context() const { return parent_context_; } + ContextBase *root_context() const { + const ContextBase *previous = this; + ContextBase *parent = parent_context_; + while (parent != previous) { + previous = parent; + parent = parent->parent_context_; + } + return parent; + } string_view root_id() const { return isRootContext() ? root_id_ : plugin_->root_id_; } string_view log_prefix() const { return isRootContext() ? root_log_prefix_ : plugin_->log_prefix(); @@ -121,10 +170,11 @@ class ContextBase : public RootInterface, */ // Context - void onCreate(uint32_t parent_context_id) override; + void onCreate() override; bool onDone() override; void onLog() override; void onDelete() override; + void onForeignFunction(uint32_t foreign_function_id, uint32_t data_size) override; // Root bool onStart(std::shared_ptr plugin) override; @@ -173,10 +223,6 @@ class ContextBase : public RootInterface, WasmResult log(uint32_t /* level */, string_view /* message */) override { return unimplemented(); } - WasmResult setTimerPeriod(std::chrono::milliseconds /* period */, - uint32_t * /* timer_token_ptr */) override { - return unimplemented(); - } uint64_t getCurrentTimeNanoseconds() override { struct timespec tpe; clock_gettime(CLOCK_REALTIME, &tpe); @@ -189,6 +235,7 @@ class ContextBase : public RootInterface, unimplemented(); return std::make_pair(1, "unimplmemented"); } + WasmResult setTimerPeriod(std::chrono::milliseconds period, uint32_t *timer_token_ptr) override; // Buffer BufferInterface *getBuffer(WasmBufferType /* type */) override { @@ -316,15 +363,24 @@ class ContextBase : public RootInterface, WasmBase *wasm_{nullptr}; uint32_t id_{0}; - uint32_t root_context_id_{0}; // 0 for roots and the general context. - ContextBase *root_context_{nullptr}; // set in all contexts. - std::string root_id_; // set only in root context. - std::string root_log_prefix_; // set only in root context. + uint32_t parent_context_id_{0}; // 0 for roots and the general context. + ContextBase *parent_context_{nullptr}; // set in all contexts. + std::string root_id_; // set only in root context. + std::string root_log_prefix_; // set only in root context. std::shared_ptr plugin_; bool in_vm_context_created_ = false; bool destroyed_ = false; }; +class DeferAfterCallActions { +public: + DeferAfterCallActions(ContextBase *context) : wasm_(context->wasm()) {} + ~DeferAfterCallActions(); + +private: + WasmBase *const wasm_; +}; + uint32_t resolveQueueForTest(string_view vm_id, string_view queue_name); } // namespace proxy_wasm diff --git a/include/proxy-wasm/context_interface.h b/include/proxy-wasm/context_interface.h index 0c8f2c06..3a92a230 100644 --- a/include/proxy-wasm/context_interface.h +++ b/include/proxy-wasm/context_interface.h @@ -133,10 +133,9 @@ struct RootInterface : public RootGrpcInterface { /** * Call on a host Context to create a corresponding Context in the VM. Note: * onNetworkNewConnection and onRequestHeaders() call onCreate(). - * @param parent_context_id is the parent Context id for the context being created. For a * stream Context this will be a Root Context id (or sub-Context thereof). */ - virtual void onCreate(uint32_t parent_context_id) = 0; + virtual void onCreate() = 0; /** * Call on a Root Context when a VM first starts up. @@ -564,6 +563,15 @@ struct GeneralInterface { * serialized.. */ virtual WasmResult setProperty(string_view key, string_view value) = 0; + + /** + * Custom extension call into the VM. Data is provided as WasmBufferType::CallData. + * @param foreign_function_id a unique identifier for the calling foreign function. These are + * defined and allocated by the foreign function implementor. + * @param data_size is the size of the WasmBufferType::CallData buffer containing data for this + * foreign function call. + */ + virtual void onForeignFunction(uint32_t foreign_function_id, uint32_t data_size) = 0; }; /** diff --git a/include/proxy-wasm/exports.h b/include/proxy-wasm/exports.h index e9470720..3b8d6796 100644 --- a/include/proxy-wasm/exports.h +++ b/include/proxy-wasm/exports.h @@ -20,6 +20,11 @@ #include "include/proxy-wasm/word.h" namespace proxy_wasm { + +class ContextBase; + +extern thread_local ContextBase *current_context_; + namespace exports { // ABI functions exported from envoy to wasm. @@ -107,5 +112,10 @@ void wasi_unstable_proc_exit(void *, Word); void wasi_unstable_proc_exit(void *, Word); Word pthread_equal(void *, Word left, Word right); +// Support for embedders, not exported to Wasm. + +// Any currently executing Wasm call context. +::proxy_wasm::ContextBase *ContextOrEffectiveContext(::proxy_wasm::ContextBase *context); + } // namespace exports } // namespace proxy_wasm diff --git a/include/proxy-wasm/null_plugin.h b/include/proxy-wasm/null_plugin.h index 1f3ff733..6f12323c 100644 --- a/include/proxy-wasm/null_plugin.h +++ b/include/proxy-wasm/null_plugin.h @@ -20,9 +20,12 @@ #include "google/protobuf/message.h" #include "include/proxy-wasm/null_vm_plugin.h" #include "include/proxy-wasm/wasm.h" +#include "include/proxy-wasm/exports.h" namespace proxy_wasm { namespace null_plugin { +template using Optional = optional; +using StringView = string_view; #include "proxy_wasm_enums.h" } // namespace null_plugin } // namespace proxy_wasm @@ -30,11 +33,6 @@ namespace null_plugin { #include "include/proxy-wasm/wasm_api_impl.h" namespace proxy_wasm { -namespace null_plugin { -using StringView = string_view; -template using Optional = optional; -#include "proxy_wasm_api.h" -} // namespace null_plugin /** * Registry for Plugin implementation. @@ -48,6 +46,8 @@ struct NullPluginRegistry { uint32_t (*proxy_on_configure_)(uint32_t root_context_id, uint32_t plugin_configuration_size) = nullptr; void (*proxy_on_tick_)(uint32_t context_id) = nullptr; + void (*proxy_on_foreign_function_)(uint32_t context_id, uint32_t token, + uint32_t data_size) = nullptr; uint32_t (*proxy_on_done_)(uint32_t context_id) = nullptr; void (*proxy_on_delete_)(uint32_t context_id) = nullptr; std::unordered_map root_factories; @@ -74,6 +74,8 @@ class NullPlugin : public NullVmPlugin { bool onConfigure(uint64_t root_context_id, uint64_t plugin_configuration_size); void onTick(uint64_t root_context_id); void onQueueReady(uint64_t root_context_id, uint64_t token); + void onForeignFunction(uint64_t root_context_id, uint64_t foreign_function_id, + uint64_t data_size); void onCreate(uint64_t context_id, uint64_t root_context_id); @@ -110,12 +112,12 @@ class NullPlugin : public NullVmPlugin { void error(string_view message) { wasm_vm_->error(message); } -private: null_plugin::Context *ensureContext(uint64_t context_id, uint64_t root_context_id); null_plugin::RootContext *ensureRootContext(uint64_t context_id); null_plugin::RootContext *getRootContext(uint64_t context_id); null_plugin::ContextBase *getContextBase(uint64_t context_id); +private: NullPluginRegistry *registry_{}; std::unordered_map root_context_map_; std::unordered_map> context_map_; diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index 0bede843..37dfebfa 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -76,6 +76,17 @@ class WasmBase : public std::enable_shared_from_this { const std::string &vm_configuration() const; bool allow_precompiled() const { return allow_precompiled_; } + void timerReady(uint32_t root_context_id); + void queueReady(uint32_t root_context_id, uint32_t token); + + void startShutdown(); + WasmResult done(ContextBase *root_context); + void finishShutdown(); + + // Proxy specific extension points. + // + virtual void registerCallbacks(); // Register functions called out from Wasm. + virtual void getFunctions(); // Get functions call into Wasm. virtual CallOnThreadFunction callOnThreadFunction() { unimplemented(); return nullptr; @@ -86,18 +97,17 @@ class WasmBase : public std::enable_shared_from_this { return new ContextBase(this, plugin); return new ContextBase(this); } - - virtual void setTickPeriod(uint32_t root_context_id, std::chrono::milliseconds tick_period) { - tick_period_[root_context_id] = tick_period; + virtual void setTimerPeriod(uint32_t root_context_id, std::chrono::milliseconds period) { + timer_period_[root_context_id] = period; } - void tick(uint32_t root_context_id); - void queueReady(uint32_t root_context_id, uint32_t token); - - void startShutdown(); - WasmResult done(ContextBase *root_context); - void finishShutdown(); + virtual void error(string_view message) { + std::cerr << message << "\n"; + abort(); + } + virtual void unimplemented() { error("unimplemented proxy-wasm API"); } // Support functions. + // void *allocMemory(uint64_t size, uint64_t *address); // Allocate a null-terminated string in the VM and return the pointer to use as a call arguments. uint64_t copyString(string_view s); @@ -107,14 +117,9 @@ class WasmBase : public std::enable_shared_from_this { WasmForeignFunction getForeignFunction(string_view function_name); - virtual void error(string_view message) { - std::cerr << message << "\n"; - abort(); - } - virtual void unimplemented() { error("unimplemented proxy-wasm API"); } - // For testing. - void setContext(ContextBase *context) { contexts_[context->id()] = context; } + // + void setContextForTesting(ContextBase *context) { contexts_[context->id()] = context; } // Returns false if onStart returns false. bool startForTesting(std::unique_ptr root_context, std::shared_ptr plugin); @@ -146,8 +151,6 @@ class WasmBase : public std::enable_shared_from_this { } } - // These are the same as the values of the MetricType enum, here separately for - // convenience. static const uint32_t kMetricTypeMask = 0x3; // Enough to cover the 3 types. static const uint32_t kMetricIdIncrement = 0x4; // Enough to cover the 3 types. bool isCounterMetricId(uint32_t metric_id) { @@ -166,9 +169,8 @@ class WasmBase : public std::enable_shared_from_this { protected: friend class ContextBase; class ShutdownHandle; - void registerCallbacks(); // Register functions called out from WASM. + void establishEnvironment(); // Language specific environments. - void getFunctions(); // Get functions call into WASM. std::string vm_id_; // User-provided vm_id. std::string vm_key_; // vm_id + hash of code. @@ -179,8 +181,8 @@ class WasmBase : public std::enable_shared_from_this { std::shared_ptr vm_context_; // Context unrelated to any specific root or stream // (e.g. for global constructors). std::unordered_map> root_contexts_; - std::unordered_map contexts_; // Contains all contexts. - std::unordered_map tick_period_; // per root_id. + std::unordered_map contexts_; // Contains all contexts. + std::unordered_map timer_period_; // per root_id. std::unique_ptr shutdown_handle_; std::unordered_set pending_done_; // Root contexts not done during shutdown. @@ -224,6 +226,7 @@ class WasmBase : public std::enable_shared_from_this { WasmCallVoid<3> on_grpc_receive_trailing_metadata_; WasmCallVoid<2> on_queue_ready_; + WasmCallVoid<3> on_foreign_function_; WasmCallWord<1> on_done_; WasmCallVoid<1> on_log_; diff --git a/include/proxy-wasm/wasm_api_impl.h b/include/proxy-wasm/wasm_api_impl.h index d0f4c97e..ab148609 100644 --- a/include/proxy-wasm/wasm_api_impl.h +++ b/include/proxy-wasm/wasm_api_impl.h @@ -18,14 +18,6 @@ #include "include/proxy-wasm/compat.h" namespace proxy_wasm { -namespace null_plugin { -class RootContext; -class Context; -} // namespace null_plugin - -null_plugin::RootContext *nullVmGetRoot(string_view root_id); -null_plugin::Context *nullVmGetContext(uint32_t context_id); - namespace null_plugin { #define WS(_x) Word(static_cast(_x)) @@ -264,8 +256,10 @@ inline WasmResult proxy_call_foreign_function(const char *function_name, size_t #undef WS #undef WR -inline RootContext *getRoot(string_view root_id) { return nullVmGetRoot(root_id); } -inline Context *getContext(uint32_t context_id) { return nullVmGetContext(context_id); } +#include "proxy_wasm_api.h" + +RootContext *getRoot(string_view root_id); +Context *getContext(uint32_t context_id); } // namespace null_plugin } // namespace proxy_wasm diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 38158092..75913206 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -102,11 +102,26 @@ enum class Cloneable { InstantiatedModule // VMs can be cloned from an instantiated module. }; +class NullPlugin; + // Integrator specific WasmVm operations. struct WasmVmIntegration { virtual ~WasmVmIntegration() {} virtual WasmVmIntegration *clone() = 0; virtual void error(string_view message) = 0; + // Get a NullVm implementation of a function. + // @param function_name is the name of the function with the implementation specific prefix. + // @param returns_word is true if the function returns a Word and false if it returns void. + // @param number_of_arguments is the number of Word arguments to the function. + // @param plugin is the Null VM plugin on which the function will be called. + // @param ptr_to_function_return is the location to write the function e.g. of type + // WasmCallWord<3>. + // @return true if the function was found. ptr_to_function_return could still be set to nullptr + // (of the correct type) if the function has no implementation. Returning true will prevent a + // "Missing getFunction" error. + virtual bool getNullVmFunction(string_view function_name, bool returns_word, + int number_of_arguments, NullPlugin *plugin, + void *ptr_to_function_return) = 0; }; // Wasm VM instance. Provides the low level WASM interface. diff --git a/include/proxy-wasm/word.h b/include/proxy-wasm/word.h index 1e4ce673..e96fdfb9 100644 --- a/include/proxy-wasm/word.h +++ b/include/proxy-wasm/word.h @@ -17,10 +17,10 @@ #include -#include "proxy_wasm_common.h" - namespace proxy_wasm { +#include "proxy_wasm_common.h" + // Represents a Wasm-native word-sized datum. On 32-bit VMs, the high bits are always zero. // The Wasm/VM API treats all bits as significant. struct Word { diff --git a/src/context.cc b/src/context.cc index df268866..2814188f 100644 --- a/src/context.cc +++ b/src/context.cc @@ -27,15 +27,6 @@ namespace proxy_wasm { namespace { -class DeferAfterCallActions { -public: - DeferAfterCallActions(ContextBase *context) : wasm_(context->wasm()) {} - ~DeferAfterCallActions() { wasm_->doAfterVmCallActions(); } - -private: - WasmBase *const wasm_; -}; - using CallOnThreadFunction = std::function)>; class SharedData { @@ -183,6 +174,24 @@ SharedData global_shared_data; } // namespace +DeferAfterCallActions::~DeferAfterCallActions() { wasm_->doAfterVmCallActions(); } + +WasmResult BufferBase::copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr, + uint64_t size_ptr) const { + if (owned_data_) { + string_view s(owned_data_.get() + start, length); + if (!wasm->copyToPointerSize(s, ptr_ptr, size_ptr)) { + return WasmResult::InvalidMemoryAccess; + } + return WasmResult::Ok; + } + string_view s = data_.substr(start, length); + if (!wasm->copyToPointerSize(s, ptr_ptr, size_ptr)) { + return WasmResult::InvalidMemoryAccess; + } + return WasmResult::Ok; +} + // Test support. uint32_t resolveQueueForTest(string_view vm_id, string_view queue_name) { @@ -203,9 +212,9 @@ std::string PluginBase::makeLogPrefix() const { return prefix; } -ContextBase::ContextBase() : root_context_(this) {} +ContextBase::ContextBase() : parent_context_(this) {} -ContextBase::ContextBase(WasmBase *wasm) : wasm_(wasm), root_context_(this) { +ContextBase::ContextBase(WasmBase *wasm) : wasm_(wasm), parent_context_(this) { wasm_->contexts_[id_] = this; } @@ -213,11 +222,12 @@ ContextBase::ContextBase(WasmBase *wasm, std::shared_ptr plugin) { initializeRootBase(wasm, plugin); } -ContextBase::ContextBase(WasmBase *wasm, uint32_t root_context_id, +ContextBase::ContextBase(WasmBase *wasm, uint32_t parent_context_id, std::shared_ptr plugin) - : wasm_(wasm), id_(wasm->allocContextId()), root_context_id_(root_context_id), plugin_(plugin) { + : wasm_(wasm), id_(wasm->allocContextId()), parent_context_id_(parent_context_id), + plugin_(plugin) { wasm_->contexts_[id_] = this; - root_context_ = wasm_->contexts_[root_context_id_]; + parent_context_ = wasm_->contexts_[parent_context_id_]; } WasmVm *ContextBase::wasmVm() const { return wasm_->wasm_vm(); } @@ -227,7 +237,7 @@ void ContextBase::initializeRootBase(WasmBase *wasm, std::shared_ptr id_ = wasm->allocContextId(); root_id_ = plugin->root_id_; root_log_prefix_ = makeRootLogPrefix(plugin->vm_id_); - root_context_ = this; + parent_context_ = this; wasm_->contexts_[id_] = this; } @@ -277,10 +287,10 @@ bool ContextBase::onConfigure(std::shared_ptr plugin) { return result; } -void ContextBase::onCreate(uint32_t parent_context_id) { +void ContextBase::onCreate() { if (!in_vm_context_created_ && wasm_->on_context_create_) { DeferAfterCallActions actions(this); - wasm_->on_context_create_(this, id_, parent_context_id); + wasm_->on_context_create_(this, id_, parent_context_ ? parent_context()->id() : 0); in_vm_context_created_ = true; } // NB: If no on_context_create function is registered the in-VM SDK is responsible for @@ -304,7 +314,7 @@ WasmResult ContextBase::registerSharedQueue(string_view queue_name, // Get the id of the root context if this is a stream context because onQueueReady is on the // root. *result = global_shared_data.registerQueue(wasm_->vm_id(), queue_name, - isRootContext() ? id_ : root_context_id_, + isRootContext() ? id_ : parent_context_id_, wasm_->callOnThreadFunction()); return WasmResult::Ok; } @@ -341,6 +351,13 @@ void ContextBase::onTick(uint32_t) { } } +void ContextBase::onForeignFunction(uint32_t foreign_function_id, uint32_t data_size) { + if (wasm_->on_foreign_function_) { + DeferAfterCallActions actions(this); + wasm_->on_foreign_function_(this, id_, foreign_function_id, data_size); + } +} + FilterStatus ContextBase::onNetworkNewConnection() { if (!wasm_->on_new_connection_) { return FilterStatus::Continue; @@ -557,9 +574,16 @@ void ContextBase::onDelete() { } } +WasmResult ContextBase::setTimerPeriod(std::chrono::milliseconds period, + uint32_t *timer_token_ptr) { + wasm()->setTimerPeriod(root_context()->id(), period); + *timer_token_ptr = 0; + return WasmResult::Ok; +} + ContextBase::~ContextBase() { // Do not remove vm or root contexts which have the same lifetime as wasm_. - if (root_context_id_) { + if (parent_context_id_) { wasm_->contexts_.erase(id_); } } diff --git a/src/exports.cc b/src/exports.cc index f6b084b3..95eeddf2 100644 --- a/src/exports.cc +++ b/src/exports.cc @@ -15,20 +15,17 @@ // #include "include/proxy-wasm/wasm.h" +#define WASM_CONTEXT(_c) \ + (ContextOrEffectiveContext(static_cast((void)_c, current_context_))) + namespace proxy_wasm { +// The id of the context which should be used for calls out of the VM in place +// of current_context_. extern thread_local uint32_t effective_context_id_; namespace exports { -// Any currently executing Wasm call context. -#define WASM_CONTEXT(_c) \ - (ContextOrEffectiveContext(static_cast((void)_c, current_context_))) -// The id of the context which should be used for calls out of the VM in place -// of current_context_ above. - -namespace { - ContextBase *ContextOrEffectiveContext(ContextBase *context) { if (effective_context_id_ == 0) { return context; @@ -41,6 +38,8 @@ ContextBase *ContextOrEffectiveContext(ContextBase *context) { return context; } +namespace { + Pairs toPairs(string_view buffer) { Pairs result; const char *b = buffer.data(); @@ -811,10 +810,10 @@ void wasi_unstable_proc_exit(void *raw_context, Word) { Word pthread_equal(void *, Word left, Word right) { return left == right; } -Word set_tick_period_milliseconds(void *raw_context, Word tick_period_milliseconds) { +Word set_tick_period_milliseconds(void *raw_context, Word period_milliseconds) { TimerToken token = 0; return WASM_CONTEXT(raw_context) - ->setTimerPeriod(std::chrono::milliseconds(tick_period_milliseconds), &token); + ->setTimerPeriod(std::chrono::milliseconds(period_milliseconds), &token); } Word get_current_time_nanoseconds(void *raw_context, Word result_uint64_ptr) { diff --git a/src/null/null_plugin.cc b/src/null/null_plugin.cc index 871727c9..d96c62ef 100644 --- a/src/null/null_plugin.cc +++ b/src/null/null_plugin.cc @@ -13,6 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "include/proxy-wasm/null_plugin.h" + +#include #include #include @@ -36,7 +39,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<0> *f) { *f = nullptr; } else if (function_name == "__wasm_call_ctors") { *f = nullptr; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, false, 0, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -59,7 +62,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<1> *f) { SaveRestoreContext saved_context(context); plugin->onDelete(context_id); }; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, false, 1, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -87,7 +90,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<2> *f) { SaveRestoreContext saved_context(context); plugin->onQueueReady(context_id, token); }; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, false, 2, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -115,7 +118,12 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<3> *f) { SaveRestoreContext saved_context(context); plugin->onGrpcReceiveTrailingMetadata(context_id, token, trailers); }; - } else { + } else if (function_name == "proxy_on_foreign_function") { + *f = [plugin](ContextBase *context, Word context_id, Word foreign_function_id, Word data_size) { + SaveRestoreContext saved_context(context); + plugin->onForeignFunction(context_id, foreign_function_id, data_size); + }; + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, false, 3, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -129,7 +137,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<5> *f) { SaveRestoreContext saved_context(context); plugin->onHttpCallResponse(context_id, token, headers, body_size, trailers); }; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, false, 5, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -151,7 +159,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<1> *f) { SaveRestoreContext saved_context(context); return Word(plugin->onDone(context_id)); }; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, true, 1, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -194,7 +202,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<2> *f) { SaveRestoreContext saved_context(context); return Word(plugin->onResponseMetadata(context_id, elements)); }; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, true, 2, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -236,7 +244,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<3> *f) { SaveRestoreContext saved_context(context); return Word(plugin->onResponseBody(context_id, body_buffer_length, end_of_stream)); }; - } else { + } else if (!wasm_vm_->integration()->getNullVmFunction(function_name, true, 3, this, f)) { error("Missing getFunction for: " + std::string(function_name)); *f = nullptr; } @@ -452,6 +460,14 @@ void NullPlugin::onQueueReady(uint64_t context_id, uint64_t token) { getRootContext(context_id)->onQueueReady(token); } +void NullPlugin::onForeignFunction(uint64_t context_id, uint64_t foreign_function_id, + uint64_t data_size) { + if (registry_->proxy_on_foreign_function_) { + return registry_->proxy_on_foreign_function_(context_id, foreign_function_id, data_size); + } + getContextBase(context_id)->onForeignFunction(foreign_function_id, data_size); +} + void NullPlugin::onLog(uint64_t context_id) { getContext(context_id)->onLog(); } uint64_t NullPlugin::onDone(uint64_t context_id) { @@ -480,4 +496,8 @@ null_plugin::Context *nullVmGetContext(uint32_t context_id) { return static_cast(null_vm->plugin_.get())->getContext(context_id); } +null_plugin::RootContext *getRoot(string_view root_id) { return nullVmGetRoot(root_id); } + +null_plugin::Context *getContext(uint32_t context_id) { return nullVmGetContext(context_id); } + } // namespace proxy_wasm diff --git a/src/wasm.cc b/src/wasm.cc index 731bc434..86f73691 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -205,6 +205,7 @@ void WasmBase::getFunctions() { _GET_PROXY(on_vm_start); _GET_PROXY(on_configure); _GET_PROXY(on_tick); + _GET_PROXY(on_foreign_function); _GET_PROXY(on_context_create); @@ -378,7 +379,7 @@ uint32_t WasmBase::allocContextId() { } } -void WasmBase::tick(uint32_t root_context_id) { +void WasmBase::timerReady(uint32_t root_context_id) { if (on_tick_) { auto it = contexts_.find(root_context_id); if (it == contexts_.end() || !it->second->isRootContext()) {