Skip to content

Commit

Permalink
ssl: make Ssl::Connection const everywhere (#4179)
Browse files Browse the repository at this point in the history
Signed-off-by: Lizan Zhou <zlizan@google.com>
  • Loading branch information
lizan authored and mattklein123 committed Aug 16, 2018
1 parent 706e262 commit 3062874
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 41 deletions.
5 changes: 0 additions & 5 deletions include/envoy/network/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,6 @@ class Connection : public Event::DeferredDeletable, public FilterManager {
*/
virtual void setConnectionStats(const ConnectionStats& stats) PURE;

/**
* @return the SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
virtual Ssl::Connection* ssl() PURE;

/**
* @return the const SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
Expand Down
5 changes: 0 additions & 5 deletions include/envoy/network/transport_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,6 @@ class TransportSocket {
*/
virtual void onConnected() PURE;

/**
* @return the SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
virtual Ssl::Connection* ssl() PURE;

/**
* @return the const SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
Expand Down
6 changes: 3 additions & 3 deletions include/envoy/ssl/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Connection {
* @return std::string the URI in the SAN feld of the local certificate. Returns "" if there is no
* local certificate, or no SAN field, or no URI.
**/
virtual std::string uriSanLocalCertificate() PURE;
virtual std::string uriSanLocalCertificate() const PURE;

/**
* @return std::string the subject field of the local certificate in RFC 2253 format. Returns ""
Expand Down Expand Up @@ -66,13 +66,13 @@ class Connection {
* @return std::vector<std::string> the DNS entries in the SAN field of the peer certificate.
* Returns {} if there is no peer certificate, or no SAN field, or no DNS.
**/
virtual std::vector<std::string> dnsSansPeerCertificate() PURE;
virtual std::vector<std::string> dnsSansPeerCertificate() const PURE;

/**
* @return std::vector<std::string> the DNS entries in the SAN field of the local certificate.
* Returns {} if there is no local certificate, or no SAN field, or no DNS.
**/
virtual std::vector<std::string> dnsSansLocalCertificate() PURE;
virtual std::vector<std::string> dnsSansLocalCertificate() const PURE;
};

} // namespace Ssl
Expand Down
1 change: 0 additions & 1 deletion source/common/network/connection_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class ConnectionImpl : public virtual Connection,
return socket_->localAddress();
}
void setConnectionStats(const ConnectionStats& stats) override;
Ssl::Connection* ssl() override { return transport_socket_->ssl(); }
const Ssl::Connection* ssl() const override { return transport_socket_->ssl(); }
State state() const override;
void write(Buffer::Instance& data, bool end_stream) override;
Expand Down
1 change: 0 additions & 1 deletion source/common/network/raw_buffer_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggable<Logge
void onConnected() override;
IoResult doRead(Buffer::Instance& buffer) override;
IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
Ssl::Connection* ssl() override { return nullptr; }
const Ssl::Connection* ssl() const override { return nullptr; }

private:
Expand Down
8 changes: 4 additions & 4 deletions source/common/ssl/ssl_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ bool SslSocket::peerCertificatePresented() const {
return cert != nullptr;
}

