Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssl: make Ssl::Connection const everywhere #4179

Merged
merged 1 commit into from
Aug 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/envoy/network/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,6 @@ class Connection : public Event::DeferredDeletable, public FilterManager {
*/
virtual void setConnectionStats(const ConnectionStats& stats) PURE;

/**
* @return the SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
virtual Ssl::Connection* ssl() PURE;

/**
* @return the const SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
Expand Down
5 changes: 0 additions & 5 deletions include/envoy/network/transport_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,6 @@ class TransportSocket {
*/
virtual void onConnected() PURE;

/**
* @return the SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
virtual Ssl::Connection* ssl() PURE;

/**
* @return the const SSL connection data if this is an SSL connection, or nullptr if it is not.
*/
Expand Down
6 changes: 3 additions & 3 deletions include/envoy/ssl/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Connection {
* @return std::string the URI in the SAN feld of the local certificate. Returns "" if there is no
* local certificate, or no SAN field, or no URI.
**/
virtual std::string uriSanLocalCertificate() PURE;
virtual std::string uriSanLocalCertificate() const PURE;

/**
* @return std::string the subject field of the local certificate in RFC 2253 format. Returns ""
Expand Down Expand Up @@ -66,13 +66,13 @@ class Connection {
* @return std::vector<std::string> the DNS entries in the SAN field of the peer certificate.
* Returns {} if there is no peer certificate, or no SAN field, or no DNS.
**/
virtual std::vector<std::string> dnsSansPeerCertificate() PURE;
virtual std::vector<std::string> dnsSansPeerCertificate() const PURE;

/**
* @return std::vector<std::string> the DNS entries in the SAN field of the local certificate.
* Returns {} if there is no local certificate, or no SAN field, or no DNS.
**/
virtual std::vector<std::string> dnsSansLocalCertificate() PURE;
virtual std::vector<std::string> dnsSansLocalCertificate() const PURE;
};

} // namespace Ssl
Expand Down
1 change: 0 additions & 1 deletion source/common/network/connection_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class ConnectionImpl : public virtual Connection,
return socket_->localAddress();
}
void setConnectionStats(const ConnectionStats& stats) override;
Ssl::Connection* ssl() override { return transport_socket_->ssl(); }
const Ssl::Connection* ssl() const override { return transport_socket_->ssl(); }
State state() const override;
void write(Buffer::Instance& data, bool end_stream) override;
Expand Down
1 change: 0 additions & 1 deletion source/common/network/raw_buffer_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggable<Logge
void onConnected() override;
IoResult doRead(Buffer::Instance& buffer) override;
IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
Ssl::Connection* ssl() override { return nullptr; }
const Ssl::Connection* ssl() const override { return nullptr; }

private:
Expand Down
8 changes: 4 additions & 4 deletions source/common/ssl/ssl_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ bool SslSocket::peerCertificatePresented() const {
return cert != nullptr;
}

