Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for mTLS #137

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion include/nebula/client/Config.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@ struct Config {
std::uint32_t idleTime_{0}; // in ms
std::uint32_t maxConnectionPoolSize_{10};
std::uint32_t minConnectionPoolSize_{0};
std::string CAPath_;
// Whether to enable SSL encryption
bool enableSSL_{false};
// Whether to enable mTLS
bool enableMTLS_{false};
// Whether to check peer CN or SAN
bool checkPeerName_{false};
std::string peerName_;
// Path to cert of CA
std::string CAPath_;
// Path to cert of client
std::string certPath_;
// path to private key of client
std::string keyPath_;
};

} // namespace nebula
4 changes: 2 additions & 2 deletions include/nebula/client/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "common/datatypes/Value.h"
#include "common/graph/Response.h"
#include "nebula/client/Config.h"

namespace folly {
class ScopedEventBaseThread;
Expand Down Expand Up @@ -50,8 +51,7 @@ class Connection {
bool open(const std::string &address,
int32_t port,
uint32_t timeout,
bool enableSSL,
const std::string &CAPath);
const Config &cfg = Config{});

AuthResponse authenticate(const std::string &user, const std::string &password);

Expand Down
11 changes: 11 additions & 0 deletions include/nebula/mclient/MConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@ struct MConfig {
int32_t connTimeoutInMs_{1000};
// It's as same as FLAG_meta_client_timeout_ms in nebula
int32_t clientTimeoutInMs_{60 * 1000};
// Whether to enable SSL encryption
bool enableSSL_{false};
// Whether to enable mTLS
bool enableMTLS_{false};
// Whether to check peer CN or SAN
bool checkPeerName_{false};
std::string peerName_;
// Path to cert of CA
std::string CAPath_;
// Path to cert of client
std::string certPath_;
// path to private key of client
std::string keyPath_;
};

} // namespace nebula
11 changes: 11 additions & 0 deletions include/nebula/sclient/SConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@ struct SConfig {
int32_t connTimeoutInMs_{1000};
// It's as same as FLAG_meta_client_timeout_ms in nebula
int32_t clientTimeoutInMs_{60 * 1000};
// Whether to enable SSL encryption
bool enableSSL_{false};
// Whether to enable mTLS
bool enableMTLS_{false};
// Whether to check peer CN or SAN
bool checkPeerName_{false};
std::string peerName_;
// Path to cert of CA
std::string CAPath_;
// Path to cert of client
std::string certPath_;
// path to private key of client
std::string keyPath_;
};

} // namespace nebula
39 changes: 28 additions & 11 deletions src/SSLConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,34 @@

