diff --git a/src/include/main/db_main.h b/src/include/main/db_main.h index dd346d0382..04110cc94a 100644 --- a/src/include/main/db_main.h +++ b/src/include/main/db_main.h @@ -229,17 +229,18 @@ class DBMain { * @param traffic_cop argument to the ConnectionHandleFactor * @param port argument to TerrierServer * @param connection_thread_count argument to TerrierServer + * @param socket_directory argument to TerrierServer */ NetworkLayer(const common::ManagedPointer thread_registry, const common::ManagedPointer traffic_cop, const uint16_t port, - const uint16_t connection_thread_count) { + const uint16_t connection_thread_count, const std::string socket_directory) { connection_handle_factory_ = std::make_unique(traffic_cop); command_factory_ = std::make_unique(); provider_ = std::make_unique(common::ManagedPointer(command_factory_)); - server_ = std::make_unique(common::ManagedPointer(provider_), - common::ManagedPointer(connection_handle_factory_), - thread_registry, port, connection_thread_count); + server_ = std::make_unique( + common::ManagedPointer(provider_), common::ManagedPointer(connection_handle_factory_), thread_registry, port, + connection_thread_count, socket_directory); } /** @@ -360,7 +361,7 @@ class DBMain { TERRIER_ASSERT(use_traffic_cop_ && traffic_cop != DISABLED, "NetworkLayer needs TrafficCopLayer."); network_layer = std::make_unique(common::ManagedPointer(thread_registry), common::ManagedPointer(traffic_cop), - network_port_, connection_thread_count_); + network_port_, connection_thread_count_, uds_file_directory_); } db_main->settings_manager_ = std::move(settings_manager); @@ -663,6 +664,7 @@ class DBMain { bool use_query_cache_ = true; execution::vm::ExecutionMode execution_mode_ = execution::vm::ExecutionMode::Interpret; uint16_t network_port_ = 15721; + std::string uds_file_directory_ = "/tmp/"; uint16_t connection_thread_count_ = 4; bool use_network_ = false; @@ -696,6 +698,7 @@ class DBMain { gc_interval_ = settings_manager->GetInt(settings::Param::gc_interval); + uds_file_directory_ = settings_manager->GetString(settings::Param::uds_file_directory); network_port_ = static_cast(settings_manager->GetInt(settings::Param::port)); connection_thread_count_ = static_cast(settings_manager->GetInt(settings::Param::connection_thread_count)); diff --git a/src/include/network/connection_dispatcher_task.h b/src/include/network/connection_dispatcher_task.h index 5e6bec0c4a..1a4b5a5a38 100644 --- a/src/include/network/connection_dispatcher_task.h +++ b/src/include/network/connection_dispatcher_task.h @@ -31,17 +31,18 @@ class ConnectionDispatcherTask : public common::NotifiableTask { * Creates a new ConnectionDispatcherTask * * @param num_handlers The number of handler tasks to spawn. - * @param listen_fd The server socket fd to listen on. * @param dedicated_thread_owner The DedicatedThreadOwner associated with this task * @param interpreter_provider provider that constructs protocol interpreters * @param connection_handle_factory The connection handle factory pointer to pass down to the handlers * @param thread_registry DedicatedThreadRegistry dependency needed because it eventually spawns more threads in * RunTask + * @param file_descriptors The list of file descriptors to listen on */ - ConnectionDispatcherTask(uint32_t num_handlers, int listen_fd, common::DedicatedThreadOwner *dedicated_thread_owner, + ConnectionDispatcherTask(uint32_t num_handlers, common::DedicatedThreadOwner *dedicated_thread_owner, common::ManagedPointer interpreter_provider, common::ManagedPointer connection_handle_factory, - common::ManagedPointer thread_registry); + common::ManagedPointer thread_registry, + std::initializer_list file_descriptors); /** * @brief Dispatches the client connection at fd to a handler. diff --git a/src/include/network/terrier_server.h b/src/include/network/terrier_server.h index 1a352d5e8b..6a5b4af54a 100644 --- a/src/include/network/terrier_server.h +++ b/src/include/network/terrier_server.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include "common/dedicated_thread_owner.h" @@ -25,6 +26,11 @@ namespace terrier::network { +// The name is based on https://www.postgresql.org/docs/9.3/runtime-config-connection.html +// {0}: Directory for the socket from -uds_file_directory on the command line. +// {1}: Port number from -port on the command line. +constexpr std::string_view UNIX_DOMAIN_SOCKET_FORMAT_STRING = "{0}/.s.PGSQL.{1}"; + /** * TerrierServer is the entry point of the network layer */ @@ -36,7 +42,7 @@ class TerrierServer : public common::DedicatedThreadOwner { TerrierServer(common::ManagedPointer protocol_provider, common::ManagedPointer connection_handle_factory, common::ManagedPointer thread_registry, uint16_t port, - uint16_t connection_thread_count); + uint16_t connection_thread_count, std::string socket_directory); ~TerrierServer() override = default; @@ -78,6 +84,9 @@ class TerrierServer : public common::DedicatedThreadOwner { // threads can be safely taken away, but I don't understand the networking stuff well enough to say for sure what // that assertion is bool OnThreadRemoval(common::ManagedPointer task) override { return true; } + enum SocketType { UNIX_DOMAIN_SOCKET, NETWORKED_SOCKET }; + template + void RegisterSocket(); std::mutex running_mutex_; bool running_; @@ -86,9 +95,11 @@ class TerrierServer : public common::DedicatedThreadOwner { // For logging purposes // static void LogCallback(int severity, const char *msg); - uint16_t port_; // port number - int listen_fd_ = -1; // server socket fd that TerrierServer is listening on - const uint32_t max_connections_; // maximum number of connections + uint16_t port_; // port number + int network_socket_fd_ = -1; // networked server socket fd that TerrierServer is listening on + int unix_domain_socket_fd_ = -1; // unix-based local socket fd that TerrierServer may listen on + const std::string socket_directory_; // Where to store the Unix domain socket + const uint32_t max_connections_; // maximum number of connections common::ManagedPointer connection_handle_factory_; common::ManagedPointer provider_; diff --git a/src/include/settings/settings_defs.h b/src/include/settings/settings_defs.h index 8be3050470..d3e458b7f3 100644 --- a/src/include/settings/settings_defs.h +++ b/src/include/settings/settings_defs.h @@ -22,6 +22,15 @@ SETTING_int( terrier::settings::Callbacks::NoOp ) +// Path to socket file for Unix domain sockets +SETTING_string( + uds_file_directory, + "The directory for the Unix domain socket (default: /tmp/)", + "/tmp/", + false, + terrier::settings::Callbacks::NoOp +) + // RecordBufferSegmentPool size limit SETTING_int( record_buffer_segment_size, diff --git a/src/network/connection_dispatcher_task.cpp b/src/network/connection_dispatcher_task.cpp index b4e7fb3788..c686e36c6d 100644 --- a/src/network/connection_dispatcher_task.cpp +++ b/src/network/connection_dispatcher_task.cpp @@ -10,10 +10,11 @@ namespace terrier::network { ConnectionDispatcherTask::ConnectionDispatcherTask( - uint32_t num_handlers, int listen_fd, common::DedicatedThreadOwner *dedicated_thread_owner, + uint32_t num_handlers, common::DedicatedThreadOwner *dedicated_thread_owner, common::ManagedPointer interpreter_provider, common::ManagedPointer connection_handle_factory, - common::ManagedPointer thread_registry) + common::ManagedPointer thread_registry, + std::initializer_list file_descriptors) : NotifiableTask(MASTER_THREAD_ID), num_handlers_(num_handlers), dedicated_thread_owner_(dedicated_thread_owner), @@ -21,8 +22,13 @@ ConnectionDispatcherTask::ConnectionDispatcherTask( thread_registry_(thread_registry), interpreter_provider_(interpreter_provider), next_handler_(0) { - RegisterEvent(listen_fd, EV_READ | EV_PERSIST, METHOD_AS_CALLBACK(ConnectionDispatcherTask, DispatchConnection), - this); + for (int listen_fd : file_descriptors) { + if (listen_fd >= 0) { + RegisterEvent(listen_fd, EV_READ | EV_PERSIST, METHOD_AS_CALLBACK(ConnectionDispatcherTask, DispatchConnection), + this); + } + } + RegisterSignalEvent(SIGHUP, METHOD_AS_CALLBACK(NotifiableTask, ExitLoop), this); } diff --git a/src/network/network_io_wrapper.cpp b/src/network/network_io_wrapper.cpp index 15d6f04868..bfa4f9c13d 100644 --- a/src/network/network_io_wrapper.cpp +++ b/src/network/network_io_wrapper.cpp @@ -5,7 +5,6 @@ #include #include -#include #include "network/terrier_server.h" diff --git a/src/network/terrier_server.cpp b/src/network/terrier_server.cpp index 62ec919d25..1592eede56 100644 --- a/src/network/terrier_server.cpp +++ b/src/network/terrier_server.cpp @@ -1,9 +1,11 @@ #include "network/terrier_server.h" +#include #include #include #include +#include #include "common/dedicated_thread_registry.h" #include "common/settings.h" @@ -18,10 +20,11 @@ namespace terrier::network { TerrierServer::TerrierServer(common::ManagedPointer protocol_provider, common::ManagedPointer connection_handle_factory, common::ManagedPointer thread_registry, - const uint16_t port, const uint16_t connection_thread_count) + const uint16_t port, const uint16_t connection_thread_count, std::string socket_directory) : DedicatedThreadOwner(thread_registry), running_(false), port_(port), + socket_directory_(std::move(socket_directory)), max_connections_(connection_thread_count), connection_handle_factory_(connection_handle_factory), provider_(protocol_provider) { @@ -39,44 +42,113 @@ TerrierServer::TerrierServer(common::ManagedPointer +void TerrierServer::RegisterSocket() { + static_assert(type == NETWORKED_SOCKET || type == UNIX_DOMAIN_SOCKET, "There should only be two socket types."); + + constexpr auto conn_backlog = common::Settings::CONNECTION_BACKLOG; + constexpr auto is_networked_socket = type == NETWORKED_SOCKET; + constexpr auto socket_description = std::string_view(is_networked_socket ? "networked" : "Unix domain"); + + auto &socket_fd = is_networked_socket ? network_socket_fd_ : unix_domain_socket_fd_; + + // Gets the appropriate sockaddr for the given SocketType. Abuse a lambda and auto to specialize the type. + auto socket_addr = ([&] { + if constexpr (is_networked_socket) { // NOLINT + struct sockaddr_in sin = {0}; - int conn_backlog = common::Settings::CONNECTION_BACKLOG; + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = INADDR_ANY; + sin.sin_port = htons(port_); - struct sockaddr_in sin; - std::memset(&sin, 0, sizeof(sin)); - sin.sin_family = AF_INET; - sin.sin_addr.s_addr = INADDR_ANY; - sin.sin_port = htons(port_); + return sin; + } else { // NOLINT + // Builds the socket path name + const std::string socket_path = fmt::format(UNIX_DOMAIN_SOCKET_FORMAT_STRING, socket_directory_, port_); + struct sockaddr_un sun = {0}; - listen_fd_ = socket(AF_INET, SOCK_STREAM, 0); + // Validate pathname + if (socket_path.length() > sizeof(sun.sun_path) /* Max Unix socket path length */) { + NETWORK_LOG_ERROR(fmt::format("Domain socket name too long (must be <= {} characters)", sizeof(sun.sun_path))); + throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to name {} socket.", socket_description)); + } - if (listen_fd_ < 0) { - NETWORK_LOG_ERROR("Failed to open socket: {}", strerror(errno)); - throw NETWORK_PROCESS_EXCEPTION("Failed to open socket."); + sun.sun_family = AF_UNIX; + socket_path.copy(sun.sun_path, sizeof(sun.sun_path)); + + return sun; + } + })(); + + // Create socket + socket_fd = socket(is_networked_socket ? AF_INET : AF_UNIX, SOCK_STREAM, 0); + + // Check if socket was successfully created + if (socket_fd < 0) { + NETWORK_LOG_ERROR("Failed to open {} socket: {}", socket_description, strerror(errno)); + throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to open {} socket.", socket_description)); } - int reuse = 1; - setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); + // For networked sockets, tell the kernel that we would like to reuse local addresses whenever possible. + if constexpr (is_networked_socket) { // NOLINT + int reuse = 1; + setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); + } - int retval = bind(listen_fd_, reinterpret_cast(&sin), sizeof(sin)); - if (retval < 0) { - NETWORK_LOG_ERROR("Failed to bind socket: {}", strerror(errno)); - throw NETWORK_PROCESS_EXCEPTION("Failed to bind socket."); + // Bind the socket + int status = bind(socket_fd, reinterpret_cast(&socket_addr), sizeof(socket_addr)); + if (status < 0) { + NETWORK_LOG_ERROR("Failed to bind {} socket: {}", socket_description, strerror(errno), errno); + + // We can recover from exactly one type of error here, contingent on this being a Unix domain socket. + if constexpr (!is_networked_socket) { // NOLINT + auto recovered = false; + + if (errno == EADDRINUSE) { + // I find this disgusting, but it's the approach favored by a bunch of software that uses Unix domain sockets. + // BSD syslogd, for example, does this in *every* case--error handling or not--and I'm not one to question it. + // To elaborate, I strongly dislike the idea of overwriting existing domain sockets. The idea is that in some + // edge cases (say, the process gets kill -9'd or crashes) the OS will not remove the existing Unix socket, so + // we have to delete it ourselves. You would think there'd be a better way to handle such cases--and technically + // there is, with Linux abstract namespace sockets--but it's non-portable and it's incompatible with psql. + recovered = !std::remove(fmt::format(UNIX_DOMAIN_SOCKET_FORMAT_STRING, socket_directory_, port_).c_str()) && + bind(socket_fd, reinterpret_cast(&socket_addr), sizeof(socket_addr)) >= 0; + } + + if (recovered) { + NETWORK_LOG_INFO("Recovered! Managed to bind {} socket by purging a pre-existing bind.", socket_description); + } else { + throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to bind and recover {} socket.", socket_description)); + } + } else { // NOLINT + throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to bind {} socket.", socket_description)); + } } - retval = listen(listen_fd_, conn_backlog); - if (retval < 0) { - NETWORK_LOG_ERROR("Failed to create listen socket: {}", strerror(errno)); - throw NETWORK_PROCESS_EXCEPTION("Failed to create listen socket."); + + // Listen on the socket + status = listen(socket_fd, conn_backlog); + if (status < 0) { + NETWORK_LOG_ERROR("Failed to listen on {} socket: {}", socket_description, strerror(errno)); + throw NETWORK_PROCESS_EXCEPTION(fmt::format("Failed to listen on {} socket.", socket_description)); } - dispatcher_task_ = thread_registry_->RegisterDedicatedThread( - this /* requester */, max_connections_, listen_fd_, this, common::ManagedPointer(provider_.Get()), - connection_handle_factory_, thread_registry_); + NETWORK_LOG_INFO("Listening on {} socket with port {} [PID={}]", socket_description, port_, ::getpid()); +} + +void TerrierServer::RunServer() { + // This line is critical to performance for some reason + evthread_use_pthreads(); + + // Register the network socket + RegisterSocket(); - NETWORK_LOG_INFO("Listening on port {0} [PID={1}]", port_, ::getpid()); + // Register the Unix domain socket + RegisterSocket(); + + // Create a dispatcher to handle connections to the sockets that have been created. + dispatcher_task_ = thread_registry_->RegisterDedicatedThread( + this /* requester */, max_connections_, this, common::ManagedPointer(provider_.Get()), connection_handle_factory_, + thread_registry_, std::initializer_list({unix_domain_socket_fd_, network_socket_fd_})); // Set the running_ flag for any waiting threads { @@ -90,7 +162,15 @@ void TerrierServer::StopServer() { const bool result UNUSED_ATTRIBUTE = thread_registry_->StopTask(this, dispatcher_task_.CastManagedPointerTo()); TERRIER_ASSERT(result, "Failed to stop ConnectionDispatcherTask."); - TerrierClose(listen_fd_); + + // Close the network socket + TerrierClose(network_socket_fd_); + + // Close the Unix domain socket if it exists + if (unix_domain_socket_fd_ >= 0) { + std::remove(fmt::format(UNIX_DOMAIN_SOCKET_FORMAT_STRING, socket_directory_, port_).c_str()); + } + NETWORK_LOG_INFO("Server Closed"); // Clear the running_ flag for any waiting threads and wake up them up with the condition variable diff --git a/test/network/network_test.cpp b/test/network/network_test.cpp index 59c6414fdd..3822ab622b 100644 --- a/test/network/network_test.cpp +++ b/test/network/network_test.cpp @@ -51,6 +51,7 @@ class NetworkTests : public TerrierTest { std::unique_ptr handle_factory_; common::DedicatedThreadRegistry thread_registry_ = common::DedicatedThreadRegistry(DISABLED); uint16_t port_ = 15721; + std::string socket_directory_ = "/tmp/"; uint16_t connection_thread_count_ = 4; FakeCommandFactory fake_command_factory_; PostgresProtocolInterpreter::Provider protocol_provider_{ @@ -81,10 +82,10 @@ class NetworkTests : public TerrierTest { try { handle_factory_ = std::make_unique(common::ManagedPointer(tcop_)); - server_ = - std::make_unique(common::ManagedPointer(&protocol_provider_), - common::ManagedPointer(handle_factory_.get()), - common::ManagedPointer(&thread_registry_), port_, connection_thread_count_); + server_ = std::make_unique( + common::ManagedPointer(&protocol_provider_), + common::ManagedPointer(handle_factory_.get()), common::ManagedPointer(&thread_registry_), port_, + connection_thread_count_, socket_directory_); server_->RunServer(); } catch (NetworkProcessException &exception) { NETWORK_LOG_ERROR("[LaunchServer] exception when launching server"); @@ -181,6 +182,35 @@ TEST_F(NetworkTests, SimpleQueryTest) { NETWORK_LOG_DEBUG("[SimpleQueryTest] Client has closed"); } +/** + * Performs the exact same test as SimpleQueryTest, but using a Unix domain socket instead. + * This just verifies that the Unix domain socket infrastructure works. + */ +// NOLINTNEXTLINE +TEST_F(NetworkTests, UnixDomainSocketTest) { + try { + /* + * We specify the location of the domain socket (defaults to /tmp/) for PSQL. + * This is necessary in order to ensure that the Unix domain socket gets used. + */ + pqxx::connection c(fmt::format("host={0} port={1} user={2} sslmode=disable application_name=psql", + socket_directory_, port_, catalog::DEFAULT_DATABASE)); + + pqxx::work txn1(c); + txn1.exec("INSERT INTO employee VALUES (1, 'Han LI');"); + txn1.exec("INSERT INTO employee VALUES (2, 'Shaokun ZOU');"); + txn1.exec("INSERT INTO employee VALUES (3, 'Yilei CHU');"); + + pqxx::result r = txn1.exec("SELECT name FROM employee where id=1;"); + txn1.commit(); + EXPECT_EQ(r.size(), 0); + } catch (const std::exception &e) { + NETWORK_LOG_ERROR("[UnixDomainSocketTest] Exception occurred: {0}", e.what()); + EXPECT_TRUE(false); + } + NETWORK_LOG_DEBUG("[UnixDomainSocketTest] Client has closed"); +} + // NOLINTNEXTLINE TEST_F(NetworkTests, BadQueryTest) { try {