diff --git a/clickhouse/base/socket.cpp b/clickhouse/base/socket.cpp index c6dc920e..e0f8fb1c 100644 --- a/clickhouse/base/socket.cpp +++ b/clickhouse/base/socket.cpp @@ -112,6 +112,16 @@ void SetNonBlock(SOCKET fd, bool value) { #endif } +void SetTimeout(SOCKET fd, const SocketTimeoutParams& timeout_params) { +#if defined(_unix_) + timeval recv_timeout { .tv_sec = timeout_params.recv_timeout.count(), .tv_usec = 0 }; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout)); + + timeval send_timeout { .tv_sec = timeout_params.send_timeout.count(), .tv_usec = 0 }; + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout)); +#endif +}; + ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept { #if defined(_win_) return WSAPoll(fds, nfds, timeout); @@ -120,7 +130,7 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept { #endif } -SOCKET SocketConnect(const NetworkAddress& addr) { +SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) { int last_err = 0; for (auto res = addr.Info(); res != nullptr; res = res->ai_next) { SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol)); @@ -130,6 +140,7 @@ SOCKET SocketConnect(const NetworkAddress& addr) { } SetNonBlock(s, true); + SetTimeout(s, timeout_params); if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) { int err = getSocketErrorCode(); @@ -213,6 +224,7 @@ NetworkAddress::~NetworkAddress() { const struct addrinfo* NetworkAddress::Info() const { return info_; } + const std::string & NetworkAddress::Host() const { return host_; } @@ -220,6 +232,7 @@ const std::string & NetworkAddress::Host() const { SocketBase::~SocketBase() = default; + SocketFactory::~SocketFactory() = default; void SocketFactory::sleepFor(const std::chrono::milliseconds& duration) { @@ -227,8 +240,8 @@ void SocketFactory::sleepFor(const std::chrono::milliseconds& duration) { } -Socket::Socket(const NetworkAddress& addr) - : handle_(SocketConnect(addr)) +Socket::Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) + : handle_(SocketConnect(addr, timeout_params)) {} Socket::Socket(Socket&& other) noexcept @@ -300,19 +313,21 @@ std::unique_ptr Socket::makeOutputStream() const { return std::make_unique(handle_); } + NonSecureSocketFactory::~NonSecureSocketFactory() {} std::unique_ptr NonSecureSocketFactory::connect(const ClientOptions &opts) { const auto address = NetworkAddress(opts.host, std::to_string(opts.port)); - auto socket = doConnect(address); + auto socket = doConnect(address, opts); setSocketOptions(*socket, opts); return socket; } -std::unique_ptr NonSecureSocketFactory::doConnect(const NetworkAddress& address) { - return std::make_unique(address); +std::unique_ptr NonSecureSocketFactory::doConnect(const NetworkAddress& address, const ClientOptions& opts) { + SocketTimeoutParams timeout_params { opts.connection_recv_timeout, opts.connection_send_timeout }; + return std::make_unique(address, timeout_params); } void NonSecureSocketFactory::setSocketOptions(Socket &socket, const ClientOptions &opts) { @@ -327,6 +342,7 @@ void NonSecureSocketFactory::setSocketOptions(Socket &socket, const ClientOption } } + SocketInput::SocketInput(SOCKET s) : s_(s) { diff --git a/clickhouse/base/socket.h b/clickhouse/base/socket.h index e7cacc19..b3d916e1 100644 --- a/clickhouse/base/socket.h +++ b/clickhouse/base/socket.h @@ -82,9 +82,14 @@ class SocketFactory { }; +struct SocketTimeoutParams { + const std::chrono::seconds recv_timeout {0}; + const std::chrono::seconds send_timeout {0}; +}; + class Socket : public SocketBase { public: - Socket(const NetworkAddress& addr); + Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params); Socket(Socket&& other) noexcept; Socket& operator=(Socket&& other) noexcept; @@ -119,7 +124,7 @@ class NonSecureSocketFactory : public SocketFactory { std::unique_ptr connect(const ClientOptions& opts) override; protected: - virtual std::unique_ptr doConnect(const NetworkAddress& address); + virtual std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts); void setSocketOptions(Socket& socket, const ClientOptions& opts); }; diff --git a/clickhouse/base/sslsocket.cpp b/clickhouse/base/sslsocket.cpp index 392c22fd..29efa504 100644 --- a/clickhouse/base/sslsocket.cpp +++ b/clickhouse/base/sslsocket.cpp @@ -198,9 +198,9 @@ SSL_CTX * SSLContext::getContext() { << "\n\t handshake state: " << SSL_get_state(ssl_) \ << std::endl */ -SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, - SSLContext& context) - : Socket(addr) +SSLSocket::SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params, + const SSLParams & ssl_params, SSLContext& context) + : Socket(addr, timeout_params) , ssl_(SSL_new(context.getContext()), &SSL_free) { auto ssl = ssl_.get(); @@ -267,8 +267,9 @@ SSLSocketFactory::SSLSocketFactory(const ClientOptions& opts) SSLSocketFactory::~SSLSocketFactory() = default; -std::unique_ptr SSLSocketFactory::doConnect(const NetworkAddress& address) { - return std::make_unique(address, ssl_params_, *ssl_context_); +std::unique_ptr SSLSocketFactory::doConnect(const NetworkAddress& address, const ClientOptions& opts) { + SocketTimeoutParams timeout_params { opts.connection_recv_timeout, opts.connection_send_timeout }; + return std::make_unique(address, timeout_params, ssl_params_, *ssl_context_); } std::unique_ptr SSLSocket::makeInputStream() const { diff --git a/clickhouse/base/sslsocket.h b/clickhouse/base/sslsocket.h index f37e4a5a..945de86d 100644 --- a/clickhouse/base/sslsocket.h +++ b/clickhouse/base/sslsocket.h @@ -48,7 +48,9 @@ class SSLContext class SSLSocket : public Socket { public: - explicit SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context); + explicit SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params, + const SSLParams& ssl_params, SSLContext& context); + SSLSocket(SSLSocket &&) = default; ~SSLSocket() override = default; @@ -69,7 +71,7 @@ class SSLSocketFactory : public NonSecureSocketFactory { ~SSLSocketFactory() override; protected: - std::unique_ptr doConnect(const NetworkAddress& address) override; + std::unique_ptr doConnect(const NetworkAddress& address, const ClientOptions& opts) override; private: const SSLParams ssl_params_; diff --git a/clickhouse/client.h b/clickhouse/client.h index 6de09b8a..679dd32c 100644 --- a/clickhouse/client.h +++ b/clickhouse/client.h @@ -86,6 +86,10 @@ struct ClientOptions { // TCP options DECLARE_FIELD(tcp_nodelay, bool, TcpNoDelay, true); + /// Connection socket timeout. If the timeout is set to zero then the operation will never timeout. + DECLARE_FIELD(connection_recv_timeout, std::chrono::seconds, SetConnectionRecvTimeout, std::chrono::seconds(0)); + DECLARE_FIELD(connection_send_timeout, std::chrono::seconds, SetConnectionSendTimeout, std::chrono::seconds(0)); + // TODO deprecate setting /** It helps to ease migration of the old codebases, which can't afford to switch * to using ColumnLowCardinalityT or ColumnLowCardinality directly, diff --git a/ut/socket_ut.cpp b/ut/socket_ut.cpp index 6f428428..36b6a65b 100644 --- a/ut/socket_ut.cpp +++ b/ut/socket_ut.cpp @@ -18,7 +18,7 @@ TEST(Socketcase, connecterror) { std::this_thread::sleep_for(std::chrono::seconds(1)); try { - Socket socket(addr); + Socket socket(addr, SocketTimeoutParams {}); } catch (const std::system_error& e) { FAIL(); } @@ -26,13 +26,37 @@ TEST(Socketcase, connecterror) { std::this_thread::sleep_for(std::chrono::seconds(1)); server.stop(); try { - Socket socket(addr); + Socket socket(addr, SocketTimeoutParams {}); FAIL(); } catch (const std::system_error& e) { ASSERT_NE(EINPROGRESS,e.code().value()); } } +TEST(Socketcase, timeoutrecv) { + using Seconds = std::chrono::seconds; + + int port = 19979; + NetworkAddress addr("localhost", std::to_string(port)); + LocalTcpServer server(port); + server.start(); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + try { + Socket socket(addr, SocketTimeoutParams { .recv_timeout = Seconds(5), .send_timeout = Seconds(5) }); + + std::unique_ptr ptr_input_stream = socket.makeInputStream(); + char buf[1024]; + ptr_input_stream->Read(buf, sizeof(buf)); + + } catch (const std::system_error& e) { + ASSERT_EQ(EAGAIN, e.code().value()); + } + + std::this_thread::sleep_for(std::chrono::seconds(1)); + server.stop(); +} + // Test to verify that reading from empty socket doesn't hangs. //TEST(Socketcase, ReadFromEmptySocket) { // const int port = 12345;