namespace nebula {

std::shared_ptr<folly::SSLContext> createSSLContext(const std::string &CAPath) {
auto context = std::make_shared<folly::SSLContext>();
if (!CAPath.empty()) {
context->loadTrustedCertificates(CAPath.c_str());
// don't do peer name validation
context->authenticate(true, false);
// verify the server cert
context->setVerificationOption(folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
}
folly::ssl::setSignatureAlgorithms<folly::ssl::SSLCommonOptions>(*context);
return context;
std::shared_ptr<folly::SSLContext> createSSLContext(const SSLConfig &cfg) {
if (cfg.check_peer_name && cfg.peer_name.empty()) {
throw std::runtime_error("peer name checking enabled but not provied");
}

if (cfg.enable_mtls && (cfg.cert_path.empty() || cfg.key_path.empty())) {
throw std::runtime_error("mTLS enabled but cert/key not provided");
}

auto context = std::make_shared<folly::SSLContext>();

if (!cfg.ca_path.empty()) {
context->loadTrustedCertificates(cfg.ca_path.c_str());
if (cfg.check_peer_name) {
context->authenticate(true, true, cfg.peer_name);
} else {
context->authenticate(true, false);
}
context->setVerificationOption(folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
}

if (cfg.enable_mtls) {
context->loadCertKeyPairFromFiles(cfg.cert_path.c_str(), cfg.key_path.c_str());
}

folly::ssl::setSignatureAlgorithms<folly::ssl::SSLCommonOptions>(*context);

return context;
}

} // namespace nebula
16 changes: 15 additions & 1 deletion src/SSLConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@

namespace nebula {

std::shared_ptr<folly::SSLContext> createSSLContext(const std::string &CAPath);
struct SSLConfig final {
// Whether to enable mTLS(mutual TLS authentication)
bool enable_mtls{false};
// Check whether the given peername matches the CN or SAN in the certificate
bool check_peer_name{false};
std::string peer_name;
// Path to certificate(s) of the CA used to authenticate the cert of server
std::string ca_path;
// Path to the client cert, must be present if mTLS enabled
std::string cert_path;
// Path to the client private key, must be present if mTLS enabled
std::string key_path;
};

std::shared_ptr<folly::SSLContext> createSSLContext(const SSLConfig &cfg);

} // namespace nebula
16 changes: 11 additions & 5 deletions src/client/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ Connection &Connection::operator=(Connection &&c) {
bool Connection::open(const std::string &address,
int32_t port,
uint32_t timeout,
bool enableSSL,
const std::string &CAPath) {
const Config &cfg) {
if (address.empty()) {
return false;
}
Expand All @@ -91,10 +90,17 @@ bool Connection::open(const std::string &address,
return false;
}
clientLoopThread_->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait(
[this, &complete, &socket, timeout, &socketAddr, enableSSL, &CAPath]() {
[this, &complete, &socket, timeout, &socketAddr, &cfg]() {
try {
if (enableSSL) {
socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(CAPath),
if (cfg.enableSSL_) {
SSLConfig sslcfg;
sslcfg.enable_mtls = cfg.enableMTLS_;
sslcfg.check_peer_name = cfg.checkPeerName_;
sslcfg.peer_name = cfg.peerName_;
sslcfg.ca_path = cfg.CAPath_;
sslcfg.cert_path = cfg.certPath_;
sslcfg.key_path = cfg.keyPath_;
socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(sslcfg),
clientLoopThread_->getEventBase());
socket->connect(nullptr, std::move(socketAddr), timeout);
} else {
Expand Down
3 changes: 1 addition & 2 deletions src/client/ConnectionPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ void ConnectionPool::newConnection(std::size_t cursor, std::size_t count) {
if (conn.open(address_[addrCursor].first,
address_[addrCursor].second,
config_.timeout_,
config_.enableSSL_,
config_.CAPath_)) {
config_)) {
++connectionCount;
conns_.emplace_back(std::move(conn));
}
Expand Down
14 changes: 11 additions & 3 deletions src/client/tests/ConnectionSSLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class ConnectionTest : public ClientTest {};

TEST_F(ConnectionTest, SSL) {
nebula::Connection c;
nebula::Config cfg;
cfg.enableSSL_ = true;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, true, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10, cfg));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -38,7 +40,10 @@ TEST_F(ConnectionTest, SSL) {
TEST_F(ConnectionTest, SSCA) {
{
nebula::Connection c;
ASSERT_TRUE(c.open(kServerHost, 9669, 10, true, "./test.ca.pem"));
nebula::Config cfg;
cfg.enableSSL_ = true;
cfg.CAPath_ = "./test.ca.pem";
ASSERT_TRUE(c.open(kServerHost, 9669, 10, cfg));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -55,7 +60,10 @@ TEST_F(ConnectionTest, SSCA) {
{
// mismatch
nebula::Connection c;
ASSERT_FALSE(c.open(kServerHost, 9669, 10, true, "./test.2.crt"));
nebula::Config cfg;
cfg.enableSSL_ = true;
cfg.CAPath_ = "./test.2.crt";
ASSERT_FALSE(c.open(kServerHost, 9669, 10, cfg));
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/client/tests/ConnectionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConnectionTest : public ClientTest {
});

// open
ASSERT_TRUE(c.open(kServerHost, 9669, 0, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 0));

// ping
EXPECT_TRUE(c.ping());
Expand Down Expand Up @@ -128,7 +128,7 @@ TEST_F(ConnectionTest, Basic) {
TEST_F(ConnectionTest, Timeout) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 100, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 100));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand Down Expand Up @@ -167,7 +167,7 @@ TEST_F(ConnectionTest, Timeout) {
TEST_F(ConnectionTest, JsonResult) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -187,7 +187,7 @@ TEST_F(ConnectionTest, JsonResult) {
TEST_F(ConnectionTest, DurationResult) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand All @@ -204,7 +204,7 @@ TEST_F(ConnectionTest, DurationResult) {
TEST_F(ConnectionTest, ExecuteParameter) {
nebula::Connection c;

ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, ""));
ASSERT_TRUE(c.open(kServerHost, 9669, 10));

// auth
auto authResp = c.authenticate("root", "nebula");
Expand Down Expand Up @@ -232,13 +232,13 @@ TEST_F(ConnectionTest, ExecuteParameter) {
TEST_F(ConnectionTest, InvalidPort) {
nebula::Connection c;

ASSERT_FALSE(c.open(kServerHost, 2333, 10, false, ""));
ASSERT_FALSE(c.open(kServerHost, 2333, 10));
}

TEST_F(ConnectionTest, InvalidHost) {
nebula::Connection c;

ASSERT_FALSE(c.open("Invalid Host", 9669, 10, false, ""));
ASSERT_FALSE(c.open("Invalid Host", 9669, 10));
}

int main(int argc, char **argv) {
Expand Down
2 changes: 1 addition & 1 deletion src/client/tests/RegistHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main(int argc, char** argv) {
google::SetStderrLogging(google::INFO);

nebula::ConnectionPool pool;
nebula::Config c{10, 0, 300, 0, "", FLAGS_enable_ssl};
nebula::Config c{10, 0, 300, 0, FLAGS_enable_ssl, false, false, "", "", "", ""};
pool.init({FLAGS_server}, c);
auto session = pool.getSession("root", "nebula");
CHECK(session.valid());
Expand Down
6 changes: 4 additions & 2 deletions src/client/tests/SessionPoolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class SessionPoolTest : public ClientTest {
protected:
void SetUp() {
nebula::ConnectionPool pool;
pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 1, 0, "", false});
nebula::Config cfg{0, 0, 1, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, cfg);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());

Expand All @@ -41,7 +42,8 @@ class SessionPoolTest : public ClientTest {

void TearDown() {
nebula::ConnectionPool pool;
pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 1, 0, "", false});
nebula::Config cfg{0, 0, 1, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, cfg);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());

Expand Down
2 changes: 1 addition & 1 deletion src/client/tests/SessionSSLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SessionTest : public ClientTest {};

TEST_F(SessionTest, SSL) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", true};
nebula::Config c{10, 0, 10, 0, true, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down
10 changes: 5 additions & 5 deletions src/client/tests/SessionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ TEST_F(SessionTest, InvalidAddress) {

TEST_F(SessionTest, Data) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 300, 0, "", false};
nebula::Config c{10, 0, 300, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down Expand Up @@ -192,7 +192,7 @@ TEST_F(SessionTest, Data) {

TEST_F(SessionTest, Timeout) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 100, 0, "", false};
nebula::Config c{10, 0, 100, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down Expand Up @@ -228,7 +228,7 @@ TEST_F(SessionTest, Timeout) {

TEST_F(SessionTest, JsonResult) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", false};
nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand All @@ -246,7 +246,7 @@ TEST_F(SessionTest, JsonResult) {

TEST_F(SessionTest, DurationResult) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", false};
nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand All @@ -261,7 +261,7 @@ TEST_F(SessionTest, DurationResult) {

TEST_F(SessionTest, ExecuteParameter) {
nebula::ConnectionPool pool;
nebula::Config c{10, 0, 10, 0, "", false};
nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""};
pool.init({kServerHost ":9669"}, c);
auto session = pool.getSession("root", "nebula");
ASSERT_TRUE(session.valid());
Expand Down
9 changes: 8 additions & 1 deletion src/mclient/MetaClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ MetaClient::MetaClient(const std::vector<std::string>& metaAddrs, const MConfig&
}
CHECK(!metaAddrs_.empty()) << "metaAddrs_ is empty";
mConfig_ = mConfig;
SSLConfig sslcfg;
sslcfg.enable_mtls = mConfig_.enableMTLS_;
sslcfg.check_peer_name = mConfig_.checkPeerName_;
sslcfg.peer_name = mConfig_.peerName_;
sslcfg.ca_path = mConfig_.CAPath_;
sslcfg.cert_path = mConfig_.certPath_;
sslcfg.key_path = mConfig_.keyPath_;

ioExecutor_ = std::make_shared<folly::IOThreadPoolExecutor>(std::thread::hardware_concurrency());
clientsMan_ = std::make_shared<thrift::ThriftClientManager<meta::cpp2::MetaServiceAsyncClient>>(
mConfig_.connTimeoutInMs_, mConfig_.enableSSL_, mConfig_.CAPath_);
mConfig_.connTimeoutInMs_, mConfig_.enableSSL_, sslcfg);
bool b = loadData(); // load data into cache
if (!b) {
LOG(ERROR) << "load data failed";
Expand Down
Loading
Loading