Skip to content

Socket RAII wrapper to prevent leaking socket #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 17, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,53 @@ ssize_t Poll(struct pollfd* fds, int nfds, int timeout) noexcept {
#endif
}

#ifndef INVALID_SOCKET
const SOCKET INVALID_SOCKET = -1;
#endif

void CloseSocket(SOCKET socket) {
if (socket == INVALID_SOCKET)
return;

#if defined(_win_)
closesocket(socket);
#else
close(socket);
#endif
}

struct SocketRAIIWrapper {
SOCKET socket = INVALID_SOCKET;

~SocketRAIIWrapper() {
CloseSocket(socket);
}

SOCKET operator*() const {
return socket;
}

SOCKET release() {
auto result = socket;
socket = INVALID_SOCKET;

return result;
}
};

SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params) {
int last_err = 0;
for (auto res = addr.Info(); res != nullptr; res = res->ai_next) {
SOCKET s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
SocketRAIIWrapper s{socket(res->ai_family, res->ai_socktype, res->ai_protocol)};

if (s == -1) {
if (*s == INVALID_SOCKET) {
continue;
}

SetNonBlock(s, true);
SetTimeout(s, timeout_params);
SetNonBlock(*s, true);
SetTimeout(*s, timeout_params);

if (connect(s, res->ai_addr, (int)res->ai_addrlen) != 0) {
if (connect(*s, res->ai_addr, (int)res->ai_addrlen) != 0) {
int err = getSocketErrorCode();
if (
err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK
Expand All @@ -165,7 +199,7 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
#endif
) {
pollfd fd;
fd.fd = s;
fd.fd = *s;
fd.events = POLLOUT;
fd.revents = 0;
ssize_t rval = Poll(&fd, 1, 5000);
Expand All @@ -175,18 +209,18 @@ SOCKET SocketConnect(const NetworkAddress& addr, const SocketTimeoutParams& time
}
if (rval > 0) {
socklen_t len = sizeof(err);
getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);
getsockopt(*s, SOL_SOCKET, SO_ERROR, (char*)&err, &len);

if (!err) {
SetNonBlock(s, false);
return s;
SetNonBlock(*s, false);
return s.release();
}
last_err = err;
}
}
} else {
SetNonBlock(s, false);
return s;
SetNonBlock(*s, false);
return s.release();
}
}
if (last_err > 0) {
Expand Down Expand Up @@ -265,15 +299,15 @@ Socket::Socket(const NetworkAddress & addr)
Socket::Socket(Socket&& other) noexcept
: handle_(other.handle_)
{
other.handle_ = -1;
other.handle_ = INVALID_SOCKET;
}

Socket& Socket::operator=(Socket&& other) noexcept {
if (this != &other) {
Close();

handle_ = other.handle_;
other.handle_ = -1;
other.handle_ = INVALID_SOCKET;
}

return *this;
Expand All @@ -284,14 +318,8 @@ Socket::~Socket() {
}

void Socket::Close() {
if (handle_ != -1) {
#if defined(_win_)
closesocket(handle_);
#else
close(handle_);
#endif
handle_ = -1;
}
CloseSocket(handle_);
handle_ = INVALID_SOCKET;
}

void Socket::SetTcpKeepAlive(int idle, int intvl, int cnt) noexcept {
Expand Down