Skip to content

Commit

Permalink
Added support for mTLS
Browse files Browse the repository at this point in the history
  • Loading branch information
dutor committed Apr 29, 2024
1 parent 03a5067 commit 77cd472
Show file tree
Hide file tree
Showing 20 changed files with 147 additions and 51 deletions.
13 changes: 12 additions & 1 deletion include/nebula/client/Config.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@ struct Config {
std::uint32_t idleTime_{0}; // in ms
std::uint32_t maxConnectionPoolSize_{10};
std::uint32_t minConnectionPoolSize_{0};
std::string CAPath_;
// Whether to enable SSL encryption
bool enableSSL_{false};
// Whether to enable mTLS
bool enableMTLS_{false};
// Whether to check peer CN or SAN
bool checkPeerName_{false};
std::string peerName_;
// Path to cert of CA
std::string CAPath_;
// Path to cert of client
std::string certPath_;
// path to private key of client
std::string keyPath_;
};

} // namespace nebula
4 changes: 2 additions & 2 deletions include/nebula/client/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "common/datatypes/Value.h"
#include "common/graph/Response.h"
#include "nebula/client/Config.h"

namespace folly {
class ScopedEventBaseThread;
Expand Down Expand Up @@ -50,8 +51,7 @@ class Connection {
bool open(const std::string &address,
int32_t port,
uint32_t timeout,
bool enableSSL,
const std::string &CAPath);
const Config &cfg = Config{});

AuthResponse authenticate(const std::string &user, const std::string &password);

Expand Down
11 changes: 11 additions & 0 deletions include/nebula/mclient/MConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@ struct MConfig {
int32_t connTimeoutInMs_{1000};
// It's as same as FLAG_meta_client_timeout_ms in nebula
int32_t clientTimeoutInMs_{60 * 1000};
// Whether to enable SSL encryption
bool enableSSL_{false};
// Whether to enable mTLS
bool enableMTLS_{false};
// Whether to check peer CN or SAN
bool checkPeerName_{false};
std::string peerName_;
// Path to cert of CA
std::string CAPath_;
// Path to cert of client
std::string certPath_;
// path to private key of client
std::string keyPath_;
};

} // namespace nebula
11 changes: 11 additions & 0 deletions include/nebula/sclient/SConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@ struct SConfig {
int32_t connTimeoutInMs_{1000};
// It's as same as FLAG_meta_client_timeout_ms in nebula
int32_t clientTimeoutInMs_{60 * 1000};
// Whether to enable SSL encryption
bool enableSSL_{false};
// Whether to enable mTLS
bool enableMTLS_{false};
// Whether to check peer CN or SAN
bool checkPeerName_{false};
std::string peerName_;
// Path to cert of CA
std::string CAPath_;
// Path to cert of client
std::string certPath_;
// path to private key of client
std::string keyPath_;
};

} // namespace nebula
35 changes: 25 additions & 10 deletions src/SSLConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,32 @@

