diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 41a8d2605ac4..284ee1cb0a74 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -50,6 +50,15 @@ namespace Wasm { namespace { +class DeferAfterCallActions { +public: + DeferAfterCallActions(Context* context) : wasm_(context->wasm()) {} + ~DeferAfterCallActions() { wasm_->doAfterVmCallActions(); } + +private: + Wasm* const wasm_; +}; + using HashPolicy = envoy::config::route::v3::RouteAction::HashPolicy; class SharedData { @@ -265,6 +274,8 @@ std::string Context::makeRootLogPrefix(absl::string_view vm_id) const { WasmVm* Context::wasmVm() const { return wasm_->wasm_vm(); } Upstream::ClusterManager& Context::clusterManager() const { return wasm_->clusterManager(); } +void Context::addAfterVmCallAction(std::function f) { wasm_->addAfterVmCallAction(f); } + WasmResult Context::setTickPeriod(std::chrono::milliseconds tick_period) { wasm_->setTickPeriod(root_context_id_ ? root_context_id_ : id_, tick_period); return WasmResult::Ok; @@ -1059,6 +1070,7 @@ bool Context::isSsl() { return decoder_callbacks_->connection()->ssl() != nullpt // Calls into the WASM code. // bool Context::onStart(absl::string_view vm_configuration, PluginSharedPtr plugin) { + DeferAfterCallActions actions(this); bool result = 0; if (wasm_->on_context_create_) { plugin_ = plugin; @@ -1094,6 +1106,7 @@ bool Context::onConfigure(absl::string_view plugin_configuration, PluginSharedPt if (!wasm_->on_configure_) { return true; } + DeferAfterCallActions actions(this); configuration_ = plugin_configuration; plugin_ = plugin; auto result = @@ -1111,17 +1124,20 @@ std::pair Context::getStatus() { void Context::onTick() { if (wasm_->on_tick_) { + DeferAfterCallActions actions(this); wasm_->on_tick_(this, id_); } } void Context::onCreate(uint32_t parent_context_id) { if (wasm_->on_context_create_) { + DeferAfterCallActions actions(this); wasm_->on_context_create_(this, id_, parent_context_id); } } Network::FilterStatus Context::onNetworkNewConnection() { + DeferAfterCallActions actions(this); onCreate(root_context_id_); if (!wasm_->on_new_connection_) { return Network::FilterStatus::Continue; @@ -1136,6 +1152,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str if (!wasm_->on_downstream_data_) { return Network::FilterStatus::Continue; } + DeferAfterCallActions actions(this); end_of_stream_ = end_of_stream; auto result = wasm_->on_downstream_data_(this, id_, static_cast(data_length), static_cast(end_of_stream)); @@ -1147,6 +1164,7 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea if (!wasm_->on_upstream_data_) { return Network::FilterStatus::Continue; } + DeferAfterCallActions actions(this); end_of_stream_ = end_of_stream; auto result = wasm_->on_upstream_data_(this, id_, static_cast(data_length), static_cast(end_of_stream)); @@ -1156,12 +1174,14 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea void Context::onDownstreamConnectionClose(PeerType peer_type) { if (wasm_->on_downstream_connection_close_) { + DeferAfterCallActions actions(this); wasm_->on_downstream_connection_close_(this, id_, static_cast(peer_type)); } } void Context::onUpstreamConnectionClose(PeerType peer_type) { if (wasm_->on_upstream_connection_close_) { + DeferAfterCallActions actions(this); wasm_->on_upstream_connection_close_(this, id_, static_cast(peer_type)); } } @@ -1170,6 +1190,7 @@ void Context::onUpstreamConnectionClose(PeerType peer_type) { template static uint32_t headerSize(const P& p) { return p ? p->size() : 0; } Http::FilterHeadersStatus Context::onRequestHeaders() { + DeferAfterCallActions actions(this); onCreate(root_context_id_); in_vm_context_created_ = true; if (!wasm_->on_request_headers_) { @@ -1185,6 +1206,7 @@ Http::FilterDataStatus Context::onRequestBody(int body_buffer_length, bool end_o if (!wasm_->on_request_body_) { return Http::FilterDataStatus::Continue; } + DeferAfterCallActions actions(this); switch (wasm_ ->on_request_body_(this, id_, static_cast(body_buffer_length), static_cast(end_of_stream)) @@ -1204,6 +1226,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() { if (!wasm_->on_request_trailers_) { return Http::FilterTrailersStatus::Continue; } + DeferAfterCallActions actions(this); if (wasm_->on_request_trailers_(this, id_, headerSize(request_trailers_)).u64_ == 0) { return Http::FilterTrailersStatus::Continue; } @@ -1214,6 +1237,7 @@ Http::FilterMetadataStatus Context::onRequestMetadata() { if (!wasm_->on_request_metadata_) { return Http::FilterMetadataStatus::Continue; } + DeferAfterCallActions actions(this); if (wasm_->on_request_metadata_(this, id_, headerSize(request_metadata_)).u64_ == 0) { return Http::FilterMetadataStatus::Continue; } @@ -1221,6 +1245,7 @@ Http::FilterMetadataStatus Context::onRequestMetadata() { } Http::FilterHeadersStatus Context::onResponseHeaders() { + DeferAfterCallActions actions(this); if (!in_vm_context_created_) { // If the request is invalid then onRequestHeaders() will not be called and neither will // onCreate() then sendLocalReply be called which will call this function. In this case we @@ -1242,6 +1267,7 @@ Http::FilterDataStatus Context::onResponseBody(int body_buffer_length, bool end_ if (!wasm_->on_response_body_) { return Http::FilterDataStatus::Continue; } + DeferAfterCallActions actions(this); switch (wasm_ ->on_response_body_(this, id_, static_cast(body_buffer_length), static_cast(end_of_stream)) @@ -1261,6 +1287,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() { if (!wasm_->on_response_trailers_) { return Http::FilterTrailersStatus::Continue; } + DeferAfterCallActions actions(this); if (wasm_->on_response_trailers_(this, id_, headerSize(response_trailers_)).u64_ == 0) { return Http::FilterTrailersStatus::Continue; } @@ -1271,6 +1298,7 @@ Http::FilterMetadataStatus Context::onResponseMetadata() { if (!wasm_->on_response_metadata_) { return Http::FilterMetadataStatus::Continue; } + DeferAfterCallActions actions(this); if (wasm_->on_response_metadata_(this, id_, headerSize(response_metadata_)).u64_ == 0) { return Http::FilterMetadataStatus::Continue; } @@ -1282,11 +1310,13 @@ void Context::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body if (!wasm_->on_http_call_response_) { return; } + DeferAfterCallActions actions(this); wasm_->on_http_call_response_(this, id_, token, headers, body_size, trailers); } void Context::onQueueReady(uint32_t token) { if (wasm_->on_queue_ready_) { + DeferAfterCallActions actions(this); wasm_->on_queue_ready_(this, id_, token); } } @@ -1295,6 +1325,7 @@ void Context::onGrpcCreateInitialMetadata(uint32_t token, Http::HeaderMap& metad if (!wasm_->on_grpc_create_initial_metadata_) { return; } + DeferAfterCallActions actions(this); grpc_create_initial_metadata_ = &metadata; wasm_->on_grpc_create_initial_metadata_(this, id_, token, headerSize(grpc_create_initial_metadata_)); @@ -1305,6 +1336,7 @@ void Context::onGrpcReceiveInitialMetadata(uint32_t token, Http::HeaderMapPtr&& if (!wasm_->on_grpc_receive_initial_metadata_) { return; } + DeferAfterCallActions actions(this); grpc_receive_initial_metadata_ = std::move(metadata); wasm_->on_grpc_receive_initial_metadata_(this, id_, token, headerSize(grpc_receive_initial_metadata_)); @@ -1315,6 +1347,7 @@ void Context::onGrpcReceiveTrailingMetadata(uint32_t token, Http::HeaderMapPtr&& if (!wasm_->on_grpc_receive_trailing_metadata_) { return; } + DeferAfterCallActions actions(this); grpc_receive_trailing_metadata_ = std::move(metadata); wasm_->on_grpc_receive_trailing_metadata_(this, id_, token, headerSize(grpc_receive_trailing_metadata_)); @@ -1452,6 +1485,7 @@ Context::~Context() { Network::FilterStatus Context::onNewConnection() { return onNetworkNewConnection(); }; Network::FilterStatus Context::onData(Buffer::Instance& data, bool end_stream) { + DeferAfterCallActions actions(this); network_downstream_data_buffer_ = &data; auto result = onDownstreamData(data.length(), end_stream); network_downstream_data_buffer_ = nullptr; @@ -1459,6 +1493,7 @@ Network::FilterStatus Context::onData(Buffer::Instance& data, bool end_stream) { } Network::FilterStatus Context::onWrite(Buffer::Instance& data, bool end_stream) { + DeferAfterCallActions actions(this); network_upstream_data_buffer_ = &data; auto result = onUpstreamData(data.length(), end_stream); network_upstream_data_buffer_ = nullptr; @@ -1471,6 +1506,7 @@ Network::FilterStatus Context::onWrite(Buffer::Instance& data, bool end_stream) } void Context::onEvent(Network::ConnectionEvent event) { + DeferAfterCallActions actions(this); switch (event) { case Network::ConnectionEvent::LocalClose: onDownstreamConnectionClose(PeerType::Local); @@ -1528,6 +1564,7 @@ void Context::onDestroy() { } bool Context::onDone() { + DeferAfterCallActions actions(this); if (wasm_->on_done_) { return wasm_->on_done_(this, id_).u64_ != 0; } @@ -1535,12 +1572,14 @@ bool Context::onDone() { } void Context::onLog() { + DeferAfterCallActions actions(this); if (wasm_->on_log_) { wasm_->on_log_(this, id_); } } void Context::onDelete() { + DeferAfterCallActions actions(this); if (wasm_->on_delete_) { wasm_->on_delete_(this, id_); } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 64da6a35a1c8..a6a2ebcb2723 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -330,6 +330,8 @@ class Context : public Logger::Loggable, // Connection virtual bool isSsl(); + void addAfterVmCallAction(std::function f); + protected: friend class Wasm; diff --git a/source/extensions/common/wasm/exports.cc b/source/extensions/common/wasm/exports.cc index df03669761de..9952ef1b86ce 100644 --- a/source/extensions/common/wasm/exports.cc +++ b/source/extensions/common/wasm/exports.cc @@ -196,8 +196,12 @@ Word send_local_response(void* raw_context, Word response_code, Word response_co auto grpc_status_opt = (grpc_status != Grpc::Status::WellKnownGrpcStatus::InvalidCode) ? absl::optional(grpc_status) : absl::optional(); - context->sendLocalResponse(static_cast(response_code.u64_), body.value(), - modify_headers, grpc_status_opt, details.value()); + context->addAfterVmCallAction([context, response_code, body = std::string(body.value()), + modify_headers = std::move(modify_headers), grpc_status_opt, + details = std::string(details.value())] { + context->sendLocalResponse(static_cast(response_code.u64_), body, + modify_headers, grpc_status_opt, details); + }); return wasmResultToWord(WasmResult::Ok); } diff --git a/source/extensions/common/wasm/wasm.h b/source/extensions/common/wasm/wasm.h index c187347b958e..64855f7d92e1 100644 --- a/source/extensions/common/wasm/wasm.h +++ b/source/extensions/common/wasm/wasm.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -141,6 +142,15 @@ class Wasm : public Logger::Loggable, public std::enable_share return true; } + void addAfterVmCallAction(std::function f) { after_vm_call_actions_.push_back(f); } + void doAfterVmCallActions() { + while (!after_vm_call_actions_.empty()) { + auto f = std::move(after_vm_call_actions_.front()); + after_vm_call_actions_.pop_front(); + f(); + } + } + private: friend class Context; class ShutdownHandle; @@ -268,6 +278,9 @@ class Wasm : public Logger::Loggable, public std::enable_share // Foreign Functions. absl::flat_hash_map foreign_functions_; + + // Actions to be done after the call into the VM returns. + std::deque> after_vm_call_actions_; }; using WasmSharedPtr = std::shared_ptr;