diff --git a/api/envoy/config/core/v3/proxy_protocol.proto b/api/envoy/config/core/v3/proxy_protocol.proto index c276ce4d58ed..6a1c87fe64b2 100644 --- a/api/envoy/config/core/v3/proxy_protocol.proto +++ b/api/envoy/config/core/v3/proxy_protocol.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.config.core.v3; import "udpa/annotations/status.proto"; +import "validate/validate.proto"; option java_package = "io.envoyproxy.envoy.config.core.v3"; option java_outer_classname = "ProxyProtocolProto"; @@ -12,6 +13,25 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Proxy protocol] +message ProxyProtocolPassThroughTLVs { + enum PassTLVsMatchType { + // Pass all TLVs. + INCLUDE_ALL = 0; + + // Pass specific TLVs defined in tlv_type. + INCLUDE = 1; + } + + // The strategy to pass through TLVs. Default is INCLUDE_ALL. + // If INCLUDE_ALL is set, all TLVs will be passed through no matter the tlv_type field. + PassTLVsMatchType match_type = 1; + + // The TLV types that are applied based on match_type. + // TLV type is defined as uint8_t in proxy protocol. See `the spec + // `_ for details. + repeated uint32 tlv_type = 2 [(validate.rules).repeated = {items {uint32 {lt: 256}}}]; +} + message ProxyProtocolConfig { enum Version { // PROXY protocol version 1. Human readable format. @@ -23,4 +43,8 @@ message ProxyProtocolConfig { // The PROXY protocol version to use. See https://www.haproxy.org/download/2.1/doc/proxy-protocol.txt for details Version version = 1; + + // This config controls which TLVs can be passed to filter state if it is Proxy Protocol + // V2 header. If there is no setting for this field, no TLVs will be passed through. + ProxyProtocolPassThroughTLVs pass_through_tlvs = 2; } diff --git a/api/envoy/extensions/filters/listener/proxy_protocol/v3/BUILD b/api/envoy/extensions/filters/listener/proxy_protocol/v3/BUILD index ee92fb652582..1c1a6f6b4423 100644 --- a/api/envoy/extensions/filters/listener/proxy_protocol/v3/BUILD +++ b/api/envoy/extensions/filters/listener/proxy_protocol/v3/BUILD @@ -5,5 +5,8 @@ load("@envoy_api//bazel:api_build_system.bzl", "api_proto_package") licenses(["notice"]) # Apache 2 api_proto_package( - deps = ["@com_github_cncf_udpa//udpa/annotations:pkg"], + deps = [ + "//envoy/config/core/v3:pkg", + "@com_github_cncf_udpa//udpa/annotations:pkg", + ], ) diff --git a/api/envoy/extensions/filters/listener/proxy_protocol/v3/proxy_protocol.proto b/api/envoy/extensions/filters/listener/proxy_protocol/v3/proxy_protocol.proto index 50472e568830..3fc5306831af 100644 --- a/api/envoy/extensions/filters/listener/proxy_protocol/v3/proxy_protocol.proto +++ b/api/envoy/extensions/filters/listener/proxy_protocol/v3/proxy_protocol.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package envoy.extensions.filters.listener.proxy_protocol.v3; +import "envoy/config/core/v3/proxy_protocol.proto"; + import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -58,4 +60,8 @@ message ProxyProtocol { // signature will timeout (Envoy is unable to differentiate these requests // from incomplete proxy protocol requests). bool allow_requests_without_proxy_protocol = 2; + + // This config controls which TLVs can be passed to filter state if it is Proxy Protocol + // V2 header. If there is no setting for this field, no TLVs will be passed through. + config.core.v3.ProxyProtocolPassThroughTLVs pass_through_tlvs = 3; } diff --git a/envoy/network/proxy_protocol.h b/envoy/network/proxy_protocol.h index 12f7323c4ccc..04267db8f7e3 100644 --- a/envoy/network/proxy_protocol.h +++ b/envoy/network/proxy_protocol.h @@ -1,13 +1,25 @@ #pragma once +#include +#include +#include + #include "envoy/network/address.h" namespace Envoy { namespace Network { +struct ProxyProtocolTLV { + const uint8_t type; + const std::vector value; +}; + +using ProxyProtocolTLVVector = std::vector; + struct ProxyProtocolData { const Network::Address::InstanceConstSharedPtr src_addr_; const Network::Address::InstanceConstSharedPtr dst_addr_; + const ProxyProtocolTLVVector tlv_vector_{}; std::string asStringForHash() const { return std::string(src_addr_ ? src_addr_->asString() : "null") + (dst_addr_ ? dst_addr_->asString() : "null"); diff --git a/source/extensions/common/proxy_protocol/proxy_protocol_header.cc b/source/extensions/common/proxy_protocol/proxy_protocol_header.cc index 025ced0dd606..0bbe17dd6d9e 100644 --- a/source/extensions/common/proxy_protocol/proxy_protocol_header.cc +++ b/source/extensions/common/proxy_protocol/proxy_protocol_header.cc @@ -43,7 +43,7 @@ void generateV1Header(const Network::Address::Ip& source_address, void generateV2Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port, uint32_t dst_port, Network::Address::IpVersion ip_version, - Buffer::Instance& out) { + uint16_t extension_length, Buffer::Instance& out) { out.add(PROXY_PROTO_V2_SIGNATURE, PROXY_PROTO_V2_SIGNATURE_LEN); const uint8_t version_and_command = PROXY_PROTO_V2_VERSION << 4 | PROXY_PROTO_V2_ONBEHALF_OF; @@ -61,11 +61,15 @@ void generateV2Header(const std::string& src_addr, const std::string& dst_addr, address_family_and_protocol |= PROXY_PROTO_V2_TRANSPORT_STREAM; out.add(&address_family_and_protocol, 1); - uint8_t addr_length[2]{0, 0}; + // Number of following bytes part of the header in V2 protocol. + uint16_t addr_length; + uint16_t addr_length_n; // Network byte order + switch (ip_version) { case Network::Address::IpVersion::v4: { - addr_length[1] = PROXY_PROTO_V2_ADDR_LEN_INET; - out.add(addr_length, 2); + addr_length = PROXY_PROTO_V2_ADDR_LEN_INET + extension_length; + addr_length_n = htons(addr_length); + out.add(&addr_length_n, 2); const uint32_t net_src_addr = Network::Address::Ipv4Instance(src_addr, src_port).ip()->ipv4()->address(); const uint32_t net_dst_addr = @@ -75,8 +79,9 @@ void generateV2Header(const std::string& src_addr, const std::string& dst_addr, break; } case Network::Address::IpVersion::v6: { - addr_length[1] = PROXY_PROTO_V2_ADDR_LEN_INET6; - out.add(addr_length, 2); + addr_length = PROXY_PROTO_V2_ADDR_LEN_INET6 + extension_length; + addr_length_n = htons(addr_length); + out.add(&addr_length_n, 2); const absl::uint128 net_src_addr = Network::Address::Ipv6Instance(src_addr, src_port).ip()->ipv6()->address(); const absl::uint128 net_dst_addr = @@ -93,10 +98,51 @@ void generateV2Header(const std::string& src_addr, const std::string& dst_addr, out.add(&net_dst_port, 2); } +void generateV2Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port, + uint32_t dst_port, Network::Address::IpVersion ip_version, + Buffer::Instance& out) { + generateV2Header(src_addr, dst_addr, src_port, dst_port, ip_version, 0, out); +} + void generateV2Header(const Network::Address::Ip& source_address, const Network::Address::Ip& dest_address, Buffer::Instance& out) { generateV2Header(source_address.addressAsString(), dest_address.addressAsString(), - source_address.port(), dest_address.port(), source_address.version(), out); + source_address.port(), dest_address.port(), source_address.version(), 0, out); +} + +bool generateV2Header(const Network::ProxyProtocolData& proxy_proto_data, Buffer::Instance& out, + bool pass_all_tlvs, const absl::flat_hash_set& pass_through_tlvs) { + uint64_t extension_length = 0; + for (auto&& tlv : proxy_proto_data.tlv_vector_) { + if (!pass_all_tlvs && !pass_through_tlvs.contains(tlv.type)) { + continue; + } + extension_length += PROXY_PROTO_V2_TLV_TYPE_LENGTH_LEN + tlv.value.size(); + if (extension_length > std::numeric_limits::max()) { + ENVOY_LOG_MISC( + warn, "Generating Proxy Protocol V2 header: TLVs exceed length limit {}, already got {}", + std::numeric_limits::max(), extension_length); + return false; + } + } + + ASSERT(extension_length <= std::numeric_limits::max()); + const auto& src = *proxy_proto_data.src_addr_->ip(); + const auto& dst = *proxy_proto_data.dst_addr_->ip(); + generateV2Header(src.addressAsString(), dst.addressAsString(), src.port(), dst.port(), + src.version(), static_cast(extension_length), out); + + // Generate the TLV vector. + for (auto&& tlv : proxy_proto_data.tlv_vector_) { + if (!pass_all_tlvs && !pass_through_tlvs.contains(tlv.type)) { + continue; + } + out.add(&tlv.type, 1); + uint16_t size = htons(static_cast(tlv.value.size())); + out.add(&size, sizeof(uint16_t)); + out.add(&tlv.value.front(), tlv.value.size()); + } + return true; } void generateProxyProtoHeader(const envoy::config::core::v3::ProxyProtocolConfig& config, diff --git a/source/extensions/common/proxy_protocol/proxy_protocol_header.h b/source/extensions/common/proxy_protocol/proxy_protocol_header.h index 013c842ced20..a4a09f46f98c 100644 --- a/source/extensions/common/proxy_protocol/proxy_protocol_header.h +++ b/source/extensions/common/proxy_protocol/proxy_protocol_header.h @@ -5,6 +5,8 @@ #include "envoy/network/address.h" #include "envoy/network/connection.h" +#include "absl/container/flat_hash_set.h" + namespace Envoy { namespace Extensions { namespace Common { @@ -39,6 +41,8 @@ constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_INET = 12; constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_INET6 = 36; constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_UNIX = 216; +constexpr uint32_t PROXY_PROTO_V2_TLV_TYPE_LENGTH_LEN = 3; + // Generates the v1 PROXY protocol header and adds it to the specified buffer void generateV1Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port, uint32_t dst_port, Network::Address::IpVersion ip_version, @@ -48,6 +52,9 @@ void generateV1Header(const Network::Address::Ip& source_address, // Generates the v2 PROXY protocol header and adds it to the specified buffer // TCP is assumed as the transport protocol +void generateV2Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port, + uint32_t dst_port, Network::Address::IpVersion ip_version, + uint16_t extension_length, Buffer::Instance& out); void generateV2Header(const std::string& src_addr, const std::string& dst_addr, uint32_t src_port, uint32_t dst_port, Network::Address::IpVersion ip_version, Buffer::Instance& out); @@ -61,6 +68,10 @@ void generateProxyProtoHeader(const envoy::config::core::v3::ProxyProtocolConfig // Generates the v2 PROXY protocol local command header and adds it to the specified buffer void generateV2LocalHeader(Buffer::Instance& out); +// Generates the v2 PROXY protocol header including the TLV vector into the specified buffer. +bool generateV2Header(const Network::ProxyProtocolData& proxy_proto_data, Buffer::Instance& out, + bool pass_all_tlvs, const absl::flat_hash_set& pass_through_tlvs); + } // namespace ProxyProtocol } // namespace Common } // namespace Extensions diff --git a/source/extensions/filters/listener/proxy_protocol/BUILD b/source/extensions/filters/listener/proxy_protocol/BUILD index 7435277bdb7d..7ddeb6e56d9c 100644 --- a/source/extensions/filters/listener/proxy_protocol/BUILD +++ b/source/extensions/filters/listener/proxy_protocol/BUILD @@ -26,12 +26,15 @@ envoy_cc_library( "//source/common/buffer:buffer_lib", "//source/common/common:assert_lib", "//source/common/common:empty_string", + "//source/common/common:hex_lib", "//source/common/common:minimal_logger_lib", "//source/common/common:safe_memcpy_lib", "//source/common/common:utility_lib", "//source/common/network:address_lib", + "//source/common/network:proxy_protocol_filter_state_lib", "//source/common/network:utility_lib", "//source/extensions/common/proxy_protocol:proxy_protocol_header_lib", + "@envoy_api//envoy/config/core/v3:pkg_cc_proto", "@envoy_api//envoy/extensions/filters/listener/proxy_protocol/v3:pkg_cc_proto", ], ) diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc index 22267d1e9c37..ce4df5066bd9 100644 --- a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc @@ -9,6 +9,7 @@ #include "envoy/common/exception.h" #include "envoy/common/platform.h" +#include "envoy/config/core/v3/proxy_protocol.pb.h" #include "envoy/event/dispatcher.h" #include "envoy/network/listen_socket.h" #include "envoy/stats/scope.h" @@ -17,12 +18,15 @@ #include "source/common/common/assert.h" #include "source/common/common/empty_string.h" #include "source/common/common/fmt.h" +#include "source/common/common/hex.h" #include "source/common/common/safe_memcpy.h" #include "source/common/common/utility.h" #include "source/common/network/address_impl.h" +#include "source/common/network/proxy_protocol_filter_state.h" #include "source/common/network/utility.h" #include "source/extensions/common/proxy_protocol/proxy_protocol_header.h" +using envoy::config::core::v3::ProxyProtocolPassThroughTLVs; using Envoy::Extensions::Common::ProxyProtocol::PROXY_PROTO_V1_SIGNATURE; using Envoy::Extensions::Common::ProxyProtocol::PROXY_PROTO_V1_SIGNATURE_LEN; using Envoy::Extensions::Common::ProxyProtocol::PROXY_PROTO_V2_ADDR_LEN_INET; @@ -47,10 +51,21 @@ Config::Config( Stats::Scope& scope, const envoy::extensions::filters::listener::proxy_protocol::v3::ProxyProtocol& proto_config) : stats_{ALL_PROXY_PROTOCOL_STATS(POOL_COUNTER(scope))}, - allow_requests_without_proxy_protocol_(proto_config.allow_requests_without_proxy_protocol()) { + allow_requests_without_proxy_protocol_(proto_config.allow_requests_without_proxy_protocol()), + pass_all_tlvs_(proto_config.has_pass_through_tlvs() + ? proto_config.pass_through_tlvs().match_type() == + ProxyProtocolPassThroughTLVs::INCLUDE_ALL + : false) { for (const auto& rule : proto_config.rules()) { tlv_types_[0xFF & rule.tlv_type()] = rule.on_tlv_present(); } + + if (proto_config.has_pass_through_tlvs() && + proto_config.pass_through_tlvs().match_type() == ProxyProtocolPassThroughTLVs::INCLUDE) { + for (const auto& tlv_type : proto_config.pass_through_tlvs().tlv_type()) { + pass_through_tlvs_.insert(0xFF & tlv_type); + } + } } const KeyValuePair* Config::isTlvTypeNeeded(uint8_t type) const { @@ -62,6 +77,13 @@ const KeyValuePair* Config::isTlvTypeNeeded(uint8_t type) const { return nullptr; } +bool Config::isPassThroughTlvTypeNeeded(uint8_t tlv_type) const { + if (pass_all_tlvs_) { + return true; + } + return pass_through_tlvs_.contains(tlv_type); +} + size_t Config::numberOfNeededTlvTypes() const { return tlv_types_.size(); } bool Config::allowRequestsWithoutProxyProtocol() const { @@ -119,6 +141,29 @@ ReadOrParseState Filter::parseBuffer(Network::ListenerFilterBuffer& buffer) { } } + if (proxy_protocol_header_.has_value() && + !cb_->filterState().hasData( + Network::ProxyProtocolFilterState::key())) { + if (!proxy_protocol_header_.value().local_command_) { + auto buf = reinterpret_cast(buffer.rawSlice().mem_); + ENVOY_LOG( + trace, + "Parsed proxy protocol header, length: {}, buffer: {}, TLV length: {}, TLV buffer: {}", + proxy_protocol_header_.value().wholeHeaderLength(), + Envoy::Hex::encode(buf, proxy_protocol_header_.value().wholeHeaderLength()), + proxy_protocol_header_.value().extensions_length_, + Envoy::Hex::encode(buf + proxy_protocol_header_.value().headerLengthWithoutExtension(), + proxy_protocol_header_.value().extensions_length_)); + } + + cb_->filterState().setData( + Network::ProxyProtocolFilterState::key(), + std::make_unique(Network::ProxyProtocolData{ + proxy_protocol_header_.value().remote_address_, + proxy_protocol_header_.value().local_address_, parsed_tlvs_}), + StreamInfo::FilterState::StateType::Mutable, StreamInfo::FilterState::LifeSpan::Connection); + } + if (proxy_protocol_header_.has_value() && !proxy_protocol_header_.value().local_command_) { // If this is a local_command, we are not to override address // Error check the source and destination fields. Most errors are caught by the address @@ -360,10 +405,11 @@ bool Filter::parseTlvs(const uint8_t* buf, size_t len) { } // Only save to dynamic metadata if this type of TLV is needed. + absl::string_view tlv_value(reinterpret_cast(buf + idx), tlv_value_length); auto key_value_pair = config_->isTlvTypeNeeded(tlv_type); if (nullptr != key_value_pair) { ProtobufWkt::Value metadata_value; - metadata_value.set_string_value(reinterpret_cast(buf + idx), tlv_value_length); + metadata_value.set_string_value(tlv_value.data(), tlv_value.size()); std::string metadata_key = key_value_pair->metadata_namespace().empty() ? "envoy.filters.listener.proxy_protocol" @@ -374,7 +420,15 @@ bool Filter::parseTlvs(const uint8_t* buf, size_t len) { metadata.mutable_fields()->insert({key_value_pair->key(), metadata_value}); cb_->setDynamicMetadata(metadata_key, metadata); } else { - ENVOY_LOG(trace, "proxy_protocol: Skip TLV of type {} since it's not needed", tlv_type); + ENVOY_LOG(trace, + "proxy_protocol: Skip TLV of type {} since it's not needed for dynamic metadata", + tlv_type); + } + + // Save TLVs to the filter state. + if (config_->isPassThroughTlvTypeNeeded(tlv_type)) { + ENVOY_LOG(trace, "proxy_protocol: Storing parsed TLV of type {} to filter state.", tlv_type); + parsed_tlvs_.push_back({tlv_type, {tlv_value.begin(), tlv_value.end()}}); } idx += tlv_value_length; @@ -390,9 +444,9 @@ ReadOrParseState Filter::readExtensions(Network::ListenerFilterBuffer& buffer) { return ReadOrParseState::TryAgainLater; } - if (proxy_protocol_header_.value().local_command_ || 0 == config_->numberOfNeededTlvTypes()) { - // Ignores the extensions if this is a local command or there's no TLV needs to be saved - // to metadata. Those will drained from the buffer in the end. + if (proxy_protocol_header_.value().local_command_) { + // Ignores the extensions if this is a local command. + // Those will drained from the buffer in the end. return ReadOrParseState::Done; } diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h index 4d594244755f..e7f6974626f1 100644 --- a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h @@ -60,6 +60,11 @@ class Config : public Logger::Loggable { */ size_t numberOfNeededTlvTypes() const; + /** + * Return true if the type of TLV is needed for pass-through. + */ + bool isPassThroughTlvTypeNeeded(uint8_t type) const; + /** * Filter configuration that determines if we should pass-through requests without * proxy protocol. Should only be configured to true for trusted downstreams. @@ -69,6 +74,8 @@ class Config : public Logger::Loggable { private: absl::flat_hash_map tlv_types_; const bool allow_requests_without_proxy_protocol_; + const bool pass_all_tlvs_; + absl::flat_hash_set pass_through_tlvs_{}; }; using ConfigSharedPtr = std::shared_ptr; @@ -133,6 +140,9 @@ class Filter : public Network::ListenerFilter, Logger::Loggable proxy_protocol_header_; size_t max_proxy_protocol_len_{MAX_PROXY_PROTO_LEN_V2}; + + // Store the parsed proxy protocol TLVs. + Network::ProxyProtocolTLVVector parsed_tlvs_; }; } // namespace ProxyProtocol diff --git a/source/extensions/transport_sockets/proxy_protocol/BUILD b/source/extensions/transport_sockets/proxy_protocol/BUILD index 2917ee384df4..4e91f1ad363b 100644 --- a/source/extensions/transport_sockets/proxy_protocol/BUILD +++ b/source/extensions/transport_sockets/proxy_protocol/BUILD @@ -31,6 +31,7 @@ envoy_cc_library( "//envoy/network:connection_interface", "//envoy/network:transport_socket_interface", "//source/common/buffer:buffer_lib", + "//source/common/common:hex_lib", "//source/common/common:scalar_to_byte_vector_lib", "//source/common/common:utility_lib", "//source/common/network:address_lib", diff --git a/source/extensions/transport_sockets/proxy_protocol/config.cc b/source/extensions/transport_sockets/proxy_protocol/config.cc index 9e62bc11fb1d..cf8f1e08b387 100644 --- a/source/extensions/transport_sockets/proxy_protocol/config.cc +++ b/source/extensions/transport_sockets/proxy_protocol/config.cc @@ -26,8 +26,8 @@ UpstreamProxyProtocolSocketConfigFactory::createTransportSocketFactory( outer_config.transport_socket(), context.messageValidationVisitor(), inner_config_factory); auto inner_transport_factory = inner_config_factory.createTransportSocketFactory(*inner_factory_config, context); - return std::make_unique(std::move(inner_transport_factory), - outer_config.config()); + return std::make_unique( + std::move(inner_transport_factory), outer_config.config(), context.scope()); } ProtobufTypes::MessagePtr UpstreamProxyProtocolSocketConfigFactory::createEmptyConfigProto() { diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc index 73d9c2272a45..537483186b21 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.cc @@ -6,6 +6,7 @@ #include "envoy/network/transport_socket.h" #include "source/common/buffer/buffer_impl.h" +#include "source/common/common/hex.h" #include "source/common/common/scalar_to_byte_vector.h" #include "source/common/common/utility.h" #include "source/common/network/address_impl.h" @@ -13,16 +14,34 @@ using envoy::config::core::v3::ProxyProtocolConfig; using envoy::config::core::v3::ProxyProtocolConfig_Version; +using envoy::config::core::v3::ProxyProtocolPassThroughTLVs; namespace Envoy { namespace Extensions { namespace TransportSockets { namespace ProxyProtocol { +UpstreamProxyProtocolStats GenerateUpstreamProxyProtocolStats(Stats::Scope& stats_scope) { + const char prefix[]{"upstream.proxyprotocol."}; + return {ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(POOL_COUNTER_PREFIX(stats_scope, prefix))}; +} + UpstreamProxyProtocolSocket::UpstreamProxyProtocolSocket( Network::TransportSocketPtr&& transport_socket, - Network::TransportSocketOptionsConstSharedPtr options, ProxyProtocolConfig_Version version) - : PassthroughSocket(std::move(transport_socket)), options_(options), version_(version) {} + Network::TransportSocketOptionsConstSharedPtr options, ProxyProtocolConfig config, + Stats::Scope& scope) + : PassthroughSocket(std::move(transport_socket)), options_(options), version_(config.version()), + stats_(GenerateUpstreamProxyProtocolStats(scope)), + pass_all_tlvs_(config.has_pass_through_tlvs() ? config.pass_through_tlvs().match_type() == + ProxyProtocolPassThroughTLVs::INCLUDE_ALL + : false) { + if (config.has_pass_through_tlvs() && + config.pass_through_tlvs().match_type() == ProxyProtocolPassThroughTLVs::INCLUDE) { + for (const auto& tlv_type : config.pass_through_tlvs().tlv_type()) { + pass_through_tlvs_.insert(0xFF & tlv_type); + } + } +} void UpstreamProxyProtocolSocket::setTransportSocketCallbacks( Network::TransportSocketCallbacks& callbacks) { @@ -66,13 +85,26 @@ void UpstreamProxyProtocolSocket::generateHeaderV1() { Common::ProxyProtocol::generateV1Header(*src_addr->ip(), *dst_addr->ip(), header_buffer_); } +namespace { +std::string toHex(const Buffer::Instance& buffer) { + std::string bufferStr = buffer.toString(); + return Hex::encode(reinterpret_cast(bufferStr.data()), bufferStr.length()); +} +} // namespace + void UpstreamProxyProtocolSocket::generateHeaderV2() { if (!options_ || !options_->proxyProtocolOptions().has_value()) { Common::ProxyProtocol::generateV2LocalHeader(header_buffer_); } else { const auto options = options_->proxyProtocolOptions().value(); - Common::ProxyProtocol::generateV2Header(*options.src_addr_->ip(), *options.dst_addr_->ip(), - header_buffer_); + if (!Common::ProxyProtocol::generateV2Header(options, header_buffer_, pass_all_tlvs_, + pass_through_tlvs_)) { + // There is a warn log in generateV2Header method. + stats_.v2_tlvs_exceed_max_length_.inc(); + } + + ENVOY_LOG(trace, "generated proxy protocol v2 header, length: {}, buffer: {}", + header_buffer_.length(), toHex(header_buffer_)); } } @@ -108,8 +140,9 @@ void UpstreamProxyProtocolSocket::onConnected() { } UpstreamProxyProtocolSocketFactory::UpstreamProxyProtocolSocketFactory( - Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config) - : PassthroughFactory(std::move(transport_socket_factory)), config_(config) {} + Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, ProxyProtocolConfig config, + Stats::Scope& scope) + : PassthroughFactory(std::move(transport_socket_factory)), config_(config), scope_(scope) {} Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportSocket( Network::TransportSocketOptionsConstSharedPtr options, @@ -118,8 +151,8 @@ Network::TransportSocketPtr UpstreamProxyProtocolSocketFactory::createTransportS if (inner_socket == nullptr) { return nullptr; } - return std::make_unique(std::move(inner_socket), options, - config_.version()); + return std::make_unique(std::move(inner_socket), options, config_, + scope_); } void UpstreamProxyProtocolSocketFactory::hashKey( diff --git a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h index 583f79bca273..1a92423c9a1c 100644 --- a/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h +++ b/source/extensions/transport_sockets/proxy_protocol/proxy_protocol.h @@ -3,6 +3,7 @@ #include "envoy/config/core/v3/proxy_protocol.pb.h" #include "envoy/network/connection.h" #include "envoy/network/transport_socket.h" +#include "envoy/stats/stats.h" #include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" @@ -16,12 +17,23 @@ namespace Extensions { namespace TransportSockets { namespace ProxyProtocol { +#define ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(COUNTER) \ + /* Upstream events counter. */ \ + COUNTER(v2_tlvs_exceed_max_length) + +/** + * Wrapper struct for upstream ProxyProtocol stats. @see stats_macros.h + */ +struct UpstreamProxyProtocolStats { + ALL_PROXY_PROTOCOL_TRANSPORT_SOCKET_STATS(GENERATE_COUNTER_STRUCT) +}; + class UpstreamProxyProtocolSocket : public TransportSockets::PassthroughSocket, public Logger::Loggable { public: UpstreamProxyProtocolSocket(Network::TransportSocketPtr&& transport_socket, Network::TransportSocketOptionsConstSharedPtr options, - ProxyProtocolConfig_Version version); + ProxyProtocolConfig config, Stats::Scope& scope); void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override; Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; @@ -37,13 +49,16 @@ class UpstreamProxyProtocolSocket : public TransportSockets::PassthroughSocket, Network::TransportSocketCallbacks* callbacks_{}; Buffer::OwnedImpl header_buffer_{}; ProxyProtocolConfig_Version version_{ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1}; + UpstreamProxyProtocolStats stats_; + const bool pass_all_tlvs_; + absl::flat_hash_set pass_through_tlvs_{}; }; class UpstreamProxyProtocolSocketFactory : public PassthroughFactory { public: UpstreamProxyProtocolSocketFactory( Network::UpstreamTransportSocketFactoryPtr transport_socket_factory, - ProxyProtocolConfig config); + ProxyProtocolConfig config, Stats::Scope& scope); // Network::UpstreamTransportSocketFactory Network::TransportSocketPtr @@ -54,6 +69,7 @@ class UpstreamProxyProtocolSocketFactory : public PassthroughFactory { private: ProxyProtocolConfig config_; + Stats::Scope& scope_; }; } // namespace ProxyProtocol diff --git a/test/extensions/common/proxy_protocol/proxy_protocol_header_test.cc b/test/extensions/common/proxy_protocol/proxy_protocol_header_test.cc index af6d7e25f7b0..170810998f27 100644 --- a/test/extensions/common/proxy_protocol/proxy_protocol_header_test.cc +++ b/test/extensions/common/proxy_protocol/proxy_protocol_header_test.cc @@ -1,6 +1,8 @@ #include "envoy/network/address.h" #include "source/common/buffer/buffer_impl.h" +#include "source/common/common/logger.h" +#include "source/common/network/address_impl.h" #include "source/extensions/common/proxy_protocol/proxy_protocol_header.h" #include "test/mocks/network/connection.h" @@ -15,6 +17,8 @@ namespace Common { namespace ProxyProtocol { namespace { +using namespace std::literals::string_literals; + TEST(ProxyProtocolHeaderTest, GeneratesV1IPv4Header) { const auto expectedHeaderStr = "PROXY TCP4 174.2.2.222 172.0.0.1 50000 80\r\n"; const Buffer::OwnedImpl expectedBuff(expectedHeaderStr); @@ -116,6 +120,97 @@ TEST(ProxyProtocolHeaderTest, GeneratesV2LocalHeader) { EXPECT_TRUE(TestUtility::buffersEqual(expectedBuff, buff)); } +TEST(ProxyProtocolHeaderTest, GeneratesV2IPv4HeaderWithTLVPassAll) { + const uint8_t v2_protocol[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, + 0x0a, 0x21, 0x11, 0x00, 0x11, 0x01, 0x02, 0x03, 0x04, 0x00, 0x01, + 0x01, 0x02, 0x03, 0x05, 0x02, 0x01, 0x05, 0x00, 0x02, 0x06, 0x07}; + + const Buffer::OwnedImpl expectedBuff(v2_protocol, sizeof(v2_protocol)); + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + Network::ProxyProtocolTLV tlv{0x5, {0x06, 0x07}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, {tlv}}; + Buffer::OwnedImpl buff{}; + + ASSERT_TRUE(generateV2Header(proxy_proto_data, buff, true, {})); + + EXPECT_TRUE(TestUtility::buffersEqual(expectedBuff, buff)); +} + +TEST(ProxyProtocolHeaderTest, GeneratesV2IPv4HeaderWithTLVPassEmpty) { + const uint8_t v2_protocol[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x02, 0x01}; + + const Buffer::OwnedImpl expectedBuff(v2_protocol, sizeof(v2_protocol)); + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + Network::ProxyProtocolTLV tlv{0x5, {0x06, 0x07}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, {tlv}}; + Buffer::OwnedImpl buff{}; + + ASSERT_TRUE(generateV2Header(proxy_proto_data, buff, false, {})); + + EXPECT_TRUE(TestUtility::buffersEqual(expectedBuff, buff)); +} + +TEST(ProxyProtocolHeaderTest, GeneratesV2IPv4HeaderWithTLVPassSpecific) { + const uint8_t v2_protocol[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, + 0x0a, 0x21, 0x11, 0x00, 0x11, 0x01, 0x02, 0x03, 0x04, 0x00, 0x01, + 0x01, 0x02, 0x03, 0x05, 0x02, 0x01, 0x05, 0x00, 0x02, 0x06, 0x07}; + + const Buffer::OwnedImpl expectedBuff(v2_protocol, sizeof(v2_protocol)); + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + Network::ProxyProtocolTLV tlv{0x5, {0x06, 0x07}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, {tlv}}; + Buffer::OwnedImpl buff{}; + + ASSERT_TRUE(generateV2Header(proxy_proto_data, buff, false, {0x5})); + + EXPECT_TRUE(TestUtility::buffersEqual(expectedBuff, buff)); +} + +TEST(ProxyProtocolHeaderTest, GeneratesV2IPv6HeaderWithTLV) { + const uint8_t v2_protocol[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x21, 0x00, + 0x29, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x04, 0x00, 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x05, 0x00, 0x02, 0x06, 0x07}; + const Buffer::OwnedImpl expectedBuff(v2_protocol, sizeof(v2_protocol)); + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv6Instance("1:2:3::4", 8)); + auto dst_addr = Network::Address::InstanceConstSharedPtr( + new Network::Address::Ipv6Instance("1:100:200:3::", 2)); + Network::ProxyProtocolTLV tlv{0x5, {0x06, 0x07}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, {tlv}}; + + Buffer::OwnedImpl buff{}; + ASSERT_TRUE(generateV2Header(proxy_proto_data, buff, true, {})); + + EXPECT_TRUE(TestUtility::buffersEqual(expectedBuff, buff)); +} + +TEST(ProxyProtocolHeaderTest, GeneratesV2WithTLVExceedingLengthLimit) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + const std::string long_tlv(65536, 'a'); + Network::ProxyProtocolTLV tlv{0x5, std::vector(long_tlv.begin(), long_tlv.end())}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, {tlv}}; + Buffer::OwnedImpl buff{}; + + EXPECT_LOG_CONTAINS("warn", "Generating Proxy Protocol V2 header: TLVs exceed length limit 65535", + generateV2Header(proxy_proto_data, buff, true, {})); +} + } // namespace } // namespace ProxyProtocol } // namespace Common diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc index e19f8f5f23c8..eadb9b637b8d 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc @@ -4,6 +4,7 @@ #include "envoy/common/platform.h" #include "envoy/config/core/v3/base.pb.h" +#include "envoy/config/core/v3/proxy_protocol.pb.h" #include "envoy/stats/scope.h" #include "source/common/api/os_sys_calls_impl.h" @@ -11,6 +12,7 @@ #include "source/common/event/dispatcher_impl.h" #include "source/common/network/connection_balancer_impl.h" #include "source/common/network/listen_socket_impl.h" +#include "source/common/network/proxy_protocol_filter_state.h" #include "source/common/network/raw_buffer_socket.h" #include "source/common/network/tcp_listener_impl.h" #include "source/common/network/utility.h" @@ -31,6 +33,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +using envoy::config::core::v3::ProxyProtocolPassThroughTLVs; using Envoy::Extensions::Common::ProxyProtocol::PROXY_PROTO_V1_SIGNATURE_LEN; using Envoy::Extensions::Common::ProxyProtocol::PROXY_PROTO_V2_SIGNATURE_LEN; using testing::_; @@ -1556,6 +1559,129 @@ TEST_P(ProxyProtocolTest, V2IncompleteTLV) { expectProxyProtoError(); } +TEST_P(ProxyProtocolTest, V2ExtractTLVToFilterState) { + // A well-formed ipv4/tcp with a pair of TLV extensions is accepted + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x1a, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + constexpr uint8_t tlv1[] = {0x0, 0x0, 0x1, 0xff}; + constexpr uint8_t tlv_type_authority[] = {0x02, 0x00, 0x07, 0x66, 0x6f, + 0x6f, 0x2e, 0x63, 0x6f, 0x6d}; + constexpr uint8_t data[] = {'D', 'A', 'T', 'A'}; + + envoy::extensions::filters::listener::proxy_protocol::v3::ProxyProtocol proto_config; + auto rule = proto_config.add_rules(); + rule->set_tlv_type(0x02); + rule->mutable_on_tlv_present()->set_key("PP2 type authority"); + + auto pass_through_tlvs = proto_config.mutable_pass_through_tlvs(); + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE_ALL); + + connect(true, &proto_config); + write(buffer, sizeof(buffer)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + write(tlv1, sizeof(tlv1)); + write(tlv_type_authority, sizeof(tlv_type_authority)); + write(data, sizeof(data)); + expectData("DATA"); + + auto& filter_state = server_connection_->streamInfo().filterState(); + EXPECT_TRUE(filter_state->hasData( + Network::ProxyProtocolFilterState::key())); + + const auto& proxy_proto_data = filter_state + ->getDataReadOnly( + Network::ProxyProtocolFilterState::key()) + ->value(); + + EXPECT_EQ(2, proxy_proto_data.tlv_vector_.size()); + EXPECT_EQ(0x0, proxy_proto_data.tlv_vector_[0].type); + EXPECT_EQ(0xFF, proxy_proto_data.tlv_vector_[0].value[0]); + EXPECT_EQ(1, proxy_proto_data.tlv_vector_[0].value.size()); + EXPECT_EQ(0x02, proxy_proto_data.tlv_vector_[1].type); + EXPECT_EQ("foo.com", std::string(proxy_proto_data.tlv_vector_[1].value.begin(), + proxy_proto_data.tlv_vector_[1].value.end())); + + disconnect(); +} + +TEST_P(ProxyProtocolTest, V2ExtractTLVToFilterStateIncludeEmpty) { + // A well-formed ipv4/tcp with a pair of TLV extensions is accepted + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x1a, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + constexpr uint8_t tlv1[] = {0x0, 0x0, 0x1, 0xff}; + constexpr uint8_t tlv_type_authority[] = {0x02, 0x00, 0x07, 0x66, 0x6f, + 0x6f, 0x2e, 0x63, 0x6f, 0x6d}; + constexpr uint8_t data[] = {'D', 'A', 'T', 'A'}; + envoy::extensions::filters::listener::proxy_protocol::v3::ProxyProtocol proto_config; + + auto pass_through_tlvs = proto_config.mutable_pass_through_tlvs(); + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE); + + connect(true, &proto_config); + write(buffer, sizeof(buffer)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + write(tlv1, sizeof(tlv1)); + write(tlv_type_authority, sizeof(tlv_type_authority)); + write(data, sizeof(data)); + expectData("DATA"); + + auto& filter_state = server_connection_->streamInfo().filterState(); + EXPECT_TRUE(filter_state->hasData( + Network::ProxyProtocolFilterState::key())); + + const auto& proxy_proto_data = filter_state + ->getDataReadOnly( + Network::ProxyProtocolFilterState::key()) + ->value(); + + EXPECT_EQ(0, proxy_proto_data.tlv_vector_.size()); + disconnect(); +} + +TEST_P(ProxyProtocolTest, V2ExtractTLVToFilterStateIncludeTlV) { + // A well-formed ipv4/tcp with a pair of TLV extensions is accepted + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x1a, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + constexpr uint8_t tlv1[] = {0x0, 0x0, 0x1, 0xff}; + constexpr uint8_t tlv_type_authority[] = {0x02, 0x00, 0x07, 0x66, 0x6f, + 0x6f, 0x2e, 0x63, 0x6f, 0x6d}; + constexpr uint8_t data[] = {'D', 'A', 'T', 'A'}; + envoy::extensions::filters::listener::proxy_protocol::v3::ProxyProtocol proto_config; + + auto pass_through_tlvs = proto_config.mutable_pass_through_tlvs(); + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE); + pass_through_tlvs->add_tlv_type(0x02); + + connect(true, &proto_config); + write(buffer, sizeof(buffer)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + write(tlv1, sizeof(tlv1)); + write(tlv_type_authority, sizeof(tlv_type_authority)); + write(data, sizeof(data)); + expectData("DATA"); + + auto& filter_state = server_connection_->streamInfo().filterState(); + EXPECT_TRUE(filter_state->hasData( + Network::ProxyProtocolFilterState::key())); + + const auto& proxy_proto_data = filter_state + ->getDataReadOnly( + Network::ProxyProtocolFilterState::key()) + ->value(); + + EXPECT_EQ(1, proxy_proto_data.tlv_vector_.size()); + EXPECT_EQ(0x02, proxy_proto_data.tlv_vector_[0].type); + EXPECT_EQ("foo.com", std::string(proxy_proto_data.tlv_vector_[0].value.begin(), + proxy_proto_data.tlv_vector_[0].value.end())); + disconnect(); +} + TEST_P(ProxyProtocolTest, MalformedProxyLine) { connect(false); diff --git a/test/extensions/transport_sockets/proxy_protocol/BUILD b/test/extensions/transport_sockets/proxy_protocol/BUILD index efb4f98e3739..cacd476c36e7 100644 --- a/test/extensions/transport_sockets/proxy_protocol/BUILD +++ b/test/extensions/transport_sockets/proxy_protocol/BUILD @@ -32,11 +32,13 @@ envoy_extension_cc_test( srcs = ["proxy_protocol_integration_test.cc"], extension_names = ["envoy.transport_sockets.upstream_proxy_protocol"], deps = [ + "//source/extensions/filters/listener/proxy_protocol:config", "//source/extensions/filters/network/tcp_proxy:config", "//source/extensions/transport_sockets/proxy_protocol:upstream_config", "//test/integration:http_integration_lib", "//test/integration:integration_lib", "@envoy_api//envoy/config/core/v3:pkg_cc_proto", + "@envoy_api//envoy/extensions/filters/listener/proxy_protocol/v3:pkg_cc_proto", "@envoy_api//envoy/extensions/transport_sockets/proxy_protocol/v3:pkg_cc_proto", ], ) diff --git a/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_integration_test.cc b/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_integration_test.cc index 8e523717c301..af6675d5f3fe 100644 --- a/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_integration_test.cc +++ b/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_integration_test.cc @@ -1,11 +1,13 @@ #include "envoy/config/core/v3/base.pb.h" #include "envoy/config/core/v3/health_check.pb.h" #include "envoy/config/core/v3/proxy_protocol.pb.h" +#include "envoy/extensions/filters/listener/proxy_protocol/v3/proxy_protocol.pb.h" #include "envoy/extensions/transport_sockets/proxy_protocol/v3/upstream_proxy_protocol.pb.h" #include "test/integration/http_integration.h" #include "test/integration/integration.h" +using envoy::config::core::v3::ProxyProtocolPassThroughTLVs; namespace Envoy { namespace { @@ -398,5 +400,238 @@ TEST_P(ProxyProtocolHttpIntegrationTest, TestProxyProtocolHealthCheck) { ASSERT_TRUE(fake_upstream_health_connection->waitForDisconnect()); } +class ProxyProtocolTLVsIntegrationTest : public testing::TestWithParam, + public BaseIntegrationTest { +public: + ProxyProtocolTLVsIntegrationTest() + : BaseIntegrationTest(GetParam(), ConfigHelper::tcpProxyConfig()){}; + + void TearDown() override { + test_server_.reset(); + fake_upstream_connection_.reset(); + fake_upstreams_.clear(); + } + + void setup(bool pass_all_tlvs, const std::vector& tlvs_listener, + const std::vector& tlvs_upstream) { + pass_all_tlvs_ = pass_all_tlvs; + tlvs_listener_.assign(tlvs_listener.begin(), tlvs_listener.end()); + tlvs_upstream_.assign(tlvs_upstream.begin(), tlvs_upstream.end()); + } + + void initialize() override { + config_helper_.addConfigModifier([this](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { + envoy::extensions::filters::listener::proxy_protocol::v3::ProxyProtocol proxy_protocol; + auto pass_through_tlvs = proxy_protocol.mutable_pass_through_tlvs(); + if (pass_all_tlvs_) { + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE_ALL); + } else { + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE); + for (const auto& tlv_type : tlvs_listener_) { + pass_through_tlvs->add_tlv_type(tlv_type); + } + } + + auto* listener = bootstrap.mutable_static_resources()->mutable_listeners(0); + auto* ppv_filter = listener->add_listener_filters(); + ppv_filter->set_name("envoy.listener.proxy_protocol"); + ppv_filter->mutable_typed_config()->PackFrom(proxy_protocol); + }); + + config_helper_.addConfigModifier([this](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { + auto* transport_socket = + bootstrap.mutable_static_resources()->mutable_clusters(0)->mutable_transport_socket(); + transport_socket->set_name("envoy.transport_sockets.upstream_proxy_protocol"); + envoy::config::core::v3::TransportSocket inner_socket; + inner_socket.set_name("envoy.transport_sockets.raw_buffer"); + + envoy::config::core::v3::ProxyProtocolConfig proxy_protocol; + proxy_protocol.set_version(envoy::config::core::v3::ProxyProtocolConfig::V2); + auto pass_through_tlvs = proxy_protocol.mutable_pass_through_tlvs(); + + if (pass_all_tlvs_) { + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE_ALL); + } else { + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE); + for (const auto& tlv_type : tlvs_upstream_) { + pass_through_tlvs->add_tlv_type(tlv_type); + } + } + + envoy::extensions::transport_sockets::proxy_protocol::v3::ProxyProtocolUpstreamTransport + proxy_proto_transport; + proxy_proto_transport.mutable_transport_socket()->MergeFrom(inner_socket); + proxy_proto_transport.mutable_config()->MergeFrom(proxy_protocol); + transport_socket->mutable_typed_config()->PackFrom(proxy_proto_transport); + }); + + BaseIntegrationTest::initialize(); + } + + FakeRawConnectionPtr fake_upstream_connection_; + +private: + bool pass_all_tlvs_ = false; + std::vector tlvs_listener_; + std::vector tlvs_upstream_; +}; + +INSTANTIATE_TEST_SUITE_P(IpVersions, ProxyProtocolTLVsIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), + TestUtility::ipTestParamsToString); + +// This test adding the listener proxy protocol filter and upstream proxy filter, the TLVs +// are passed by listener and re-generated in transport socket based on API config. +TEST_P(ProxyProtocolTLVsIntegrationTest, TestV2TLVProxyProtocolPassSepcificTLVs) { + setup(false, {0x05, 0x06}, {0x06}); + initialize(); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + std::string observed_data; + if (GetParam() == Envoy::Network::Address::IpVersion::v4) { + // 2 TLVs are included: + // 0x05, 0x00, 0x02, 0x06, 0x07 + // 0x06, 0x00, 0x02, 0x11, 0x12 + const uint8_t v2_protocol[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x16, 0x7f, 0x00, 0x00, 0x01, + 0x7f, 0x00, 0x00, 0x01, 0x03, 0x05, 0x02, 0x01, 0x05, 0x00, + 0x02, 0x06, 0x07, 0x06, 0x00, 0x02, 0x11, 0x12}; + Buffer::OwnedImpl buffer(v2_protocol, sizeof(v2_protocol)); + ASSERT_TRUE(tcp_client->write(buffer.toString())); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForData(33, &observed_data)); + + // - signature + // - version and command type, address family and protocol, length of addresses + // - src address, dest address + auto header_start = "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a\ + \x21\x11\x00\x16\ + \x7f\x00\x00\x01\x7f\x00\x00\x01"; + EXPECT_THAT(observed_data, testing::StartsWith(header_start)); + + // Only tlv: 0x06, 0x00, 0x02, 0x11, 0x12 is sent to upstream. + EXPECT_EQ(static_cast(observed_data[28]), 0x06); + EXPECT_EQ(static_cast(observed_data[29]), 0x00); + EXPECT_EQ(static_cast(observed_data[30]), 0x02); + EXPECT_EQ(static_cast(observed_data[31]), 0x11); + EXPECT_EQ(static_cast(observed_data[32]), 0x12); + } else if (GetParam() == Envoy::Network::Address::IpVersion::v6) { + // 2 TLVs are included: + // 0x05, 0x00, 0x02, 0x06, 0x07 + // 0x06, 0x00, 0x02, 0x09, 0x0A + const uint8_t v2_protocol_ipv6[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, + 0x21, 0x00, 0x2E, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0x00, 0x02, + 0x05, 0x00, 0x02, 0x06, 0x07, 0x06, 0x00, 0x02, 0x09, 0x0A}; + Buffer::OwnedImpl buffer(v2_protocol_ipv6, sizeof(v2_protocol_ipv6)); + ASSERT_TRUE(tcp_client->write(buffer.toString())); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection_)); + + ASSERT_TRUE(fake_upstream_connection_->waitForData(57, &observed_data)); + // - signature + // - version and command type, address family and protocol, length of addresses + // - src address + // - dest address + auto header_start = "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a\ + \x21\x21\x00\x2E\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"; + EXPECT_THAT(observed_data, testing::StartsWith(header_start)); + + // Only tlv: 0x06, 0x00, 0x02, 0x09, 0x0A is sent to upstream. + EXPECT_EQ(static_cast(observed_data[52]), 0x06); + EXPECT_EQ(static_cast(observed_data[53]), 0x00); + EXPECT_EQ(static_cast(observed_data[54]), 0x02); + EXPECT_EQ(static_cast(observed_data[55]), 0x09); + EXPECT_EQ(static_cast(observed_data[56]), 0x0A); + } + + tcp_client->close(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); +} + +TEST_P(ProxyProtocolTLVsIntegrationTest, TestV2TLVProxyProtocolPassAll) { + setup(true, {}, {}); + initialize(); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + ; + std::string observed_data; + if (GetParam() == Envoy::Network::Address::IpVersion::v4) { + // 2 TLVs are included: + // 0x05, 0x00, 0x02, 0x06, 0x07 + // 0x06, 0x00, 0x02, 0x11, 0x12 + const uint8_t v2_protocol[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x16, 0x7f, 0x00, 0x00, 0x01, + 0x7f, 0x00, 0x00, 0x01, 0x03, 0x05, 0x02, 0x01, 0x05, 0x00, + 0x02, 0x06, 0x07, 0x06, 0x00, 0x02, 0x11, 0x12}; + Buffer::OwnedImpl buffer(v2_protocol, sizeof(v2_protocol)); + ASSERT_TRUE(tcp_client->write(buffer.toString())); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForData(38, &observed_data)); + + // - signature + // - version and command type, address family and protocol, length of addresses + // - src address, dest address + auto header_start = "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a\ + \x21\x11\x00\x16\ + \x7f\x00\x00\x01\x7f\x00\x00\x01"; + EXPECT_THAT(observed_data, testing::StartsWith(header_start)); + + // Only tlv: 0x06, 0x00, 0x02, 0x11, 0x12 is sent to upstream. + EXPECT_EQ(static_cast(observed_data[28]), 0x05); + EXPECT_EQ(static_cast(observed_data[29]), 0x00); + EXPECT_EQ(static_cast(observed_data[30]), 0x02); + EXPECT_EQ(static_cast(observed_data[31]), 0x06); + EXPECT_EQ(static_cast(observed_data[32]), 0x07); + EXPECT_EQ(static_cast(observed_data[33]), 0x06); + EXPECT_EQ(static_cast(observed_data[34]), 0x00); + EXPECT_EQ(static_cast(observed_data[35]), 0x02); + EXPECT_EQ(static_cast(observed_data[36]), 0x11); + EXPECT_EQ(static_cast(observed_data[37]), 0x12); + } else if (GetParam() == Envoy::Network::Address::IpVersion::v6) { + // 2 TLVs are included: + // 0x05, 0x00, 0x02, 0x06, 0x07 + // 0x06, 0x00, 0x02, 0x09, 0x0A + const uint8_t v2_protocol_ipv6[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, + 0x21, 0x00, 0x2E, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0x00, 0x02, + 0x05, 0x00, 0x02, 0x06, 0x07, 0x06, 0x00, 0x02, 0x09, 0x0A}; + Buffer::OwnedImpl buffer(v2_protocol_ipv6, sizeof(v2_protocol_ipv6)); + ASSERT_TRUE(tcp_client->write(buffer.toString())); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection_)); + + ASSERT_TRUE(fake_upstream_connection_->waitForData(62, &observed_data)); + // - signature + // - version and command type, address family and protocol, length of addresses + // - src address + // - dest address + auto header_start = "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a\ + \x21\x21\x00\x2E\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"; + EXPECT_THAT(observed_data, testing::StartsWith(header_start)); + + // Only tlv: 0x06, 0x00, 0x02, 0x09, 0x0A is sent to upstream. + EXPECT_EQ(static_cast(observed_data[52]), 0x05); + EXPECT_EQ(static_cast(observed_data[53]), 0x00); + EXPECT_EQ(static_cast(observed_data[54]), 0x02); + EXPECT_EQ(static_cast(observed_data[55]), 0x06); + EXPECT_EQ(static_cast(observed_data[56]), 0x07); + EXPECT_EQ(static_cast(observed_data[57]), 0x06); + EXPECT_EQ(static_cast(observed_data[58]), 0x00); + EXPECT_EQ(static_cast(observed_data[59]), 0x02); + EXPECT_EQ(static_cast(observed_data[60]), 0x09); + EXPECT_EQ(static_cast(observed_data[61]), 0x0A); + } + + tcp_client->close(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); +} + } // namespace } // namespace Envoy diff --git a/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_test.cc b/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_test.cc index ac11d61b95cb..53d1646b50f9 100644 --- a/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_test.cc +++ b/test/extensions/transport_sockets/proxy_protocol/proxy_protocol_test.cc @@ -25,6 +25,7 @@ using testing::ReturnRef; using envoy::config::core::v3::ProxyProtocolConfig; using envoy::config::core::v3::ProxyProtocolConfig_Version; +using envoy::config::core::v3::ProxyProtocolPassThroughTLVs; namespace Envoy { namespace Extensions { @@ -34,13 +35,13 @@ namespace { class ProxyProtocolTest : public testing::Test { public: - void initialize(ProxyProtocolConfig_Version version, + void initialize(ProxyProtocolConfig& config, Network::TransportSocketOptionsConstSharedPtr socket_options) { auto inner_socket = std::make_unique>(); inner_socket_ = inner_socket.get(); ON_CALL(transport_callbacks_, ioHandle()).WillByDefault(ReturnRef(io_handle_)); - proxy_protocol_socket_ = std::make_unique(std::move(inner_socket), - socket_options, version); + proxy_protocol_socket_ = std::make_unique( + std::move(inner_socket), socket_options, config, *stats_store_.rootScope()); proxy_protocol_socket_->setTransportSocketCallbacks(transport_callbacks_); proxy_protocol_socket_->onConnected(); } @@ -49,6 +50,7 @@ class ProxyProtocolTest : public testing::Test { NiceMock io_handle_; std::unique_ptr proxy_protocol_socket_; NiceMock transport_callbacks_; + Stats::TestUtil::TestStore stats_store_; }; // Test injects PROXY protocol header only once @@ -60,7 +62,9 @@ TEST_F(ProxyProtocolTest, InjectesHeaderOnlyOnce) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("174.2.2.222", "172.0.0.1", 50000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, nullptr); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, nullptr); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -90,7 +94,9 @@ TEST_F(ProxyProtocolTest, BytesProcessedIncludesProxyProtocolHeader) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("174.2.2.222", "172.0.0.1", 50000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, nullptr); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, nullptr); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -123,7 +129,9 @@ TEST_F(ProxyProtocolTest, ReturnsKeepOpenWhenWriteErrorIsAgain) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("174.2.2.222", "172.0.0.1", 50000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, nullptr); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, nullptr); auto msg = Buffer::OwnedImpl("some data"); { @@ -159,7 +167,9 @@ TEST_F(ProxyProtocolTest, ReturnsCloseWhenWriteErrorIsNotAgain) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("174.2.2.222", "172.0.0.1", 50000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, nullptr); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, nullptr); auto msg = Buffer::OwnedImpl("some data"); { @@ -185,7 +195,9 @@ TEST_F(ProxyProtocolTest, V1IPV4LocalAddressWhenTransportOptionsAreNull) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("174.2.2.222", "172.0.0.1", 50000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, nullptr); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, nullptr); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -208,8 +220,9 @@ TEST_F(ProxyProtocolTest, V1IPV4LocalAddressesWhenHeaderOptionsAreNull) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("174.2.2.222", "172.0.0.1", 50000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, - std::make_shared()); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, std::make_shared()); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -232,8 +245,9 @@ TEST_F(ProxyProtocolTest, V1IPV6LocalAddressesWhenHeaderOptionsAreNull) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("a:b:c:d::", "e:b:c:f::", 50000, 8080, Network::Address::IpVersion::v6, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, - std::make_shared()); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, std::make_shared()); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -265,7 +279,9 @@ TEST_F(ProxyProtocolTest, V1IPV4DownstreamAddresses) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("202.168.0.13", "174.2.2.222", 52000, 80, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, socket_options); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, socket_options); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -297,7 +313,9 @@ TEST_F(ProxyProtocolTest, V1IPV6DownstreamAddresses) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV1Header("1::2:3", "a:b:c:d::", 52000, 80, Network::Address::IpVersion::v6, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1, socket_options); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V1); + initialize(config, socket_options); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -319,7 +337,10 @@ TEST_F(ProxyProtocolTest, V2IPV4LocalCommandWhenTransportOptionsAreNull) { ->setRemoteAddress(Network::Utility::resolveUrl("tcp://0.1.1.2:513")); Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV2LocalHeader(expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2, nullptr); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + initialize(config, nullptr); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -341,8 +362,10 @@ TEST_F(ProxyProtocolTest, V2IPV4LocalCommandWhenHeaderOptionsAreNull) { ->setRemoteAddress(Network::Utility::resolveUrl("tcp://0.1.1.2:513")); Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV2LocalHeader(expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2, - std::make_shared()); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + initialize(config, std::make_shared()); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -374,7 +397,9 @@ TEST_F(ProxyProtocolTest, V2IPV4DownstreamAddresses) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV2Header("1.2.3.4", "0.1.1.2", 773, 513, Network::Address::IpVersion::v4, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2, socket_options); + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + initialize(config, socket_options); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -406,7 +431,10 @@ TEST_F(ProxyProtocolTest, V2IPV6DownstreamAddresses) { Buffer::OwnedImpl expected_buff{}; Common::ProxyProtocol::generateV2Header("1:2:3::4", "1:100:200:3::", 8, 2, Network::Address::IpVersion::v6, expected_buff); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2, socket_options); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + initialize(config, socket_options); EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { @@ -435,23 +463,250 @@ TEST_F(ProxyProtocolTest, OnConnectedCallsInnerOnConnected) { ->setLocalAddress(Network::Utility::resolveUrl("tcp://[1:100:200:3::]:50000")); transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ ->setRemoteAddress(Network::Utility::resolveUrl("tcp://[e:b:c:f::]:8080")); - initialize(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2, socket_options); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + initialize(config, socket_options); EXPECT_CALL(*inner_socket_, onConnected()); proxy_protocol_socket_->onConnected(); } +// Test injects V2 PROXY protocol for downstream IPV4 addresses and TLVs +TEST_F(ProxyProtocolTest, V2IPV4DownstreamAddressesAndTLVs) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + // TLV type 0x5 is PP2_TYPE_UNIQUE_ID + Network::ProxyProtocolTLVVector tlv_vector{Network::ProxyProtocolTLV{0x5, {'a', 'b', 'c'}}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, tlv_vector}; + Network::TransportSocketOptionsConstSharedPtr socket_options = + std::make_shared( + "", std::vector{}, std::vector{}, std::vector{}, + absl::optional(proxy_proto_data)); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setLocalAddress(Network::Utility::resolveUrl("tcp://0.1.1.2:50000")); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setRemoteAddress(Network::Utility::resolveUrl("tcp://3.3.3.3:80")); + Buffer::OwnedImpl expected_buff{}; + absl::flat_hash_set pass_tlvs_set{}; + Common::ProxyProtocol::generateV2Header(proxy_proto_data, expected_buff, true, pass_tlvs_set); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + auto pass_through_tlvs = config.mutable_pass_through_tlvs(); + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE_ALL); + initialize(config, socket_options); + + EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { + auto length = buffer.length(); + buffer.drain(length); + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + })); + auto msg = Buffer::OwnedImpl("some data"); + EXPECT_CALL(*inner_socket_, doWrite(BufferEqual(&msg), false)); + + proxy_protocol_socket_->doWrite(msg, false); +} + +// Test injects V2 PROXY protocol for downstream IPV4 addresses and TLVs with passing specific TLV. +TEST_F(ProxyProtocolTest, V2IPV4PassSpecificTLVs) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + // TLV type 0x5 is PP2_TYPE_UNIQUE_ID + Network::ProxyProtocolTLVVector tlv_vector{Network::ProxyProtocolTLV{0x5, {'a', 'b', 'c'}}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, tlv_vector}; + Network::TransportSocketOptionsConstSharedPtr socket_options = + std::make_shared( + "", std::vector{}, std::vector{}, std::vector{}, + absl::optional(proxy_proto_data)); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setLocalAddress(Network::Utility::resolveUrl("tcp://0.1.1.2:50000")); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setRemoteAddress(Network::Utility::resolveUrl("tcp://3.3.3.3:80")); + Buffer::OwnedImpl expected_buff{}; + absl::flat_hash_set pass_tlvs_set{0x05}; + Common::ProxyProtocol::generateV2Header(proxy_proto_data, expected_buff, false, pass_tlvs_set); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + auto pass_through_tlvs = config.mutable_pass_through_tlvs(); + pass_through_tlvs->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE); + pass_through_tlvs->add_tlv_type(0x05); + initialize(config, socket_options); + + EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { + auto length = buffer.length(); + buffer.drain(length); + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + })); + auto msg = Buffer::OwnedImpl("some data"); + EXPECT_CALL(*inner_socket_, doWrite(BufferEqual(&msg), false)); + + proxy_protocol_socket_->doWrite(msg, false); +} + +// Test injects V2 PROXY protocol for downstream IPV4 addresses and TLVs with empty passing TLV set. +TEST_F(ProxyProtocolTest, V2IPV4PassEmptyTLVs) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + // TLV type 0x5 is PP2_TYPE_UNIQUE_ID + Network::ProxyProtocolTLVVector tlv_vector{Network::ProxyProtocolTLV{0x5, {'a', 'b', 'c'}}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, tlv_vector}; + Network::TransportSocketOptionsConstSharedPtr socket_options = + std::make_shared( + "", std::vector{}, std::vector{}, std::vector{}, + absl::optional(proxy_proto_data)); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setLocalAddress(Network::Utility::resolveUrl("tcp://0.1.1.2:50000")); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setRemoteAddress(Network::Utility::resolveUrl("tcp://3.3.3.3:80")); + Buffer::OwnedImpl expected_buff{}; + absl::flat_hash_set pass_tlvs_set{}; + Common::ProxyProtocol::generateV2Header(proxy_proto_data, expected_buff, false, pass_tlvs_set); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + config.mutable_pass_through_tlvs()->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE); + initialize(config, socket_options); + + EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { + auto length = buffer.length(); + buffer.drain(length); + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + })); + auto msg = Buffer::OwnedImpl("some data"); + EXPECT_CALL(*inner_socket_, doWrite(BufferEqual(&msg), false)); + + proxy_protocol_socket_->doWrite(msg, false); +} + +// Test injects V2 PROXY protocol for downstream IPV4 addresses with exceeding TLV max length. +TEST_F(ProxyProtocolTest, V2IPV4TLVsExceedLengthLimit) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("1.2.3.4", 773)); + auto dst_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("0.1.1.2", 513)); + + const std::string long_tlv(65536, 'a'); + Network::ProxyProtocolTLV tlv{0x5, std::vector(long_tlv.begin(), long_tlv.end())}; + + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, {tlv}}; + Network::TransportSocketOptionsConstSharedPtr socket_options = + std::make_shared( + "", std::vector{}, std::vector{}, std::vector{}, + absl::optional(proxy_proto_data)); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setLocalAddress(Network::Utility::resolveUrl("tcp://0.1.1.2:50000")); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setRemoteAddress(Network::Utility::resolveUrl("tcp://3.3.3.3:80")); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + config.mutable_pass_through_tlvs()->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE_ALL); + initialize(config, socket_options); + + auto msg = Buffer::OwnedImpl("some data"); + proxy_protocol_socket_->doWrite(msg, false); + EXPECT_EQ(stats_store_.counter("upstream.proxyprotocol.v2_tlvs_exceed_max_length").value(), 1); +} + +// Test injects V2 PROXY protocol for downstream IPV6 addresses and TLVs +TEST_F(ProxyProtocolTest, V2IPV6DownstreamAddressesAndTLVs) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv6Instance("1:2:3::4", 8)); + auto dst_addr = Network::Address::InstanceConstSharedPtr( + new Network::Address::Ipv6Instance("1:100:200:3::", 2)); + // TLV type 0x5 is PP2_TYPE_UNIQUE_ID + Network::ProxyProtocolTLVVector tlv_vector{Network::ProxyProtocolTLV{0x5, {'a', 'b', 'c'}}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, tlv_vector}; + Network::TransportSocketOptionsConstSharedPtr socket_options = + std::make_shared( + "", std::vector{}, std::vector{}, std::vector{}, + absl::optional(proxy_proto_data)); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setLocalAddress(Network::Utility::resolveUrl("tcp://[1:100:200:3::]:50000")); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setRemoteAddress(Network::Utility::resolveUrl("tcp://[e:b:c:f::]:8080")); + Buffer::OwnedImpl expected_buff{}; + absl::flat_hash_set pass_through_tlvs{}; + Common::ProxyProtocol::generateV2Header(proxy_proto_data, expected_buff, true, pass_through_tlvs); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + config.mutable_pass_through_tlvs()->set_match_type(ProxyProtocolPassThroughTLVs::INCLUDE_ALL); + initialize(config, socket_options); + + EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { + auto length = buffer.length(); + buffer.drain(length); + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + })); + auto msg = Buffer::OwnedImpl("some data"); + EXPECT_CALL(*inner_socket_, doWrite(BufferEqual(&msg), false)); + + proxy_protocol_socket_->doWrite(msg, false); +} + +// Test injects V2 PROXY protocol for downstream IPV6 addresses and TLVs without pass TLV config. +TEST_F(ProxyProtocolTest, V2IPV6DownstreamAddressesAndTLVsWithoutPassConfig) { + auto src_addr = + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv6Instance("1:2:3::4", 8)); + auto dst_addr = Network::Address::InstanceConstSharedPtr( + new Network::Address::Ipv6Instance("1:100:200:3::", 2)); + // TLV type 0x5 is PP2_TYPE_UNIQUE_ID + Network::ProxyProtocolTLVVector tlv_vector{Network::ProxyProtocolTLV{0x5, {'a', 'b', 'c'}}}; + Network::ProxyProtocolData proxy_proto_data{src_addr, dst_addr, tlv_vector}; + Network::TransportSocketOptionsConstSharedPtr socket_options = + std::make_shared( + "", std::vector{}, std::vector{}, std::vector{}, + absl::optional(proxy_proto_data)); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setLocalAddress(Network::Utility::resolveUrl("tcp://[1:100:200:3::]:50000")); + transport_callbacks_.connection_.stream_info_.downstream_connection_info_provider_ + ->setRemoteAddress(Network::Utility::resolveUrl("tcp://[e:b:c:f::]:8080")); + Buffer::OwnedImpl expected_buff{}; + absl::flat_hash_set pass_through_tlvs{}; + Common::ProxyProtocol::generateV2Header(proxy_proto_data, expected_buff, false, + pass_through_tlvs); + + ProxyProtocolConfig config; + config.set_version(ProxyProtocolConfig_Version::ProxyProtocolConfig_Version_V2); + initialize(config, socket_options); + + EXPECT_CALL(io_handle_, write(BufferStringEqual(expected_buff.toString()))) + .WillOnce(Invoke([&](Buffer::Instance& buffer) -> Api::IoCallUint64Result { + auto length = buffer.length(); + buffer.drain(length); + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + })); + auto msg = Buffer::OwnedImpl("some data"); + EXPECT_CALL(*inner_socket_, doWrite(BufferEqual(&msg), false)); + + proxy_protocol_socket_->doWrite(msg, false); +} + class ProxyProtocolSocketFactoryTest : public testing::Test { public: void initialize() { auto inner_factory = std::make_unique>(); inner_factory_ = inner_factory.get(); - factory_ = std::make_unique(std::move(inner_factory), - ProxyProtocolConfig()); + factory_ = std::make_unique( + std::move(inner_factory), ProxyProtocolConfig(), *stats_store_.rootScope()); } NiceMock* inner_factory_; std::unique_ptr factory_; + Stats::TestUtil::TestStore stats_store_; }; // Test createTransportSocket returns nullptr if inner call returns nullptr