diff --git a/thirdparty/patches/brpc-1.6.0-Force-SSL-for-all-connections-of-Acceptor.patch b/thirdparty/patches/brpc-1.6.0-Force-SSL-for-all-connections-of-Acceptor.patch new file mode 100644 index 00000000000000..86c8ccc6c749c9 --- /dev/null +++ b/thirdparty/patches/brpc-1.6.0-Force-SSL-for-all-connections-of-Acceptor.patch @@ -0,0 +1,327 @@ +From bc6f30deeef33d4bc1ecf1ea0c321d7d1804678d Mon Sep 17 00:00:00 2001 +From: Bright Chen +Date: Sun, 25 Jun 2023 14:35:36 +0800 +Subject: [PATCH] Force SSL for all connections of Acceptor (#2231) + +* Force SSL for all connections + +* Force SSL for all connections of Acceptor + +* Force SSL option in ServerOptions +--- + src/brpc/acceptor.cpp | 12 ++++++- + src/brpc/acceptor.h | 4 ++- + src/brpc/server.cpp | 11 ++++-- + src/brpc/server.h | 3 ++ + src/brpc/socket.cpp | 5 +++ + src/brpc/socket.h | 4 +++ + src/brpc/socket_inl.h | 1 + + test/brpc_channel_unittest.cpp | 2 +- + test/brpc_input_messenger_unittest.cpp | 2 +- + test/brpc_socket_unittest.cpp | 4 +-- + test/brpc_ssl_unittest.cpp | 50 ++++++++++++++++++++++++++ + 11 files changed, 90 insertions(+), 8 deletions(-) + +diff --git a/src/brpc/acceptor.cpp b/src/brpc/acceptor.cpp +index 62732881..f2d1c087 100644 +--- a/src/brpc/acceptor.cpp ++++ b/src/brpc/acceptor.cpp +@@ -38,6 +38,7 @@ Acceptor::Acceptor(bthread_keytable_pool_t* pool) + , _listened_fd(-1) + , _acception_id(0) + , _empty_cond(&_map_mutex) ++ , _force_ssl(false) + , _ssl_ctx(NULL) + , _use_rdma(false) { + } +@@ -48,11 +49,18 @@ Acceptor::~Acceptor() { + } + + int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec, +- const std::shared_ptr& ssl_ctx) { ++ const std::shared_ptr& ssl_ctx, ++ bool force_ssl) { + if (listened_fd < 0) { + LOG(FATAL) << "Invalid listened_fd=" << listened_fd; + return -1; + } ++ ++ if (!ssl_ctx && force_ssl) { ++ LOG(ERROR) << "Fail to force SSL for all connections " ++ " because ssl_ctx is NULL"; ++ return -1; ++ } + + BAIDU_SCOPED_LOCK(_map_mutex); + if (_status == UNINITIALIZED) { +@@ -74,6 +82,7 @@ int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec, + } + } + _idle_timeout_sec = idle_timeout_sec; ++ _force_ssl = force_ssl; + _ssl_ctx = ssl_ctx; + + // Creation of _acception_id is inside lock so that OnNewConnections +@@ -274,6 +283,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) { + options.fd = in_fd; + butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side); + options.user = acception->user(); ++ options.force_ssl = am->_force_ssl; + options.initial_ssl_ctx = am->_ssl_ctx; + #if BRPC_WITH_RDMA + if (am->_use_rdma) { +diff --git a/src/brpc/acceptor.h b/src/brpc/acceptor.h +index c442a60c..c82cdcc1 100644 +--- a/src/brpc/acceptor.h ++++ b/src/brpc/acceptor.h +@@ -55,7 +55,8 @@ public: + // `idle_timeout_sec' > 0 + // Return 0 on success, -1 otherwise. + int StartAccept(int listened_fd, int idle_timeout_sec, +- const std::shared_ptr& ssl_ctx); ++ const std::shared_ptr& ssl_ctx, ++ bool force_ssl); + + // [thread-safe] Stop accepting connections. + // `closewait_ms' is not used anymore. +@@ -106,6 +107,7 @@ private: + // The map containing all the accepted sockets + SocketMap _socket_map; + ++ bool _force_ssl; + std::shared_ptr _ssl_ctx; + + // Whether to use rdma or not +diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp +index 4953f88c..ce5a0dd2 100644 +--- a/src/brpc/server.cpp ++++ b/src/brpc/server.cpp +@@ -139,6 +139,7 @@ ServerOptions::ServerOptions() + , bthread_init_count(0) + , internal_port(-1) + , has_builtin_services(true) ++ , force_ssl(false) + , use_rdma(false) + , http_master_service(NULL) + , health_reporter(NULL) +@@ -933,6 +934,10 @@ int Server::StartInternal(const butil::EndPoint& endpoint, + return -1; + } + } ++ } else if (_options.force_ssl) { ++ LOG(ERROR) << "Fail to force SSL for all connections " ++ "without ServerOptions.ssl_options"; ++ return -1; + } + + _concurrency = 0; +@@ -1045,7 +1050,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint, + + // Pass ownership of `sockfd' to `_am' + if (_am->StartAccept(sockfd, _options.idle_timeout_sec, +- _default_ssl_ctx) != 0) { ++ _default_ssl_ctx, ++ _options.force_ssl) != 0) { + LOG(ERROR) << "Fail to start acceptor"; + return -1; + } +@@ -1085,7 +1091,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint, + } + // Pass ownership of `sockfd' to `_internal_am' + if (_internal_am->StartAccept(sockfd, _options.idle_timeout_sec, +- _default_ssl_ctx) != 0) { ++ _default_ssl_ctx, ++ false) != 0) { + LOG(ERROR) << "Fail to start internal_acceptor"; + return -1; + } +diff --git a/src/brpc/server.h b/src/brpc/server.h +index c00f9dc8..e598a6e8 100644 +--- a/src/brpc/server.h ++++ b/src/brpc/server.h +@@ -217,6 +217,9 @@ struct ServerOptions { + const ServerSSLOptions& ssl_options() const { return *_ssl_options; } + ServerSSLOptions* mutable_ssl_options(); + ++ // Force ssl for all connections of the port to Start(). ++ bool force_ssl; ++ + // Whether the server uses rdma or not + // Default: false + bool use_rdma; +diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp +index e0a69422..c49ca083 100644 +--- a/src/brpc/socket.cpp ++++ b/src/brpc/socket.cpp +@@ -698,6 +698,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { + m->SetFailed(rc2, "Fail to create auth_id: %s", berror(rc2)); + return -1; + } ++ m->_force_ssl = options.force_ssl; + // Disable SSL check if there is no SSL context + m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN); + m->_ssl_session = NULL; +@@ -2026,6 +2027,10 @@ ssize_t Socket::DoRead(size_t size_hint) { + } + // _ssl_state has been set + if (ssl_state() == SSL_OFF) { ++ if (_force_ssl) { ++ errno = ESSL; ++ return -1; ++ } + CHECK(_rdma_state == RDMA_OFF); + return _read_buf.append_from_file_descriptor(fd(), size_hint); + } +diff --git a/src/brpc/socket.h b/src/brpc/socket.h +index bd753f60..eff9474c 100644 +--- a/src/brpc/socket.h ++++ b/src/brpc/socket.h +@@ -205,6 +205,8 @@ struct SocketOptions { + // one thread at any time. + void (*on_edge_triggered_events)(Socket*); + int health_check_interval_s; ++ // Only accept ssl connection. ++ bool force_ssl; + std::shared_ptr initial_ssl_ctx; + bool use_rdma; + bthread_keytable_pool_t* keytable_pool; +@@ -826,6 +828,8 @@ private: + // exists in server side + AuthContext* _auth_context; + ++ // Only accept ssl connection. ++ bool _force_ssl; + SSLState _ssl_state; + // SSL objects cannot be read and written at the same time. + // Use mutex to protect SSL objects when ssl_state is SSL_CONNECTED. +diff --git a/src/brpc/socket_inl.h b/src/brpc/socket_inl.h +index 9423bfdf..df93ac71 100644 +--- a/src/brpc/socket_inl.h ++++ b/src/brpc/socket_inl.h +@@ -57,6 +57,7 @@ inline SocketOptions::SocketOptions() + , user(NULL) + , on_edge_triggered_events(NULL) + , health_check_interval_s(-1) ++ , force_ssl(false) + , use_rdma(false) + , keytable_pool(NULL) + , conn(NULL) +diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp +index 4de8e350..694f3f7f 100644 +--- a/test/brpc_channel_unittest.cpp ++++ b/test/brpc_channel_unittest.cpp +@@ -263,7 +263,7 @@ protected: + return -1; + } + } +- if (_messenger.StartAccept(listening_fd, -1, NULL) != 0) { ++ if (_messenger.StartAccept(listening_fd, -1, NULL, false) != 0) { + return -1; + } + return 0; +diff --git a/test/brpc_input_messenger_unittest.cpp b/test/brpc_input_messenger_unittest.cpp +index 7682b83b..00b14ed4 100644 +--- a/test/brpc_input_messenger_unittest.cpp ++++ b/test/brpc_input_messenger_unittest.cpp +@@ -169,7 +169,7 @@ TEST_F(MessengerTest, dispatch_tasks) { + ASSERT_TRUE(listening_fd > 0); + butil::make_non_blocking(listening_fd); + ASSERT_EQ(0, messenger[i].AddHandler(pairs[0])); +- ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL)); ++ ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL, false)); + } + + for (size_t i = 0; i < NCLIENT; ++i) { +diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp +index 3f080911..36a3b1b0 100644 +--- a/test/brpc_socket_unittest.cpp ++++ b/test/brpc_socket_unittest.cpp +@@ -339,7 +339,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) { + ASSERT_TRUE(listening_fd > 0); + butil::make_non_blocking(listening_fd); + ASSERT_EQ(0, messenger->AddHandler(pairs[0])); +- ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL)); ++ ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false)); + + brpc::SocketId id = 8888; + brpc::SocketOptions options; +@@ -727,7 +727,7 @@ TEST_F(SocketTest, health_check) { + ASSERT_TRUE(listening_fd > 0); + butil::make_non_blocking(listening_fd); + ASSERT_EQ(0, messenger->AddHandler(pairs[0])); +- ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL)); ++ ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false)); + + int64_t start_time = butil::gettimeofday_us(); + nref = -1; +diff --git a/test/brpc_ssl_unittest.cpp b/test/brpc_ssl_unittest.cpp +index f32dbcb7..7d58e455 100644 +--- a/test/brpc_ssl_unittest.cpp ++++ b/test/brpc_ssl_unittest.cpp +@@ -35,6 +35,7 @@ + #include "echo.pb.h" + + namespace brpc { ++ + void ExtractHostnames(X509* x, std::vector* hostnames); + } // namespace brpc + +@@ -175,6 +176,55 @@ TEST_F(SSLTest, sanity) { + ASSERT_EQ(0, server.Join()); + } + ++TEST_F(SSLTest, force_ssl) { ++ const int port = 8613; ++ brpc::Server server; ++ brpc::ServerOptions options; ++ EchoServiceImpl echo_svc; ++ ASSERT_EQ(0, server.AddService( ++ &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE)); ++ ++ options.force_ssl = true; ++ ASSERT_EQ(-1, server.Start(port, &options)); ++ ++ brpc::CertInfo cert; ++ cert.certificate = "cert1.crt"; ++ cert.private_key = "cert1.key"; ++ options.mutable_ssl_options()->default_cert = cert; ++ ++ ASSERT_EQ(0, server.Start(port, &options)); ++ ++ test::EchoRequest req; ++ req.set_message(EXP_REQUEST); ++ { ++ brpc::Channel channel; ++ brpc::ChannelOptions coptions; ++ coptions.mutable_ssl_options(); ++ coptions.mutable_ssl_options()->sni_name = "localhost"; ++ ASSERT_EQ(0, channel.Init("localhost", port, &coptions)); ++ ++ brpc::Controller cntl; ++ test::EchoService_Stub stub(&channel); ++ test::EchoResponse res; ++ stub.Echo(&cntl, &req, &res, NULL); ++ EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText(); ++ } ++ ++ { ++ brpc::Channel channel; ++ ASSERT_EQ(0, channel.Init("localhost", port, NULL)); ++ ++ brpc::Controller cntl; ++ test::EchoService_Stub stub(&channel); ++ test::EchoResponse res; ++ stub.Echo(&cntl, &req, &res, NULL); ++ EXPECT_TRUE(cntl.Failed()); ++ } ++ ++ ASSERT_EQ(0, server.Stop(0)); ++ ASSERT_EQ(0, server.Join()); ++} ++ + void CheckCert(const char* cname, const char* cert) { + const int port = 8613; + brpc::Channel channel; +-- +2.51.0 +