Skip to content

Commit

Permalink
Make ThreadLocalSSLContextProvider EventBase-local
Browse files Browse the repository at this point in the history
Summary:
If the context is associated to the connection it should be local to the `EventBase`, not the thread, in case the `EventBase` migrates threads.

`EventBaseLocal` is slightly more expensive than `thread_local`, but the overall cost of the `get{Client,Server}ContextInfo` functions is completely negligible, so the overhead here will likely not even be measurable:
https://pxl.cl/4r09g

See motivation here https://fb.prod.workplace.com/groups/143349833027145/posts/1305210553507728

Reviewed By: disylh

Differential Revision: D54159214

fbshipit-source-id: 238cb2bcbd2b438e5a32746eb2f6e840fd0234e8
  • Loading branch information
ot authored and facebook-github-bot committed Mar 12, 2024
1 parent 94f6028 commit 1c2a60a
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 47 deletions.
1 change: 1 addition & 0 deletions mcrouter/lib/network/AsyncMcServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ class McServerThread {
}

auto contextPair = getServerContexts(
mcServerThread_->eventBase(),
opts.pemCertPath,
opts.pemKeyPath,
opts.pemCaPath,
Expand Down
4 changes: 2 additions & 2 deletions mcrouter/lib/network/SocketUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ createSocketCommon(
const auto& sessionKey = getSessionKey(connectionOptions);
if (isAsyncSSLSocketMech(mech)) {
// openssl based tls
auto sslContext = getClientContext(securityOpts, mech);
auto sslContext = getClientContext(eventBase, securityOpts, mech);
if (!sslContext) {
return folly::makeUnexpected(folly::AsyncSocketException(
folly::AsyncSocketException::SSL_ERROR,
Expand Down Expand Up @@ -174,7 +174,7 @@ createSocketCommon(
}

// tls 13 fizz
auto fizzContextAndVerifier = getFizzClientConfig(securityOpts);
auto fizzContextAndVerifier = getFizzClientConfig(eventBase, securityOpts);
if (!fizzContextAndVerifier.first) {
return folly::makeUnexpected(folly::AsyncSocketException(
folly::AsyncSocketException::SSL_ERROR,
Expand Down
82 changes: 47 additions & 35 deletions mcrouter/lib/network/ThreadLocalSSLContextProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

#include "ThreadLocalSSLContextProvider.h"

#include <unordered_map>

#include <folly/Singleton.h>
#include <folly/container/F14Map.h>
#include <folly/hash/Hash.h>
#include <folly/io/async/EventBaseLocal.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/SSLOptions.h>
#include <folly/portability/OpenSSL.h>
Expand Down Expand Up @@ -371,68 +371,76 @@ std::shared_ptr<SSLContext> createClientSSLContext(
}

ClientContextInfo& getClientContextInfo(
folly::EventBase& evb,
const SecurityOptions& opts,
SecurityMech mech) {
thread_local std::
unordered_map<ContextKey, ClientContextInfo, ContextKeyHasher>
localContexts;
static folly::EventBaseLocal<folly::F14FastMap<
ContextKey,
std::unique_ptr<ClientContextInfo>,
ContextKeyHasher>>
localContexts;

ContextKey key;
key.pemCertPath = opts.sslPemCertPath;
key.pemKeyPath = opts.sslPemKeyPath;
key.pemCaPath = opts.sslPemCaPath;
key.mech = mech;

auto iter = localContexts.find(key);
if (iter == localContexts.end()) {
auto& map = localContexts.try_emplace(evb);
auto iter = map.find(key);
if (iter == map.end()) {
// Copy strings.
ClientContextInfo info;
info.pemCertPath = opts.sslPemCertPath;
info.pemKeyPath = opts.sslPemKeyPath;
info.pemCaPath = opts.sslPemCaPath;
info.mech = mech;
auto info = std::make_unique<ClientContextInfo>();
info->pemCertPath = opts.sslPemCertPath;
info->pemKeyPath = opts.sslPemKeyPath;
info->pemCaPath = opts.sslPemCaPath;
info->mech = mech;

// Point all StringPiece's to our own strings.
key.pemCertPath = info.pemCertPath;
key.pemKeyPath = info.pemKeyPath;
key.pemCaPath = info.pemCaPath;
iter = localContexts.insert(std::make_pair(key, std::move(info))).first;
key.pemCertPath = info->pemCertPath;
key.pemKeyPath = info->pemKeyPath;
key.pemCaPath = info->pemCaPath;
iter = map.try_emplace(key, std::move(info)).first;
}

return iter->second;
return *iter->second;
}

ServerContextInfo& getServerContextInfo(
folly::EventBase& evb,
folly::StringPiece pemCertPath,
folly::StringPiece pemKeyPath,
folly::StringPiece pemCaPath,
bool requireClientVerification) {
thread_local std::
unordered_map<ContextKey, ServerContextInfo, ContextKeyHasher>
localContexts;
static folly::EventBaseLocal<folly::F14FastMap<
ContextKey,
std::unique_ptr<ServerContextInfo>,
ContextKeyHasher>>
localContexts;

ContextKey key;
key.pemCertPath = pemCertPath;
key.pemKeyPath = pemKeyPath;
key.pemCaPath = pemCaPath;
key.requireClientVerification = requireClientVerification;

auto iter = localContexts.find(key);
if (iter == localContexts.end()) {
auto& map = localContexts.try_emplace(evb);
auto iter = map.find(key);
if (iter == map.end()) {
// Copy strings.
ServerContextInfo info;
info.pemCertPath = pemCertPath.toString();
info.pemKeyPath = pemKeyPath.toString();
info.pemCaPath = pemCaPath.toString();
auto info = std::make_unique<ServerContextInfo>();
info->pemCertPath = pemCertPath.toString();
info->pemKeyPath = pemKeyPath.toString();
info->pemCaPath = pemCaPath.toString();

// Point all StringPiece's to our own strings.
key.pemCertPath = info.pemCertPath;
key.pemKeyPath = info.pemKeyPath;
key.pemCaPath = info.pemCaPath;
iter = localContexts.insert(std::make_pair(key, std::move(info))).first;
key.pemCertPath = info->pemCertPath;
key.pemKeyPath = info->pemKeyPath;
key.pemCaPath = info->pemCaPath;
iter = map.try_emplace(key, std::move(info)).first;
}

return iter->second;
return *iter->second;
}

} // namespace
Expand All @@ -442,8 +450,10 @@ bool isAsyncSSLSocketMech(SecurityMech mech) {
mech == SecurityMech::KTLS12;
}

FizzContextAndVerifier getFizzClientConfig(const SecurityOptions& opts) {
auto& info = getClientContextInfo(opts, SecurityMech::TLS13_FIZZ);
FizzContextAndVerifier getFizzClientConfig(
folly::EventBase& evb,
const SecurityOptions& opts) {
auto& info = getClientContextInfo(evb, opts, SecurityMech::TLS13_FIZZ);
auto now = std::chrono::steady_clock::now();
if (info.needsFizzContext(now)) {
auto certData = readFile(opts.sslPemCertPath);
Expand All @@ -459,6 +469,7 @@ FizzContextAndVerifier getFizzClientConfig(const SecurityOptions& opts) {
}

std::shared_ptr<folly::SSLContext> getClientContext(
folly::EventBase& evb,
const SecurityOptions& opts,
SecurityMech mech) {
if (!isAsyncSSLSocketMech(mech)) {
Expand All @@ -469,7 +480,7 @@ std::shared_ptr<folly::SSLContext> getClientContext(
static_cast<uint8_t>(mech));
return nullptr;
}
auto& info = getClientContextInfo(opts, mech);
auto& info = getClientContextInfo(evb, opts, mech);
auto now = std::chrono::steady_clock::now();
if (info.needsContext(now)) {
auto ctx = createClientSSLContext(opts, mech);
Expand All @@ -479,14 +490,15 @@ std::shared_ptr<folly::SSLContext> getClientContext(
}

ServerContextPair getServerContexts(
folly::EventBase& evb,
folly::StringPiece pemCertPath,
folly::StringPiece pemKeyPath,
folly::StringPiece pemCaPath,
bool requireClientCerts,
folly::Optional<wangle::TLSTicketKeySeeds> seeds,
bool preferOcbCipher) {
auto& info = getServerContextInfo(
pemCertPath, pemKeyPath, pemCaPath, requireClientCerts);
evb, pemCertPath, pemKeyPath, pemCaPath, requireClientCerts);
auto now = std::chrono::steady_clock::now();
if (info.needsContexts(now)) {
auto certData = readFile(pemCertPath);
Expand Down
10 changes: 7 additions & 3 deletions mcrouter/lib/network/ThreadLocalSSLContextProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class ClientSSLContext : public folly::SSLContext {
};

/**
* The following methods return thread local managed SSL Contexts. Contexts are
* reloaded on demand if they are 30 minutes old on a per thread basis.
* The following methods return EventBase-local managed SSL Contexts. Contexts
* are reloaded on demand if they are 30 minutes old on a per evb basis.
*/
FizzContextAndVerifier getFizzClientConfig(const SecurityOptions& opts);
FizzContextAndVerifier getFizzClientConfig(
folly::EventBase& evb,
const SecurityOptions& opts);

/**
* Determine if we are to use a AsyncSSLSocket with the provided mech.
Expand All @@ -67,6 +69,7 @@ bool isAsyncSSLSocketMech(SecurityMech mech);
* opts.
*/
std::shared_ptr<folly::SSLContext> getClientContext(
folly::EventBase& evb,
const SecurityOptions& opts,
SecurityMech mech);

Expand All @@ -81,6 +84,7 @@ using ServerContextPair = std::pair<
* during the handshake will be rejected.
*/
ServerContextPair getServerContexts(
folly::EventBase& evb,
folly::StringPiece pemCertPath,
folly::StringPiece pemKeyPath,
folly::StringPiece pemCaPath,
Expand Down
18 changes: 11 additions & 7 deletions mcrouter/lib/network/test/AsyncMcClientTestSync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,8 @@ TEST(AsyncMcClient, caretGoAway) {
}

TEST(AsyncMcClient, contextProviders) {
folly::EventBase evb;

auto clientCtxPaths = validClientSsl();
auto serverCtxPaths = validSsl();

Expand All @@ -912,33 +914,35 @@ TEST(AsyncMcClient, contextProviders) {
opts.sslPemKeyPath = clientCtxPaths.sslKeyPath;
opts.sslPemCaPath = clientCtxPaths.sslCaPath;
auto mech = SecurityMech::TLS;
auto clientCtx1 = getClientContext(opts, mech);
auto clientCtx2 = getClientContext(opts, mech);
auto clientCtx1 = getClientContext(evb, opts, mech);
auto clientCtx2 = getClientContext(evb, opts, mech);

// make sure mech changes the context
mech = SecurityMech::TLS_TO_PLAINTEXT;
auto clientCtx3 = getClientContext(opts, mech);
auto clientCtx4 = getClientContext(opts, mech);
auto clientCtx3 = getClientContext(evb, opts, mech);
auto clientCtx4 = getClientContext(evb, opts, mech);

auto fizzCfg1 = getFizzClientConfig(opts);
auto fizzCfg2 = getFizzClientConfig(opts);
auto fizzCfg1 = getFizzClientConfig(evb, opts);
auto fizzCfg2 = getFizzClientConfig(evb, opts);
EXPECT_EQ(fizzCfg1, fizzCfg2);

auto serverCtxs1 = getServerContexts(
evb,
serverCtxPaths.sslCertPath,
serverCtxPaths.sslKeyPath,
serverCtxPaths.sslCaPath,
true,
folly::none);
auto serverCtxs2 = getServerContexts(
evb,
serverCtxPaths.sslCertPath,
serverCtxPaths.sslKeyPath,
serverCtxPaths.sslCaPath,
true,
folly::none);

// client contexts should be the same since they are
// thread local cached
// EventBase-local
EXPECT_EQ(clientCtx1, clientCtx2);
EXPECT_EQ(clientCtx3, clientCtx4);
EXPECT_NE(clientCtx1, clientCtx3);
Expand Down

0 comments on commit 1c2a60a

Please sign in to comment.