diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index eed7bc5f5..f564ae532 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -93,6 +93,10 @@ add_executable(example_file_upload example_file_upload.cpp) add_warnings_optimizations(example_file_upload) target_link_libraries(example_file_upload PUBLIC Crow::Crow) +add_executable(example_unix_socket example_unix_socket.cpp) +add_warnings_optimizations(example_unix_socket) +target_link_libraries(example_unix_socket PUBLIC Crow::Crow) + if(MSVC) add_executable(example_vs example_vs.cpp) add_warnings_optimizations(example_vs) diff --git a/examples/example_unix_socket.cpp b/examples/example_unix_socket.cpp new file mode 100644 index 000000000..2c86621db --- /dev/null +++ b/examples/example_unix_socket.cpp @@ -0,0 +1,18 @@ +#include "crow.h" + +#include + +int main() +{ + crow::SimpleApp app; + + CROW_ROUTE(app, "/") + ([]() { + return "Hello, world!"; + }); + + std::string local_socket_path = "example.sock"; + unlink(local_socket_path.c_str()); + app.local_socket_path(local_socket_path).run(); + +} diff --git a/include/crow.h b/include/crow.h index e2afd8ac8..cb3720cff 100644 --- a/include/crow.h +++ b/include/crow.h @@ -5,6 +5,7 @@ #include "crow/TinySHA1.hpp" #include "crow/settings.h" #include "crow/socket_adaptors.h" +#include "crow/socket_acceptors.h" #include "crow/json.h" #include "crow/mustache.h" #include "crow/logging.h" diff --git a/include/crow/app.h b/include/crow/app.h index 4119859b7..8c244e1cb 100644 --- a/include/crow/app.h +++ b/include/crow/app.h @@ -202,11 +202,12 @@ namespace crow using self_t = Crow; /// \brief The HTTP server - using server_t = Server; - + using server_t = Server; + /// \brief An HTTP server that runs on unix domain socket + using unix_server_t = Server; #ifdef CROW_ENABLE_SSL /// \brief An HTTP server that runs on SSL with an SSLAdaptor - using ssl_server_t = Server; + using ssl_server_t = Server; #endif Crow() {} @@ -349,6 +350,20 @@ namespace crow return bindaddr_; } + /// \brief Disable tcp/ip and use unix domain socket instead + self_t& local_socket_path(std::string path) + { + bindaddr_ = path; + use_unix_ = true; + return *this; + } + + /// \brief Get the unix domain socket path + std::string local_socket_path() + { + return bindaddr_; + } + /// \brief Run the server on multiple threads using all available threads self_t& multithreaded() { @@ -516,6 +531,7 @@ namespace crow #ifdef CROW_ENABLE_SSL if (ssl_used_) { + router_.using_ssl = true; ssl_server_ = std::move(std::unique_ptr(new ssl_server_t(this, endpoint, server_name_, &middlewares_, concurrency_, timeout_, &ssl_context_))); ssl_server_->set_tick_function(tick_interval_, tick_function_); @@ -530,14 +546,30 @@ namespace crow else #endif { - server_ = std::move(std::unique_ptr(new server_t(this, endpoint, server_name_, &middlewares_, concurrency_, timeout_, nullptr))); - server_->set_tick_function(tick_interval_, tick_function_); - for (auto snum : signals_) + if (use_unix_) { - server_->signal_add(snum); + UnixSocketAcceptor::endpoint endpoint(bindaddr_); + unix_server_ = std::move(std::unique_ptr(new unix_server_t(this, endpoint, server_name_, &middlewares_, concurrency_, timeout_, nullptr))); + unix_server_->set_tick_function(tick_interval_, tick_function_); + for (auto snum : signals_) + { + unix_server_->signal_add(snum); + } + notify_server_start(); + unix_server_->run(); + } + else + { + TCPAcceptor::endpoint endpoint(asio::ip::address::from_string(bindaddr_), port_); + server_ = std::move(std::unique_ptr(new server_t(this, endpoint, server_name_, &middlewares_, concurrency_, timeout_, nullptr))); + server_->set_tick_function(tick_interval_, tick_function_); + for (auto snum : signals_) + { + server_->signal_add(snum); + } + notify_server_start(); + server_->run(); } - notify_server_start(); - server_->run(); } } @@ -571,6 +603,7 @@ namespace crow websocket->close("Server Application Terminated"); } if (server_) { server_->stop(); } + if (unix_server_) { unix_server_->stop(); } } } @@ -713,12 +746,12 @@ namespace crow status = cv_started_.wait_until(lock, wait_until); } } - if (status == std::cv_status::no_timeout) { - if (server_) - { + if (server_) { status = server_->wait_for_start(wait_until); + } else if (unix_server_) { + status = unix_server_->wait_for_start(wait_until); } #ifdef CROW_ENABLE_SSL else if (ssl_server_) @@ -763,6 +796,7 @@ namespace crow uint64_t max_payload_{UINT64_MAX}; std::string server_name_ = std::string("Crow/") + VERSION; std::string bindaddr_ = "0.0.0.0"; + bool use_unix_ = false; size_t res_stream_threshold_ = 1048576; Router router_; bool static_routes_added_{false}; @@ -784,6 +818,7 @@ namespace crow #endif std::unique_ptr server_; + std::unique_ptr unix_server_; std::vector signals_{SIGINT, SIGTERM}; diff --git a/include/crow/http_connection.h b/include/crow/http_connection.h index 76c82f014..1391b5108 100644 --- a/include/crow/http_connection.h +++ b/include/crow/http_connection.h @@ -144,9 +144,7 @@ namespace crow req_.middleware_context = static_cast(&ctx_); req_.middleware_container = static_cast(middlewares_); req_.io_context = &adaptor_.get_io_context(); - req_.remote_ip_address = adaptor_.remote_endpoint().address().to_string(); - add_keep_alive_ = req_.keep_alive; close_connection_ = req_.close_connection; diff --git a/include/crow/http_server.h b/include/crow/http_server.h index 12ff3726d..d25ad96af 100644 --- a/include/crow/http_server.h +++ b/include/crow/http_server.h @@ -37,13 +37,14 @@ namespace crow // NOTE: Already documented in "crow/app.h" using error_code = asio::error_code; #endif using tcp = asio::ip::tcp; + using stream_protocol = asio::local::stream_protocol; - template + template class Server { public: Server(Handler* handler, - const tcp::endpoint& endpoint, + typename Acceptor::endpoint endpoint, std::string server_name = std::string("Crow/") + VERSION, std::tuple* middlewares = nullptr, uint16_t concurrency = 1, @@ -153,13 +154,10 @@ namespace crow // NOTE: Already documented in "crow/app.h" on_tick(); }); } - handler_->port(acceptor_.local_endpoint().port()); - - - CROW_LOG_INFO << server_name_ - << " server is running at " << (handler_->ssl_used() ? "https://" : "http://") - << acceptor_.local_endpoint().address() << ":" << acceptor_.local_endpoint().port() << " using " << concurrency_ << " threads"; + CROW_LOG_INFO << server_name_ + << " server is running at " << acceptor_.url_display(handler_->ssl_used()) + << " using " << concurrency_ << " threads"; CROW_LOG_INFO << "Call `app.loglevel(crow::LogLevel::Warning)` to hide Info level logs."; signals_.async_wait( @@ -252,7 +250,7 @@ namespace crow // NOTE: Already documented in "crow/app.h" ic, handler_, server_name_, middlewares_, get_cached_date_str_pool_[context_idx], *task_timer_pool_[context_idx], adaptor_ctx_, task_queue_length_pool_[context_idx]); - acceptor_.async_accept( + acceptor_.raw_acceptor().async_accept( p->socket(), [this, p, &ic, context_idx](error_code ec) { if (!ec) @@ -285,7 +283,7 @@ namespace crow // NOTE: Already documented in "crow/app.h" asio::io_context io_context_; std::vector task_timer_pool_; std::vector> get_cached_date_str_pool_; - tcp::acceptor acceptor_; + Acceptor acceptor_; bool shutting_down_ = false; bool server_started_{false}; std::condition_variable cv_started_; @@ -298,6 +296,7 @@ namespace crow // NOTE: Already documented in "crow/app.h" uint16_t concurrency_{2}; std::uint8_t timeout_; std::string server_name_; + bool use_unix_; std::vector> task_queue_length_pool_; std::chrono::milliseconds tick_interval_; diff --git a/include/crow/routing.h b/include/crow/routing.h index 1ee7b0f57..046241cf0 100644 --- a/include/crow/routing.h +++ b/include/crow/routing.h @@ -121,6 +121,11 @@ namespace crow // NOTE: Already documented in "crow/app.h" res = response(404); res.end(); } + virtual void handle_upgrade(const request&, response& res, UnixSocketAdaptor&&) + { + res = response(404); + res.end(); + } #ifdef CROW_ENABLE_SSL virtual void handle_upgrade(const request&, response& res, SSLAdaptor&&) { @@ -442,6 +447,11 @@ namespace crow // NOTE: Already documented in "crow/app.h" max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload(); new crow::websocket::Connection(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_); } + void handle_upgrade(const request& req, response&, UnixSocketAdaptor&& adaptor) override + { + max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload(); + new crow::websocket::Connection(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_); + } #ifdef CROW_ENABLE_SSL void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override { diff --git a/include/crow/socket_acceptors.h b/include/crow/socket_acceptors.h new file mode 100644 index 000000000..db31f3133 --- /dev/null +++ b/include/crow/socket_acceptors.h @@ -0,0 +1,62 @@ +#pragma once +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#endif +#include + +namespace crow +{ + using tcp = asio::ip::tcp; + using stream_protocol = asio::local::stream_protocol; + + struct TCPAcceptor + { + using endpoint = tcp::endpoint; + tcp::acceptor acceptor_; + TCPAcceptor(asio::io_service& io_service, const endpoint& endpoint_): + acceptor_(io_service, endpoint_) {} + + int16_t port() const + { + return acceptor_.local_endpoint().port(); + } + std::string url_display(bool ssl_used) const + { + return (ssl_used ? "https://" : "http://") + acceptor_.local_endpoint().address().to_string() + ":" + std::to_string(acceptor_.local_endpoint().port()); + } + tcp::acceptor& raw_acceptor() + { + return acceptor_; + } + endpoint local_endpoint() const + { + return acceptor_.local_endpoint(); + } + }; + + struct UnixSocketAcceptor + { + using endpoint = stream_protocol::endpoint; + stream_protocol::acceptor acceptor_; + UnixSocketAcceptor(asio::io_service& io_service, const endpoint& endpoint_): + acceptor_(io_service, endpoint_, false) {} + // reuse addr must be false (https://github.com/chriskohlhoff/asio/issues/622) + + int16_t port() const + { + return 0; + } + std::string url_display(bool) const + { + return acceptor_.local_endpoint().path(); + } + stream_protocol::acceptor& raw_acceptor() + { + return acceptor_; + } + endpoint local_endpoint() const + { + return acceptor_.local_endpoint(); + } + }; +} // namespace crow \ No newline at end of file diff --git a/include/crow/socket_adaptors.h b/include/crow/socket_adaptors.h index 076e4ad05..dead417a0 100644 --- a/include/crow/socket_adaptors.h +++ b/include/crow/socket_adaptors.h @@ -33,6 +33,7 @@ namespace crow using error_code = asio::error_code; #endif using tcp = asio::ip::tcp; + using stream_protocol = asio::local::stream_protocol; /// A wrapper for the asio::ip::tcp::socket and asio::ssl::stream struct SocketAdaptor @@ -64,6 +65,11 @@ namespace crow return socket_.remote_endpoint(); } + std::string address() const + { + return socket_.remote_endpoint().address().to_string(); + } + bool is_open() { return socket_.is_open(); @@ -102,6 +108,77 @@ namespace crow tcp::socket socket_; }; + struct UnixSocketAdaptor + { + using context = void; + UnixSocketAdaptor(asio::io_service& io_service, context*): + socket_(io_service) + { + } + + asio::io_service& get_io_service() + { + return GET_IO_SERVICE(socket_); + } + + stream_protocol::socket& raw_socket() + { + return socket_; + } + + stream_protocol::socket& socket() + { + return socket_; + } + + stream_protocol::endpoint remote_endpoint() + { + return socket_.local_endpoint(); + } + + std::string address() const + { + return ""; + } + + bool is_open() + { + return socket_.is_open(); + } + + void close() + { + asio::error_code ec; + socket_.close(ec); + } + + void shutdown_readwrite() + { + asio::error_code ec; + socket_.shutdown(asio::socket_base::shutdown_type::shutdown_both, ec); + } + + void shutdown_write() + { + asio::error_code ec; + socket_.shutdown(asio::socket_base::shutdown_type::shutdown_send, ec); + } + + void shutdown_read() + { + asio::error_code ec; + socket_.shutdown(asio::socket_base::shutdown_type::shutdown_receive, ec); + } + + template + void start(F f) + { + f(asio::error_code()); + } + + stream_protocol::socket socket_; + }; + #ifdef CROW_ENABLE_SSL struct SSLAdaptor { @@ -127,6 +204,11 @@ namespace crow return raw_socket().remote_endpoint(); } + std::string address() const + { + return ssl_socket_->lowest_layer().remote_endpoint().address().to_string(); + } + bool is_open() { return ssl_socket_ ? raw_socket().is_open() : false; diff --git a/include/crow/websocket.h b/include/crow/websocket.h index a315226f6..9659a9de6 100644 --- a/include/crow/websocket.h +++ b/include/crow/websocket.h @@ -279,7 +279,7 @@ namespace crow // NOTE: Already documented in "crow/app.h" std::string get_remote_ip() override { - return adaptor_.remote_endpoint().address().to_string(); + return adaptor_.address(); } void set_max_payload_size(uint64_t payload) diff --git a/tests/unittest.cpp b/tests/unittest.cpp index 454614292..eeb9f35ac 100644 --- a/tests/unittest.cpp +++ b/tests/unittest.cpp @@ -1536,9 +1536,9 @@ struct NullSimpleMiddleware TEST_CASE("middleware_simple") { App app; - asio::ip::address adr = asio::ip::make_address(LOCALHOST_ADDRESS); - tcp::endpoint ep(adr,45451); - decltype(app)::server_t server(&app, ep); + TCPAcceptor::endpoint endpoint(asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451); + decltype(app)::server_t server(&app, endpoint); + CROW_ROUTE(app, "/") ([&](const crow::request& req) { app.get_context(req); @@ -4036,6 +4036,33 @@ TEST_CASE("http2_upgrade_is_ignored") app.stop(); } +TEST_CASE("unix_socket") +{ + static char buf[2048]; + SimpleApp app; + CROW_ROUTE(app, "/").methods("GET"_method)([] { + return "A"; + }); + + constexpr const char* socket_path = "unittest.sock"; + unlink(socket_path); + auto _ = app.local_socket_path(socket_path).run_async(); + app.wait_for_server_start(); + + std::string sendmsg = "GET / HTTP/1.0\r\n\r\n"; + { + asio::io_service is; + asio::local::stream_protocol::socket c(is); + c.connect(asio::local::stream_protocol::endpoint(socket_path)); + + c.send(asio::buffer(sendmsg)); + + size_t recved = c.receive(asio::buffer(buf, 2048)); + CHECK('A' == buf[recved - 1]); + } + app.stop(); +} // unix_socket + TEST_CASE("option_header_passed_in_full") { const std::string ServerName = "AN_EXTREMELY_UNIQUE_SERVER_NAME";