Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
From bc6f30deeef33d4bc1ecf1ea0c321d7d1804678d Mon Sep 17 00:00:00 2001
From: Bright Chen <chenguangmingfe@foxmail.com>
Date: Sun, 25 Jun 2023 14:35:36 +0800
Subject: [PATCH] Force SSL for all connections of Acceptor (#2231)

* Force SSL for all connections

* Force SSL for all connections of Acceptor

* Force SSL option in ServerOptions
---
src/brpc/acceptor.cpp | 12 ++++++-
src/brpc/acceptor.h | 4 ++-
src/brpc/server.cpp | 11 ++++--
src/brpc/server.h | 3 ++
src/brpc/socket.cpp | 5 +++
src/brpc/socket.h | 4 +++
src/brpc/socket_inl.h | 1 +
test/brpc_channel_unittest.cpp | 2 +-
test/brpc_input_messenger_unittest.cpp | 2 +-
test/brpc_socket_unittest.cpp | 4 +--
test/brpc_ssl_unittest.cpp | 50 ++++++++++++++++++++++++++
11 files changed, 90 insertions(+), 8 deletions(-)

diff --git a/src/brpc/acceptor.cpp b/src/brpc/acceptor.cpp
index 62732881..f2d1c087 100644
--- a/src/brpc/acceptor.cpp
+++ b/src/brpc/acceptor.cpp
@@ -38,6 +38,7 @@ Acceptor::Acceptor(bthread_keytable_pool_t* pool)
, _listened_fd(-1)
, _acception_id(0)
, _empty_cond(&_map_mutex)
+ , _force_ssl(false)
, _ssl_ctx(NULL)
, _use_rdma(false) {
}
@@ -48,11 +49,18 @@ Acceptor::~Acceptor() {
}

int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
- const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
+ const std::shared_ptr<SocketSSLContext>& ssl_ctx,
+ bool force_ssl) {
if (listened_fd < 0) {
LOG(FATAL) << "Invalid listened_fd=" << listened_fd;
return -1;
}
+
+ if (!ssl_ctx && force_ssl) {
+ LOG(ERROR) << "Fail to force SSL for all connections "
+ " because ssl_ctx is NULL";
+ return -1;
+ }

BAIDU_SCOPED_LOCK(_map_mutex);
if (_status == UNINITIALIZED) {
@@ -74,6 +82,7 @@ int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
}
}
_idle_timeout_sec = idle_timeout_sec;
+ _force_ssl = force_ssl;
_ssl_ctx = ssl_ctx;

// Creation of _acception_id is inside lock so that OnNewConnections
@@ -274,6 +283,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) {
options.fd = in_fd;
butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side);
options.user = acception->user();
+ options.force_ssl = am->_force_ssl;
options.initial_ssl_ctx = am->_ssl_ctx;
#if BRPC_WITH_RDMA
if (am->_use_rdma) {
diff --git a/src/brpc/acceptor.h b/src/brpc/acceptor.h
index c442a60c..c82cdcc1 100644
--- a/src/brpc/acceptor.h
+++ b/src/brpc/acceptor.h
@@ -55,7 +55,8 @@ public:
// `idle_timeout_sec' > 0
// Return 0 on success, -1 otherwise.
int StartAccept(int listened_fd, int idle_timeout_sec,
- const std::shared_ptr<SocketSSLContext>& ssl_ctx);
+ const std::shared_ptr<SocketSSLContext>& ssl_ctx,
+ bool force_ssl);

// [thread-safe] Stop accepting connections.
// `closewait_ms' is not used anymore.
@@ -106,6 +107,7 @@ private:
// The map containing all the accepted sockets
SocketMap _socket_map;

+ bool _force_ssl;
std::shared_ptr<SocketSSLContext> _ssl_ctx;