std::string SslSocket::uriSanLocalCertificate() {
std::string SslSocket::uriSanLocalCertificate() const {
// The cert object is not owned.
X509* cert = SSL_get_certificate(ssl_.get());
if (!cert) {
Expand All @@ -228,7 +228,7 @@ std::string SslSocket::uriSanLocalCertificate() {
return getUriSanFromCertificate(cert);
}

std::vector<std::string> SslSocket::dnsSansLocalCertificate() {
std::vector<std::string> SslSocket::dnsSansLocalCertificate() const {
X509* cert = SSL_get_certificate(ssl_.get());
if (!cert) {
return {};
Expand Down Expand Up @@ -284,7 +284,7 @@ std::string SslSocket::uriSanPeerCertificate() const {
return getUriSanFromCertificate(cert.get());
}

std::vector<std::string> SslSocket::dnsSansPeerCertificate() {
std::vector<std::string> SslSocket::dnsSansPeerCertificate() const {
bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl_.get()));
if (!cert) {
return {};
Expand All @@ -309,7 +309,7 @@ std::string SslSocket::getUriSanFromCertificate(X509* cert) const {
return "";
}

std::vector<std::string> SslSocket::getDnsSansFromCertificate(X509* cert) {
std::vector<std::string> SslSocket::getDnsSansFromCertificate(X509* cert) const {
bssl::UniquePtr<GENERAL_NAMES> san_names(
static_cast<GENERAL_NAMES*>(X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)));
if (san_names == nullptr) {
Expand Down
11 changes: 5 additions & 6 deletions source/common/ssl/ssl_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class SslSocket : public Network::TransportSocket,

// Ssl::Connection
bool peerCertificatePresented() const override;
std::string uriSanLocalCertificate() override;
std::string uriSanLocalCertificate() const override;
const std::string& sha256PeerCertificateDigest() const override;
std::string serialNumberPeerCertificate() const override;
std::string subjectPeerCertificate() const override;
std::string subjectLocalCertificate() const override;
std::string uriSanPeerCertificate() const override;
const std::string& urlEncodedPemEncodedPeerCertificate() const override;
std::vector<std::string> dnsSansPeerCertificate() override;
std::vector<std::string> dnsSansLocalCertificate() override;
std::vector<std::string> dnsSansPeerCertificate() const override;
std::vector<std::string> dnsSansLocalCertificate() const override;

// Network::TransportSocket
void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override;
Expand All @@ -43,10 +43,9 @@ class SslSocket : public Network::TransportSocket,
Network::IoResult doRead(Buffer::Instance& read_buffer) override;
Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override;
void onConnected() override;
Ssl::Connection* ssl() override { return this; }
const Ssl::Connection* ssl() const override { return this; }

SSL* rawSslForTest() { return ssl_.get(); }
SSL* rawSslForTest() const { return ssl_.get(); }

private:
Network::PostIoAction doHandshake();
Expand All @@ -56,7 +55,7 @@ class SslSocket : public Network::TransportSocket,
// TODO: Move helper functions to the `Ssl::Utility` namespace.
std::string getUriSanFromCertificate(X509* cert) const;
std::string getSubjectFromCertificate(X509* cert) const;
std::vector<std::string> getDnsSansFromCertificate(X509* cert);
std::vector<std::string> getDnsSansFromCertificate(X509* cert) const;

Network::TransportSocketCallbacks* callbacks_{};
ContextImplSharedPtr ctx_;
Expand Down
2 changes: 0 additions & 2 deletions source/extensions/transport_sockets/capture/capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ Network::IoResult CaptureSocket::doWrite(Buffer::Instance& buffer, bool end_stre

void CaptureSocket::onConnected() { transport_socket_->onConnected(); }

Ssl::Connection* CaptureSocket::ssl() { return transport_socket_->ssl(); }

const Ssl::Connection* CaptureSocket::ssl() const { return transport_socket_->ssl(); }

CaptureSocketFactory::CaptureSocketFactory(
Expand Down
1 change: 0 additions & 1 deletion source/extensions/transport_sockets/capture/capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class CaptureSocket : public Network::TransportSocket {
Network::IoResult doRead(Buffer::Instance& buffer) override;
Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
void onConnected() override;
Ssl::Connection* ssl() override;
const Ssl::Connection* ssl() const override;

private:
Expand Down
18 changes: 11 additions & 7 deletions test/common/ssl/ssl_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto,
client_ssl_socket_factory.createTransportSocket(), nullptr);

if (!client_session.empty()) {
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL* client_ssl_socket = ssl_socket->rawSslForTest();
SSL_CTX* client_ssl_context = SSL_get_SSL_CTX(client_ssl_socket);
SSL_SESSION* client_ssl_session =
Expand Down Expand Up @@ -218,7 +219,8 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto,
EXPECT_EQ(expected_alpn_protocol, client_connection->nextProtocol());
}
EXPECT_EQ(expected_client_cert_uri, server_connection->ssl()->uriSanPeerCertificate());
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL* client_ssl_socket = ssl_socket->rawSslForTest();
if (!expected_protocol_version.empty()) {
EXPECT_EQ(expected_protocol_version, SSL_get_version(client_ssl_socket));
Expand Down Expand Up @@ -1705,7 +1707,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) {
ssl_socket_factory.createTransportSocket(), nullptr);

// Verify that server sent list with 2 acceptable client certificate CA names.
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket = dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL_set_cert_cb(ssl_socket->rawSslForTest(),
[](SSL* ssl, void*) -> int {
STACK_OF(X509_NAME)* list = SSL_get_client_CA_list(ssl);
Expand Down Expand Up @@ -1805,7 +1807,8 @@ void testTicketSessionResumption(const std::string& server_ctx_json1,

EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
ssl_session = SSL_get1_session(ssl_socket->rawSslForTest());
EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session));
client_connection->close(Network::ConnectionCloseType::NoFlush);
Expand All @@ -1822,7 +1825,7 @@ void testTicketSessionResumption(const std::string& server_ctx_json1,
socket2.localAddress(), Network::Address::InstanceConstSharedPtr(),
ssl_socket_factory.createTransportSocket(), nullptr);
client_connection->addConnectionCallbacks(client_connection_callbacks);
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket = dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL_set_session(ssl_socket->rawSslForTest(), ssl_session);
SSL_SESSION_free(ssl_session);

Expand Down Expand Up @@ -2179,7 +2182,8 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) {
EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose));
EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket =
dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
ssl_session = SSL_get1_session(ssl_socket->rawSslForTest());
EXPECT_TRUE(SSL_SESSION_is_resumable(ssl_session));
server_connection->close(Network::ConnectionCloseType::NoFlush);
Expand All @@ -2197,7 +2201,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) {
socket2.localAddress(), Network::Address::InstanceConstSharedPtr(),
ssl_socket_factory.createTransportSocket(), nullptr);
client_connection->addConnectionCallbacks(client_connection_callbacks);
Ssl::SslSocket* ssl_socket = dynamic_cast<Ssl::SslSocket*>(client_connection->ssl());
const Ssl::SslSocket* ssl_socket = dynamic_cast<const Ssl::SslSocket*>(client_connection->ssl());
SSL_set_session(ssl_socket->rawSslForTest(), ssl_session);
SSL_SESSION_free(ssl_session);

Expand Down
3 changes: 0 additions & 3 deletions test/mocks/network/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class MockConnection : public Connection, public MockConnectionBase {
MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&());
MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&());
MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats));
MOCK_METHOD0(ssl, Ssl::Connection*());
MOCK_CONST_METHOD0(ssl, const Ssl::Connection*());
MOCK_CONST_METHOD0(requestedServerName, absl::string_view());
MOCK_CONST_METHOD0(state, State());
Expand Down Expand Up @@ -117,7 +116,6 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase
MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&());
MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&());
MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats));
MOCK_METHOD0(ssl, Ssl::Connection*());
MOCK_CONST_METHOD0(ssl, const Ssl::Connection*());
MOCK_CONST_METHOD0(requestedServerName, absl::string_view());
MOCK_CONST_METHOD0(state, State());
Expand Down Expand Up @@ -435,7 +433,6 @@ class MockTransportSocket : public TransportSocket {
MOCK_METHOD1(doRead, IoResult(Buffer::Instance& buffer));
MOCK_METHOD2(doWrite, IoResult(Buffer::Instance& buffer, bool end_stream));
MOCK_METHOD0(onConnected, void());
MOCK_METHOD0(ssl, Ssl::Connection*());
MOCK_CONST_METHOD0(ssl, const Ssl::Connection*());
};

Expand Down
6 changes: 3 additions & 3 deletions test/mocks/ssl/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class MockConnection : public Connection {
~MockConnection();

MOCK_CONST_METHOD0(peerCertificatePresented, bool());
MOCK_METHOD0(uriSanLocalCertificate, std::string());
MOCK_CONST_METHOD0(uriSanLocalCertificate, std::string());
MOCK_CONST_METHOD0(sha256PeerCertificateDigest, std::string&());
MOCK_CONST_METHOD0(serialNumberPeerCertificate, std::string());
MOCK_CONST_METHOD0(subjectPeerCertificate, std::string());
MOCK_CONST_METHOD0(uriSanPeerCertificate, std::string());
MOCK_CONST_METHOD0(subjectLocalCertificate, std::string());
MOCK_CONST_METHOD0(urlEncodedPemEncodedPeerCertificate, std::string&());
MOCK_METHOD0(dnsSansPeerCertificate, std::vector<std::string>());
MOCK_METHOD0(dnsSansLocalCertificate, std::vector<std::string>());
MOCK_CONST_METHOD0(dnsSansPeerCertificate, std::vector<std::string>());
MOCK_CONST_METHOD0(dnsSansLocalCertificate, std::vector<std::string>());
};

class MockClientContext : public ClientContext {
Expand Down