diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto b/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto index 2b9863e91ffa..53bec0436116 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto @@ -58,7 +58,7 @@ enum ProtocolType { TWITTER = 4; } -// [#next-free-field: 6] +// [#next-free-field: 7] message ThriftProxy { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.thrift_proxy.v2alpha1.ThriftProxy"; @@ -82,6 +82,12 @@ message ThriftProxy { // compatibility, if no thrift_filters are specified, a default Thrift router filter // (`envoy.filters.thrift.router`) is used. repeated ThriftFilter thrift_filters = 5; + + // If set to true, Envoy will try to skip decode data after metadata in the Thrift message. + // This mode will only work if the upstream and downstream protocols are the same and the transport + // is the same, the transport type is framed and the protocol is not Twitter. Otherwise Envoy will + // fallback to decode the data. + bool payload_passthrough = 6; } // ThriftFilter configures a Thrift filter. diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto b/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto index b75d0e39eaf2..8e7bf3c91a03 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto @@ -58,7 +58,7 @@ enum ProtocolType { TWITTER = 4; } -// [#next-free-field: 6] +// [#next-free-field: 7] message ThriftProxy { option (udpa.annotations.versioning).previous_message_type = "envoy.extensions.filters.network.thrift_proxy.v3.ThriftProxy"; @@ -82,6 +82,12 @@ message ThriftProxy { // compatibility, if no thrift_filters are specified, a default Thrift router filter // (`envoy.filters.thrift.router`) is used. repeated ThriftFilter thrift_filters = 5; + + // If set to true, Envoy will try to skip decode data after metadata in the Thrift message. + // This mode will only work if the upstream and downstream protocols are the same and the transport + // is the same, the transport type is framed and the protocol is not Twitter. Otherwise Envoy will + // fallback to decode the data. + bool payload_passthrough = 6; } // ThriftFilter configures a Thrift filter. diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index a19010abf84d..90389259cc2e 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -75,6 +75,7 @@ New Features :ref:`CertificateValidationContext `. * signal: added an extension point for custom actions to run on the thread that has encountered a fatal error. Actions are configurable via :ref:`fatal_actions `. * tcp: added a new :ref:`envoy.overload_actions.reject_incoming_connections ` action to reject incoming TCP connections. +* thrift_proxy: added a new :ref: `payload_passthrough ` option to skip decoding body in the Thrift message. * tls: added support for RSA certificates with 4096-bit keys in FIPS mode. * tracing: added SkyWalking tracer. * xds: added support for resource TTLs. A TTL is specified on the :ref:`Resource `. For SotW, a :ref:`Resource ` can be embedded diff --git a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto index 88f7b013fec7..4fc04f2d8803 100644 --- a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto +++ b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto @@ -58,7 +58,7 @@ enum ProtocolType { TWITTER = 4; } -// [#next-free-field: 6] +// [#next-free-field: 7] message ThriftProxy { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.thrift_proxy.v2alpha1.ThriftProxy"; @@ -82,6 +82,12 @@ message ThriftProxy { // compatibility, if no thrift_filters are specified, a default Thrift router filter // (`envoy.filters.thrift.router`) is used. repeated ThriftFilter thrift_filters = 5; + + // If set to true, Envoy will try to skip decode data after metadata in the Thrift message. + // This mode will only work if the upstream and downstream protocols are the same and the transport + // is the same, the transport type is framed and the protocol is not Twitter. Otherwise Envoy will + // fallback to decode the data. + bool payload_passthrough = 6; } // ThriftFilter configures a Thrift filter. diff --git a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto index b75d0e39eaf2..8e7bf3c91a03 100644 --- a/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto +++ b/generated_api_shadow/envoy/extensions/filters/network/thrift_proxy/v4alpha/thrift_proxy.proto @@ -58,7 +58,7 @@ enum ProtocolType { TWITTER = 4; } -// [#next-free-field: 6] +// [#next-free-field: 7] message ThriftProxy { option (udpa.annotations.versioning).previous_message_type = "envoy.extensions.filters.network.thrift_proxy.v3.ThriftProxy"; @@ -82,6 +82,12 @@ message ThriftProxy { // compatibility, if no thrift_filters are specified, a default Thrift router filter // (`envoy.filters.thrift.router`) is used. repeated ThriftFilter thrift_filters = 5; + + // If set to true, Envoy will try to skip decode data after metadata in the Thrift message. + // This mode will only work if the upstream and downstream protocols are the same and the transport + // is the same, the transport type is framed and the protocol is not Twitter. Otherwise Envoy will + // fallback to decode the data. + bool payload_passthrough = 6; } // ThriftFilter configures a Thrift filter. diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index fc2edbb54cb1..52b5030bc57f 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -121,7 +121,8 @@ ConfigImpl::ConfigImpl( : context_(context), stats_prefix_(fmt::format("thrift.{}.", config.stat_prefix())), stats_(ThriftFilterStats::generateStats(stats_prefix_, context_.scope())), transport_(lookupTransport(config.transport())), proto_(lookupProtocol(config.protocol())), - route_matcher_(new Router::RouteMatcher(config.route_config())) { + route_matcher_(new Router::RouteMatcher(config.route_config())), + payload_passthrough_(config.payload_passthrough()) { if (config.thrift_filters().empty()) { ENVOY_LOG(debug, "using default router filter"); diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 532298c380e3..02c1fcf4d13a 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -81,6 +81,7 @@ class ConfigImpl : public Config, TransportPtr createTransport() override; ProtocolPtr createProtocol() override; Router::Config& routerConfig() override { return *this; } + bool payloadPassthrough() const override { return payload_passthrough_; } private: void processFilter( @@ -94,6 +95,7 @@ class ConfigImpl : public Config, std::unique_ptr route_matcher_; std::list filter_factories_; + const bool payload_passthrough_; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index 737e70736978..7f7129715edf 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -176,6 +176,17 @@ DecoderEventHandler& ConnectionManager::newDecoderEventHandler() { return **rpcs_.begin(); } +bool ConnectionManager::passthroughEnabled() const { + if (!config_.payloadPassthrough()) { + return false; + } + + // This is called right after the metadata has been parsed, and the ActiveRpc being processed must + // be in the rpcs_ list. + ASSERT(!rpcs_.empty()); + return (*rpcs_.begin())->passthroughSupported(); +} + bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { upstream_buffer_.move(data); @@ -274,6 +285,10 @@ FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { return FilterStatus::Continue; } +bool ConnectionManager::ResponseDecoder::passthroughEnabled() const { + return parent_.parent_.passthroughEnabled(); +} + void ConnectionManager::ActiveRpcDecoderFilter::continueDecoding() { const FilterStatus status = parent_.applyDecoderFilters(this); if (status == FilterStatus::Continue) { @@ -398,6 +413,25 @@ void ConnectionManager::ActiveRpc::finalizeRequest() { } } +bool ConnectionManager::ActiveRpc::passthroughSupported() const { + for (auto& entry : decoder_filters_) { + if (!entry->handle_->passthroughSupported()) { + return false; + } + } + return true; +} + +FilterStatus ConnectionManager::ActiveRpc::passthroughData(Buffer::Instance& data) { + filter_context_ = &data; + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + Buffer::Instance* data = absl::any_cast(filter_context_); + return filter->passthroughData(*data); + }; + + return applyDecoderFilters(nullptr); +} + FilterStatus ConnectionManager::ActiveRpc::messageBegin(MessageMetadataSharedPtr metadata) { ASSERT(metadata->hasSequenceId()); ASSERT(metadata->hasMessageType()); diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index b7408e1a3def..3fc0cf7b2f04 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -39,6 +39,7 @@ class Config { virtual TransportPtr createTransport() PURE; virtual ProtocolPtr createProtocol() PURE; virtual Router::Config& routerConfig() PURE; + virtual bool payloadPassthrough() const PURE; }; /** @@ -76,6 +77,7 @@ class ConnectionManager : public Network::ReadFilter, // DecoderCallbacks DecoderEventHandler& newDecoderEventHandler() override; + bool passthroughEnabled() const override; private: struct ActiveRpc; @@ -102,6 +104,7 @@ class ConnectionManager : public Network::ReadFilter, // DecoderCallbacks DecoderEventHandler& newDecoderEventHandler() override { return *this; } + bool passthroughEnabled() const override; ActiveRpc& parent_; DecoderPtr decoder_; @@ -180,6 +183,7 @@ class ConnectionManager : public Network::ReadFilter, // DecoderEventHandler FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override; FilterStatus transportEnd() override; + FilterStatus passthroughData(Buffer::Instance& data) override; FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override; FilterStatus messageEnd() override; FilterStatus structBegin(absl::string_view name) override; @@ -225,6 +229,7 @@ class ConnectionManager : public Network::ReadFilter, LinkedList::moveIntoListBack(std::move(wrapper), decoder_filters_); } + bool passthroughSupported() const; FilterStatus applyDecoderFilters(ActiveRpcDecoderFilter* filter); void finalizeRequest(); diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index 73a12ff23377..55f23af6ef88 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -2,6 +2,7 @@ #include "envoy/common/exception.h" +#include "common/buffer/buffer_impl.h" #include "common/common/assert.h" #include "common/common/macros.h" @@ -12,8 +13,22 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +// PassthroughData -> PassthroughData +// PassthroughData -> MessageEnd (all body bytes received) +DecoderStateMachine::DecoderStatus DecoderStateMachine::passthroughData(Buffer::Instance& buffer) { + if (body_bytes_ > buffer.length()) { + return {ProtocolState::WaitForData}; + } + + Buffer::OwnedImpl body; + body.move(buffer, body_bytes_); + + return {ProtocolState::MessageEnd, handler_.passthroughData(body)}; +} + // MessageBegin -> StructBegin DecoderStateMachine::DecoderStatus DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { + const auto total = buffer.length(); if (!proto_.readMessageBegin(buffer, *metadata_)) { return {ProtocolState::WaitForData}; } @@ -21,7 +36,14 @@ DecoderStateMachine::DecoderStatus DecoderStateMachine::messageBegin(Buffer::Ins stack_.clear(); stack_.emplace_back(Frame(ProtocolState::MessageEnd)); - return {ProtocolState::StructBegin, handler_.messageBegin(metadata_)}; + const auto status = handler_.messageBegin(metadata_); + + if (callbacks_.passthroughEnabled()) { + body_bytes_ = metadata_->frameSize() - (total - buffer.length()); + return {ProtocolState::PassthroughData, status}; + } + + return {ProtocolState::StructBegin, status}; } // MessageEnd -> Done @@ -293,6 +315,8 @@ DecoderStateMachine::DecoderStatus DecoderStateMachine::handleValue(Buffer::Inst DecoderStateMachine::DecoderStatus DecoderStateMachine::handleState(Buffer::Instance& buffer) { switch (state_) { + case ProtocolState::PassthroughData: + return passthroughData(buffer); case ProtocolState::MessageBegin: return messageBegin(buffer); case ProtocolState::StructBegin: @@ -416,7 +440,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { request_ = std::make_unique(callbacks_.newDecoderEventHandler()); frame_started_ = true; state_machine_ = - std::make_unique(protocol_, metadata_, request_->handler_); + std::make_unique(protocol_, metadata_, request_->handler_, callbacks_); if (request_->handler_.transportBegin(metadata_) == FilterStatus::StopIteration) { return FilterStatus::StopIteration; diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index 1f7675b9f76f..6c8b8bef5755 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -17,6 +17,7 @@ namespace ThriftProxy { #define ALL_PROTOCOL_STATES(FUNCTION) \ FUNCTION(StopIteration) \ FUNCTION(WaitForData) \ + FUNCTION(PassthroughData) \ FUNCTION(MessageBegin) \ FUNCTION(MessageEnd) \ FUNCTION(StructBegin) \ @@ -56,6 +57,8 @@ class ProtocolStateNameValues { } }; +class DecoderCallbacks; + /** * DecoderStateMachine is the Thrift message state machine as described in * source/extensions/filters/network/thrift_proxy/docs. @@ -63,9 +66,9 @@ class ProtocolStateNameValues { class DecoderStateMachine : public Logger::Loggable { public: DecoderStateMachine(Protocol& proto, MessageMetadataSharedPtr& metadata, - DecoderEventHandler& handler) - : proto_(proto), metadata_(metadata), handler_(handler), state_(ProtocolState::MessageBegin) { - } + DecoderEventHandler& handler, DecoderCallbacks& callbacks) + : proto_(proto), metadata_(metadata), handler_(handler), callbacks_(callbacks), + state_(ProtocolState::MessageBegin) {} /** * Consumes as much data from the configured Buffer as possible and executes the decoding state @@ -129,6 +132,7 @@ class DecoderStateMachine : public Logger::Loggable { // These functions map directly to the matching ProtocolState values. Each returns the next state // or ProtocolState::WaitForData if more data is required. + DecoderStatus passthroughData(Buffer::Instance& buffer); DecoderStatus messageBegin(Buffer::Instance& buffer); DecoderStatus messageEnd(Buffer::Instance& buffer); DecoderStatus structBegin(Buffer::Instance& buffer); @@ -165,8 +169,10 @@ class DecoderStateMachine : public Logger::Loggable { Protocol& proto_; MessageMetadataSharedPtr metadata_; DecoderEventHandler& handler_; + DecoderCallbacks& callbacks_; ProtocolState state_; std::vector stack_; + uint32_t body_bytes_{}; }; using DecoderStateMachinePtr = std::unique_ptr; @@ -179,6 +185,11 @@ class DecoderCallbacks { * @return DecoderEventHandler& a new DecoderEventHandler for a message. */ virtual DecoderEventHandler& newDecoderEventHandler() PURE; + + /** + * @return True if payload passthrough is enabled and is supported by filter chain. + */ + virtual bool passthroughEnabled() const PURE; }; /** diff --git a/source/extensions/filters/network/thrift_proxy/decoder_events.h b/source/extensions/filters/network/thrift_proxy/decoder_events.h index c69db94c0d26..b5981d6bc1df 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder_events.h +++ b/source/extensions/filters/network/thrift_proxy/decoder_events.h @@ -35,6 +35,14 @@ class DecoderEventHandler { */ virtual FilterStatus transportEnd() PURE; + /** + * Indicates raw bytes after metadata in a Thrift transport frame was detected. + * Filters should not modify data except for the router. + * @param data data to send as passthrough + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus passthroughData(Buffer::Instance& data) PURE; + /** * Indicates that the start of a Thrift protocol message was detected. * @param metadata MessageMetadataSharedPtr describing the message diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index c2ef1a895061..122ff019e711 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -123,6 +123,12 @@ class DecoderFilter : public virtual DecoderEventHandler { * filter should use. Callbacks will not be invoked by the filter after onDestroy() is called. */ virtual void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) PURE; + + /** + * @return True if payload passthrough is supported. Called by the connection manager once after + * messageBegin. + */ + virtual bool passthroughSupported() const PURE; }; using DecoderFilterSharedPtr = std::shared_ptr; diff --git a/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h b/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h index 6992253a1198..dfcdebb8888c 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h @@ -30,6 +30,12 @@ class PassThroughDecoderFilter : public DecoderFilter { ThriftProxy::FilterStatus transportEnd() override { return ThriftProxy::FilterStatus::Continue; } + bool passthroughSupported() const override { return true; } + + ThriftProxy::FilterStatus passthroughData(Buffer::Instance&) override { + return ThriftProxy::FilterStatus::Continue; + } + ThriftProxy::FilterStatus messageBegin(ThriftProxy::MessageMetadataSharedPtr) override { return ThriftProxy::FilterStatus::Continue; } diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index 2d73f4c9498b..16a2d4111541 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -25,6 +25,11 @@ class ProtocolConverter : public virtual DecoderEventHandler { } // DecoderEventHandler + FilterStatus passthroughData(Buffer::Instance& data) override { + buffer_->move(data); + return FilterStatus::Continue; + } + FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override { proto_->writeMessageBegin(*buffer_, *metadata); return FilterStatus::Continue; diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index aa10f24d5b71..c0529c69109e 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -262,6 +262,12 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { : callbacks_->downstreamProtocolType(); ASSERT(protocol != ProtocolType::Auto); + if (callbacks_->downstreamTransportType() == TransportType::Framed && + transport == TransportType::Framed && callbacks_->downstreamProtocolType() == protocol && + protocol != ProtocolType::Twitter) { + passthrough_supported_ = true; + } + Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( cluster_name, Upstream::ResourcePriority::Default, this); if (!conn_pool) { diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 26a94c90c753..a2328d4777f5 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -178,13 +178,15 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, public: Router(Upstream::ClusterManager& cluster_manager, const std::string& stat_prefix, Stats::Scope& scope) - : cluster_manager_(cluster_manager), stats_(generateStats(stat_prefix, scope)) {} + : cluster_manager_(cluster_manager), stats_(generateStats(stat_prefix, scope)), + passthrough_supported_(false) {} ~Router() override = default; // ThriftFilters::DecoderFilter void onDestroy() override; void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) override; + bool passthroughSupported() const override { return passthrough_supported_; } // ProtocolConverter FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override; @@ -265,6 +267,8 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, std::unique_ptr upstream_request_; Buffer::OwnedImpl upstream_request_buffer_; + + bool passthrough_supported_ : 1; }; } // namespace Router diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h index d6961ce87357..700e6b223839 100644 --- a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h @@ -21,6 +21,7 @@ class ThriftBase : public DecoderEventHandler { ~ThriftBase() override = default; // DecoderEventHandler + FilterStatus passthroughData(Buffer::Instance&) override { return FilterStatus::Continue; } FilterStatus transportBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } FilterStatus transportEnd() override { return FilterStatus::Continue; } FilterStatus messageBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } @@ -246,6 +247,7 @@ class ThriftObjectImpl : public ThriftObject, complete_ = true; return FilterStatus::Continue; } + bool passthroughEnabled() const override { return false; } // ThriftObject bool onData(Buffer::Instance& buffer) override; diff --git a/test/config/utility.cc b/test/config/utility.cc index 54a8247dfb18..dc86c425e960 100644 --- a/test/config/utility.cc +++ b/test/config/utility.cc @@ -1047,25 +1047,17 @@ void ConfigHelper::addListenerFilter(const std::string& filter_yaml) { bool ConfigHelper::loadHttpConnectionManager( envoy::extensions::filters::network::http_connection_manager::v3::HttpConnectionManager& hcm) { - RELEASE_ASSERT(!finalized_, ""); - auto* hcm_filter = getFilterFromListener("http"); - if (hcm_filter) { - auto* config = hcm_filter->mutable_typed_config(); - hcm = MessageUtil::anyConvert< - envoy::extensions::filters::network::http_connection_manager::v3::HttpConnectionManager>( - *config); - return true; - } - return false; + return loadFilter< + envoy::extensions::filters::network::http_connection_manager::v3::HttpConnectionManager>( + "http", hcm); } void ConfigHelper::storeHttpConnectionManager( const envoy::extensions::filters::network::http_connection_manager::v3::HttpConnectionManager& hcm) { - RELEASE_ASSERT(!finalized_, ""); - auto* hcm_config_any = getFilterFromListener("http")->mutable_typed_config(); - - hcm_config_any->PackFrom(hcm); + return storeFilter< + envoy::extensions::filters::network::http_connection_manager::v3::HttpConnectionManager>( + "http", hcm); } void ConfigHelper::addConfigModifier(ConfigModifierFunction function) { diff --git a/test/config/utility.h b/test/config/utility.h index 0f7fc73b433e..a0f19e90a445 100644 --- a/test/config/utility.h +++ b/test/config/utility.h @@ -19,6 +19,7 @@ #include "common/config/api_version.h" #include "common/network/address_impl.h" #include "common/protobuf/protobuf.h" +#include "common/protobuf/utility.h" #include "test/integration/server_stats.h" @@ -219,6 +220,19 @@ class ConfigHelper { // Modifiers will be applied just before ports are modified in finalize void addConfigModifier(HttpModifierFunction function); + // Allows callers to easily modify the filter named 'name' from the first filter chain from the + // first listener. Modifiers will be applied just before ports are modified in finalize + template + void addFilterConfigModifier(const std::string& name, + std::function function) { + addConfigModifier([name, function, this](envoy::config::bootstrap::v3::Bootstrap&) -> void { + FilterType filter_config; + loadFilter(name, filter_config); + function(filter_config); + storeFilter(name, filter_config); + }); + } + // Apply any outstanding config modifiers, stick all the listeners in a discovery response message // and write it to the lds file. void setLds(absl::string_view version_info); @@ -275,6 +289,26 @@ class ConfigHelper { // struct of the first listener. void storeHttpConnectionManager(const HttpConnectionManager& hcm); + // Load the first FilterType struct from the first listener into a parsed proto. + template bool loadFilter(const std::string& name, FilterType& filter) { + RELEASE_ASSERT(!finalized_, ""); + auto* filter_config = getFilterFromListener(name); + if (filter_config) { + auto* config = filter_config->mutable_typed_config(); + filter = MessageUtil::anyConvert(*config); + return true; + } + return false; + } + // Take the contents of the provided FilterType proto and stuff them into the first FilterType + // struct of the first listener. + template void storeFilter(const std::string& name, const FilterType& filter) { + RELEASE_ASSERT(!finalized_, ""); + auto* filter_config_any = getFilterFromListener(name)->mutable_typed_config(); + + filter_config_any->PackFrom(filter); + } + // Finds the filter named 'name' from the first filter chain from the first listener. envoy::config::listener::v3::Filter* getFilterFromListener(const std::string& name); diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index 24eb2f193703..b4bb10f1d2d2 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -267,6 +267,7 @@ envoy_extension_cc_test( ":mocks", ":utility_lib", "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy:config", "//source/extensions/filters/network/thrift_proxy/router:config", "//source/extensions/filters/network/thrift_proxy/router:router_lib", "//test/mocks/network:network_mocks", @@ -333,6 +334,7 @@ envoy_extension_cc_test( "//test/extensions/filters/network/thrift_proxy/driver:generate_fixture", ], extension_name = "envoy.filters.network.thrift_proxy", + shard_count = 4, deps = [ ":integration_lib", ":utility_lib", diff --git a/test/extensions/filters/network/thrift_proxy/config_test.cc b/test/extensions/filters/network/thrift_proxy/config_test.cc index 6bf4afbd3f7a..9b3db4c4ebc8 100644 --- a/test/extensions/filters/network/thrift_proxy/config_test.cc +++ b/test/extensions/filters/network/thrift_proxy/config_test.cc @@ -206,6 +206,24 @@ stat_prefix: ingress EXPECT_EQ("thrift.ingress.", factory.config_stat_prefix_); } +// Test config with payload passthrough enabled. +TEST_F(ThriftFilterConfigTest, ThriftProxyPayloadPassthrough) { + const std::string yaml = R"EOF( +stat_prefix: ingress +payload_passthrough: true +route_config: + name: local_route +thrift_filters: + - name: envoy.filters.thrift.router +)EOF"; + + envoy::extensions::filters::network::thrift_proxy::v3::ThriftProxy config = + parseThriftProxyFromV2Yaml(yaml); + testConfig(config); + + EXPECT_EQ(true, config.payload_passthrough()); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 9798db6e2e9c..f58948b262fc 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -348,7 +348,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftOneWay) { initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -386,7 +386,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataHandlesStopIterationAndResume) { EXPECT_EQ(&filter_callbacks_.connection_, callbacks->connection()); // Resume processing. - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); callbacks->continueDecoding(); EXPECT_EQ(1U, store_.counter("test.request").value()); @@ -461,7 +461,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolError) { EXPECT_EQ(write_buffer_.toString(), buffer.toString()); })); EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); @@ -561,7 +561,7 @@ TEST_F(ThriftConnectionManagerTest, OnEvent) { }); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); filter_->onEvent(Network::ConnectionEvent::RemoteClose); EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); @@ -580,7 +580,7 @@ TEST_F(ThriftConnectionManagerTest, OnEvent) { }); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); filter_->onEvent(Network::ConnectionEvent::LocalClose); EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); @@ -596,7 +596,7 @@ TEST_F(ThriftConnectionManagerTest, OnEvent) { writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); filter_->onEvent(Network::ConnectionEvent::RemoteClose); EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); @@ -612,7 +612,7 @@ TEST_F(ThriftConnectionManagerTest, OnEvent) { writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); filter_->onEvent(Network::ConnectionEvent::LocalClose); EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); @@ -655,7 +655,7 @@ stat_prefix: test EXPECT_NE(nullptr, route->routeEntry()); EXPECT_EQ("cluster", route->routeEntry()->clusterName()); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); callbacks->continueDecoding(); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -679,7 +679,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -713,7 +713,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndVoidResponse) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -755,7 +755,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponseSequenceIdHandling) { .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { EXPECT_EQ(response_buffer.toString(), buffer.toString()); })); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -789,7 +789,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -824,7 +824,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -859,7 +859,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -901,7 +901,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_, write(_, true)); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Reset, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -943,7 +943,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndTransportApplicationException) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Reset, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -979,7 +979,7 @@ TEST_F(ThriftConnectionManagerTest, RequestAndGarbageResponse) { BinaryProtocolImpl proto; callbacks->startUpstreamResponse(transport, proto); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(ThriftFilters::ResponseStatus::Reset, callbacks->upstreamData(write_buffer_)); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); @@ -1249,7 +1249,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalReply) { EXPECT_EQ(8, buffer.drainBEInt()); EXPECT_EQ("response", buffer.toString()); })); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); @@ -1293,7 +1293,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalErrorReply) { EXPECT_EQ(8, buffer.drainBEInt()); EXPECT_EQ("response", buffer.toString()); })); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); @@ -1328,7 +1328,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendLocalReplyRemoteClosedCo return FilterStatus::StopIteration; })); EXPECT_CALL(filter_callbacks_.connection_, write(_, false)).Times(0); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); // Remote closes the connection. filter_callbacks_.connection_.state_ = Network::Connection::State::Closed; @@ -1417,6 +1417,230 @@ TEST_F(ThriftConnectionManagerTest, TransportEndWhenRemoteClose) { filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); } +// TODO(caitong93): use TEST_P to avoid duplicating test cases +TEST_F(ThriftConnectionManagerTest, PayloadPassthroughOnDataHandlesThriftCall) { + const std::string yaml = R"EOF( +transport: FRAMED +protocol: BINARY +stat_prefix: test +payload_passthrough: true +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, stats_.request_active_.value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, PayloadPassthroughOnDataHandlesThriftOneWay) { + const std::string yaml = R"EOF( +stat_prefix: test +payload_passthrough: true +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(0U, stats_.request_active_.value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, PayloadPassthroughRequestAndExceptionResponse) { + const std::string yaml = R"EOF( +stat_prefix: test +payload_passthrough: true +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryTApplicationException(write_buffer_, 0x0F); + + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, stats_.request_active_.value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, PayloadPassthroughRequestAndErrorResponse) { + const std::string yaml = R"EOF( +stat_prefix: test +payload_passthrough: true +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryIDLException(write_buffer_, 0x0F); + + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, stats_.request_active_.value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + // In payload_passthrough mode, Envoy cannot detect response error. + EXPECT_EQ(1U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, PayloadPassthroughRequestAndInvalidResponse) { + const std::string yaml = R"EOF( +stat_prefix: test +payload_passthrough: true +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // Call is not valid in a response + writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); + + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, stats_.request_active_.value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, PayloadPassthroughRouting) { + const std::string yaml = R"EOF( +transport: FRAMED +protocol: BINARY +payload_passthrough: true +stat_prefix: test +route_config: + name: "routes" + routes: + - match: + method_name: name + route: + cluster: cluster +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_)).WillOnce(Return(FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, stats_.request_active_.value()); + + Router::RouteConstSharedPtr route = callbacks->route(); + EXPECT_NE(nullptr, route); + EXPECT_NE(nullptr, route->routeEntry()); + EXPECT_EQ("cluster", route->routeEntry()->clusterName()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + callbacks->continueDecoding(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 4699f3d94a91..73c8ac0b85dc 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -185,6 +185,7 @@ class DecoderStateMachineTestBase { NiceMock proto_; MessageMetadataSharedPtr metadata_; NiceMock handler_; + NiceMock callbacks_; }; class DecoderStateMachineNonValueTest : public DecoderStateMachineTestBase, @@ -236,7 +237,7 @@ TEST_P(DecoderStateMachineNonValueTest, NoData) { ProtocolState state = GetParam(); Buffer::OwnedImpl buffer; - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(state); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), state); @@ -256,7 +257,7 @@ TEST_P(DecoderStateMachineValueTest, NoFieldValueData) { EXPECT_CALL(proto_, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -280,7 +281,7 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { EXPECT_CALL(proto_, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -295,7 +296,7 @@ TEST_F(DecoderStateMachineTest, NoListValueData) { .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -310,7 +311,7 @@ TEST_F(DecoderStateMachineTest, EmptyList) { .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto_, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -329,7 +330,7 @@ TEST_P(DecoderStateMachineValueTest, ListValue) { EXPECT_CALL(proto_, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -346,7 +347,7 @@ TEST_P(DecoderStateMachineValueTest, IncompleteListValue) { expectValue(proto_, handler_, field_type, false); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -372,7 +373,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { EXPECT_CALL(proto_, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -388,7 +389,7 @@ TEST_F(DecoderStateMachineTest, NoMapKeyData) { SetArgReferee<3>(1), Return(true))); EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -405,7 +406,7 @@ TEST_F(DecoderStateMachineTest, NoMapValueData) { EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(proto_, readString(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -421,7 +422,7 @@ TEST_F(DecoderStateMachineTest, EmptyMap) { SetArgReferee<3>(0), Return(true))); EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -442,7 +443,7 @@ TEST_P(DecoderStateMachineValueTest, MapKeyValue) { EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -463,7 +464,7 @@ TEST_P(DecoderStateMachineValueTest, MapValueValue) { EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -481,7 +482,7 @@ TEST_P(DecoderStateMachineValueTest, IncompleteMapKey) { expectValue(proto_, handler_, field_type, false); // key - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -506,7 +507,7 @@ TEST_P(DecoderStateMachineValueTest, IncompleteMapValue) { expectValue(proto_, handler_, FieldType::I32); // key expectValue(proto_, handler_, field_type, false); // value - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -534,7 +535,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -549,7 +550,7 @@ TEST_F(DecoderStateMachineTest, NoSetValueData) { .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -564,7 +565,7 @@ TEST_F(DecoderStateMachineTest, EmptySet) { .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto_, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -583,7 +584,7 @@ TEST_P(DecoderStateMachineValueTest, SetValue) { EXPECT_CALL(proto_, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -600,7 +601,7 @@ TEST_P(DecoderStateMachineValueTest, IncompleteSetValue) { expectValue(proto_, handler_, field_type, false); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -626,7 +627,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { EXPECT_CALL(proto_, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -650,7 +651,7 @@ TEST_F(DecoderStateMachineTest, EmptyStruct) { EXPECT_CALL(proto_, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -706,7 +707,7 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -767,7 +768,7 @@ TEST_F(DecoderStateMachineTest, MultiFieldStruct) { EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -823,7 +824,7 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - DecoderStateMachine dsm(proto_, metadata_, handler_); + DecoderStateMachine dsm(proto_, metadata_, handler_, callbacks_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -869,6 +870,7 @@ TEST(DecoderTest, OnData) { EXPECT_EQ(100U, metadata->sequenceId()); return FilterStatus::Continue; })); + EXPECT_CALL(callbacks, passthroughEnabled()).Times(1).WillRepeatedly(Return(false)); EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); @@ -936,6 +938,7 @@ TEST(DecoderTest, OnDataWithProtocolHint) { EXPECT_EQ(100U, metadata->sequenceId()); return FilterStatus::Continue; })); + EXPECT_CALL(callbacks, passthroughEnabled()).Times(1); EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); @@ -1178,6 +1181,7 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(100U, metadata->sequenceId()); return FilterStatus::StopIteration; })); + EXPECT_CALL(callbacks, passthroughEnabled()).Times(1).WillRepeatedly(Return(false)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); @@ -1231,6 +1235,284 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_TRUE(underflow); } +TEST(DecoderTest, OnDataPassthrough) { + NiceMock transport; + NiceMock proto; + NiceMock callbacks; + StrictMock handler; + ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); + + InSequence dummy; + Decoder decoder(transport, proto, callbacks); + Buffer::OwnedImpl buffer(std::string(100, 'a')); + + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(handler, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasFrameSize()); + EXPECT_EQ(100U, metadata->frameSize()); + return FilterStatus::Continue; + })); + + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + buffer.drain(20); + return true; + })); + EXPECT_CALL(handler, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return FilterStatus::Continue; + })); + + EXPECT_CALL(callbacks, passthroughEnabled()).WillOnce(Return(true)); + EXPECT_CALL(handler, passthroughData(_)) + .WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus { + EXPECT_EQ(80, data.length()); + return FilterStatus::Continue; + })); + + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); + + bool underflow = false; + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); +} + +TEST(DecoderTest, OnDataPassthroughResumes) { + NiceMock transport; + NiceMock proto; + NiceMock callbacks; + NiceMock handler; + ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); + + InSequence dummy; + + Decoder decoder(transport, proto, callbacks); + Buffer::OwnedImpl buffer; + buffer.add("x"); + + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(proto, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(callbacks, passthroughEnabled()).WillOnce(Return(true)); + EXPECT_CALL(handler, passthroughData(_)).Times(0); + + bool underflow = false; + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + buffer.add(std::string(100, 'a')); + EXPECT_CALL(handler, passthroughData(_)) + .WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus { + EXPECT_EQ(100, data.length()); + return FilterStatus::Continue; + })); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); + + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); // buffer.length() == 1 +} + +TEST(DecoderTest, OnDataPassthroughResumesTransportFrameStart) { + StrictMock transport; + StrictMock proto; + NiceMock callbacks; + NiceMock handler; + ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); + + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); + + InSequence dummy; + + Decoder decoder(transport, proto, callbacks); + Buffer::OwnedImpl buffer; + buffer.add(std::string(100, 'a')); + bool underflow = false; + + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(proto, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(callbacks, passthroughEnabled()).WillOnce(Return(true)); + EXPECT_CALL(handler, passthroughData(_)) + .WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus { + EXPECT_EQ(100, data.length()); + return FilterStatus::Continue; + })); + + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); + + underflow = false; + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 +} + +TEST(DecoderTest, OnDataPassthroughResumesTransportFrameEnd) { + StrictMock transport; + StrictMock proto; + NiceMock callbacks; + NiceMock handler; + ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); + + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); + + InSequence dummy; + + Decoder decoder(transport, proto, callbacks); + Buffer::OwnedImpl buffer; + buffer.add(std::string(100, 'a')); + + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(proto, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(callbacks, passthroughEnabled()).WillOnce(Return(true)); + EXPECT_CALL(handler, passthroughData(_)) + .WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus { + EXPECT_EQ(100, data.length()); + return FilterStatus::Continue; + })); + + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(false)); + + bool underflow = false; + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 +} + +TEST(DecoderTest, OnDataPassthroughHandlesStopIterationAndResumes) { + StrictMock transport; + EXPECT_CALL(transport, name()).WillRepeatedly(ReturnRef(transport.name_)); + + StrictMock proto; + EXPECT_CALL(proto, name()).WillRepeatedly(ReturnRef(proto.name_)); + + NiceMock callbacks; + StrictMock handler; + ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); + + InSequence dummy; + Decoder decoder(transport, proto, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = true; + + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(handler, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasFrameSize()); + EXPECT_EQ(100U, metadata->frameSize()); + + return FilterStatus::StopIteration; + })); + EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(handler, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return FilterStatus::StopIteration; + })); + EXPECT_CALL(callbacks, passthroughEnabled()).WillOnce(Return(true)); + EXPECT_CALL(handler, passthroughData(_)).Times(0); + + EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + buffer.add(std::string(100, 'a')); + EXPECT_CALL(handler, passthroughData(_)) + .WillOnce(Invoke([&](Buffer::Instance& data) -> FilterStatus { + EXPECT_EQ(100, data.length()); + return FilterStatus::StopIteration; + })); + + EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); // buffer.length() == 0 + + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::StopIteration)); + EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::StopIteration)); + EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); +} + #define TEST_NAME(X) EXPECT_EQ(ProtocolStateNameValues::name(ProtocolState::X), #X); TEST(ProtocolStateNameValuesTest, ValidNames) { ALL_PROTOCOL_STATES(TEST_NAME) } diff --git a/test/extensions/filters/network/thrift_proxy/integration_test.cc b/test/extensions/filters/network/thrift_proxy/integration_test.cc index c3ea956fb23f..16dddb2b3163 100644 --- a/test/extensions/filters/network/thrift_proxy/integration_test.cc +++ b/test/extensions/filters/network/thrift_proxy/integration_test.cc @@ -19,7 +19,7 @@ namespace NetworkFilters { namespace ThriftProxy { class ThriftConnManagerIntegrationTest - : public testing::TestWithParam>, + : public testing::TestWithParam>, public BaseThriftIntegrationTest { public: static void SetUpTestSuite() { // NOLINT(readability-identifier-naming) @@ -68,7 +68,7 @@ class ThriftConnManagerIntegrationTest } void initializeCall(DriverMode mode) { - std::tie(transport_, protocol_, multiplexed_) = GetParam(); + std::tie(transport_, protocol_, multiplexed_, std::ignore) = GetParam(); absl::optional service_name; if (multiplexed_) { @@ -92,7 +92,7 @@ class ThriftConnManagerIntegrationTest } void initializeOneway() { - std::tie(transport_, protocol_, multiplexed_) = GetParam(); + std::tie(transport_, protocol_, multiplexed_, std::ignore) = GetParam(); absl::optional service_name; if (multiplexed_) { @@ -106,6 +106,21 @@ class ThriftConnManagerIntegrationTest initializeCommon(); } + void tryInitializePassthrough() { + std::tie(std::ignore, std::ignore, std::ignore, payload_passthrough_) = GetParam(); + + if (payload_passthrough_) { + config_helper_.addFilterConfigModifier< + envoy::extensions::filters::network::thrift_proxy::v3::ThriftProxy>( + "thrift", [](Protobuf::Message& filter) { + auto& conn_manager = + dynamic_cast( + filter); + conn_manager.set_payload_passthrough(true); + }); + } + } + // We allocate as many upstreams as there are clusters, with each upstream being allocated // to clusters in the order they're defined in the bootstrap config. void initializeCommon() { @@ -119,6 +134,8 @@ class ThriftConnManagerIntegrationTest } }); + tryInitializePassthrough(); + BaseThriftIntegrationTest::initialize(); } @@ -142,6 +159,7 @@ class ThriftConnManagerIntegrationTest TransportType transport_; ProtocolType protocol_; bool multiplexed_; + bool payload_passthrough_; std::string result_; @@ -150,26 +168,35 @@ class ThriftConnManagerIntegrationTest }; static std::string -paramToString(const TestParamInfo>& params) { +paramToString(const TestParamInfo>& params) { TransportType transport; ProtocolType protocol; bool multiplexed; - std::tie(transport, protocol, multiplexed) = params.param; + bool passthrough; + std::tie(transport, protocol, multiplexed, passthrough) = params.param; std::string transport_name = transportNameForTest(transport); std::string protocol_name = protocolNameForTest(protocol); + std::string result; + if (multiplexed) { - return fmt::format("{}{}Multiplexed", transport_name, protocol_name); + result = fmt::format("{}{}Multiplexed", transport_name, protocol_name); + } else { + result = fmt::format("{}{}", transport_name, protocol_name); } - return fmt::format("{}{}", transport_name, protocol_name); + if (passthrough) { + result = fmt::format("{}Passthrough", result); + } + return result; } -INSTANTIATE_TEST_SUITE_P( - TransportAndProtocol, ThriftConnManagerIntegrationTest, - Combine(Values(TransportType::Framed, TransportType::Unframed, TransportType::Header), - Values(ProtocolType::Binary, ProtocolType::Compact), Values(false, true)), - paramToString); +INSTANTIATE_TEST_SUITE_P(TransportAndProtocol, ThriftConnManagerIntegrationTest, + Combine(Values(TransportType::Framed, TransportType::Unframed, + TransportType::Header), + Values(ProtocolType::Binary, ProtocolType::Compact), + Values(false, true), Values(false, true)), + paramToString); TEST_P(ThriftConnManagerIntegrationTest, Success) { initializeCall(DriverMode::Success); @@ -222,7 +249,12 @@ TEST_P(ThriftConnManagerIntegrationTest, IDLException) { Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); EXPECT_EQ(1U, counter->value()); counter = test_server_->counter("thrift.thrift_stats.response_error"); - EXPECT_EQ(1U, counter->value()); + if (payload_passthrough_ && transport_ == TransportType::Framed && + protocol_ != ProtocolType::Twitter) { + EXPECT_EQ(0U, counter->value()); + } else { + EXPECT_EQ(1U, counter->value()); + } } TEST_P(ThriftConnManagerIntegrationTest, Exception) { @@ -395,7 +427,7 @@ class ThriftTwitterConnManagerIntegrationTest : public ThriftConnManagerIntegrat INSTANTIATE_TEST_SUITE_P(FramedTwitter, ThriftTwitterConnManagerIntegrationTest, Combine(Values(TransportType::Framed), Values(ProtocolType::Twitter), - Values(false, true)), + Values(false, true), Values(false, true)), paramToString); // Because of the protocol upgrade requests and the difficulty of separating them, we test this diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 05e4e88dbc98..b5846a878a21 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -132,6 +132,7 @@ class MockDecoderCallbacks : public DecoderCallbacks { // ThriftProxy::DecoderCallbacks MOCK_METHOD(DecoderEventHandler&, newDecoderEventHandler, ()); + MOCK_METHOD(bool, passthroughEnabled, (), (const)); }; class MockDecoderEventHandler : public DecoderEventHandler { @@ -140,6 +141,7 @@ class MockDecoderEventHandler : public DecoderEventHandler { ~MockDecoderEventHandler() override; // ThriftProxy::DecoderEventHandler + MOCK_METHOD(FilterStatus, passthroughData, (Buffer::Instance & data)); MOCK_METHOD(FilterStatus, transportBegin, (MessageMetadataSharedPtr metadata)); MOCK_METHOD(FilterStatus, transportEnd, ()); MOCK_METHOD(FilterStatus, messageBegin, (MessageMetadataSharedPtr metadata)); @@ -207,8 +209,10 @@ class MockDecoderFilter : public DecoderFilter { MOCK_METHOD(void, onDestroy, ()); MOCK_METHOD(void, setDecoderFilterCallbacks, (DecoderFilterCallbacks & callbacks)); MOCK_METHOD(void, resetUpstreamConnection, ()); + MOCK_METHOD(bool, passthroughSupported, (), (const)); // ThriftProxy::DecoderEventHandler + MOCK_METHOD(FilterStatus, passthroughData, (Buffer::Instance & data)); MOCK_METHOD(FilterStatus, transportBegin, (MessageMetadataSharedPtr metadata)); MOCK_METHOD(FilterStatus, transportEnd, ()); MOCK_METHOD(FilterStatus, messageBegin, (MessageMetadataSharedPtr metadata)); diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 110c10b5d624..e5ba97bcb8b7 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -7,6 +7,7 @@ #include "common/buffer/buffer_impl.h" #include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/config.h" #include "extensions/filters/network/thrift_proxy/router/config.h" #include "extensions/filters/network/thrift_proxy/router/router_impl.h" @@ -22,6 +23,8 @@ #include "gtest/gtest.h" using testing::_; +using testing::AtLeast; +using testing::Combine; using testing::ContainsRegex; using testing::Eq; using testing::Invoke; @@ -29,6 +32,7 @@ using testing::NiceMock; using testing::Ref; using testing::Return; using testing::ReturnRef; +using ::testing::TestParamInfo; using testing::Values; namespace Envoy { @@ -102,7 +106,9 @@ class ThriftRouterTestBase { } void startRequest(MessageType msg_type, std::string method = "method", - const bool strip_service_name = false) { + const bool strip_service_name = false, + const TransportType transport_type = TransportType::Framed, + const ProtocolType protocol_type = ProtocolType::Binary) { EXPECT_EQ(FilterStatus::Continue, router_->transportBegin(metadata_)); EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); @@ -115,8 +121,12 @@ class ThriftRouterTestBase { initializeMetadata(msg_type, method); - EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); - EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); + EXPECT_CALL(callbacks_, downstreamTransportType()) + .Times(AtLeast(1)) + .WillRepeatedly(Return(transport_type)); + EXPECT_CALL(callbacks_, downstreamProtocolType()) + .Times(AtLeast(1)) + .WillRepeatedly(Return(protocol_type)); EXPECT_EQ(FilterStatus::StopIteration, router_->messageBegin(metadata_)); EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection_)); @@ -184,8 +194,12 @@ class ThriftRouterTestBase { EXPECT_EQ(nullptr, router_->metadataMatchCriteria()); EXPECT_EQ(nullptr, router_->downstreamHeaders()); - EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); - EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); + EXPECT_CALL(callbacks_, downstreamTransportType()) + .Times(2) + .WillRepeatedly(Return(TransportType::Framed)); + EXPECT_CALL(callbacks_, downstreamProtocolType()) + .Times(2) + .WillRepeatedly(Return(ProtocolType::Binary)); mock_protocol_cb_ = [&](MockProtocol* protocol) -> void { ON_CALL(*protocol, type()).WillByDefault(Return(ProtocolType::Binary)); @@ -355,6 +369,37 @@ INSTANTIATE_TEST_SUITE_P(ContainerFieldTypes, ThriftRouterContainerTest, Values(FieldType::Map, FieldType::List, FieldType::Set), fieldTypeParamToString); +class ThriftRouterPassthroughTest + : public testing::TestWithParam< + std::tuple>, + public ThriftRouterTestBase { +public: +}; + +static std::string downstreamUpstreamTypesToString( + const TestParamInfo>& + params) { + TransportType downstream_transport_type; + ProtocolType downstream_protocol_type; + TransportType upstream_transport_type; + ProtocolType upstream_protocol_type; + + std::tie(downstream_transport_type, downstream_protocol_type, upstream_transport_type, + upstream_protocol_type) = params.param; + + return fmt::format("{}{}{}{}", TransportNames::get().fromType(downstream_transport_type), + ProtocolNames::get().fromType(downstream_protocol_type), + TransportNames::get().fromType(upstream_transport_type), + ProtocolNames::get().fromType(upstream_protocol_type)); +} + +INSTANTIATE_TEST_SUITE_P(DownstreamUpstreamTypes, ThriftRouterPassthroughTest, + Combine(Values(TransportType::Framed, TransportType::Unframed), + Values(ProtocolType::Binary, ProtocolType::Twitter), + Values(TransportType::Framed, TransportType::Unframed), + Values(ProtocolType::Binary, ProtocolType::Twitter)), + downstreamUpstreamTypesToString); + TEST_F(ThriftRouterTest, PoolRemoteConnectionFailure) { initializeRouter(); @@ -949,6 +994,55 @@ TEST_P(ThriftRouterContainerTest, DecoderFilterCallbacks) { destroyRouter(); } +TEST_P(ThriftRouterPassthroughTest, PassthroughEnable) { + TransportType downstream_transport_type; + ProtocolType downstream_protocol_type; + TransportType upstream_transport_type; + ProtocolType upstream_protocol_type; + + std::tie(downstream_transport_type, downstream_protocol_type, upstream_transport_type, + upstream_protocol_type) = GetParam(); + + const std::string yaml_string = R"EOF( + transport: {} + protocol: {} + )EOF"; + + envoy::extensions::filters::network::thrift_proxy::v3::ThriftProtocolOptions configuration; + TestUtility::loadFromYaml(fmt::format(yaml_string, + TransportNames::get().fromType(upstream_transport_type), + ProtocolNames::get().fromType(upstream_protocol_type)), + configuration); + + const auto protocol_option = std::make_shared(configuration); + EXPECT_CALL(*context_.cluster_manager_.thread_local_cluster_.cluster_.info_, + extensionProtocolOptions(_)) + .WillRepeatedly(Return(protocol_option)); + + initializeRouter(); + startRequest(MessageType::Call, "method", false, downstream_transport_type, + downstream_protocol_type); + + bool passthroughSupported = false; + if (downstream_transport_type == upstream_transport_type && + downstream_transport_type == TransportType::Framed && + downstream_protocol_type == upstream_protocol_type && + downstream_protocol_type != ProtocolType::Twitter) { + passthroughSupported = true; + } + ASSERT_EQ(passthroughSupported, router_->passthroughSupported()); + + EXPECT_CALL(callbacks_, sendLocalReply(_, _)) + .WillOnce(Invoke([&](const DirectResponse& response, bool end_stream) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*")); + EXPECT_TRUE(end_stream); + })); + context_.cluster_manager_.tcp_conn_pool_.poolFailure( + ConnectionPool::PoolFailureReason::RemoteConnectionFailure); +} + } // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters