Skip to content
This repository has been archived by the owner on Feb 20, 2023. It is now read-only.

Commit

Permalink
Always use Unix domain sockets if available (#1109)
Browse files Browse the repository at this point in the history
Co-authored-by: Andy Pavlo <pavlo@cs.brown.edu>
Co-authored-by: Wan Shen Lim <wanshen.lim@gmail.com>
  • Loading branch information
3 people authored Sep 16, 2020
1 parent 70f7e9e commit b68a4a8
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 50 deletions.
13 changes: 8 additions & 5 deletions src/include/main/db_main.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,18 @@ class DBMain {
* @param traffic_cop argument to the ConnectionHandleFactor
* @param port argument to TerrierServer
* @param connection_thread_count argument to TerrierServer
* @param socket_directory argument to TerrierServer
*/
NetworkLayer(const common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry,
const common::ManagedPointer<trafficcop::TrafficCop> traffic_cop, const uint16_t port,
const uint16_t connection_thread_count) {
const uint16_t connection_thread_count, const std::string socket_directory) {
connection_handle_factory_ = std::make_unique<network::ConnectionHandleFactory>(traffic_cop);
command_factory_ = std::make_unique<network::PostgresCommandFactory>();
provider_ =
std::make_unique<network::PostgresProtocolInterpreter::Provider>(common::ManagedPointer(command_factory_));
server_ = std::make_unique<network::TerrierServer>(common::ManagedPointer(provider_),
common::ManagedPointer(connection_handle_factory_),
thread_registry, port, connection_thread_count);
server_ = std::make_unique<network::TerrierServer>(
common::ManagedPointer(provider_), common::ManagedPointer(connection_handle_factory_), thread_registry, port,
connection_thread_count, socket_directory);
}

/**
Expand Down Expand Up @@ -360,7 +361,7 @@ class DBMain {
TERRIER_ASSERT(use_traffic_cop_ && traffic_cop != DISABLED, "NetworkLayer needs TrafficCopLayer.");
network_layer =
std::make_unique<NetworkLayer>(common::ManagedPointer(thread_registry), common::ManagedPointer(traffic_cop),
network_port_, connection_thread_count_);
network_port_, connection_thread_count_, uds_file_directory_);
}

db_main->settings_manager_ = std::move(settings_manager);
Expand Down Expand Up @@ -663,6 +664,7 @@ class DBMain {
bool use_query_cache_ = true;
execution::vm::ExecutionMode execution_mode_ = execution::vm::ExecutionMode::Interpret;
uint16_t network_port_ = 15721;
std::string uds_file_directory_ = "/tmp/";
uint16_t connection_thread_count_ = 4;
bool use_network_ = false;

Expand Down Expand Up @@ -696,6 +698,7 @@ class DBMain {

gc_interval_ = settings_manager->GetInt(settings::Param::gc_interval);

uds_file_directory_ = settings_manager->GetString(settings::Param::uds_file_directory);
network_port_ = static_cast<uint16_t>(settings_manager->GetInt(settings::Param::port));
connection_thread_count_ =
static_cast<uint16_t>(settings_manager->GetInt(settings::Param::connection_thread_count));
Expand Down
7 changes: 4 additions & 3 deletions src/include/network/connection_dispatcher_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@ class ConnectionDispatcherTask : public common::NotifiableTask {
* Creates a new ConnectionDispatcherTask
*
* @param num_handlers The number of handler tasks to spawn.
* @param listen_fd The server socket fd to listen on.
* @param dedicated_thread_owner The DedicatedThreadOwner associated with this task
* @param interpreter_provider provider that constructs protocol interpreters
* @param connection_handle_factory The connection handle factory pointer to pass down to the handlers
* @param thread_registry DedicatedThreadRegistry dependency needed because it eventually spawns more threads in
* RunTask
* @param file_descriptors The list of file descriptors to listen on
*/
ConnectionDispatcherTask(uint32_t num_handlers, int listen_fd, common::DedicatedThreadOwner *dedicated_thread_owner,
ConnectionDispatcherTask(uint32_t num_handlers, common::DedicatedThreadOwner *dedicated_thread_owner,
common::ManagedPointer<ProtocolInterpreter::Provider> interpreter_provider,
common::ManagedPointer<ConnectionHandleFactory> connection_handle_factory,
common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry);
common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry,
std::initializer_list<int> file_descriptors);

/**
* @brief Dispatches the client connection at fd to a handler.
Expand Down
19 changes: 15 additions & 4 deletions src/include/network/terrier_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <cstdlib>
#include <cstring>
#include <memory>
#include <string>
#include <vector>

#include "common/dedicated_thread_owner.h"
Expand All @@ -25,6 +26,11 @@

namespace terrier::network {

// The name is based on https://www.postgresql.org/docs/9.3/runtime-config-connection.html
// {0}: Directory for the socket from -uds_file_directory on the command line.
// {1}: Port number from -port on the command line.
constexpr std::string_view UNIX_DOMAIN_SOCKET_FORMAT_STRING = "{0}/.s.PGSQL.{1}";

/**
* TerrierServer is the entry point of the network layer
*/
Expand All @@ -36,7 +42,7 @@ class TerrierServer : public common::DedicatedThreadOwner {
TerrierServer(common::ManagedPointer<ProtocolInterpreter::Provider> protocol_provider,
common::ManagedPointer<ConnectionHandleFactory> connection_handle_factory,
common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry, uint16_t port,
uint16_t connection_thread_count);
uint16_t connection_thread_count, std::string socket_directory);

~TerrierServer() override = default;

Expand Down Expand Up @@ -78,6 +84,9 @@ class TerrierServer : public common::DedicatedThreadOwner {
// threads can be safely taken away, but I don't understand the networking stuff well enough to say for sure what
// that assertion is
bool OnThreadRemoval(common::ManagedPointer<common::DedicatedThreadTask> task) override { return true; }
enum SocketType { UNIX_DOMAIN_SOCKET, NETWORKED_SOCKET };
template <SocketType type>
void RegisterSocket();

std::mutex running_mutex_;
bool running_;
Expand All @@ -86,9 +95,11 @@ class TerrierServer : public common::DedicatedThreadOwner {
// For logging purposes
// static void LogCallback(int severity, const char *msg);

uint16_t port_; // port number
int listen_fd_ = -1; // server socket fd that TerrierServer is listening on
const uint32_t max_connections_; // maximum number of connections
uint16_t port_; // port number
int network_socket_fd_ = -1; // networked server socket fd that TerrierServer is listening on
int unix_domain_socket_fd_ = -1; // unix-based local socket fd that TerrierServer may listen on
const std::string socket_directory_; // Where to store the Unix domain socket
const uint32_t max_connections_; // maximum number of connections

common::ManagedPointer<ConnectionHandleFactory> connection_handle_factory_;
common::ManagedPointer<ProtocolInterpreter::Provider> provider_;
Expand Down
9 changes: 9 additions & 0 deletions src/include/settings/settings_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ SETTING_int(
terrier::settings::Callbacks::NoOp
)

// Path to socket file for Unix domain sockets
SETTING_string(
uds_file_directory,
"The directory for the Unix domain socket (default: /tmp/)",
"/tmp/",
false,
terrier::settings::Callbacks::NoOp
)

// RecordBufferSegmentPool size limit
SETTING_int(
record_buffer_segment_size,
Expand Down
14 changes: 10 additions & 4 deletions src/network/connection_dispatcher_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@
namespace terrier::network {

ConnectionDispatcherTask::ConnectionDispatcherTask(
uint32_t num_handlers, int listen_fd, common::DedicatedThreadOwner *dedicated_thread_owner,
uint32_t num_handlers, common::DedicatedThreadOwner *dedicated_thread_owner,
common::ManagedPointer<ProtocolInterpreter::Provider> interpreter_provider,
common::ManagedPointer<ConnectionHandleFactory> connection_handle_factory,
common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry)
common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry,
std::initializer_list<int> file_descriptors)
: NotifiableTask(MASTER_THREAD_ID),
num_handlers_(num_handlers),
dedicated_thread_owner_(dedicated_thread_owner),
connection_handle_factory_(connection_handle_factory),
thread_registry_(thread_registry),
interpreter_provider_(interpreter_provider),
next_handler_(0) {
RegisterEvent(listen_fd, EV_READ | EV_PERSIST, METHOD_AS_CALLBACK(ConnectionDispatcherTask, DispatchConnection),
this);
for (int listen_fd : file_descriptors) {
if (listen_fd >= 0) {
RegisterEvent(listen_fd, EV_READ | EV_PERSIST, METHOD_AS_CALLBACK(ConnectionDispatcherTask, DispatchConnection),
this);
}
}

RegisterSignalEvent(SIGHUP, METHOD_AS_CALLBACK(NotifiableTask, ExitLoop), this);
}

Expand Down
1 change: 0 additions & 1 deletion src/network/network_io_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <sys/file.h>

#include <memory>
#include <utility>

#include "network/terrier_server.h"

Expand Down
138 changes: 109 additions & 29 deletions src/network/terrier_server.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "network/terrier_server.h"

#include <sys/un.h>
#include <unistd.h>

#include <fstream>
#include <memory>
#include <utility>

#include "common/dedicated_thread_registry.h"
#include "common/settings.h"
Expand All @@ -18,10 +20,11 @@ namespace terrier::network {
TerrierServer::TerrierServer(common::ManagedPointer<ProtocolInterpreter::Provider> protocol_provider,
common::ManagedPointer<ConnectionHandleFactory> connection_handle_factory,
common::ManagedPointer<common::DedicatedThreadRegistry> thread_registry,
const uint16_t port, const uint16_t connection_thread_count)
const uint16_t port, const uint16_t connection_thread_count, std::string socket_directory)
: DedicatedThreadOwner(thread_registry),
running_(false),
port_(port),
socket_directory_(std::move(socket_directory)),
max_connections_(connection_thread_count),
connection_handle_factory_(connection_handle_factory),
provider_(protocol_provider) {
Expand All @@ -39,44 +42,113 @@ TerrierServer::TerrierServer(common::ManagedPointer<ProtocolInterpreter::Provide
signal(SIGPIPE, SIG_IGN);
}

void TerrierServer::RunServer() {
// This line is critical to performance for some reason
evthread_use_pthreads();
template <TerrierServer::SocketType type>
void TerrierServer::RegisterSocket() {
static_assert(type == NETWORKED_SOCKET || type == UNIX_DOMAIN_SOCKET, "There should only be two socket types.");

constexpr auto conn_backlog = common::Settings::CONNECTION_BACKLOG;
constexpr auto is_networked_socket = type == NETWORKED_SOCKET;
constexpr auto socket_description = std::string_view(is_networked_socket ? "networked" : "Unix domain");

auto &socket_fd = is_networked_socket ? network_socket_fd_ : unix_domain_socket_fd_;

// Gets the appropriate sockaddr for the given SocketType. Abuse a lambda and auto to specialize the type.
auto socket_addr = ([&] {
if constexpr (is_networked_socket) { // NOLINT
struct sockaddr_in sin = {0};

int conn_backlog = common::Settings::CONNECTION_BACKLOG;
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = INADDR_ANY;
sin.sin_port = htons(port_);

struct sockaddr_in sin;
std::memset(&sin, 0, sizeof(sin));
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = INADDR_ANY;
sin.sin_port = htons(port_);
return sin;
} else { // NOLINT
// Builds the socket path name
const std::string socket_path = fmt::format(UNIX_DOMAIN_SOCKET_FORMAT_STRING, socket_directory_, port_);
struct sockaddr_un sun = {0};

listen_fd_ = socket(AF_INET, SOCK_STREAM, 0);
// Validate pathname
if (socket_path.length() > sizeof(sun.sun_path) /* Max Unix socket path length */) {
NETWORK_LOG_ERROR(fmt::format("Domain socket name too long (must be <= {} characters)", sizeof(sun.sun_path)));
throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to name {} socket.", socket_description));
}

if (listen_fd_ < 0) {
NETWORK_LOG_ERROR("Failed to open socket: {}", strerror(errno));
throw NETWORK_PROCESS_EXCEPTION("Failed to open socket.");
sun.sun_family = AF_UNIX;
socket_path.copy(sun.sun_path, sizeof(sun.sun_path));

return sun;
}
})();

// Create socket
socket_fd = socket(is_networked_socket ? AF_INET : AF_UNIX, SOCK_STREAM, 0);

// Check if socket was successfully created
if (socket_fd < 0) {
NETWORK_LOG_ERROR("Failed to open {} socket: {}", socket_description, strerror(errno));
throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to open {} socket.", socket_description));
}

int reuse = 1;
setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));
// For networked sockets, tell the kernel that we would like to reuse local addresses whenever possible.
if constexpr (is_networked_socket) { // NOLINT
int reuse = 1;
setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));
}

int retval = bind(listen_fd_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
if (retval < 0) {
NETWORK_LOG_ERROR("Failed to bind socket: {}", strerror(errno));
throw NETWORK_PROCESS_EXCEPTION("Failed to bind socket.");
// Bind the socket
int status = bind(socket_fd, reinterpret_cast<struct sockaddr *>(&socket_addr), sizeof(socket_addr));
if (status < 0) {
NETWORK_LOG_ERROR("Failed to bind {} socket: {}", socket_description, strerror(errno), errno);

// We can recover from exactly one type of error here, contingent on this being a Unix domain socket.
if constexpr (!is_networked_socket) { // NOLINT
auto recovered = false;

if (errno == EADDRINUSE) {
// I find this disgusting, but it's the approach favored by a bunch of software that uses Unix domain sockets.
// BSD syslogd, for example, does this in *every* case--error handling or not--and I'm not one to question it.
// To elaborate, I strongly dislike the idea of overwriting existing domain sockets. The idea is that in some
// edge cases (say, the process gets kill -9'd or crashes) the OS will not remove the existing Unix socket, so
// we have to delete it ourselves. You would think there'd be a better way to handle such cases--and technically
// there is, with Linux abstract namespace sockets--but it's non-portable and it's incompatible with psql.
recovered = !std::remove(fmt::format(UNIX_DOMAIN_SOCKET_FORMAT_STRING, socket_directory_, port_).c_str()) &&
bind(socket_fd, reinterpret_cast<struct sockaddr *>(&socket_addr), sizeof(socket_addr)) >= 0;
}

if (recovered) {
NETWORK_LOG_INFO("Recovered! Managed to bind {} socket by purging a pre-existing bind.", socket_description);
} else {
throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to bind and recover {} socket.", socket_description));
}
} else { // NOLINT
throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to bind {} socket.", socket_description));
}
}
retval = listen(listen_fd_, conn_backlog);
if (retval < 0) {
NETWORK_LOG_ERROR("Failed to create listen socket: {}", strerror(errno));
throw NETWORK_PROCESS_EXCEPTION("Failed to create listen socket.");

// Listen on the socket
status = listen(socket_fd, conn_backlog);
if (status < 0) {
NETWORK_LOG_ERROR("Failed to listen on {} socket: {}", socket_description, strerror(errno));
throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to listen on {} socket.", socket_description));
}

dispatcher_task_ = thread_registry_->RegisterDedicatedThread<ConnectionDispatcherTask>(
this /* requester */, max_connections_, listen_fd_, this, common::ManagedPointer(provider_.Get()),
connection_handle_factory_, thread_registry_);
NETWORK_LOG_INFO("Listening on {} socket with port {} [PID={}]", socket_description, port_, ::getpid());
}

void TerrierServer::RunServer() {
// This line is critical to performance for some reason
evthread_use_pthreads();

// Register the network socket
RegisterSocket<NETWORKED_SOCKET>();

NETWORK_LOG_INFO("Listening on port {0} [PID={1}]", port_, ::getpid());
// Register the Unix domain socket
RegisterSocket<UNIX_DOMAIN_SOCKET>();

// Create a dispatcher to handle connections to the sockets that have been created.
dispatcher_task_ = thread_registry_->RegisterDedicatedThread<ConnectionDispatcherTask>(
this /* requester */, max_connections_, this, common::ManagedPointer(provider_.Get()), connection_handle_factory_,
thread_registry_, std::initializer_list<int>({unix_domain_socket_fd_, network_socket_fd_}));

// Set the running_ flag for any waiting threads
{
Expand All @@ -90,7 +162,15 @@ void TerrierServer::StopServer() {
const bool result UNUSED_ATTRIBUTE =
thread_registry_->StopTask(this, dispatcher_task_.CastManagedPointerTo<common::DedicatedThreadTask>());
TERRIER_ASSERT(result, "Failed to stop ConnectionDispatcherTask.");
TerrierClose(listen_fd_);

// Close the network socket
TerrierClose(network_socket_fd_);

// Close the Unix domain socket if it exists
if (unix_domain_socket_fd_ >= 0) {
std::remove(fmt::format(UNIX_DOMAIN_SOCKET_FORMAT_STRING, socket_directory_, port_).c_str());
}

NETWORK_LOG_INFO("Server Closed");

// Clear the running_ flag for any waiting threads and wake up them up with the condition variable
Expand Down
Loading

0 comments on commit b68a4a8

Please sign in to comment.