// Whether to use rdma or not
diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp
index 4953f88c..ce5a0dd2 100644
--- a/src/brpc/server.cpp
+++ b/src/brpc/server.cpp
@@ -139,6 +139,7 @@ ServerOptions::ServerOptions()
, bthread_init_count(0)
, internal_port(-1)
, has_builtin_services(true)
+ , force_ssl(false)
, use_rdma(false)
, http_master_service(NULL)
, health_reporter(NULL)
@@ -933,6 +934,10 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
return -1;
}
}
+ } else if (_options.force_ssl) {
+ LOG(ERROR) << "Fail to force SSL for all connections "
+ "without ServerOptions.ssl_options";
+ return -1;
}

_concurrency = 0;
@@ -1045,7 +1050,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,

// Pass ownership of `sockfd' to `_am'
if (_am->StartAccept(sockfd, _options.idle_timeout_sec,
- _default_ssl_ctx) != 0) {
+ _default_ssl_ctx,
+ _options.force_ssl) != 0) {
LOG(ERROR) << "Fail to start acceptor";
return -1;
}
@@ -1085,7 +1091,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
}
// Pass ownership of `sockfd' to `_internal_am'
if (_internal_am->StartAccept(sockfd, _options.idle_timeout_sec,
- _default_ssl_ctx) != 0) {
+ _default_ssl_ctx,
+ false) != 0) {
LOG(ERROR) << "Fail to start internal_acceptor";
return -1;
}
diff --git a/src/brpc/server.h b/src/brpc/server.h
index c00f9dc8..e598a6e8 100644
--- a/src/brpc/server.h
+++ b/src/brpc/server.h
@@ -217,6 +217,9 @@ struct ServerOptions {
const ServerSSLOptions& ssl_options() const { return *_ssl_options; }
ServerSSLOptions* mutable_ssl_options();

+ // Force ssl for all connections of the port to Start().
+ bool force_ssl;
+
// Whether the server uses rdma or not
// Default: false
bool use_rdma;
diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp
index e0a69422..c49ca083 100644
--- a/src/brpc/socket.cpp
+++ b/src/brpc/socket.cpp
@@ -698,6 +698,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
m->SetFailed(rc2, "Fail to create auth_id: %s", berror(rc2));
return -1;
}
+ m->_force_ssl = options.force_ssl;
// Disable SSL check if there is no SSL context
m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN);
m->_ssl_session = NULL;
@@ -2026,6 +2027,10 @@ ssize_t Socket::DoRead(size_t size_hint) {
}
// _ssl_state has been set
if (ssl_state() == SSL_OFF) {
+ if (_force_ssl) {
+ errno = ESSL;
+ return -1;
+ }
CHECK(_rdma_state == RDMA_OFF);
return _read_buf.append_from_file_descriptor(fd(), size_hint);
}
diff --git a/src/brpc/socket.h b/src/brpc/socket.h
index bd753f60..eff9474c 100644
--- a/src/brpc/socket.h
+++ b/src/brpc/socket.h
@@ -205,6 +205,8 @@ struct SocketOptions {
// one thread at any time.
void (*on_edge_triggered_events)(Socket*);
int health_check_interval_s;
+ // Only accept ssl connection.
+ bool force_ssl;
std::shared_ptr<SocketSSLContext> initial_ssl_ctx;
bool use_rdma;
bthread_keytable_pool_t* keytable_pool;
@@ -826,6 +828,8 @@ private:
// exists in server side
AuthContext* _auth_context;

+ // Only accept ssl connection.
+ bool _force_ssl;
SSLState _ssl_state;
// SSL objects cannot be read and written at the same time.
// Use mutex to protect SSL objects when ssl_state is SSL_CONNECTED.
diff --git a/src/brpc/socket_inl.h b/src/brpc/socket_inl.h
index 9423bfdf..df93ac71 100644
--- a/src/brpc/socket_inl.h
+++ b/src/brpc/socket_inl.h
@@ -57,6 +57,7 @@ inline SocketOptions::SocketOptions()
, user(NULL)
, on_edge_triggered_events(NULL)
, health_check_interval_s(-1)
+ , force_ssl(false)
, use_rdma(false)
, keytable_pool(NULL)
, conn(NULL)
diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp
index 4de8e350..694f3f7f 100644
--- a/test/brpc_channel_unittest.cpp
+++ b/test/brpc_channel_unittest.cpp
@@ -263,7 +263,7 @@ protected:
return -1;
}
}
- if (_messenger.StartAccept(listening_fd, -1, NULL) != 0) {
+ if (_messenger.StartAccept(listening_fd, -1, NULL, false) != 0) {
return -1;
}
return 0;
diff --git a/test/brpc_input_messenger_unittest.cpp b/test/brpc_input_messenger_unittest.cpp
index 7682b83b..00b14ed4 100644
--- a/test/brpc_input_messenger_unittest.cpp
+++ b/test/brpc_input_messenger_unittest.cpp
@@ -169,7 +169,7 @@ TEST_F(MessengerTest, dispatch_tasks) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger[i].AddHandler(pairs[0]));
- ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL));
+ ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL, false));
}

