Skip to content

Add recv/send timeouts to socket #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ void SetNonBlock(SOCKET fd, bool value) {
#endif
}

void SetTimeout(SOCKET fd, const SocketTimeoutParams& timeout_params) {
#if defined(_unix_)
timeval recv_timeout { .tv_sec = timeout_params.recv_timeout.count(), .tv_usec = 0 };
setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout));

timeval send_timeout { .tv_sec = timeout_params.send_timeout.count(), .tv_usec = 0 };
setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout));
#endif
};

ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept {
#if defined(_win_)
return WSAPoll(fds, nfds, timeout);
Expand All @@ -120,7 +130,7 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept {
#endif
}

SOCKET SocketConnect(const NetworkAddress& addr) {
SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) {
int last_err = 0;
for (auto res = addr.Info(); res != nullptr; res = res->ai_next) {
SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
Expand All @@ -130,6 +140,7 @@ SOCKET SocketConnect(const NetworkAddress& addr) {
}

SetNonBlock(s, true);
SetTimeout(s, timeout_params);

if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) {
int err = getSocketErrorCode();
Expand Down Expand Up @@ -213,22 +224,24 @@ NetworkAddress::~NetworkAddress() {
const struct addrinfo* NetworkAddress::Info() const {
return info_;
}

const std::string & NetworkAddress::Host() const {
return host_;
}


SocketBase::~SocketBase() = default;


SocketFactory::~SocketFactory() = default;

void SocketFactory::sleepFor(const std::chrono::milliseconds& duration) {
std::this_thread::sleep_for(duration);
}


Socket::Socket(const NetworkAddress& addr)
: handle_(SocketConnect(addr))
Socket::Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params)
: handle_(SocketConnect(addr, timeout_params))
{}

Socket::Socket(Socket&& other) noexcept
Expand Down Expand Up @@ -300,19 +313,21 @@ std::unique_ptr<OutputStream> Socket::makeOutputStream() const {
return std::make_unique<SocketOutput>(handle_);
}


NonSecureSocketFactory::~NonSecureSocketFactory() {}

std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts) {
const auto address = NetworkAddress(opts.host, std::to_string(opts.port));

auto socket = doConnect(address);
auto socket = doConnect(address, opts);
setSocketOptions(*socket, opts);

return socket;
}

std::unique_ptr<Socket> NonSecureSocketFactory::doConnect(const NetworkAddress& address) {
return std::make_unique<Socket>(address);
std::unique_ptr<Socket> NonSecureSocketFactory::doConnect(const NetworkAddress& address, const ClientOptions& opts) {
SocketTimeoutParams timeout_params { opts.connection_recv_timeout, opts.connection_send_timeout };
return std::make_unique<Socket>(address, timeout_params);
}

void NonSecureSocketFactory::setSocketOptions(Socket &socket, const ClientOptions &opts) {
Expand All @@ -327,6 +342,7 @@ void NonSecureSocketFactory::setSocketOptions(Socket &socket, const ClientOption
}
}


SocketInput::SocketInput(SOCKET s)
: s_(s)
{
Expand Down
9 changes: 7 additions & 2 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,14 @@ class SocketFactory {
};


struct SocketTimeoutParams {
const std::chrono::seconds recv_timeout {0};
const std::chrono::seconds send_timeout {0};
};

class Socket : public SocketBase {
public:
Socket(const NetworkAddress& addr);
Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params);
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;

Expand Down Expand Up @@ -119,7 +124,7 @@ class NonSecureSocketFactory : public SocketFactory {
std::unique_ptr<SocketBase> connect(const ClientOptions& opts) override;

protected:
virtual std::unique_ptr<Socket> doConnect(const NetworkAddress& address);
virtual std::unique_ptr<Socket> doConnect(const NetworkAddress& address, const ClientOptions& opts);

void setSocketOptions(Socket& socket, const ClientOptions& opts);
};
Expand Down
11 changes: 6 additions & 5 deletions clickhouse/base/sslsocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ SSL_CTX * SSLContext::getContext() {
<< "\n\t handshake state: " << SSL_get_state(ssl_) \
<< std::endl
*/
SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params,
SSLContext& context)
: Socket(addr)
SSLSocket::SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params,
const SSLParams & ssl_params, SSLContext& context)
: Socket(addr, timeout_params)
, ssl_(SSL_new(context.getContext()), &SSL_free)
{
auto ssl = ssl_.get();
Expand Down Expand Up @@ -267,8 +267,9 @@ SSLSocketFactory::SSLSocketFactory(const ClientOptions& opts)

SSLSocketFactory::~SSLSocketFactory() = default;

std::unique_ptr<Socket> SSLSocketFactory::doConnect(const NetworkAddress& address) {
return std::make_unique<SSLSocket>(address, ssl_params_, *ssl_context_);
std::unique_ptr<Socket> SSLSocketFactory::doConnect(const NetworkAddress& address, const ClientOptions& opts) {
SocketTimeoutParams timeout_params { opts.connection_recv_timeout, opts.connection_send_timeout };
return std::make_unique<SSLSocket>(address, timeout_params, ssl_params_, *ssl_context_);
}

std::unique_ptr<InputStream> SSLSocket::makeInputStream() const {
Expand Down
6 changes: 4 additions & 2 deletions clickhouse/base/sslsocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class SSLContext

class SSLSocket : public Socket {
public:
explicit SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context);
explicit SSLSocket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params,
const SSLParams& ssl_params, SSLContext& context);

SSLSocket(SSLSocket &&) = default;
~SSLSocket() override = default;

Expand All @@ -69,7 +71,7 @@ class SSLSocketFactory : public NonSecureSocketFactory {
~SSLSocketFactory() override;

protected:
std::unique_ptr<Socket> doConnect(const NetworkAddress& address) override;
std::unique_ptr<Socket> doConnect(const NetworkAddress& address, const ClientOptions& opts) override;

private:
const SSLParams ssl_params_;
Expand Down
4 changes: 4 additions & 0 deletions clickhouse/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ struct ClientOptions {
// TCP options
DECLARE_FIELD(tcp_nodelay, bool, TcpNoDelay, true);

/// Connection socket timeout. If the timeout is set to zero then the operation will never timeout.
DECLARE_FIELD(connection_recv_timeout, std::chrono::seconds, SetConnectionRecvTimeout, std::chrono::seconds(0));
DECLARE_FIELD(connection_send_timeout, std::chrono::seconds, SetConnectionSendTimeout, std::chrono::seconds(0));

// TODO deprecate setting
/** It helps to ease migration of the old codebases, which can't afford to switch
* to using ColumnLowCardinalityT or ColumnLowCardinality directly,
Expand Down
28 changes: 26 additions & 2 deletions ut/socket_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,45 @@ TEST(Socketcase, connecterror) {

std::this_thread::sleep_for(std::chrono::seconds(1));
try {
Socket socket(addr);
Socket socket(addr, SocketTimeoutParams {});
} catch (const std::system_error& e) {
FAIL();
}

std::this_thread::sleep_for(std::chrono::seconds(1));
server.stop();
try {
Socket socket(addr);
Socket socket(addr, SocketTimeoutParams {});
FAIL();
} catch (const std::system_error& e) {
ASSERT_NE(EINPROGRESS,e.code().value());
}
}

TEST(Socketcase, timeoutrecv) {
using Seconds = std::chrono::seconds;

int port = 19979;
NetworkAddress addr("localhost", std::to_string(port));
LocalTcpServer server(port);
server.start();

std::this_thread::sleep_for(std::chrono::seconds(1));
try {
Socket socket(addr, SocketTimeoutParams { .recv_timeout = Seconds(5), .send_timeout = Seconds(5) });

std::unique_ptr<InputStream> ptr_input_stream = socket.makeInputStream();
char buf[1024];
ptr_input_stream->Read(buf, sizeof(buf));

} catch (const std::system_error& e) {
ASSERT_EQ(EAGAIN, e.code().value());
}

std::this_thread::sleep_for(std::chrono::seconds(1));
server.stop();
}

// Test to verify that reading from empty socket doesn't hangs.
//TEST(Socketcase, ReadFromEmptySocket) {
// const int port = 12345;
Expand Down