namespace nebula {

std::shared_ptr<folly::SSLContext> createSSLContext(const std::string &CAPath) {
auto context = std::make_shared<folly::SSLContext>();
if (!CAPath.empty()) {
context->loadTrustedCertificates(CAPath.c_str());
// don't do peer name validation
context->authenticate(true, false);
// verify the server cert
std::shared_ptr<folly::SSLContext> createSSLContext(const SSLConfig &cfg) {
if (cfg.check_peer_name && cfg.peer_name.empty()) {
throw std::runtime_error("peer name checking enabled but not provied");
}

if (cfg.enable_mtls && (cfg.cert_path.empty() || cfg.key_path.empty())) {
throw std::runtime_error("mTLS enabled but cert/key not provided");
}

auto context = std::make_shared<folly::SSLContext>();

context->loadTrustedCertificates(cfg.ca_path.c_str());
context->setVerificationOption(folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
}
folly::ssl::setSignatureAlgorithms<folly::ssl::SSLCommonOptions>(*context);
return context;
folly::ssl::setSignatureAlgorithms<folly::ssl::SSLCommonOptions>(*context);

if (cfg.check_peer_name) {
context->authenticate(true, true, cfg.peer_name);
} else {
context->authenticate(true, false);
}

if (cfg.enable_mtls) {
context->loadCertKeyPairFromFiles(cfg.cert_path.c_str(), cfg.key_path.c_str());
}

return context;
}

} // namespace nebula
16 changes: 15 additions & 1 deletion src/SSLConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@

namespace nebula {

std::shared_ptr<folly::SSLContext> createSSLContext(const std::string &CAPath);
struct SSLConfig final {
// Whether to enable mTLS(mutual TLS authentication)
bool enable_mtls{false};
// Check whether the given peername matches the CN or SAN in the certificate
bool check_peer_name{false};
std::string peer_name;
// Path to certificate(s) of the CA used to authenticate the cert of server
std::string ca_path;
// Path to the client cert, must be present if mTLS enabled
std::string cert_path;
// Path to the client private key, must be present if mTLS enabled
std::string key_path;
};

std::shared_ptr<folly::SSLContext> createSSLContext(const SSLConfig &cfg);

} // namespace nebula
16 changes: 11 additions & 5 deletions src/client/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ Connection &Connection::operator=(Connection &&c) {
bool Connection::open(const std::string &address,
int32_t port,
uint32_t timeout,
bool enableSSL,
const std::string &CAPath) {
const Config &cfg) {
if (address.empty()) {
return false;
}
Expand All @@ -91,10 +90,17 @@ bool Connection::open(const std::string &address,
return false;
}
clientLoopThread_->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait(
[this, &complete, &socket, timeout, &socketAddr, enableSSL, &CAPath]() {
[this, &complete, &socket, timeout, &socketAddr, &cfg]() {
try {
if (enableSSL) {
socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(CAPath),
if (cfg.enableSSL_) {
SSLConfig sslcfg;
sslcfg.enable_mtls = cfg.enableMTLS_;
sslcfg.check_peer_name = cfg.checkPeerName_;
sslcfg.peer_name = cfg.peerName_;
sslcfg.ca_path = cfg.CAPath_;
sslcfg.cert_path = cfg.certPath_;
sslcfg.key_path = cfg.keyPath_;
socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(sslcfg),
clientLoopThread_->getEventBase());
socket->connect(nullptr, std::move(socketAddr), timeout);
} else {
Expand Down
3 changes: 1 addition & 2 deletions src/client/ConnectionPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ void ConnectionPool::newConnection(std::size_t cursor, std::size_t count) {
if (conn.open(address_[addrCursor].first,
address_[addrCursor].second,
config_.timeout_,
config_.enableSSL_,
config_.CAPath_)) {
config_)) {
++connectionCount;
conns_.emplace_back(std::move(conn));
}
Expand Down
14 changes: 11 additions & 3 deletions src/client/tests/ConnectionSSLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class ConnectionTest : public ClientTest {};

TEST_F(ConnectionTest, SSL) {
nebula::Connection c;
nebula::Config cfg;
cfg.enableSSL_ = true;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, true, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10, cfg));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -38,7 +40,10 @@ TEST_F(ConnectionTest, SSL) {
TEST_F(ConnectionTest, SSCA) {
{
nebula::Connection c;
ASSERT_TRUE(c.open(kServerHost, 9669, 10, true, "./test.ca.pem"));
nebula::Config cfg;
cfg.enableSSL_ = true;
cfg.CAPath_ = "./test.ca.pem";
ASSERT_TRUE(c.open(kServerHost, 9669, 10, cfg));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -55,7 +60,10 @@ TEST_F(ConnectionTest, SSCA) {
{
// mismatch
nebula::Connection c;
ASSERT_FALSE(c.open(kServerHost, 9669, 10, true, "./test.2.crt"));
nebula::Config cfg;
cfg.enableSSL_ = true;
cfg.CAPath_ = "./test.2.crt";
ASSERT_FALSE(c.open(kServerHost, 9669, 10, cfg));
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/client/tests/ConnectionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConnectionTest : public ClientTest {
});

// open
ASSERT_TRUE(c.open(kServerHost, 9669, 0, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 0));

// ping
EXPECT_TRUE(c.ping());
Expand Down Expand Up @@ -128,7 +128,7 @@ TEST_F(ConnectionTest, Basic) {
TEST_F(ConnectionTest, Timeout) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 100, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 100));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand Down Expand Up @@ -167,7 +167,7 @@ TEST_F(ConnectionTest, Timeout) {
TEST_F(ConnectionTest, JsonResult) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -187,7 +187,7 @@ TEST_F(ConnectionTest, JsonResult) {
TEST_F(ConnectionTest, DurationResult) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -204,7 +204,7 @@ TEST_F(ConnectionTest, DurationResult) {
TEST_F(ConnectionTest, ExecuteParameter) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand Down Expand Up @@ -232,13 +232,13 @@ TEST_F(ConnectionTest, ExecuteParameter) {
TEST_F(ConnectionTest, InvalidPort) {
nebula::Connection c;

ASSERT_FALSE(c.open(kServerHost, 2333, 10, false, ""));
ASSERT_FALSE(c.open(kServerHost, 2333, 10));
}

TEST_F(ConnectionTest, InvalidHost) {
nebula::Connection c;

ASSERT_FALSE(c.open("Invalid Host", 9669, 10, false, ""));
ASSERT_FALSE(c.open("Invalid Host", 9669, 10));
}

int main(int argc, char **argv) {
Expand Down
2 changes: 1 addition & 1 deletion src/client/tests/RegistHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main(int argc, char** argv) {
google::SetStderrLogging(google::INFO);

nebula::ConnectionPool pool;
nebula::Config c{10, 0, 300, 0, "", FLAGS_enable_ssl};
nebula::Config c{10, 0, 300, 0, FLAGS_enable_ssl, false, false, "", "", "", ""};
pool.init({FLAGS_server}, c);
auto session = pool.getSession("root", "nebula");
CHECK(session.valid());
Expand Down
6 changes: 4 additions & 2 deletions src/client/tests/SessionPoolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class SessionPoolTest : public ClientTest {
protected:
void SetUp() {
nebula::ConnectionPool pool;
pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 1, 0, "", false});
nebula::Config cfg{0, 0, 1, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, cfg);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());

Expand All @@ -41,7 +42,8 @@ class SessionPoolTest : public ClientTest {

void TearDown() {
nebula::ConnectionPool pool;
pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 1, 0, "", false});
nebula::Config cfg{0, 0, 1, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, cfg);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());

Expand Down
2 changes: 1 addition & 1 deletion src/client/tests/SessionSSLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SessionTest : public ClientTest {};

TEST_F(SessionTest, SSL) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", true};
nebula::Config c{10, 0, 10, 0, true, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down
10 changes: 5 additions & 5 deletions src/client/tests/SessionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ TEST_F(SessionTest, InvalidAddress) {

TEST_F(SessionTest, Data) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 300, 0, "", false};
nebula::Config c{10, 0, 300, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down Expand Up @@ -192,7 +192,7 @@ TEST_F(SessionTest, Data) {

TEST_F(SessionTest, Timeout) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 100, 0, "", false};
nebula::Config c{10, 0, 100, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down Expand Up @@ -228,7 +228,7 @@ TEST_F(SessionTest, Timeout) {

TEST_F(SessionTest, JsonResult) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", false};
nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand All @@ -246,7 +246,7 @@ TEST_F(SessionTest, JsonResult) {

TEST_F(SessionTest, DurationResult) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", false};
nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand All @@ -261,7 +261,7 @@ TEST_F(SessionTest, DurationResult) {

TEST_F(SessionTest, ExecuteParameter) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", false};
nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down
9 changes: 8 additions & 1 deletion src/mclient/MetaClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ MetaClient::MetaClient(const std::vector<std::string>& metaAddrs, const MConfig&
}
CHECK(!metaAddrs_.empty()) << "metaAddrs_ is empty";
mConfig_ = mConfig;
SSLConfig sslcfg;
sslcfg.enable_mtls = mConfig_.enableMTLS_;
sslcfg.check_peer_name = mConfig_.checkPeerName_;
sslcfg.peer_name = mConfig_.peerName_;
sslcfg.ca_path = mConfig_.CAPath_;
sslcfg.cert_path = mConfig_.certPath_;
sslcfg.key_path = mConfig_.keyPath_;

ioExecutor_ = std::make_shared<folly::IOThreadPoolExecutor>(std::thread::hardware_concurrency());
clientsMan_ = std::make_shared<thrift::ThriftClientManager<meta::cpp2::MetaServiceAsyncClient>>(
mConfig_.connTimeoutInMs_, mConfig_.enableSSL_, mConfig_.CAPath_);
mConfig_.connTimeoutInMs_, mConfig_.enableSSL_, sslcfg);
bool b = loadData(); // load data into cache
if (!b) {
LOG(ERROR) << "load data failed";
Expand Down
Loading

0 comments on commit 77cd472

Please sign in to comment.