for (size_t i = 0; i < NCLIENT; ++i) {
diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp
index 3f080911..36a3b1b0 100644
--- a/test/brpc_socket_unittest.cpp
+++ b/test/brpc_socket_unittest.cpp
@@ -339,7 +339,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
- ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
+ ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));

brpc::SocketId id = 8888;
brpc::SocketOptions options;
@@ -727,7 +727,7 @@ TEST_F(SocketTest, health_check) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
- ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
+ ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));

int64_t start_time = butil::gettimeofday_us();
nref = -1;
diff --git a/test/brpc_ssl_unittest.cpp b/test/brpc_ssl_unittest.cpp
index f32dbcb7..7d58e455 100644
--- a/test/brpc_ssl_unittest.cpp
+++ b/test/brpc_ssl_unittest.cpp
@@ -35,6 +35,7 @@
#include "echo.pb.h"

namespace brpc {
+
void ExtractHostnames(X509* x, std::vector<std::string>* hostnames);
} // namespace brpc

@@ -175,6 +176,55 @@ TEST_F(SSLTest, sanity) {
ASSERT_EQ(0, server.Join());
}

+TEST_F(SSLTest, force_ssl) {
+ const int port = 8613;
+ brpc::Server server;
+ brpc::ServerOptions options;
+ EchoServiceImpl echo_svc;
+ ASSERT_EQ(0, server.AddService(
+ &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE));
+
+ options.force_ssl = true;
+ ASSERT_EQ(-1, server.Start(port, &options));
+
+ brpc::CertInfo cert;
+ cert.certificate = "cert1.crt";
+ cert.private_key = "cert1.key";
+ options.mutable_ssl_options()->default_cert = cert;
+
+ ASSERT_EQ(0, server.Start(port, &options));
+
+ test::EchoRequest req;
+ req.set_message(EXP_REQUEST);
+ {
+ brpc::Channel channel;
+ brpc::ChannelOptions coptions;
+ coptions.mutable_ssl_options();
+ coptions.mutable_ssl_options()->sni_name = "localhost";
+ ASSERT_EQ(0, channel.Init("localhost", port, &coptions));
+
+ brpc::Controller cntl;
+ test::EchoService_Stub stub(&channel);
+ test::EchoResponse res;
+ stub.Echo(&cntl, &req, &res, NULL);
+ EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
+ }
+
+ {
+ brpc::Channel channel;
+ ASSERT_EQ(0, channel.Init("localhost", port, NULL));
+
+ brpc::Controller cntl;
+ test::EchoService_Stub stub(&channel);
+ test::EchoResponse res;
+ stub.Echo(&cntl, &req, &res, NULL);
+ EXPECT_TRUE(cntl.Failed());
+ }
+
+ ASSERT_EQ(0, server.Stop(0));
+ ASSERT_EQ(0, server.Join());
+}
+
void CheckCert(const char* cname, const char* cert) {
const int port = 8613;
brpc::Channel channel;
--
2.51.0

Loading