From 062bbec89fb46f9717fea66a1c7da7d9954ad2d6 Mon Sep 17 00:00:00 2001 From: Lizan Zhou Date: Thu, 16 Aug 2018 06:11:39 +0000 Subject: [PATCH] ssl: make Ssl::Connection const everywhere Signed-off-by: Lizan Zhou --- include/envoy/network/connection.h | 5 ----- include/envoy/network/transport_socket.h | 5 ----- include/envoy/ssl/connection.h | 6 +++--- source/common/network/connection_impl.h | 1 - source/common/network/raw_buffer_socket.h | 1 - source/common/ssl/ssl_socket.cc | 8 ++++---- source/common/ssl/ssl_socket.h | 11 +++++------ .../transport_sockets/capture/capture.cc | 2 -- .../transport_sockets/capture/capture.h | 1 - test/common/ssl/ssl_socket_test.cc | 18 +++++++++++------- test/mocks/network/mocks.h | 3 --- test/mocks/ssl/mocks.h | 6 +++--- 12 files changed, 26 insertions(+), 41 deletions(-) diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index b7c57f06e4c9..7574f3686b25 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -176,11 +176,6 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual void setConnectionStats(const ConnectionStats& stats) PURE; - /** - * @return the SSL connection data if this is an SSL connection, or nullptr if it is not. - */ - virtual Ssl::Connection* ssl() PURE; - /** * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. */ diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index a5390c29853a..59172fc1115a 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -128,11 +128,6 @@ class TransportSocket { */ virtual void onConnected() PURE; - /** - * @return the SSL connection data if this is an SSL connection, or nullptr if it is not. - */ - virtual Ssl::Connection* ssl() PURE; - /** * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. */ diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index 98f8506497b8..8c1c42eba647 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -24,7 +24,7 @@ class Connection { * @return std::string the URI in the SAN feld of the local certificate. Returns "" if there is no * local certificate, or no SAN field, or no URI. **/ - virtual std::string uriSanLocalCertificate() PURE; + virtual std::string uriSanLocalCertificate() const PURE; /** * @return std::string the subject field of the local certificate in RFC 2253 format. Returns "" @@ -66,13 +66,13 @@ class Connection { * @return std::vector the DNS entries in the SAN field of the peer certificate. * Returns {} if there is no peer certificate, or no SAN field, or no DNS. **/ - virtual std::vector dnsSansPeerCertificate() PURE; + virtual std::vector dnsSansPeerCertificate() const PURE; /** * @return std::vector the DNS entries in the SAN field of the local certificate. * Returns {} if there is no local certificate, or no SAN field, or no DNS. **/ - virtual std::vector dnsSansLocalCertificate() PURE; + virtual std::vector dnsSansLocalCertificate() const PURE; }; } // namespace Ssl diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index a9e796e3b02f..9bb12ef6bc42 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -78,7 +78,6 @@ class ConnectionImpl : public virtual Connection, return socket_->localAddress(); } void setConnectionStats(const ConnectionStats& stats) override; - Ssl::Connection* ssl() override { return transport_socket_->ssl(); } const Ssl::Connection* ssl() const override { return transport_socket_->ssl(); } State state() const override; void write(Buffer::Instance& data, bool end_stream) override; diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index 3ab5ac0a2725..8b8b205ce38f 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -19,7 +19,6 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggable SslSocket::dnsSansLocalCertificate() { +std::vector SslSocket::dnsSansLocalCertificate() const { X509* cert = SSL_get_certificate(ssl_.get()); if (!cert) { return {}; @@ -284,7 +284,7 @@ std::string SslSocket::uriSanPeerCertificate() const { return getUriSanFromCertificate(cert.get()); } -std::vector SslSocket::dnsSansPeerCertificate() { +std::vector SslSocket::dnsSansPeerCertificate() const { bssl::UniquePtr cert(SSL_get_peer_certificate(ssl_.get())); if (!cert) { return {}; @@ -309,7 +309,7 @@ std::string SslSocket::getUriSanFromCertificate(X509* cert) const { return ""; } -std::vector SslSocket::getDnsSansFromCertificate(X509* cert) { +std::vector SslSocket::getDnsSansFromCertificate(X509* cert) const { bssl::UniquePtr san_names( static_cast(X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr))); if (san_names == nullptr) { diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index d87a861d68bb..c6b8f92b4d63 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -25,15 +25,15 @@ class SslSocket : public Network::TransportSocket, // Ssl::Connection bool peerCertificatePresented() const override; - std::string uriSanLocalCertificate() override; + std::string uriSanLocalCertificate() const override; const std::string& sha256PeerCertificateDigest() const override; std::string serialNumberPeerCertificate() const override; std::string subjectPeerCertificate() const override; std::string subjectLocalCertificate() const override; std::string uriSanPeerCertificate() const override; const std::string& urlEncodedPemEncodedPeerCertificate() const override; - std::vector dnsSansPeerCertificate() override; - std::vector dnsSansLocalCertificate() override; + std::vector dnsSansPeerCertificate() const override; + std::vector dnsSansLocalCertificate() const override; // Network::TransportSocket void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override; @@ -43,10 +43,9 @@ class SslSocket : public Network::TransportSocket, Network::IoResult doRead(Buffer::Instance& read_buffer) override; Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override; void onConnected() override; - Ssl::Connection* ssl() override { return this; } const Ssl::Connection* ssl() const override { return this; } - SSL* rawSslForTest() { return ssl_.get(); } + SSL* rawSslForTest() const { return ssl_.get(); } private: Network::PostIoAction doHandshake(); @@ -56,7 +55,7 @@ class SslSocket : public Network::TransportSocket, // TODO: Move helper functions to the `Ssl::Utility` namespace. std::string getUriSanFromCertificate(X509* cert) const; std::string getSubjectFromCertificate(X509* cert) const; - std::vector getDnsSansFromCertificate(X509* cert); + std::vector getDnsSansFromCertificate(X509* cert) const; Network::TransportSocketCallbacks* callbacks_{}; ContextImplSharedPtr ctx_; diff --git a/source/extensions/transport_sockets/capture/capture.cc b/source/extensions/transport_sockets/capture/capture.cc index 26e8aef21a0b..25188ea6474a 100644 --- a/source/extensions/transport_sockets/capture/capture.cc +++ b/source/extensions/transport_sockets/capture/capture.cc @@ -89,8 +89,6 @@ Network::IoResult CaptureSocket::doWrite(Buffer::Instance& buffer, bool end_stre void CaptureSocket::onConnected() { transport_socket_->onConnected(); } -Ssl::Connection* CaptureSocket::ssl() { return transport_socket_->ssl(); } - const Ssl::Connection* CaptureSocket::ssl() const { return transport_socket_->ssl(); } CaptureSocketFactory::CaptureSocketFactory( diff --git a/source/extensions/transport_sockets/capture/capture.h b/source/extensions/transport_sockets/capture/capture.h index b46f94bf2930..63b94042460a 100644 --- a/source/extensions/transport_sockets/capture/capture.h +++ b/source/extensions/transport_sockets/capture/capture.h @@ -25,7 +25,6 @@ class CaptureSocket : public Network::TransportSocket { Network::IoResult doRead(Buffer::Instance& buffer) override; Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; void onConnected() override; - Ssl::Connection* ssl() override; const Ssl::Connection* ssl() const override; private: diff --git a/test/common/ssl/ssl_socket_test.cc b/test/common/ssl/ssl_socket_test.cc index 219de0e82955..190d80471840 100644 --- a/test/common/ssl/ssl_socket_test.cc +++ b/test/common/ssl/ssl_socket_test.cc @@ -177,7 +177,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, client_ssl_socket_factory.createTransportSocket(), nullptr); if (!client_session.empty()) { - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = + dynamic_cast(client_connection->ssl()); SSL* client_ssl_socket = ssl_socket->rawSslForTest(); SSL_CTX* client_ssl_context = SSL_get_SSL_CTX(client_ssl_socket); SSL_SESSION* client_ssl_session = @@ -218,7 +219,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, EXPECT_EQ(expected_alpn_protocol, client_connection->nextProtocol()); } EXPECT_EQ(expected_client_cert_uri, server_connection->ssl()->uriSanPeerCertificate()); - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = + dynamic_cast(client_connection->ssl()); SSL* client_ssl_socket = ssl_socket->rawSslForTest(); if (!expected_protocol_version.empty()) { EXPECT_EQ(expected_protocol_version, SSL_get_version(client_ssl_socket)); @@ -1705,7 +1707,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { ssl_socket_factory.createTransportSocket(), nullptr); // Verify that server sent list with 2 acceptable client certificate CA names. - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); SSL_set_cert_cb(ssl_socket->rawSslForTest(), [](SSL* ssl, void*) -> int { STACK_OF(X509_NAME)* list = SSL_get_client_CA_list(ssl); @@ -1805,7 +1807,8 @@ void testTicketSessionResumption(const std::string& server_ctx_json1, EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = + dynamic_cast(client_connection->ssl()); ssl_session = SSL_get1_session(ssl_socket->rawSslForTest()); EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session)); client_connection->close(Network::ConnectionCloseType::NoFlush); @@ -1822,7 +1825,7 @@ void testTicketSessionResumption(const std::string& server_ctx_json1, socket2.localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); SSL_set_session(ssl_socket->rawSslForTest(), ssl_session); SSL_SESSION_free(ssl_session); @@ -2179,7 +2182,8 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = + dynamic_cast(client_connection->ssl()); ssl_session = SSL_get1_session(ssl_socket->rawSslForTest()); EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session)); server_connection->close(Network::ConnectionCloseType::NoFlush); @@ -2197,7 +2201,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { socket2.localAddress(), Network::Address::InstanceConstSharedPtr(), ssl_socket_factory.createTransportSocket(), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); - Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); + const Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); SSL_set_session(ssl_socket->rawSslForTest(), ssl_session); SSL_SESSION_free(ssl_session); diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 651da87c4174..0bd73a87dbb9 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -77,7 +77,6 @@ class MockConnection : public Connection, public MockConnectionBase { MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&()); MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&()); MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats)); - MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); MOCK_CONST_METHOD0(requestedServerName, absl::string_view()); MOCK_CONST_METHOD0(state, State()); @@ -117,7 +116,6 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&()); MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&()); MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats)); - MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); MOCK_CONST_METHOD0(requestedServerName, absl::string_view()); MOCK_CONST_METHOD0(state, State()); @@ -435,7 +433,6 @@ class MockTransportSocket : public TransportSocket { MOCK_METHOD1(doRead, IoResult(Buffer::Instance& buffer)); MOCK_METHOD2(doWrite, IoResult(Buffer::Instance& buffer, bool end_stream)); MOCK_METHOD0(onConnected, void()); - MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); }; diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 47209e822f16..8ab686d1aa37 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -36,15 +36,15 @@ class MockConnection : public Connection { ~MockConnection(); MOCK_CONST_METHOD0(peerCertificatePresented, bool()); - MOCK_METHOD0(uriSanLocalCertificate, std::string()); + MOCK_CONST_METHOD0(uriSanLocalCertificate, std::string()); MOCK_CONST_METHOD0(sha256PeerCertificateDigest, std::string&()); MOCK_CONST_METHOD0(serialNumberPeerCertificate, std::string()); MOCK_CONST_METHOD0(subjectPeerCertificate, std::string()); MOCK_CONST_METHOD0(uriSanPeerCertificate, std::string()); MOCK_CONST_METHOD0(subjectLocalCertificate, std::string()); MOCK_CONST_METHOD0(urlEncodedPemEncodedPeerCertificate, std::string&()); - MOCK_METHOD0(dnsSansPeerCertificate, std::vector()); - MOCK_METHOD0(dnsSansLocalCertificate, std::vector()); + MOCK_CONST_METHOD0(dnsSansPeerCertificate, std::vector()); + MOCK_CONST_METHOD0(dnsSansLocalCertificate, std::vector()); }; class MockClientContext : public ClientContext {