From da23a7edda3189df417fdbbd40cf931e30e385ca Mon Sep 17 00:00:00 2001 From: James Buckland Date: Mon, 10 Aug 2020 09:19:23 -0400 Subject: [PATCH 01/11] [tls] Move handshaking behavior into SslSocketInfo. Signed-off-by: James Buckland --- include/envoy/network/BUILD | 7 ++ include/envoy/network/connection.h | 2 +- include/envoy/network/post_io_action.h | 17 +++ include/envoy/network/transport_socket.h | 15 +-- include/envoy/ssl/BUILD | 8 ++ include/envoy/ssl/connection.h | 10 +- include/envoy/ssl/ssl_socket_state.h | 9 ++ include/envoy/stream_info/stream_info.h | 8 +- source/common/network/connection_impl.h | 2 +- source/common/network/raw_buffer_socket.h | 2 +- source/common/stream_info/stream_info_impl.h | 15 +-- source/common/tcp_proxy/tcp_proxy.cc | 2 +- source/common/tcp_proxy/tcp_proxy.h | 2 +- .../grpc/grpc_access_log_utils.cc | 2 +- .../quic_filter_manager_connection_impl.cc | 2 +- .../quic_filter_manager_connection_impl.h | 2 +- .../transport_sockets/alts/tsi_socket.h | 2 +- .../transport_sockets/common/passthrough.cc | 6 +- .../transport_sockets/common/passthrough.h | 4 +- source/extensions/transport_sockets/tls/BUILD | 1 + .../transport_sockets/tls/ssl_socket.cc | 103 ++++++++++-------- .../transport_sockets/tls/ssl_socket.h | 30 +++-- source/server/api_listener_impl.h | 2 +- test/common/stream_info/test_util.h | 13 +-- test/mocks/network/connection.h | 6 +- test/mocks/network/transport_socket.h | 2 +- test/mocks/ssl/mocks.h | 1 + test/mocks/stream_info/mocks.h | 12 +- 28 files changed, 174 insertions(+), 113 deletions(-) create mode 100644 include/envoy/network/post_io_action.h create mode 100644 include/envoy/ssl/ssl_socket_state.h diff --git a/include/envoy/network/BUILD b/include/envoy/network/BUILD index 3a8e67613c58..6e3bb429fe5b 100644 --- a/include/envoy/network/BUILD +++ b/include/envoy/network/BUILD @@ -130,12 +130,19 @@ envoy_cc_library( hdrs = ["transport_socket.h"], deps = [ ":io_handle_interface", + ":post_io_action_interface", ":proxy_protocol_options_lib", "//include/envoy/buffer:buffer_interface", "//include/envoy/ssl:connection_interface", ], ) +envoy_cc_library( + name = "post_io_action_interface", + hdrs = ["post_io_action.h"], + deps = [], +) + envoy_cc_library( name = "connection_balancer_interface", hdrs = ["connection_balancer.h"], diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index 6e9667f77804..cf632a75c512 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -222,7 +222,7 @@ class Connection : public Event::DeferredDeletable, public FilterManager { * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. */ // TODO(snowp): Remove this in favor of StreamInfo::downstreamSslConnection. - virtual Ssl::ConnectionInfoConstSharedPtr ssl() const PURE; + virtual Ssl::ConnectionInfoSharedPtr ssl() const PURE; /** * @return requested server name (e.g. SNI in TLS), if any. diff --git a/include/envoy/network/post_io_action.h b/include/envoy/network/post_io_action.h new file mode 100644 index 000000000000..3b828bc1d5e7 --- /dev/null +++ b/include/envoy/network/post_io_action.h @@ -0,0 +1,17 @@ +#pragma once + +namespace Envoy { +namespace Network { + +/** + * Action that should occur on a connection after I/O. + */ +enum class PostIoAction { + // Close the connection. + Close, + // Keep the connection open. + KeepOpen +}; + +} // namespace Network +} // namespace Envoy diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index 9e117b116134..884da7aa22a0 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -5,6 +5,7 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" #include "envoy/network/io_handle.h" +#include "envoy/network/post_io_action.h" #include "envoy/network/proxy_protocol.h" #include "envoy/ssl/connection.h" @@ -16,16 +17,6 @@ namespace Network { class Connection; enum class ConnectionEvent; -/** - * Action that should occur on a connection after I/O. - */ -enum class PostIoAction { - // Close the connection. - Close, - // Keep the connection open. - KeepOpen -}; - /** * Result of each I/O event. */ @@ -151,9 +142,9 @@ class TransportSocket { virtual void onConnected() PURE; /** - * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. + * @return the SSL connection data if this is an SSL connection, or nullptr if it is not. */ - virtual Ssl::ConnectionInfoConstSharedPtr ssl() const PURE; + virtual Ssl::ConnectionInfoSharedPtr ssl() const PURE; }; using TransportSocketPtr = std::unique_ptr; diff --git a/include/envoy/ssl/BUILD b/include/envoy/ssl/BUILD index b8e7d530174f..7b7e3219ec77 100644 --- a/include/envoy/ssl/BUILD +++ b/include/envoy/ssl/BUILD @@ -13,7 +13,9 @@ envoy_cc_library( hdrs = ["connection.h"], external_deps = ["abseil_optional"], deps = [ + ":ssl_socket_state", "//include/envoy/common:time_interface", + "//include/envoy/network:post_io_action_interface", ], ) @@ -68,3 +70,9 @@ envoy_cc_library( deps = [ ], ) + +envoy_cc_library( + name = "ssl_socket_state", + hdrs = ["ssl_socket_state.h"], + deps = [], +) diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index 8241c48ad8d7..a76429ba7be5 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -5,6 +5,8 @@ #include "envoy/common/pure.h" #include "envoy/common/time.h" +#include "envoy/network/post_io_action.h" +#include "envoy/ssl/ssl_socket_state.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -148,9 +150,15 @@ class ConnectionInfo { * exists. */ virtual absl::optional x509Extension(absl::string_view extension_name) const PURE; + + /** + * Performs a TLS handshake on ths SSL object and returns an action indicating + * whether the callsite should close the connection or keep it open. + */ + virtual Network::PostIoAction doHandshake(SocketState& state) PURE; }; -using ConnectionInfoConstSharedPtr = std::shared_ptr; +using ConnectionInfoSharedPtr = std::shared_ptr; } // namespace Ssl } // namespace Envoy diff --git a/include/envoy/ssl/ssl_socket_state.h b/include/envoy/ssl/ssl_socket_state.h new file mode 100644 index 000000000000..aa60fbc178ab --- /dev/null +++ b/include/envoy/ssl/ssl_socket_state.h @@ -0,0 +1,9 @@ +#pragma once + +namespace Envoy { +namespace Ssl { + +enum class SocketState { PreHandshake, HandshakeInProgress, HandshakeComplete, ShutdownSent }; + +} // namespace Ssl +} // namespace Envoy diff --git a/include/envoy/stream_info/stream_info.h b/include/envoy/stream_info/stream_info.h index c64e0837266d..7c641cc0b5a9 100644 --- a/include/envoy/stream_info/stream_info.h +++ b/include/envoy/stream_info/stream_info.h @@ -462,25 +462,25 @@ class StreamInfo { * @param connection_info sets the downstream ssl connection. */ virtual void - setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) PURE; + setDownstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& ssl_connection_info) PURE; /** * @return the downstream SSL connection. This will be nullptr if the downstream * connection does not use SSL. */ - virtual Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const PURE; + virtual Ssl::ConnectionInfoSharedPtr downstreamSslConnection() const PURE; /** * @param connection_info sets the upstream ssl connection. */ virtual void - setUpstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) PURE; + setUpstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& ssl_connection_info) PURE; /** * @return the upstream SSL connection. This will be nullptr if the upstream * connection does not use SSL. */ - virtual Ssl::ConnectionInfoConstSharedPtr upstreamSslConnection() const PURE; + virtual Ssl::ConnectionInfoSharedPtr upstreamSslConnection() const PURE; /** * @return const Router::RouteEntry* Get the route entry selected for this request. Note: this diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index b464e2af96d1..00494c6035db 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -76,7 +76,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback return socket_->localAddress(); } absl::optional unixSocketPeerCredentials() const override; - Ssl::ConnectionInfoConstSharedPtr ssl() const override { return transport_socket_->ssl(); } + Ssl::ConnectionInfoSharedPtr ssl() const override { return transport_socket_->ssl(); } State state() const override; void write(Buffer::Instance& data, bool end_stream) override; void setBufferLimits(uint32_t limit) override; diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index fe87bbeda605..212f700912ca 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -20,7 +20,7 @@ class RawBufferSocket : public TransportSocket, protected Logger::LoggableupstreamHost(host); getStreamInfo().onUpstreamHostSelected(host); diff --git a/source/common/tcp_proxy/tcp_proxy.h b/source/common/tcp_proxy/tcp_proxy.h index 871be2ad16f8..413c87eb4b54 100644 --- a/source/common/tcp_proxy/tcp_proxy.h +++ b/source/common/tcp_proxy/tcp_proxy.h @@ -262,7 +262,7 @@ class Filter : public Network::ReadFilter, void onPoolReadyBase(Upstream::HostDescriptionConstSharedPtr& host, const Network::Address::InstanceConstSharedPtr& local_address, - Ssl::ConnectionInfoConstSharedPtr ssl_info); + Ssl::ConnectionInfoSharedPtr ssl_info); // Upstream::LoadBalancerContext const Router::MetadataMatchCriteria* metadataMatchCriteria() override { diff --git a/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc b/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc index 74b061cbad7c..894a9cff7fda 100644 --- a/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc +++ b/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc @@ -149,7 +149,7 @@ void Utility::extractCommonAccessLogProperties( } if (stream_info.downstreamSslConnection() != nullptr) { auto* tls_properties = common_access_log.mutable_tls_properties(); - const Ssl::ConnectionInfoConstSharedPtr downstream_ssl_connection = + const Ssl::ConnectionInfoSharedPtr downstream_ssl_connection = stream_info.downstreamSslConnection(); tls_properties->set_tls_sni_hostname(stream_info.requestedServerName()); diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc index e005a3dd7691..eb0e3309e90a 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc @@ -119,7 +119,7 @@ QuicFilterManagerConnectionImpl::localAddress() const { return quic_connection_->connectionSocket()->localAddress(); } -Ssl::ConnectionInfoConstSharedPtr QuicFilterManagerConnectionImpl::ssl() const { +Ssl::ConnectionInfoSharedPtr QuicFilterManagerConnectionImpl::ssl() const { // TODO(danzh): construct Ssl::ConnectionInfo from crypto stream return nullptr; } diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h index 54c8e87a259d..3ea110492902 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h @@ -56,7 +56,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase { Network::ConnectionImplBase::setConnectionStats(stats); quic_connection_->setConnectionStats(stats); } - Ssl::ConnectionInfoConstSharedPtr ssl() const override; + Ssl::ConnectionInfoSharedPtr ssl() const override; Network::Connection::State state() const override { if (quic_connection_ != nullptr && quic_connection_->connected()) { return Network::Connection::State::Open; diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index 0acba405022d..8c66c4673736 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -58,7 +58,7 @@ class TsiSocket : public Network::TransportSocket, std::string protocol() const override; absl::string_view failureReason() const override; bool canFlushClose() override { return handshake_complete_; } - Envoy::Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } + Envoy::Ssl::ConnectionInfoSharedPtr ssl() const override { return nullptr; } Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; void closeSocket(Network::ConnectionEvent event) override; Network::IoResult doRead(Buffer::Instance& buffer) override; diff --git a/source/extensions/transport_sockets/common/passthrough.cc b/source/extensions/transport_sockets/common/passthrough.cc index 60d632adb24a..a6228e2b5fd0 100644 --- a/source/extensions/transport_sockets/common/passthrough.cc +++ b/source/extensions/transport_sockets/common/passthrough.cc @@ -38,10 +38,8 @@ Network::IoResult PassthroughSocket::doWrite(Buffer::Instance& buffer, bool end_ void PassthroughSocket::onConnected() { transport_socket_->onConnected(); } -Ssl::ConnectionInfoConstSharedPtr PassthroughSocket::ssl() const { - return transport_socket_->ssl(); -} +Ssl::ConnectionInfoSharedPtr PassthroughSocket::ssl() const { return transport_socket_->ssl(); } } // namespace TransportSockets } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/source/extensions/transport_sockets/common/passthrough.h b/source/extensions/transport_sockets/common/passthrough.h index bbf832c73419..b6ec5d09ba3f 100644 --- a/source/extensions/transport_sockets/common/passthrough.h +++ b/source/extensions/transport_sockets/common/passthrough.h @@ -21,7 +21,7 @@ class PassthroughSocket : public Network::TransportSocket { Network::IoResult doRead(Buffer::Instance& buffer) override; Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; void onConnected() override; - Ssl::ConnectionInfoConstSharedPtr ssl() const override; + Ssl::ConnectionInfoSharedPtr ssl() const override; protected: Network::TransportSocketPtr transport_socket_; @@ -29,4 +29,4 @@ class PassthroughSocket : public Network::TransportSocket { } // namespace TransportSockets } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/BUILD b/source/extensions/transport_sockets/tls/BUILD index 1cd091050d15..c7e760764a11 100644 --- a/source/extensions/transport_sockets/tls/BUILD +++ b/source/extensions/transport_sockets/tls/BUILD @@ -48,6 +48,7 @@ envoy_cc_library( "//include/envoy/network:connection_interface", "//include/envoy/network:transport_socket_interface", "//include/envoy/ssl:ssl_socket_extended_info_interface", + "//include/envoy/ssl:ssl_socket_state", "//include/envoy/ssl/private_key:private_key_callbacks_interface", "//include/envoy/ssl/private_key:private_key_interface", "//include/envoy/stats:stats_macros", diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index ab2644ccc808..6fcd830c9a72 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -38,16 +38,16 @@ class NotReadySslSocket : public Network::TransportSocket { return {PostIoAction::Close, 0, false}; } void onConnected() override {} - Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } + Ssl::ConnectionInfoSharedPtr ssl() const override { return nullptr; } }; } // namespace SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, const Network::TransportSocketOptionsSharedPtr& transport_socket_options) : transport_socket_options_(transport_socket_options), - ctx_(std::dynamic_pointer_cast(ctx)), state_(SocketState::PreHandshake) { + ctx_(std::dynamic_pointer_cast(ctx)), state_(Ssl::SocketState::PreHandshake) { bssl::UniquePtr ssl = ctx_->newSsl(transport_socket_options_.get()); - info_ = std::make_shared(std::move(ssl), ctx_); + info_ = std::make_shared(std::move(ssl), ctx_, this); if (state == InitialState::Client) { SSL_set_connect_state(rawSsl()); @@ -96,9 +96,9 @@ SslSocket::ReadResult SslSocket::sslReadIntoSlice(Buffer::RawSlice& slice) { } Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) { - if (state_ != SocketState::HandshakeComplete && state_ != SocketState::ShutdownSent) { + if (state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent) { PostIoAction action = doHandshake(); - if (action == PostIoAction::Close || state_ != SocketState::HandshakeComplete) { + if (action == PostIoAction::Close || state_ != Ssl::SocketState::HandshakeComplete) { // end_stream is false because either a hard error occurred (action == Close) or // the handshake isn't complete, so a half-close cannot occur yet. return {action, 0, false}; @@ -158,7 +158,7 @@ Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) { void SslSocket::onPrivateKeyMethodComplete() { ASSERT(isThreadSafe()); - ASSERT(state_ == SocketState::HandshakeInProgress); + ASSERT(state_ == Ssl::SocketState::HandshakeInProgress); // Resume handshake. PostIoAction action = doHandshake(); @@ -168,39 +168,19 @@ void SslSocket::onPrivateKeyMethodComplete() { } } -PostIoAction SslSocket::doHandshake() { - ASSERT(state_ != SocketState::HandshakeComplete && state_ != SocketState::ShutdownSent); - int rc = SSL_do_handshake(rawSsl()); - if (rc == 1) { - ENVOY_CONN_LOG(debug, "handshake complete", callbacks_->connection()); - state_ = SocketState::HandshakeComplete; - ctx_->logHandshake(rawSsl()); - callbacks_->raiseEvent(Network::ConnectionEvent::Connected); +Network::Connection::State SslSocket::connectionState() const { + return callbacks_->connection().state(); +} - // It's possible that we closed during the handshake callback. - return callbacks_->connection().state() == Network::Connection::State::Open - ? PostIoAction::KeepOpen - : PostIoAction::Close; - } else { - int err = SSL_get_error(rawSsl(), rc); - switch (err) { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - ENVOY_CONN_LOG(debug, "handshake expecting {}", callbacks_->connection(), - err == SSL_ERROR_WANT_READ ? "read" : "write"); - return PostIoAction::KeepOpen; - case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: - ENVOY_CONN_LOG(debug, "handshake continued asynchronously", callbacks_->connection()); - state_ = SocketState::HandshakeInProgress; - return PostIoAction::KeepOpen; - default: - ENVOY_CONN_LOG(debug, "handshake error: {}", callbacks_->connection(), err); - drainErrorQueue(); - return PostIoAction::Close; - } - } +void SslSocket::onSuccess(SSL* ssl) { + ctx_->logHandshake(ssl); + callbacks_->raiseEvent(Network::ConnectionEvent::Connected); } +void SslSocket::onFailure() { drainErrorQueue(); } + +PostIoAction SslSocket::doHandshake() { return info_->doHandshake(state_); } + void SslSocket::drainErrorQueue() { bool saw_error = false; bool saw_counted_error = false; @@ -229,10 +209,10 @@ void SslSocket::drainErrorQueue() { } Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_stream) { - ASSERT(state_ != SocketState::ShutdownSent || write_buffer.length() == 0); - if (state_ != SocketState::HandshakeComplete && state_ != SocketState::ShutdownSent) { + ASSERT(state_ != Ssl::SocketState::ShutdownSent || write_buffer.length() == 0); + if (state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent) { PostIoAction action = doHandshake(); - if (action == PostIoAction::Close || state_ != SocketState::HandshakeComplete) { + if (action == PostIoAction::Close || state_ != Ssl::SocketState::HandshakeComplete) { return {action, 0, false}; } } @@ -285,18 +265,18 @@ Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_st return {PostIoAction::KeepOpen, total_bytes_written, false}; } -void SslSocket::onConnected() { ASSERT(state_ == SocketState::PreHandshake); } +void SslSocket::onConnected() { ASSERT(state_ == Ssl::SocketState::PreHandshake); } -Ssl::ConnectionInfoConstSharedPtr SslSocket::ssl() const { return info_; } +Ssl::ConnectionInfoSharedPtr SslSocket::ssl() const { return info_; } void SslSocket::shutdownSsl() { - ASSERT(state_ != SocketState::PreHandshake); - if (state_ != SocketState::ShutdownSent && + ASSERT(state_ != Ssl::SocketState::PreHandshake); + if (state_ != Ssl::SocketState::ShutdownSent && callbacks_->connection().state() != Network::Connection::State::Closed) { int rc = SSL_shutdown(rawSsl()); ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc); drainErrorQueue(); - state_ = SocketState::ShutdownSent; + state_ = Ssl::SocketState::ShutdownSent; } } @@ -309,8 +289,9 @@ Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidat return certificate_validation_status_; } -SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx) - : ssl_(std::move(ssl)) { +SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, + HandshakeCallbacks* handshake_callbacks) + : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks) { SSL_set_ex_data(ssl_.get(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); } @@ -478,7 +459,8 @@ void SslSocket::closeSocket(Network::ConnectionEvent) { // Attempt to send a shutdown before closing the socket. It's possible this won't go out if // there is no room on the socket. We can extend the state machine to handle this at some point // if needed. - if (state_ == SocketState::HandshakeInProgress || state_ == SocketState::HandshakeComplete) { + if (state_ == Ssl::SocketState::HandshakeInProgress || + state_ == Ssl::SocketState::HandshakeComplete) { shutdownSsl(); } } @@ -527,6 +509,33 @@ absl::optional SslSocketInfo::x509Extension(absl::string_view exten return Utility::getX509ExtensionValue(*cert, extension_name); } +Network::PostIoAction SslSocketInfo::doHandshake(Ssl::SocketState& state) { + ASSERT(state != Ssl::SocketState::HandshakeComplete && state != Ssl::SocketState::ShutdownSent); + int rc = SSL_do_handshake(ssl()); + if (rc == 1) { + state = Ssl::SocketState::HandshakeComplete; + handshake_callbacks_->onSuccess(ssl()); + + // It's possible that we closed during the handshake callback. + return handshake_callbacks_->connectionState() == Network::Connection::State::Open + ? PostIoAction::KeepOpen + : PostIoAction::Close; + } else { + int err = SSL_get_error(ssl(), rc); + switch (err) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return PostIoAction::KeepOpen; + case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: + state = Ssl::SocketState::HandshakeInProgress; + return PostIoAction::KeepOpen; + default: + handshake_callbacks_->onFailure(); + return PostIoAction::Close; + } + } +} + absl::string_view SslSocket::failureReason() const { return failure_reason_; } const std::string& SslSocketInfo::serialNumberPeerCertificate() const { diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 27416ce7f635..a340874f8e93 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -8,6 +8,7 @@ #include "envoy/secret/secret_callbacks.h" #include "envoy/ssl/private_key/private_key_callbacks.h" #include "envoy/ssl/ssl_socket_extended_info.h" +#include "envoy/ssl/ssl_socket_state.h" #include "envoy/stats/scope.h" #include "envoy/stats/stats_macros.h" @@ -39,7 +40,6 @@ struct SslSocketFactoryStats { }; enum class InitialState { Client, Server }; -enum class SocketState { PreHandshake, HandshakeInProgress, HandshakeComplete, ShutdownSent }; class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { public: @@ -51,9 +51,18 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { Envoy::Ssl::ClientValidationStatus::NotValidated}; }; +class HandshakeCallbacks { +public: + virtual ~HandshakeCallbacks() = default; + virtual Network::Connection::State connectionState() const PURE; + virtual void onSuccess(SSL* ssl) PURE; + virtual void onFailure() PURE; +}; + class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { public: - SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx); + SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, + HandshakeCallbacks* handshake_callbacks_); // Ssl::ConnectionInfo bool peerCertificatePresented() const override; @@ -77,11 +86,13 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { std::string ciphersuiteString() const override; const std::string& tlsVersion() const override; absl::optional x509Extension(absl::string_view extension_name) const override; + Network::PostIoAction doHandshake(Ssl::SocketState& state) override; SSL* ssl() const { return ssl_.get(); } bssl::UniquePtr ssl_; private: + HandshakeCallbacks* handshake_callbacks_; mutable std::vector cached_uri_san_local_certificate_; mutable std::string cached_sha_256_peer_certificate_digest_; mutable std::string cached_sha_1_peer_certificate_digest_; @@ -99,10 +110,11 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { mutable SslExtendedSocketInfoImpl extended_socket_info_; }; -using SslSocketInfoConstSharedPtr = std::shared_ptr; +using SslSocketInfoSharedPtr = std::shared_ptr; class SslSocket : public Network::TransportSocket, public Envoy::Ssl::PrivateKeyConnectionCallbacks, + public HandshakeCallbacks, protected Logger::Loggable { public: SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, @@ -112,14 +124,18 @@ class SslSocket : public Network::TransportSocket, void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override; std::string protocol() const override; absl::string_view failureReason() const override; - bool canFlushClose() override { return state_ == SocketState::HandshakeComplete; } + bool canFlushClose() override { return state_ == Ssl::SocketState::HandshakeComplete; } void closeSocket(Network::ConnectionEvent close_type) override; Network::IoResult doRead(Buffer::Instance& read_buffer) override; Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override; void onConnected() override; - Ssl::ConnectionInfoConstSharedPtr ssl() const override; + Ssl::ConnectionInfoSharedPtr ssl() const override; // Ssl::PrivateKeyConnectionCallbacks void onPrivateKeyMethodComplete() override; + // HandshakeCallbacks + Network::Connection::State connectionState() const override; + void onSuccess(SSL* ssl) override; + void onFailure() override; SSL* rawSslForTest() const { return rawSsl(); } @@ -145,9 +161,9 @@ class SslSocket : public Network::TransportSocket, ContextImplSharedPtr ctx_; uint64_t bytes_to_retry_{}; std::string failure_reason_; - SocketState state_; + Ssl::SocketState state_; - SslSocketInfoConstSharedPtr info_; + SslSocketInfoSharedPtr info_; }; class ClientSslSocketFactory : public Network::TransportSocketFactory, diff --git a/source/server/api_listener_impl.h b/source/server/api_listener_impl.h index bf806341bed0..b0e2b4045df8 100644 --- a/source/server/api_listener_impl.h +++ b/source/server/api_listener_impl.h @@ -117,7 +117,7 @@ class ApiListenerImplBase : public ApiListener, return parent_.parent_.address(); } void setConnectionStats(const Network::Connection::ConnectionStats&) override {} - Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } + Ssl::ConnectionInfoSharedPtr ssl() const override { return nullptr; } absl::string_view requestedServerName() const override { return EMPTY_STRING; } State state() const override { return Network::Connection::State::Open; } void write(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } diff --git a/test/common/stream_info/test_util.h b/test/common/stream_info/test_util.h index 5767592c7406..b960122eb0e6 100644 --- a/test/common/stream_info/test_util.h +++ b/test/common/stream_info/test_util.h @@ -90,20 +90,19 @@ class TestStreamInfo : public StreamInfo::StreamInfo { return downstream_remote_address_; } - void - setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override { + void setDownstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& connection_info) override { downstream_connection_info_ = connection_info; } - Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const override { + Ssl::ConnectionInfoSharedPtr downstreamSslConnection() const override { return downstream_connection_info_; } - void setUpstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override { + void setUpstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& connection_info) override { upstream_connection_info_ = connection_info; } - Ssl::ConnectionInfoConstSharedPtr upstreamSslConnection() const override { + Ssl::ConnectionInfoSharedPtr upstreamSslConnection() const override { return upstream_connection_info_; } void setRouteName(absl::string_view route_name) override { @@ -251,8 +250,8 @@ class TestStreamInfo : public StreamInfo::StreamInfo { Network::Address::InstanceConstSharedPtr downstream_local_address_; Network::Address::InstanceConstSharedPtr downstream_direct_remote_address_; Network::Address::InstanceConstSharedPtr downstream_remote_address_; - Ssl::ConnectionInfoConstSharedPtr downstream_connection_info_; - Ssl::ConnectionInfoConstSharedPtr upstream_connection_info_; + Ssl::ConnectionInfoSharedPtr downstream_connection_info_; + Ssl::ConnectionInfoSharedPtr upstream_connection_info_; const Router::RouteEntry* route_entry_{}; envoy::config::core::v3::Metadata metadata_{}; Envoy::StreamInfo::FilterStateSharedPtr filter_state_{ diff --git a/test/mocks/network/connection.h b/test/mocks/network/connection.h index b66d9525d268..0c22705aacef 100644 --- a/test/mocks/network/connection.h +++ b/test/mocks/network/connection.h @@ -72,7 +72,7 @@ class MockConnection : public Connection, public MockConnectionBase { unixSocketPeerCredentials, (), (const)); MOCK_METHOD(const Address::InstanceConstSharedPtr&, localAddress, (), (const)); MOCK_METHOD(void, setConnectionStats, (const ConnectionStats& stats)); - MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); @@ -118,7 +118,7 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase unixSocketPeerCredentials, (), (const)); MOCK_METHOD(const Address::InstanceConstSharedPtr&, localAddress, (), (const)); MOCK_METHOD(void, setConnectionStats, (const ConnectionStats& stats)); - MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); @@ -167,7 +167,7 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC unixSocketPeerCredentials, (), (const)); MOCK_METHOD(const Address::InstanceConstSharedPtr&, localAddress, (), (const)); MOCK_METHOD(void, setConnectionStats, (const ConnectionStats& stats)); - MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); diff --git a/test/mocks/network/transport_socket.h b/test/mocks/network/transport_socket.h index ee53570c20ac..a6785c6b402d 100644 --- a/test/mocks/network/transport_socket.h +++ b/test/mocks/network/transport_socket.h @@ -25,7 +25,7 @@ class MockTransportSocket : public TransportSocket { MOCK_METHOD(IoResult, doRead, (Buffer::Instance & buffer)); MOCK_METHOD(IoResult, doWrite, (Buffer::Instance & buffer, bool end_stream)); MOCK_METHOD(void, onConnected, ()); - MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); TransportSocketCallbacks* callbacks_{}; }; diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 7567e5807cff..27152e35a175 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -59,6 +59,7 @@ class MockConnectionInfo : public ConnectionInfo { MOCK_METHOD(std::string, ciphersuiteString, (), (const)); MOCK_METHOD(const std::string&, tlsVersion, (), (const)); MOCK_METHOD(absl::optional, x509Extension, (absl::string_view), (const)); + MOCK_METHOD(Network::PostIoAction, doHandshake, (SocketState& state), ()); }; class MockClientContext : public ClientContext { diff --git a/test/mocks/stream_info/mocks.h b/test/mocks/stream_info/mocks.h index 2c5b09562e96..9fb08bbabe9e 100644 --- a/test/mocks/stream_info/mocks.h +++ b/test/mocks/stream_info/mocks.h @@ -70,10 +70,10 @@ class MockStreamInfo : public StreamInfo { MOCK_METHOD(void, setDownstreamRemoteAddress, (const Network::Address::InstanceConstSharedPtr&)); MOCK_METHOD(const Network::Address::InstanceConstSharedPtr&, downstreamRemoteAddress, (), (const)); - MOCK_METHOD(void, setDownstreamSslConnection, (const Ssl::ConnectionInfoConstSharedPtr&)); - MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, downstreamSslConnection, (), (const)); - MOCK_METHOD(void, setUpstreamSslConnection, (const Ssl::ConnectionInfoConstSharedPtr&)); - MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, upstreamSslConnection, (), (const)); + MOCK_METHOD(void, setDownstreamSslConnection, (const Ssl::ConnectionInfoSharedPtr&)); + MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, downstreamSslConnection, (), (const)); + MOCK_METHOD(void, setUpstreamSslConnection, (const Ssl::ConnectionInfoSharedPtr&)); + MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, upstreamSslConnection, (), (const)); MOCK_METHOD(const Router::RouteEntry*, routeEntry, (), (const)); MOCK_METHOD(envoy::config::core::v3::Metadata&, dynamicMetadata, ()); MOCK_METHOD(const envoy::config::core::v3::Metadata&, dynamicMetadata, (), (const)); @@ -122,8 +122,8 @@ class MockStreamInfo : public StreamInfo { Network::Address::InstanceConstSharedPtr downstream_local_address_; Network::Address::InstanceConstSharedPtr downstream_direct_remote_address_; Network::Address::InstanceConstSharedPtr downstream_remote_address_; - Ssl::ConnectionInfoConstSharedPtr downstream_connection_info_; - Ssl::ConnectionInfoConstSharedPtr upstream_connection_info_; + Ssl::ConnectionInfoSharedPtr downstream_connection_info_; + Ssl::ConnectionInfoSharedPtr upstream_connection_info_; std::string requested_server_name_; std::string route_name_; std::string upstream_transport_failure_reason_; From b0f8414a9b09fd6d0f84ee0e7a60ea37e1fb14e8 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Mon, 10 Aug 2020 10:50:45 -0400 Subject: [PATCH 02/11] [misc] Run fix_format. Signed-off-by: James Buckland --- test/mocks/ssl/mocks.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 27152e35a175..2ad24444fd2d 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -59,7 +59,7 @@ class MockConnectionInfo : public ConnectionInfo { MOCK_METHOD(std::string, ciphersuiteString, (), (const)); MOCK_METHOD(const std::string&, tlsVersion, (), (const)); MOCK_METHOD(absl::optional, x509Extension, (absl::string_view), (const)); - MOCK_METHOD(Network::PostIoAction, doHandshake, (SocketState& state), ()); + MOCK_METHOD(Network::PostIoAction, doHandshake, (SocketState & state), ()); }; class MockClientContext : public ClientContext { From 393b621e6854426c5459cbb8e0626e01a63bc909 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Mon, 10 Aug 2020 11:38:50 -0400 Subject: [PATCH 03/11] [misc] fix typo in connection.h Signed-off-by: James Buckland --- include/envoy/ssl/connection.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index a76429ba7be5..a8bdd876aa31 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -152,7 +152,7 @@ class ConnectionInfo { virtual absl::optional x509Extension(absl::string_view extension_name) const PURE; /** - * Performs a TLS handshake on ths SSL object and returns an action indicating + * Performs a TLS handshake on the SSL object and returns an action indicating * whether the callsite should close the connection or keep it open. */ virtual Network::PostIoAction doHandshake(SocketState& state) PURE; From 3d56dc8c387de275ab087d9030e0d1948d0d4439 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Mon, 10 Aug 2020 16:25:18 -0400 Subject: [PATCH 04/11] [tls] Move state_ into SslSocketInfo. Signed-off-by: James Buckland --- include/envoy/ssl/connection.h | 3 +- .../transport_sockets/tls/ssl_socket.cc | 41 ++++++++++--------- .../transport_sockets/tls/ssl_socket.h | 9 ++-- test/mocks/ssl/mocks.h | 2 +- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index a8bdd876aa31..32c7a6d91765 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -6,7 +6,6 @@ #include "envoy/common/pure.h" #include "envoy/common/time.h" #include "envoy/network/post_io_action.h" -#include "envoy/ssl/ssl_socket_state.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -155,7 +154,7 @@ class ConnectionInfo { * Performs a TLS handshake on the SSL object and returns an action indicating * whether the callsite should close the connection or keep it open. */ - virtual Network::PostIoAction doHandshake(SocketState& state) PURE; + virtual Network::PostIoAction doHandshake() PURE; }; using ConnectionInfoSharedPtr = std::shared_ptr; diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 6fcd830c9a72..17a8bced501c 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -45,7 +45,7 @@ class NotReadySslSocket : public Network::TransportSocket { SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, const Network::TransportSocketOptionsSharedPtr& transport_socket_options) : transport_socket_options_(transport_socket_options), - ctx_(std::dynamic_pointer_cast(ctx)), state_(Ssl::SocketState::PreHandshake) { + ctx_(std::dynamic_pointer_cast(ctx)) { bssl::UniquePtr ssl = ctx_->newSsl(transport_socket_options_.get()); info_ = std::make_shared(std::move(ssl), ctx_, this); @@ -96,9 +96,10 @@ SslSocket::ReadResult SslSocket::sslReadIntoSlice(Buffer::RawSlice& slice) { } Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) { - if (state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent) { + if (info_->state() != Ssl::SocketState::HandshakeComplete && + info_->state() != Ssl::SocketState::ShutdownSent) { PostIoAction action = doHandshake(); - if (action == PostIoAction::Close || state_ != Ssl::SocketState::HandshakeComplete) { + if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) { // end_stream is false because either a hard error occurred (action == Close) or // the handshake isn't complete, so a half-close cannot occur yet. return {action, 0, false}; @@ -158,7 +159,7 @@ Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) { void SslSocket::onPrivateKeyMethodComplete() { ASSERT(isThreadSafe()); - ASSERT(state_ == Ssl::SocketState::HandshakeInProgress); + ASSERT(info_->state() == Ssl::SocketState::HandshakeInProgress); // Resume handshake. PostIoAction action = doHandshake(); @@ -179,7 +180,7 @@ void SslSocket::onSuccess(SSL* ssl) { void SslSocket::onFailure() { drainErrorQueue(); } -PostIoAction SslSocket::doHandshake() { return info_->doHandshake(state_); } +PostIoAction SslSocket::doHandshake() { return info_->doHandshake(); } void SslSocket::drainErrorQueue() { bool saw_error = false; @@ -209,10 +210,11 @@ void SslSocket::drainErrorQueue() { } Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_stream) { - ASSERT(state_ != Ssl::SocketState::ShutdownSent || write_buffer.length() == 0); - if (state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent) { + ASSERT(info_->state() != Ssl::SocketState::ShutdownSent || write_buffer.length() == 0); + if (info_->state() != Ssl::SocketState::HandshakeComplete && + info_->state() != Ssl::SocketState::ShutdownSent) { PostIoAction action = doHandshake(); - if (action == PostIoAction::Close || state_ != Ssl::SocketState::HandshakeComplete) { + if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) { return {action, 0, false}; } } @@ -265,18 +267,18 @@ Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_st return {PostIoAction::KeepOpen, total_bytes_written, false}; } -void SslSocket::onConnected() { ASSERT(state_ == Ssl::SocketState::PreHandshake); } +void SslSocket::onConnected() { ASSERT(info_->state() == Ssl::SocketState::PreHandshake); } Ssl::ConnectionInfoSharedPtr SslSocket::ssl() const { return info_; } void SslSocket::shutdownSsl() { - ASSERT(state_ != Ssl::SocketState::PreHandshake); - if (state_ != Ssl::SocketState::ShutdownSent && + ASSERT(info_->state() != Ssl::SocketState::PreHandshake); + if (info_->state() != Ssl::SocketState::ShutdownSent && callbacks_->connection().state() != Network::Connection::State::Closed) { int rc = SSL_shutdown(rawSsl()); ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc); drainErrorQueue(); - state_ = Ssl::SocketState::ShutdownSent; + info_->state() = Ssl::SocketState::ShutdownSent; } } @@ -291,7 +293,8 @@ Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidat SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, HandshakeCallbacks* handshake_callbacks) - : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks) { + : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks), + state_(Ssl::SocketState::PreHandshake) { SSL_set_ex_data(ssl_.get(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); } @@ -459,8 +462,8 @@ void SslSocket::closeSocket(Network::ConnectionEvent) { // Attempt to send a shutdown before closing the socket. It's possible this won't go out if // there is no room on the socket. We can extend the state machine to handle this at some point // if needed. - if (state_ == Ssl::SocketState::HandshakeInProgress || - state_ == Ssl::SocketState::HandshakeComplete) { + if (info_->state() == Ssl::SocketState::HandshakeInProgress || + info_->state() == Ssl::SocketState::HandshakeComplete) { shutdownSsl(); } } @@ -509,11 +512,11 @@ absl::optional SslSocketInfo::x509Extension(absl::string_view exten return Utility::getX509ExtensionValue(*cert, extension_name); } -Network::PostIoAction SslSocketInfo::doHandshake(Ssl::SocketState& state) { - ASSERT(state != Ssl::SocketState::HandshakeComplete && state != Ssl::SocketState::ShutdownSent); +Network::PostIoAction SslSocketInfo::doHandshake() { + ASSERT(state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent); int rc = SSL_do_handshake(ssl()); if (rc == 1) { - state = Ssl::SocketState::HandshakeComplete; + state_ = Ssl::SocketState::HandshakeComplete; handshake_callbacks_->onSuccess(ssl()); // It's possible that we closed during the handshake callback. @@ -527,7 +530,7 @@ Network::PostIoAction SslSocketInfo::doHandshake(Ssl::SocketState& state) { case SSL_ERROR_WANT_WRITE: return PostIoAction::KeepOpen; case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: - state = Ssl::SocketState::HandshakeInProgress; + state_ = Ssl::SocketState::HandshakeInProgress; return PostIoAction::KeepOpen; default: handshake_callbacks_->onFailure(); diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index a340874f8e93..75491253bd09 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -86,13 +86,17 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { std::string ciphersuiteString() const override; const std::string& tlsVersion() const override; absl::optional x509Extension(absl::string_view extension_name) const override; - Network::PostIoAction doHandshake(Ssl::SocketState& state) override; + Network::PostIoAction doHandshake() override; + + Ssl::SocketState& state() { return state_; } SSL* ssl() const { return ssl_.get(); } bssl::UniquePtr ssl_; private: HandshakeCallbacks* handshake_callbacks_; + + Ssl::SocketState state_; mutable std::vector cached_uri_san_local_certificate_; mutable std::string cached_sha_256_peer_certificate_digest_; mutable std::string cached_sha_1_peer_certificate_digest_; @@ -124,7 +128,7 @@ class SslSocket : public Network::TransportSocket, void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override; std::string protocol() const override; absl::string_view failureReason() const override; - bool canFlushClose() override { return state_ == Ssl::SocketState::HandshakeComplete; } + bool canFlushClose() override { return info_->state() == Ssl::SocketState::HandshakeComplete; } void closeSocket(Network::ConnectionEvent close_type) override; Network::IoResult doRead(Buffer::Instance& read_buffer) override; Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override; @@ -161,7 +165,6 @@ class SslSocket : public Network::TransportSocket, ContextImplSharedPtr ctx_; uint64_t bytes_to_retry_{}; std::string failure_reason_; - Ssl::SocketState state_; SslSocketInfoSharedPtr info_; }; diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 2ad24444fd2d..af2c6abf1222 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -59,7 +59,7 @@ class MockConnectionInfo : public ConnectionInfo { MOCK_METHOD(std::string, ciphersuiteString, (), (const)); MOCK_METHOD(const std::string&, tlsVersion, (), (const)); MOCK_METHOD(absl::optional, x509Extension, (absl::string_view), (const)); - MOCK_METHOD(Network::PostIoAction, doHandshake, (SocketState & state), ()); + MOCK_METHOD(Network::PostIoAction, doHandshake, (), ()); }; class MockClientContext : public ClientContext { From 1568afc37091aa67203b40934226a94a92184157 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Mon, 10 Aug 2020 17:39:24 -0400 Subject: [PATCH 05/11] [tls] Mark some ConnectionInfoSharedPtrs as const. Signed-off-by: James Buckland --- include/envoy/ssl/connection.h | 1 + include/envoy/stream_info/stream_info.h | 8 ++++---- source/common/stream_info/stream_info_impl.h | 15 +++++++++------ source/common/tcp_proxy/tcp_proxy.cc | 2 +- source/common/tcp_proxy/tcp_proxy.h | 2 +- .../access_loggers/grpc/grpc_access_log_utils.cc | 2 +- test/common/stream_info/test_util.h | 13 +++++++------ test/mocks/stream_info/mocks.h | 12 ++++++------ 8 files changed, 30 insertions(+), 25 deletions(-) diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index 32c7a6d91765..16a176b82a4c 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -157,6 +157,7 @@ class ConnectionInfo { virtual Network::PostIoAction doHandshake() PURE; }; +using ConnectionInfoConstSharedPtr = std::shared_ptr; using ConnectionInfoSharedPtr = std::shared_ptr; } // namespace Ssl diff --git a/include/envoy/stream_info/stream_info.h b/include/envoy/stream_info/stream_info.h index 7c641cc0b5a9..c64e0837266d 100644 --- a/include/envoy/stream_info/stream_info.h +++ b/include/envoy/stream_info/stream_info.h @@ -462,25 +462,25 @@ class StreamInfo { * @param connection_info sets the downstream ssl connection. */ virtual void - setDownstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& ssl_connection_info) PURE; + setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) PURE; /** * @return the downstream SSL connection. This will be nullptr if the downstream * connection does not use SSL. */ - virtual Ssl::ConnectionInfoSharedPtr downstreamSslConnection() const PURE; + virtual Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const PURE; /** * @param connection_info sets the upstream ssl connection. */ virtual void - setUpstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& ssl_connection_info) PURE; + setUpstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& ssl_connection_info) PURE; /** * @return the upstream SSL connection. This will be nullptr if the upstream * connection does not use SSL. */ - virtual Ssl::ConnectionInfoSharedPtr upstreamSslConnection() const PURE; + virtual Ssl::ConnectionInfoConstSharedPtr upstreamSslConnection() const PURE; /** * @return const Router::RouteEntry* Get the route entry selected for this request. Note: this diff --git a/source/common/stream_info/stream_info_impl.h b/source/common/stream_info/stream_info_impl.h index 4c18d708c461..a384cd401cf3 100644 --- a/source/common/stream_info/stream_info_impl.h +++ b/source/common/stream_info/stream_info_impl.h @@ -189,19 +189,22 @@ struct StreamInfoImpl : public StreamInfo { return downstream_remote_address_; } - void setDownstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& connection_info) override { + void + setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override { downstream_ssl_info_ = connection_info; } - Ssl::ConnectionInfoSharedPtr downstreamSslConnection() const override { + Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const override { return downstream_ssl_info_; } - void setUpstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& connection_info) override { + void setUpstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override { upstream_ssl_info_ = connection_info; } - Ssl::ConnectionInfoSharedPtr upstreamSslConnection() const override { return upstream_ssl_info_; } + Ssl::ConnectionInfoConstSharedPtr upstreamSslConnection() const override { + return upstream_ssl_info_; + } const Router::RouteEntry* routeEntry() const override { return route_entry_; } @@ -300,8 +303,8 @@ struct StreamInfoImpl : public StreamInfo { Network::Address::InstanceConstSharedPtr downstream_local_address_; Network::Address::InstanceConstSharedPtr downstream_direct_remote_address_; Network::Address::InstanceConstSharedPtr downstream_remote_address_; - Ssl::ConnectionInfoSharedPtr downstream_ssl_info_; - Ssl::ConnectionInfoSharedPtr upstream_ssl_info_; + Ssl::ConnectionInfoConstSharedPtr downstream_ssl_info_; + Ssl::ConnectionInfoConstSharedPtr upstream_ssl_info_; std::string requested_server_name_; const Http::RequestHeaderMap* request_headers_{}; Http::RequestIDExtensionSharedPtr request_id_extension_; diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index 9945d10eb12f..92dd68e4be4a 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -499,7 +499,7 @@ void Filter::onPoolFailure(ConnectionPool::PoolFailureReason reason, void Filter::onPoolReadyBase(Upstream::HostDescriptionConstSharedPtr& host, const Network::Address::InstanceConstSharedPtr& local_address, - Ssl::ConnectionInfoSharedPtr ssl_info) { + Ssl::ConnectionInfoConstSharedPtr ssl_info) { upstream_handle_.reset(); read_callbacks_->upstreamHost(host); getStreamInfo().onUpstreamHostSelected(host); diff --git a/source/common/tcp_proxy/tcp_proxy.h b/source/common/tcp_proxy/tcp_proxy.h index 413c87eb4b54..871be2ad16f8 100644 --- a/source/common/tcp_proxy/tcp_proxy.h +++ b/source/common/tcp_proxy/tcp_proxy.h @@ -262,7 +262,7 @@ class Filter : public Network::ReadFilter, void onPoolReadyBase(Upstream::HostDescriptionConstSharedPtr& host, const Network::Address::InstanceConstSharedPtr& local_address, - Ssl::ConnectionInfoSharedPtr ssl_info); + Ssl::ConnectionInfoConstSharedPtr ssl_info); // Upstream::LoadBalancerContext const Router::MetadataMatchCriteria* metadataMatchCriteria() override { diff --git a/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc b/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc index 894a9cff7fda..74b061cbad7c 100644 --- a/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc +++ b/source/extensions/access_loggers/grpc/grpc_access_log_utils.cc @@ -149,7 +149,7 @@ void Utility::extractCommonAccessLogProperties( } if (stream_info.downstreamSslConnection() != nullptr) { auto* tls_properties = common_access_log.mutable_tls_properties(); - const Ssl::ConnectionInfoSharedPtr downstream_ssl_connection = + const Ssl::ConnectionInfoConstSharedPtr downstream_ssl_connection = stream_info.downstreamSslConnection(); tls_properties->set_tls_sni_hostname(stream_info.requestedServerName()); diff --git a/test/common/stream_info/test_util.h b/test/common/stream_info/test_util.h index b960122eb0e6..5767592c7406 100644 --- a/test/common/stream_info/test_util.h +++ b/test/common/stream_info/test_util.h @@ -90,19 +90,20 @@ class TestStreamInfo : public StreamInfo::StreamInfo { return downstream_remote_address_; } - void setDownstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& connection_info) override { + void + setDownstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override { downstream_connection_info_ = connection_info; } - Ssl::ConnectionInfoSharedPtr downstreamSslConnection() const override { + Ssl::ConnectionInfoConstSharedPtr downstreamSslConnection() const override { return downstream_connection_info_; } - void setUpstreamSslConnection(const Ssl::ConnectionInfoSharedPtr& connection_info) override { + void setUpstreamSslConnection(const Ssl::ConnectionInfoConstSharedPtr& connection_info) override { upstream_connection_info_ = connection_info; } - Ssl::ConnectionInfoSharedPtr upstreamSslConnection() const override { + Ssl::ConnectionInfoConstSharedPtr upstreamSslConnection() const override { return upstream_connection_info_; } void setRouteName(absl::string_view route_name) override { @@ -250,8 +251,8 @@ class TestStreamInfo : public StreamInfo::StreamInfo { Network::Address::InstanceConstSharedPtr downstream_local_address_; Network::Address::InstanceConstSharedPtr downstream_direct_remote_address_; Network::Address::InstanceConstSharedPtr downstream_remote_address_; - Ssl::ConnectionInfoSharedPtr downstream_connection_info_; - Ssl::ConnectionInfoSharedPtr upstream_connection_info_; + Ssl::ConnectionInfoConstSharedPtr downstream_connection_info_; + Ssl::ConnectionInfoConstSharedPtr upstream_connection_info_; const Router::RouteEntry* route_entry_{}; envoy::config::core::v3::Metadata metadata_{}; Envoy::StreamInfo::FilterStateSharedPtr filter_state_{ diff --git a/test/mocks/stream_info/mocks.h b/test/mocks/stream_info/mocks.h index 9fb08bbabe9e..2c5b09562e96 100644 --- a/test/mocks/stream_info/mocks.h +++ b/test/mocks/stream_info/mocks.h @@ -70,10 +70,10 @@ class MockStreamInfo : public StreamInfo { MOCK_METHOD(void, setDownstreamRemoteAddress, (const Network::Address::InstanceConstSharedPtr&)); MOCK_METHOD(const Network::Address::InstanceConstSharedPtr&, downstreamRemoteAddress, (), (const)); - MOCK_METHOD(void, setDownstreamSslConnection, (const Ssl::ConnectionInfoSharedPtr&)); - MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, downstreamSslConnection, (), (const)); - MOCK_METHOD(void, setUpstreamSslConnection, (const Ssl::ConnectionInfoSharedPtr&)); - MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, upstreamSslConnection, (), (const)); + MOCK_METHOD(void, setDownstreamSslConnection, (const Ssl::ConnectionInfoConstSharedPtr&)); + MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, downstreamSslConnection, (), (const)); + MOCK_METHOD(void, setUpstreamSslConnection, (const Ssl::ConnectionInfoConstSharedPtr&)); + MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, upstreamSslConnection, (), (const)); MOCK_METHOD(const Router::RouteEntry*, routeEntry, (), (const)); MOCK_METHOD(envoy::config::core::v3::Metadata&, dynamicMetadata, ()); MOCK_METHOD(const envoy::config::core::v3::Metadata&, dynamicMetadata, (), (const)); @@ -122,8 +122,8 @@ class MockStreamInfo : public StreamInfo { Network::Address::InstanceConstSharedPtr downstream_local_address_; Network::Address::InstanceConstSharedPtr downstream_direct_remote_address_; Network::Address::InstanceConstSharedPtr downstream_remote_address_; - Ssl::ConnectionInfoSharedPtr downstream_connection_info_; - Ssl::ConnectionInfoSharedPtr upstream_connection_info_; + Ssl::ConnectionInfoConstSharedPtr downstream_connection_info_; + Ssl::ConnectionInfoConstSharedPtr upstream_connection_info_; std::string requested_server_name_; std::string route_name_; std::string upstream_transport_failure_reason_; From bdd8d1bf3613aaa4da4bcf29f8acda87aa030035 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Tue, 11 Aug 2020 09:04:01 -0400 Subject: [PATCH 06/11] [tls] Separate interface for handshaking Signed-off-by: James Buckland --- include/envoy/ssl/BUILD | 10 ++++- include/envoy/ssl/handshaker.h | 39 +++++++++++++++++++ source/extensions/transport_sockets/tls/BUILD | 1 + .../transport_sockets/tls/ssl_socket.cc | 4 +- .../transport_sockets/tls/ssl_socket.h | 24 +++++------- 5 files changed, 61 insertions(+), 17 deletions(-) create mode 100644 include/envoy/ssl/handshaker.h diff --git a/include/envoy/ssl/BUILD b/include/envoy/ssl/BUILD index 7b7e3219ec77..5f37befb41b7 100644 --- a/include/envoy/ssl/BUILD +++ b/include/envoy/ssl/BUILD @@ -15,7 +15,6 @@ envoy_cc_library( deps = [ ":ssl_socket_state", "//include/envoy/common:time_interface", - "//include/envoy/network:post_io_action_interface", ], ) @@ -76,3 +75,12 @@ envoy_cc_library( hdrs = ["ssl_socket_state.h"], deps = [], ) + +envoy_cc_library( + name = "handshaker_interface", + hdrs = ["handshaker.h"], + deps = [ + ":connection_interface", + "//include/envoy/network:post_io_action_interface", + ], +) diff --git a/include/envoy/ssl/handshaker.h b/include/envoy/ssl/handshaker.h new file mode 100644 index 000000000000..54fa19e727d9 --- /dev/null +++ b/include/envoy/ssl/handshaker.h @@ -0,0 +1,39 @@ +#pragma once + +#include "envoy/network/connection.h" +#include "envoy/network/post_io_action.h" + +#include "openssl/ssl.h" + +namespace Envoy { +namespace Ssl { + +class HandshakeCallbacks { +public: + virtual ~HandshakeCallbacks() = default; + + /** + * @return the connection state. + */ + virtual Network::Connection::State connectionState() const PURE; + + virtual void onSuccess(SSL* ssl) PURE; + virtual void onFailure() PURE; +}; + +/** + * Base interface for performing TLS handshakes. + */ +class Handshaker { +public: + virtual ~Handshaker() = default; + + /** + * Performs a TLS handshake and returns an action indicating + * whether the callsite should close the connection or keep it open. + */ + virtual Network::PostIoAction doHandshake() PURE; +}; + +} // namespace Ssl +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/BUILD b/source/extensions/transport_sockets/tls/BUILD index c7e760764a11..e12b887d01ff 100644 --- a/source/extensions/transport_sockets/tls/BUILD +++ b/source/extensions/transport_sockets/tls/BUILD @@ -47,6 +47,7 @@ envoy_cc_library( ":utility_lib", "//include/envoy/network:connection_interface", "//include/envoy/network:transport_socket_interface", + "//include/envoy/ssl:handshaker_interface", "//include/envoy/ssl:ssl_socket_extended_info_interface", "//include/envoy/ssl:ssl_socket_state", "//include/envoy/ssl/private_key:private_key_callbacks_interface", diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 17a8bced501c..e44bd58ae0d1 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -278,7 +278,7 @@ void SslSocket::shutdownSsl() { int rc = SSL_shutdown(rawSsl()); ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc); drainErrorQueue(); - info_->state() = Ssl::SocketState::ShutdownSent; + info_->setState(Ssl::SocketState::ShutdownSent); } } @@ -292,7 +292,7 @@ Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidat } SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - HandshakeCallbacks* handshake_callbacks) + Ssl::HandshakeCallbacks* handshake_callbacks) : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks), state_(Ssl::SocketState::PreHandshake) { SSL_set_ex_data(ssl_.get(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 75491253bd09..6b7a2a64a095 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -6,6 +6,7 @@ #include "envoy/network/connection.h" #include "envoy/network/transport_socket.h" #include "envoy/secret/secret_callbacks.h" +#include "envoy/ssl/handshaker.h" #include "envoy/ssl/private_key/private_key_callbacks.h" #include "envoy/ssl/ssl_socket_extended_info.h" #include "envoy/ssl/ssl_socket_state.h" @@ -51,18 +52,10 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { Envoy::Ssl::ClientValidationStatus::NotValidated}; }; -class HandshakeCallbacks { -public: - virtual ~HandshakeCallbacks() = default; - virtual Network::Connection::State connectionState() const PURE; - virtual void onSuccess(SSL* ssl) PURE; - virtual void onFailure() PURE; -}; - -class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { +class SslSocketInfo : public Envoy::Ssl::ConnectionInfo, public Envoy::Ssl::Handshaker { public: SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - HandshakeCallbacks* handshake_callbacks_); + Ssl::HandshakeCallbacks* handshake_callbacks_); // Ssl::ConnectionInfo bool peerCertificatePresented() const override; @@ -86,15 +79,18 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { std::string ciphersuiteString() const override; const std::string& tlsVersion() const override; absl::optional x509Extension(absl::string_view extension_name) const override; + + // Ssl::Handshaker Network::PostIoAction doHandshake() override; - Ssl::SocketState& state() { return state_; } + Ssl::SocketState state() { return state_; } + void setState(Ssl::SocketState state) { state_ = state; } SSL* ssl() const { return ssl_.get(); } bssl::UniquePtr ssl_; private: - HandshakeCallbacks* handshake_callbacks_; + Ssl::HandshakeCallbacks* handshake_callbacks_; Ssl::SocketState state_; mutable std::vector cached_uri_san_local_certificate_; @@ -118,7 +114,7 @@ using SslSocketInfoSharedPtr = std::shared_ptr; class SslSocket : public Network::TransportSocket, public Envoy::Ssl::PrivateKeyConnectionCallbacks, - public HandshakeCallbacks, + public Ssl::HandshakeCallbacks, protected Logger::Loggable { public: SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, @@ -136,7 +132,7 @@ class SslSocket : public Network::TransportSocket, Ssl::ConnectionInfoSharedPtr ssl() const override; // Ssl::PrivateKeyConnectionCallbacks void onPrivateKeyMethodComplete() override; - // HandshakeCallbacks + // Ssl::HandshakeCallbacks Network::Connection::State connectionState() const override; void onSuccess(SSL* ssl) override; void onFailure() override; From 93e2508f3246acb2c7355be8b78a3f28dd621fac Mon Sep 17 00:00:00 2001 From: James Buckland Date: Tue, 11 Aug 2020 09:28:50 -0400 Subject: [PATCH 07/11] [tls] Remove non-const ConnectionInfoSharedPtr. Signed-off-by: James Buckland --- include/envoy/network/connection.h | 2 +- include/envoy/network/transport_socket.h | 2 +- include/envoy/ssl/connection.h | 1 - source/common/network/connection_impl.h | 2 +- source/common/network/raw_buffer_socket.h | 2 +- .../quiche/quic_filter_manager_connection_impl.cc | 2 +- .../quiche/quic_filter_manager_connection_impl.h | 2 +- source/extensions/transport_sockets/alts/tsi_socket.h | 2 +- source/extensions/transport_sockets/common/passthrough.cc | 4 +++- source/extensions/transport_sockets/common/passthrough.h | 2 +- source/extensions/transport_sockets/tls/ssl_socket.cc | 4 ++-- source/extensions/transport_sockets/tls/ssl_socket.h | 2 +- source/server/api_listener_impl.h | 2 +- test/mocks/network/connection.h | 6 +++--- test/mocks/network/transport_socket.h | 2 +- 15 files changed, 19 insertions(+), 18 deletions(-) diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index cf632a75c512..6e9667f77804 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -222,7 +222,7 @@ class Connection : public Event::DeferredDeletable, public FilterManager { * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. */ // TODO(snowp): Remove this in favor of StreamInfo::downstreamSslConnection. - virtual Ssl::ConnectionInfoSharedPtr ssl() const PURE; + virtual Ssl::ConnectionInfoConstSharedPtr ssl() const PURE; /** * @return requested server name (e.g. SNI in TLS), if any. diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index 884da7aa22a0..572d8b95124b 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -144,7 +144,7 @@ class TransportSocket { /** * @return the SSL connection data if this is an SSL connection, or nullptr if it is not. */ - virtual Ssl::ConnectionInfoSharedPtr ssl() const PURE; + virtual Ssl::ConnectionInfoConstSharedPtr ssl() const PURE; }; using TransportSocketPtr = std::unique_ptr; diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index 16a176b82a4c..c1c1d74d3041 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -158,7 +158,6 @@ class ConnectionInfo { }; using ConnectionInfoConstSharedPtr = std::shared_ptr; -using ConnectionInfoSharedPtr = std::shared_ptr; } // namespace Ssl } // namespace Envoy diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index 00494c6035db..b464e2af96d1 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -76,7 +76,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback return socket_->localAddress(); } absl::optional unixSocketPeerCredentials() const override; - Ssl::ConnectionInfoSharedPtr ssl() const override { return transport_socket_->ssl(); } + Ssl::ConnectionInfoConstSharedPtr ssl() const override { return transport_socket_->ssl(); } State state() const override; void write(Buffer::Instance& data, bool end_stream) override; void setBufferLimits(uint32_t limit) override; diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index 212f700912ca..fe87bbeda605 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -20,7 +20,7 @@ class RawBufferSocket : public TransportSocket, protected Logger::LoggableconnectionSocket()->localAddress(); } -Ssl::ConnectionInfoSharedPtr QuicFilterManagerConnectionImpl::ssl() const { +Ssl::ConnectionInfoConstSharedPtr QuicFilterManagerConnectionImpl::ssl() const { // TODO(danzh): construct Ssl::ConnectionInfo from crypto stream return nullptr; } diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h index 3ea110492902..54c8e87a259d 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h @@ -56,7 +56,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase { Network::ConnectionImplBase::setConnectionStats(stats); quic_connection_->setConnectionStats(stats); } - Ssl::ConnectionInfoSharedPtr ssl() const override; + Ssl::ConnectionInfoConstSharedPtr ssl() const override; Network::Connection::State state() const override { if (quic_connection_ != nullptr && quic_connection_->connected()) { return Network::Connection::State::Open; diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index 8c66c4673736..0acba405022d 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -58,7 +58,7 @@ class TsiSocket : public Network::TransportSocket, std::string protocol() const override; absl::string_view failureReason() const override; bool canFlushClose() override { return handshake_complete_; } - Envoy::Ssl::ConnectionInfoSharedPtr ssl() const override { return nullptr; } + Envoy::Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; void closeSocket(Network::ConnectionEvent event) override; Network::IoResult doRead(Buffer::Instance& buffer) override; diff --git a/source/extensions/transport_sockets/common/passthrough.cc b/source/extensions/transport_sockets/common/passthrough.cc index a6228e2b5fd0..86fc282ed3e3 100644 --- a/source/extensions/transport_sockets/common/passthrough.cc +++ b/source/extensions/transport_sockets/common/passthrough.cc @@ -38,7 +38,9 @@ Network::IoResult PassthroughSocket::doWrite(Buffer::Instance& buffer, bool end_ void PassthroughSocket::onConnected() { transport_socket_->onConnected(); } -Ssl::ConnectionInfoSharedPtr PassthroughSocket::ssl() const { return transport_socket_->ssl(); } +Ssl::ConnectionInfoConstSharedPtr PassthroughSocket::ssl() const { + return transport_socket_->ssl(); +} } // namespace TransportSockets } // namespace Extensions diff --git a/source/extensions/transport_sockets/common/passthrough.h b/source/extensions/transport_sockets/common/passthrough.h index b6ec5d09ba3f..5084d973a865 100644 --- a/source/extensions/transport_sockets/common/passthrough.h +++ b/source/extensions/transport_sockets/common/passthrough.h @@ -21,7 +21,7 @@ class PassthroughSocket : public Network::TransportSocket { Network::IoResult doRead(Buffer::Instance& buffer) override; Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; void onConnected() override; - Ssl::ConnectionInfoSharedPtr ssl() const override; + Ssl::ConnectionInfoConstSharedPtr ssl() const override; protected: Network::TransportSocketPtr transport_socket_; diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index e44bd58ae0d1..f8a95b42fef8 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -38,7 +38,7 @@ class NotReadySslSocket : public Network::TransportSocket { return {PostIoAction::Close, 0, false}; } void onConnected() override {} - Ssl::ConnectionInfoSharedPtr ssl() const override { return nullptr; } + Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } }; } // namespace @@ -269,7 +269,7 @@ Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_st void SslSocket::onConnected() { ASSERT(info_->state() == Ssl::SocketState::PreHandshake); } -Ssl::ConnectionInfoSharedPtr SslSocket::ssl() const { return info_; } +Ssl::ConnectionInfoConstSharedPtr SslSocket::ssl() const { return info_; } void SslSocket::shutdownSsl() { ASSERT(info_->state() != Ssl::SocketState::PreHandshake); diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 6b7a2a64a095..db5566aab30c 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -129,7 +129,7 @@ class SslSocket : public Network::TransportSocket, Network::IoResult doRead(Buffer::Instance& read_buffer) override; Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override; void onConnected() override; - Ssl::ConnectionInfoSharedPtr ssl() const override; + Ssl::ConnectionInfoConstSharedPtr ssl() const override; // Ssl::PrivateKeyConnectionCallbacks void onPrivateKeyMethodComplete() override; // Ssl::HandshakeCallbacks diff --git a/source/server/api_listener_impl.h b/source/server/api_listener_impl.h index b0e2b4045df8..bf806341bed0 100644 --- a/source/server/api_listener_impl.h +++ b/source/server/api_listener_impl.h @@ -117,7 +117,7 @@ class ApiListenerImplBase : public ApiListener, return parent_.parent_.address(); } void setConnectionStats(const Network::Connection::ConnectionStats&) override {} - Ssl::ConnectionInfoSharedPtr ssl() const override { return nullptr; } + Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } absl::string_view requestedServerName() const override { return EMPTY_STRING; } State state() const override { return Network::Connection::State::Open; } void write(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } diff --git a/test/mocks/network/connection.h b/test/mocks/network/connection.h index 0c22705aacef..b66d9525d268 100644 --- a/test/mocks/network/connection.h +++ b/test/mocks/network/connection.h @@ -72,7 +72,7 @@ class MockConnection : public Connection, public MockConnectionBase { unixSocketPeerCredentials, (), (const)); MOCK_METHOD(const Address::InstanceConstSharedPtr&, localAddress, (), (const)); MOCK_METHOD(void, setConnectionStats, (const ConnectionStats& stats)); - MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); @@ -118,7 +118,7 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase unixSocketPeerCredentials, (), (const)); MOCK_METHOD(const Address::InstanceConstSharedPtr&, localAddress, (), (const)); MOCK_METHOD(void, setConnectionStats, (const ConnectionStats& stats)); - MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); @@ -167,7 +167,7 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC unixSocketPeerCredentials, (), (const)); MOCK_METHOD(const Address::InstanceConstSharedPtr&, localAddress, (), (const)); MOCK_METHOD(void, setConnectionStats, (const ConnectionStats& stats)); - MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); diff --git a/test/mocks/network/transport_socket.h b/test/mocks/network/transport_socket.h index a6785c6b402d..ee53570c20ac 100644 --- a/test/mocks/network/transport_socket.h +++ b/test/mocks/network/transport_socket.h @@ -25,7 +25,7 @@ class MockTransportSocket : public TransportSocket { MOCK_METHOD(IoResult, doRead, (Buffer::Instance & buffer)); MOCK_METHOD(IoResult, doWrite, (Buffer::Instance & buffer, bool end_stream)); MOCK_METHOD(void, onConnected, ()); - MOCK_METHOD(Ssl::ConnectionInfoSharedPtr, ssl, (), (const)); + MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); TransportSocketCallbacks* callbacks_{}; }; From ad678fff44bd790b109ad289762a18bcbc5740a7 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Wed, 12 Aug 2020 09:17:25 -0400 Subject: [PATCH 08/11] [tls] Fix invalid case style Signed-off-by: James Buckland --- source/extensions/transport_sockets/tls/ssl_socket.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index db5566aab30c..b72cd93a966d 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -55,7 +55,7 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { class SslSocketInfo : public Envoy::Ssl::ConnectionInfo, public Envoy::Ssl::Handshaker { public: SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - Ssl::HandshakeCallbacks* handshake_callbacks_); + Ssl::HandshakeCallbacks* handshake_callbacks); // Ssl::ConnectionInfo bool peerCertificatePresented() const override; From 0d59a860db741cf7c9afe10427c514d9c74cb5be Mon Sep 17 00:00:00 2001 From: James Buckland Date: Wed, 12 Aug 2020 09:24:14 -0400 Subject: [PATCH 09/11] [tls] Rename SslSocketInfo to SslHandshakerImpl Signed-off-by: James Buckland --- .../transport_sockets/tls/ssl_socket.cc | 51 ++++++++++--------- .../transport_sockets/tls/ssl_socket.h | 10 ++-- .../proxy_filter_integration_test.cc | 8 +-- .../http/router/auto_sni_integration_test.cc | 12 ++--- .../transport_sockets/tls/ssl_socket_test.cc | 36 ++++++------- 5 files changed, 59 insertions(+), 58 deletions(-) diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index f8a95b42fef8..a5f1329191e8 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -47,7 +47,7 @@ SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, : transport_socket_options_(transport_socket_options), ctx_(std::dynamic_pointer_cast(ctx)) { bssl::UniquePtr ssl = ctx_->newSsl(transport_socket_options_.get()); - info_ = std::make_shared(std::move(ssl), ctx_, this); + info_ = std::make_shared(std::move(ssl), ctx_, this); if (state == InitialState::Client) { SSL_set_connect_state(rawSsl()); @@ -291,24 +291,24 @@ Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidat return certificate_validation_status_; } -SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - Ssl::HandshakeCallbacks* handshake_callbacks) +SslHandshakerImpl::SslHandshakerImpl(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, + Ssl::HandshakeCallbacks* handshake_callbacks) : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks), state_(Ssl::SocketState::PreHandshake) { SSL_set_ex_data(ssl_.get(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); } -bool SslSocketInfo::peerCertificatePresented() const { +bool SslHandshakerImpl::peerCertificatePresented() const { bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); return cert != nullptr; } -bool SslSocketInfo::peerCertificateValidated() const { +bool SslHandshakerImpl::peerCertificateValidated() const { return extended_socket_info_.certificateValidationStatus() == Envoy::Ssl::ClientValidationStatus::Validated; } -absl::Span SslSocketInfo::uriSanLocalCertificate() const { +absl::Span SslHandshakerImpl::uriSanLocalCertificate() const { if (!cached_uri_san_local_certificate_.empty()) { return cached_uri_san_local_certificate_; } @@ -323,7 +323,7 @@ absl::Span SslSocketInfo::uriSanLocalCertificate() const { return cached_uri_san_local_certificate_; } -absl::Span SslSocketInfo::dnsSansLocalCertificate() const { +absl::Span SslHandshakerImpl::dnsSansLocalCertificate() const { if (!cached_dns_san_local_certificate_.empty()) { return cached_dns_san_local_certificate_; } @@ -337,7 +337,7 @@ absl::Span SslSocketInfo::dnsSansLocalCertificate() const { return cached_dns_san_local_certificate_; } -const std::string& SslSocketInfo::sha256PeerCertificateDigest() const { +const std::string& SslHandshakerImpl::sha256PeerCertificateDigest() const { if (!cached_sha_256_peer_certificate_digest_.empty()) { return cached_sha_256_peer_certificate_digest_; } @@ -355,7 +355,7 @@ const std::string& SslSocketInfo::sha256PeerCertificateDigest() const { return cached_sha_256_peer_certificate_digest_; } -const std::string& SslSocketInfo::sha1PeerCertificateDigest() const { +const std::string& SslHandshakerImpl::sha1PeerCertificateDigest() const { if (!cached_sha_1_peer_certificate_digest_.empty()) { return cached_sha_1_peer_certificate_digest_; } @@ -373,7 +373,7 @@ const std::string& SslSocketInfo::sha1PeerCertificateDigest() const { return cached_sha_1_peer_certificate_digest_; } -const std::string& SslSocketInfo::urlEncodedPemEncodedPeerCertificate() const { +const std::string& SslHandshakerImpl::urlEncodedPemEncodedPeerCertificate() const { if (!cached_url_encoded_pem_encoded_peer_certificate_.empty()) { return cached_url_encoded_pem_encoded_peer_certificate_; } @@ -395,7 +395,7 @@ const std::string& SslSocketInfo::urlEncodedPemEncodedPeerCertificate() const { return cached_url_encoded_pem_encoded_peer_certificate_; } -const std::string& SslSocketInfo::urlEncodedPemEncodedPeerCertificateChain() const { +const std::string& SslHandshakerImpl::urlEncodedPemEncodedPeerCertificateChain() const { if (!cached_url_encoded_pem_encoded_peer_cert_chain_.empty()) { return cached_url_encoded_pem_encoded_peer_cert_chain_; } @@ -425,7 +425,7 @@ const std::string& SslSocketInfo::urlEncodedPemEncodedPeerCertificateChain() con return cached_url_encoded_pem_encoded_peer_cert_chain_; } -absl::Span SslSocketInfo::uriSanPeerCertificate() const { +absl::Span SslHandshakerImpl::uriSanPeerCertificate() const { if (!cached_uri_san_peer_certificate_.empty()) { return cached_uri_san_peer_certificate_; } @@ -439,7 +439,7 @@ absl::Span SslSocketInfo::uriSanPeerCertificate() const { return cached_uri_san_peer_certificate_; } -absl::Span SslSocketInfo::dnsSansPeerCertificate() const { +absl::Span SslHandshakerImpl::dnsSansPeerCertificate() const { if (!cached_dns_san_peer_certificate_.empty()) { return cached_dns_san_peer_certificate_; } @@ -475,7 +475,7 @@ std::string SslSocket::protocol() const { return std::string(reinterpret_cast(proto), proto_len); } -uint16_t SslSocketInfo::ciphersuiteId() const { +uint16_t SslHandshakerImpl::ciphersuiteId() const { const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); if (cipher == nullptr) { return 0xffff; @@ -487,7 +487,7 @@ uint16_t SslSocketInfo::ciphersuiteId() const { return static_cast(SSL_CIPHER_get_id(cipher)); } -std::string SslSocketInfo::ciphersuiteString() const { +std::string SslHandshakerImpl::ciphersuiteString() const { const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); if (cipher == nullptr) { return {}; @@ -496,7 +496,7 @@ std::string SslSocketInfo::ciphersuiteString() const { return SSL_CIPHER_get_name(cipher); } -const std::string& SslSocketInfo::tlsVersion() const { +const std::string& SslHandshakerImpl::tlsVersion() const { if (!cached_tls_version_.empty()) { return cached_tls_version_; } @@ -504,7 +504,8 @@ const std::string& SslSocketInfo::tlsVersion() const { return cached_tls_version_; } -absl::optional SslSocketInfo::x509Extension(absl::string_view extension_name) const { +absl::optional +SslHandshakerImpl::x509Extension(absl::string_view extension_name) const { bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); if (!cert) { return absl::nullopt; @@ -512,7 +513,7 @@ absl::optional SslSocketInfo::x509Extension(absl::string_view exten return Utility::getX509ExtensionValue(*cert, extension_name); } -Network::PostIoAction SslSocketInfo::doHandshake() { +Network::PostIoAction SslHandshakerImpl::doHandshake() { ASSERT(state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent); int rc = SSL_do_handshake(ssl()); if (rc == 1) { @@ -541,7 +542,7 @@ Network::PostIoAction SslSocketInfo::doHandshake() { absl::string_view SslSocket::failureReason() const { return failure_reason_; } -const std::string& SslSocketInfo::serialNumberPeerCertificate() const { +const std::string& SslHandshakerImpl::serialNumberPeerCertificate() const { if (!cached_serial_number_peer_certificate_.empty()) { return cached_serial_number_peer_certificate_; } @@ -554,7 +555,7 @@ const std::string& SslSocketInfo::serialNumberPeerCertificate() const { return cached_serial_number_peer_certificate_; } -const std::string& SslSocketInfo::issuerPeerCertificate() const { +const std::string& SslHandshakerImpl::issuerPeerCertificate() const { if (!cached_issuer_peer_certificate_.empty()) { return cached_issuer_peer_certificate_; } @@ -567,7 +568,7 @@ const std::string& SslSocketInfo::issuerPeerCertificate() const { return cached_issuer_peer_certificate_; } -const std::string& SslSocketInfo::subjectPeerCertificate() const { +const std::string& SslHandshakerImpl::subjectPeerCertificate() const { if (!cached_subject_peer_certificate_.empty()) { return cached_subject_peer_certificate_; } @@ -580,7 +581,7 @@ const std::string& SslSocketInfo::subjectPeerCertificate() const { return cached_subject_peer_certificate_; } -const std::string& SslSocketInfo::subjectLocalCertificate() const { +const std::string& SslHandshakerImpl::subjectLocalCertificate() const { if (!cached_subject_local_certificate_.empty()) { return cached_subject_local_certificate_; } @@ -593,7 +594,7 @@ const std::string& SslSocketInfo::subjectLocalCertificate() const { return cached_subject_local_certificate_; } -absl::optional SslSocketInfo::validFromPeerCertificate() const { +absl::optional SslHandshakerImpl::validFromPeerCertificate() const { bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); if (!cert) { return absl::nullopt; @@ -601,7 +602,7 @@ absl::optional SslSocketInfo::validFromPeerCertificate() const { return Utility::getValidFrom(*cert); } -absl::optional SslSocketInfo::expirationPeerCertificate() const { +absl::optional SslHandshakerImpl::expirationPeerCertificate() const { bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); if (!cert) { return absl::nullopt; @@ -609,7 +610,7 @@ absl::optional SslSocketInfo::expirationPeerCertificate() const { return Utility::getExpirationTime(*cert); } -const std::string& SslSocketInfo::sessionId() const { +const std::string& SslHandshakerImpl::sessionId() const { if (!cached_session_id_.empty()) { return cached_session_id_; } diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index b72cd93a966d..425e6644b209 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -52,10 +52,10 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { Envoy::Ssl::ClientValidationStatus::NotValidated}; }; -class SslSocketInfo : public Envoy::Ssl::ConnectionInfo, public Envoy::Ssl::Handshaker { +class SslHandshakerImpl : public Envoy::Ssl::ConnectionInfo, public Envoy::Ssl::Handshaker { public: - SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - Ssl::HandshakeCallbacks* handshake_callbacks); + SslHandshakerImpl(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, + Ssl::HandshakeCallbacks* handshake_callbacks); // Ssl::ConnectionInfo bool peerCertificatePresented() const override; @@ -110,7 +110,7 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo, public Envoy::Ssl::Hand mutable SslExtendedSocketInfoImpl extended_socket_info_; }; -using SslSocketInfoSharedPtr = std::shared_ptr; +using SslHandshakerImplSharedPtr = std::shared_ptr; class SslSocket : public Network::TransportSocket, public Envoy::Ssl::PrivateKeyConnectionCallbacks, @@ -162,7 +162,7 @@ class SslSocket : public Network::TransportSocket, uint64_t bytes_to_retry_{}; std::string failure_reason_; - SslSocketInfoSharedPtr info_; + SslHandshakerImplSharedPtr info_; }; class ClientSslSocketFactory : public Network::TransportSocketFactory, diff --git a/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_integration_test.cc b/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_integration_test.cc index e066cf482805..58efa9bdac2d 100644 --- a/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_integration_test.cc +++ b/test/extensions/filters/http/dynamic_forward_proxy/proxy_filter_integration_test.cc @@ -270,8 +270,8 @@ TEST_P(ProxyFilterIntegrationTest, UpstreamTls) { auto response = codec_client_->makeHeaderOnlyRequest(request_headers); waitForNextUpstreamRequest(); - const Extensions::TransportSockets::Tls::SslSocketInfo* ssl_socket = - dynamic_cast( + const Extensions::TransportSockets::Tls::SslHandshakerImpl* ssl_socket = + dynamic_cast( fake_upstream_connection_->connection().ssl().get()); EXPECT_STREQ("localhost", SSL_get_servername(ssl_socket->ssl(), TLSEXT_NAMETYPE_host_name)); @@ -295,8 +295,8 @@ TEST_P(ProxyFilterIntegrationTest, UpstreamTlsWithIpHost) { waitForNextUpstreamRequest(); // No SNI for IP hosts. - const Extensions::TransportSockets::Tls::SslSocketInfo* ssl_socket = - dynamic_cast( + const Extensions::TransportSockets::Tls::SslHandshakerImpl* ssl_socket = + dynamic_cast( fake_upstream_connection_->connection().ssl().get()); EXPECT_STREQ(nullptr, SSL_get_servername(ssl_socket->ssl(), TLSEXT_NAMETYPE_host_name)); diff --git a/test/extensions/filters/http/router/auto_sni_integration_test.cc b/test/extensions/filters/http/router/auto_sni_integration_test.cc index 10f0d7818e3f..d180c8cad956 100644 --- a/test/extensions/filters/http/router/auto_sni_integration_test.cc +++ b/test/extensions/filters/http/router/auto_sni_integration_test.cc @@ -76,8 +76,8 @@ TEST_P(AutoSniIntegrationTest, BasicAutoSniTest) { EXPECT_TRUE(upstream_request_->complete()); EXPECT_TRUE(response_->complete()); - const Extensions::TransportSockets::Tls::SslSocketInfo* ssl_socket = - dynamic_cast( + const Extensions::TransportSockets::Tls::SslHandshakerImpl* ssl_socket = + dynamic_cast( fake_upstream_connection_->connection().ssl().get()); EXPECT_STREQ("localhost", SSL_get_servername(ssl_socket->ssl(), TLSEXT_NAMETYPE_host_name)); } @@ -93,8 +93,8 @@ TEST_P(AutoSniIntegrationTest, PassingNotDNS) { EXPECT_TRUE(upstream_request_->complete()); EXPECT_TRUE(response_->complete()); - const Extensions::TransportSockets::Tls::SslSocketInfo* ssl_socket = - dynamic_cast( + const Extensions::TransportSockets::Tls::SslHandshakerImpl* ssl_socket = + dynamic_cast( fake_upstream_connection_->connection().ssl().get()); EXPECT_STREQ(nullptr, SSL_get_servername(ssl_socket->ssl(), TLSEXT_NAMETYPE_host_name)); } @@ -112,8 +112,8 @@ TEST_P(AutoSniIntegrationTest, PassingHostWithoutPort) { EXPECT_TRUE(upstream_request_->complete()); EXPECT_TRUE(response_->complete()); - const Extensions::TransportSockets::Tls::SslSocketInfo* ssl_socket = - dynamic_cast( + const Extensions::TransportSockets::Tls::SslHandshakerImpl* ssl_socket = + dynamic_cast( fake_upstream_connection_->connection().ssl().get()); EXPECT_STREQ("example.com", SSL_get_servername(ssl_socket->ssl(), TLSEXT_NAMETYPE_host_name)); } diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index 76f3a16b56b1..45bfe819cbc0 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -627,8 +627,8 @@ const std::string testUtilV2(const TestUtilOptionsV2& options) { client_ssl_socket_factory.createTransportSocket(options.transportSocketOptions()), nullptr); if (!options.clientSession().empty()) { - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); SSL* client_ssl_socket = ssl_socket->ssl(); SSL_CTX* client_ssl_context = SSL_get_SSL_CTX(client_ssl_socket); SSL_SESSION* client_ssl_session = @@ -670,8 +670,8 @@ const std::string testUtilV2(const TestUtilOptionsV2& options) { EXPECT_EQ(options.expectedALPNProtocol(), client_connection->nextProtocol()); } EXPECT_EQ(options.expectedClientCertUri(), server_connection->ssl()->uriSanPeerCertificate()); - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); SSL* client_ssl_socket = ssl_socket->ssl(); if (!options.expectedProtocolVersion().empty()) { // Assert twice to ensure a cached value is returned and still valid. @@ -687,8 +687,8 @@ const std::string testUtilV2(const TestUtilOptionsV2& options) { } absl::optional server_ssl_requested_server_name; - const SslSocketInfo* server_ssl_socket = - dynamic_cast(server_connection->ssl().get()); + const SslHandshakerImpl* server_ssl_socket = + dynamic_cast(server_connection->ssl().get()); SSL* server_ssl = server_ssl_socket->ssl(); auto requested_server_name = SSL_get_servername(server_ssl, TLSEXT_NAMETYPE_host_name); if (requested_server_name != nullptr) { @@ -2560,8 +2560,8 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { ssl_socket_factory.createTransportSocket(nullptr), nullptr); // Verify that server sent list with 2 acceptable client certificate CA names. - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); SSL_set_cert_cb( ssl_socket->ssl(), [](SSL* ssl, void*) -> int { @@ -2674,8 +2674,8 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); ssl_session = SSL_get1_session(ssl_socket->ssl()); EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session)); if (expected_lifetime_hint) { @@ -2697,8 +2697,8 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, socket2->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); SSL_set_session(ssl_socket->ssl(), ssl_session); SSL_SESSION_free(ssl_session); @@ -2803,8 +2803,8 @@ void testSupportForStatelessSessionResumption(const std::string& server_ctx_yaml std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), stream_info); - const SslSocketInfo* ssl_socket = - dynamic_cast(server_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(server_connection->ssl().get()); SSL* server_ssl_socket = ssl_socket->ssl(); SSL_CTX* server_ssl_context = SSL_get_SSL_CTX(server_ssl_socket); if (expect_support) { @@ -3257,8 +3257,8 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); ssl_session = SSL_get1_session(ssl_socket->ssl()); EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session)); server_connection->close(Network::ConnectionCloseType::NoFlush); @@ -3276,8 +3276,8 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { socket2->localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); - const SslSocketInfo* ssl_socket = - dynamic_cast(client_connection->ssl().get()); + const SslHandshakerImpl* ssl_socket = + dynamic_cast(client_connection->ssl().get()); SSL_set_session(ssl_socket->ssl(), ssl_session); SSL_SESSION_free(ssl_session); From ed874865ff5e81d5ad00a9c769d2f34724a07892 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Wed, 12 Aug 2020 10:45:13 -0400 Subject: [PATCH 10/11] [tls] Remove redundant doHandshake method from Connection interface. Signed-off-by: James Buckland --- include/envoy/ssl/BUILD | 3 ++- include/envoy/ssl/connection.h | 7 ------- test/mocks/ssl/mocks.h | 1 - 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/include/envoy/ssl/BUILD b/include/envoy/ssl/BUILD index 5f37befb41b7..b295a20e2a1a 100644 --- a/include/envoy/ssl/BUILD +++ b/include/envoy/ssl/BUILD @@ -79,8 +79,9 @@ envoy_cc_library( envoy_cc_library( name = "handshaker_interface", hdrs = ["handshaker.h"], + external_deps = ["ssl"], deps = [ - ":connection_interface", + "//include/envoy/network:connection_interface", "//include/envoy/network:post_io_action_interface", ], ) diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index c1c1d74d3041..8241c48ad8d7 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -5,7 +5,6 @@ #include "envoy/common/pure.h" #include "envoy/common/time.h" -#include "envoy/network/post_io_action.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -149,12 +148,6 @@ class ConnectionInfo { * exists. */ virtual absl::optional x509Extension(absl::string_view extension_name) const PURE; - - /** - * Performs a TLS handshake on the SSL object and returns an action indicating - * whether the callsite should close the connection or keep it open. - */ - virtual Network::PostIoAction doHandshake() PURE; }; using ConnectionInfoConstSharedPtr = std::shared_ptr; diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index af2c6abf1222..7567e5807cff 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -59,7 +59,6 @@ class MockConnectionInfo : public ConnectionInfo { MOCK_METHOD(std::string, ciphersuiteString, (), (const)); MOCK_METHOD(const std::string&, tlsVersion, (), (const)); MOCK_METHOD(absl::optional, x509Extension, (absl::string_view), (const)); - MOCK_METHOD(Network::PostIoAction, doHandshake, (), ()); }; class MockClientContext : public ClientContext { From 26f3c0d5fa9a2ffac416281a5b2e32f271f3d3e4 Mon Sep 17 00:00:00 2001 From: James Buckland Date: Wed, 12 Aug 2020 16:51:57 -0400 Subject: [PATCH 11/11] [misc] Add documentation for new handshaker interface. Signed-off-by: James Buckland --- include/envoy/network/transport_socket.h | 2 +- include/envoy/ssl/handshaker.h | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index 572d8b95124b..fe054ce2f16d 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -142,7 +142,7 @@ class TransportSocket { virtual void onConnected() PURE; /** - * @return the SSL connection data if this is an SSL connection, or nullptr if it is not. + * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. */ virtual Ssl::ConnectionInfoConstSharedPtr ssl() const PURE; }; diff --git a/include/envoy/ssl/handshaker.h b/include/envoy/ssl/handshaker.h index 54fa19e727d9..de11fc85f41f 100644 --- a/include/envoy/ssl/handshaker.h +++ b/include/envoy/ssl/handshaker.h @@ -17,7 +17,15 @@ class HandshakeCallbacks { */ virtual Network::Connection::State connectionState() const PURE; + /** + * A callback which will be executed at most once upon successful completion + * of a handshake. + */ virtual void onSuccess(SSL* ssl) PURE; + + /** + * A callback which will be executed at most once upon handshake failure. + */ virtual void onFailure() PURE; };