std::string SslSocket::uriSanLocalCertificate() {
std::string SslSocket::uriSanLocalCertificate() const {
// The cert object is not owned.
X509* cert = SSL_get_certificate(ssl_.get());
if (!cert) {
Expand All @@ -228,7 +228,7 @@ std::string SslSocket::uriSanLocalCertificate() {
return getUriSanFromCertificate(cert);
}

std::vector<std::string> SslSocket::dnsSansLocalCertificate() {
std::vector<std::string> SslSocket::dnsSansLocalCertificate() const {
X509* cert = SSL_get_certificate(ssl_.get());
if (!cert) {
return {};
Expand Down Expand Up @@ -284,7 +284,7 @@ std::string SslSocket::uriSanPeerCertificate() const {
return getUriSanFromCertificate(cert.get());
}

std::vector<std::string> SslSocket::dnsSansPeerCertificate() {
std::vector<std::string> SslSocket::dnsSansPeerCertificate() const {
bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl_.get()));
if (!cert) {
return {};
Expand All @@ -309,7 +309,7 @@ std::string SslSocket::getUriSanFromCertificate(X509* cert) const {
return "";
}

std::vector<std::string> SslSocket::getDnsSansFromCertificate(X509* cert) {
std::vector<std::string> SslSocket::getDnsSansFromCertificate(X509* cert) const {
bssl::UniquePtr<GENERAL_NAMES> san_names(
static_cast<GENERAL_NAMES*>(X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)));
if (san_names == nullptr) {
Expand Down
11 changes: 5 additions & 6 deletions source/common/ssl/ssl_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class SslSocket : public Network::TransportSocket,

// Ssl::Connection
bool peerCertificatePresented() const override;
std::string uriSanLocalCertificate() override;
std::string uriSanLocalCertificate() const override;
const std::string& sha256PeerCertificateDigest() const override;
std::string serialNumberPeerCertificate() const override;
std::string subjectPeerCertificate() const override;
std::string subjectLocalCertificate() const override;
std::string uriSanPeerCertificate() const override;
const std::string& urlEncodedPemEncodedPeerCertificate() const override;
std::vector<std::string> dnsSansPeerCertificate() override;
std::vector<std::string> dnsSansLocalCertificate() override;
std::vector<std::string> dnsSansPeerCertificate() const override;
std::vector<std::string> dnsSansLocalCertificate() const override;

// Network::TransportSocket
void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override;
Expand All @@ -43,10 +43,9 @@ class SslSocket : public Network::TransportSocket,
Network::IoResult doRead(Buffer::Instance& read_buffer) override;
Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override;
void onConnected() override;
Ssl::Connection* ssl() override { return this; }
const Ssl::Connection* ssl() const override { return this; }

SSL* rawSslForTest() { return ssl_.get(); }
SSL* rawSslForTest() const { return ssl_.get(); }

private:
Network::PostIoAction doHandshake();
Expand All @@ -56,7 +55,7 @@ class SslSocket : public Network::TransportSocket,
// TODO: Move helper functions to the `Ssl::Utility` namespace.
std::string getUriSanFromCertificate(X509* cert) const;
std::string getSubjectFromCertificate(X509* cert) const;
std::vector<std::string> getDnsSansFromCertificate(X509* cert);
std::vector<std::string> getDnsSansFromCertificate(X509* cert) const;

Network::TransportSocketCallbacks* callbacks_{};
ContextImplSharedPtr ctx_;
Expand Down
2 changes: 0 additions & 2 deletions source/extensions/transport_sockets/capture/capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ Network::IoResult CaptureSocket::doWrite(Buffer::Instance& buffer, bool end_stre

void CaptureSocket::onConnected() { transport_socket_->onConnected(); }

Ssl::Connection* CaptureSocket::ssl() { return transport_socket_->ssl(); }

const Ssl::Connection* CaptureSocket::ssl() const { return transport_socket_->ssl(); }

CaptureSocketFactory::CaptureSocketFactory(
Expand Down
1 change: 0 additions & 1 deletion source/extensions/transport_sockets/capture/capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class CaptureSocket : public Network::TransportSocket {
Network::IoResult doRead(Buffer::Instance& buffer) override;
Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
void onConnected() override;
Ssl::Connection* ssl() override;
const Ssl::Connection* ssl() const override;

private:
Expand Down
18 changes: 11 additions & 7 deletions test/common/ssl/ssl_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto,
client_ssl_socket_factory.createTransportSocket(), nullptr);

if (!client_session.empty()) {
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL* client_ssl_socket = ssl_socket->rawSslForTest();
SSL_CTX* client_ssl_context = SSL_get_SSL_CTX(client_ssl_socket);
SSL_SESSION* client_ssl_session =
Expand Down Expand Up @@ -218,7 +219,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto,
EXPECT_EQ(expected_alpn_protocol, client_connection->nextProtocol());
}
EXPECT_EQ(expected_client_cert_uri, server_connection->ssl()->uriSanPeerCertificate());
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL* client_ssl_socket = ssl_socket->rawSslForTest();
if (!expected_protocol_version.empty()) {
EXPECT_EQ(expected_protocol_version, SSL_get_version(client_ssl_socket));
Expand Down Expand Up @@ -1705,7 +1707,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) {
ssl_socket_factory.createTransportSocket(), nullptr);

// Verify that server sent list with 2 acceptable client certificate CA names.
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket = dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL_set_cert_cb(ssl_socket->rawSslForTest(),
[](SSL* ssl, void*) -> int {
STACK_OF(X509_NAME)* list = SSL_get_client_CA_list(ssl);
Expand Down Expand Up @@ -1805,7 +1807,8 @@ void testTicketSessionResumption(const std::string& server_ctx_json1,

EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
ssl_session = SSL_get1_session(ssl_socket->rawSslForTest());
EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session));
client_connection->close(Network::ConnectionCloseType::NoFlush);
Expand All @@ -1822,7 +1825,7 @@ void testTicketSessionResumption(const std::string& server_ctx_json1,
socket2.localAddress(), Network::Address::InstanceConstSharedPtr(),
ssl_socket_factory.createTransportSocket(), nullptr);
client_connection->addConnectionCallbacks(client_connection_callbacks);
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket = dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL_set_session(ssl_socket->rawSslForTest(), ssl_session);
SSL_SESSION_free(ssl_session);

Expand Down Expand Up @@ -2179,7 +2182,8 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) {
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
ssl_session = SSL_get1_session(ssl_socket->rawSslForTest());
EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session));
server_connection->close(Network::ConnectionCloseType::NoFlush);
Expand All @@ -2197,7 +2201,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) {
socket2.localAddress(), Network::Address::InstanceConstSharedPtr(),
ssl_socket_factory.createTransportSocket(), nullptr);
client_connection->addConnectionCallbacks(client_connection_callbacks);
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket = dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL_set_session(ssl_socket->rawSslForTest(), ssl_session);
SSL_SESSION_free(ssl_session);

Expand Down
3 changes: 0 additions & 3 deletions test/mocks/network/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class MockConnection : public Connection, public MockConnectionBase {
MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&());
MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&());
MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats));
MOCK_METHOD0(ssl, Ssl::Connection*());
MOCK_CONST_METHOD0(ssl, const Ssl::Connection*());
MOCK_CONST_METHOD0(requestedServerName, absl::string_view());
MOCK_CONST_METHOD0(state, State());
Expand Down Expand Up @@ -117,7 +116,6 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase
MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&());
MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&());
MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats));
MOCK_METHOD0(ssl, Ssl::Connection*());
MOCK_CONST_METHOD0(ssl, const Ssl::Connection*());
MOCK_CONST_METHOD0(requestedServerName, absl::string_view());
MOCK_CONST_METHOD0(state, State());
Expand Down Expand Up @@ -435,7 +433,6 @@ class MockTransportSocket : public TransportSocket {
MOCK_METHOD1(doRead, IoResult(Buffer::Instance& buffer));
MOCK_METHOD2(doWrite, IoResult(Buffer::Instance& buffer, bool end_stream));
MOCK_METHOD0(onConnected, void());
MOCK_METHOD0(ssl, Ssl::Connection*());
MOCK_CONST_METHOD0(ssl, const Ssl::Connection*());
};

Expand Down
6 changes: 3 additions & 3 deletions test/mocks/ssl/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class MockConnection : public Connection {
~MockConnection();

MOCK_CONST_METHOD0(peerCertificatePresented, bool());
MOCK_METHOD0(uriSanLocalCertificate, std::string());
MOCK_CONST_METHOD0(uriSanLocalCertificate, std::string());
MOCK_CONST_METHOD0(sha256PeerCertificateDigest, std::string&());
MOCK_CONST_METHOD0(serialNumberPeerCertificate, std::string());
MOCK_CONST_METHOD0(subjectPeerCertificate, std::string());
MOCK_CONST_METHOD0(uriSanPeerCertificate, std::string());
MOCK_CONST_METHOD0(subjectLocalCertificate, std::string());
MOCK_CONST_METHOD0(urlEncodedPemEncodedPeerCertificate, std::string&());
MOCK_METHOD0(dnsSansPeerCertificate, std::vector<std::string>());
MOCK_METHOD0(dnsSansLocalCertificate, std::vector<std::string>());
MOCK_CONST_METHOD0(dnsSansPeerCertificate, std::vector<std::string>());
MOCK_CONST_METHOD0(dnsSansLocalCertificate, std::vector<std::string>());
};

class MockClientContext : public ClientContext {
Expand Down

0 comments on commit 3062874

Please sign in to comment.