Skip to content

Commit

Permalink
Merge branch 'vsock-fixes'
Browse files Browse the repository at this point in the history
Filippo Storniolo says:

====================
vsock: fix server prevents clients from reconnecting

This patch series introduce fix and tests for the following vsock bug:
If the same remote peer, using the same port, tries to connect
to a server on a listening port more than once, the server will
reject the connection, causing a "connection reset by peer"
error on the remote peer. This is due to the presence of a
dangling socket from a previous connection in both the connected
and bound socket lists.
The inconsistency of the above lists only occurs when the remote
peer disconnects and the server remains active.
This bug does not occur when the server socket is closed.

More details on the first patch changelog.
The remaining patches are refactoring and test.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
davem330 committed Nov 7, 2023
2 parents 7425627 + d80f63f commit 97b9432
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 17 deletions.
16 changes: 11 additions & 5 deletions net/vmw_vsock/virtio_transport_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -1369,11 +1369,17 @@ virtio_transport_recv_connected(struct sock *sk,
vsk->peer_shutdown |= RCV_SHUTDOWN;
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
vsk->peer_shutdown |= SEND_SHUTDOWN;
if (vsk->peer_shutdown == SHUTDOWN_MASK &&
vsock_stream_has_data(vsk) <= 0 &&
!sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset(vsk, NULL);
virtio_transport_do_close(vsk, true);
if (vsk->peer_shutdown == SHUTDOWN_MASK) {
if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset(vsk, NULL);
virtio_transport_do_close(vsk, true);
}
/* Remove this socket anyway because the remote peer sent
* the shutdown. This way a new connection will succeed
* if the remote peer uses the same source port,
* even if the old socket is still unreleased, but now disconnected.
*/
vsock_remove_sock(vsk);
}
if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
sk->sk_state_change(sk);
Expand Down
87 changes: 75 additions & 12 deletions tools/testing/vsock/util.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,48 @@ void vsock_wait_remote_close(int fd)
close(epollfd);
}

/* Bind to <bind_port>, connect to <cid, port> and return the file descriptor. */
int vsock_bind_connect(unsigned int cid, unsigned int port, unsigned int bind_port, int type)
{
struct sockaddr_vm sa_client = {
.svm_family = AF_VSOCK,
.svm_cid = VMADDR_CID_ANY,
.svm_port = bind_port,
};
struct sockaddr_vm sa_server = {
.svm_family = AF_VSOCK,
.svm_cid = cid,
.svm_port = port,
};

int client_fd, ret;

client_fd = socket(AF_VSOCK, type, 0);
if (client_fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}

if (bind(client_fd, (struct sockaddr *)&sa_client, sizeof(sa_client))) {
perror("bind");
exit(EXIT_FAILURE);
}

timeout_begin(TIMEOUT);
do {
ret = connect(client_fd, (struct sockaddr *)&sa_server, sizeof(sa_server));
timeout_check("connect");
} while (ret < 0 && errno == EINTR);
timeout_end();

if (ret < 0) {
perror("connect");
exit(EXIT_FAILURE);
}

return client_fd;
}

/* Connect to <cid, port> and return the file descriptor. */
static int vsock_connect(unsigned int cid, unsigned int port, int type)
{
Expand All @@ -104,6 +146,10 @@ static int vsock_connect(unsigned int cid, unsigned int port, int type)
control_expectln("LISTENING");

fd = socket(AF_VSOCK, type, 0);
if (fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}

timeout_begin(TIMEOUT);
do {
Expand Down Expand Up @@ -132,11 +178,8 @@ int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
return vsock_connect(cid, port, SOCK_SEQPACKET);
}

/* Listen on <cid, port> and return the first incoming connection. The remote
* address is stored to clientaddrp. clientaddrp may be NULL.
*/
static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp, int type)
/* Listen on <cid, port> and return the file descriptor. */
static int vsock_listen(unsigned int cid, unsigned int port, int type)
{
union {
struct sockaddr sa;
Expand All @@ -148,16 +191,13 @@ static int vsock_accept(unsigned int cid, unsigned int port,
.svm_cid = cid,
},
};
union {
struct sockaddr sa;
struct sockaddr_vm svm;
} clientaddr;
socklen_t clientaddr_len = sizeof(clientaddr.svm);
int fd;
int client_fd;
int old_errno;

fd = socket(AF_VSOCK, type, 0);
if (fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}

if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
perror("bind");
Expand All @@ -169,6 +209,24 @@ static int vsock_accept(unsigned int cid, unsigned int port,
exit(EXIT_FAILURE);
}

return fd;
}

/* Listen on <cid, port> and return the first incoming connection. The remote
* address is stored to clientaddrp. clientaddrp may be NULL.
*/
static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp, int type)
{
union {
struct sockaddr sa;
struct sockaddr_vm svm;
} clientaddr;
socklen_t clientaddr_len = sizeof(clientaddr.svm);
int fd, client_fd, old_errno;

fd = vsock_listen(cid, port, type);

control_writeln("LISTENING");

timeout_begin(TIMEOUT);
Expand Down Expand Up @@ -207,6 +265,11 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
}

int vsock_stream_listen(unsigned int cid, unsigned int port)
{
return vsock_listen(cid, port, SOCK_STREAM);
}

int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
Expand Down
3 changes: 3 additions & 0 deletions tools/testing/vsock/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ struct test_case {
void init_signals(void);
unsigned int parse_cid(const char *str);
int vsock_stream_connect(unsigned int cid, unsigned int port);
int vsock_bind_connect(unsigned int cid, unsigned int port,
unsigned int bind_port, int type);
int vsock_seqpacket_connect(unsigned int cid, unsigned int port);
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
int vsock_stream_listen(unsigned int cid, unsigned int port);
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
void vsock_wait_remote_close(int fd);
Expand Down
50 changes: 50 additions & 0 deletions tools/testing/vsock/vsock_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,51 @@ static void test_stream_shutrd_server(const struct test_opts *opts)
close(fd);
}

static void test_double_bind_connect_server(const struct test_opts *opts)
{
int listen_fd, client_fd, i;
struct sockaddr_vm sa_client;
socklen_t socklen_client = sizeof(sa_client);

listen_fd = vsock_stream_listen(VMADDR_CID_ANY, 1234);

for (i = 0; i < 2; i++) {
control_writeln("LISTENING");

timeout_begin(TIMEOUT);
do {
client_fd = accept(listen_fd, (struct sockaddr *)&sa_client,
&socklen_client);
timeout_check("accept");
} while (client_fd < 0 && errno == EINTR);
timeout_end();

if (client_fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}

/* Waiting for remote peer to close connection */
vsock_wait_remote_close(client_fd);
}

close(listen_fd);
}

static void test_double_bind_connect_client(const struct test_opts *opts)
{
int i, client_fd;

for (i = 0; i < 2; i++) {
/* Wait until server is ready to accept a new connection */
control_expectln("LISTENING");

client_fd = vsock_bind_connect(opts->peer_cid, 1234, 4321, SOCK_STREAM);

close(client_fd);
}
}

static struct test_case test_cases[] = {
{
.name = "SOCK_STREAM connection reset",
Expand Down Expand Up @@ -1285,6 +1330,11 @@ static struct test_case test_cases[] = {
.run_client = test_stream_msgzcopy_empty_errq_client,
.run_server = test_stream_msgzcopy_empty_errq_server,
},
{
.name = "SOCK_STREAM double bind connect",
.run_client = test_double_bind_connect_client,
.run_server = test_double_bind_connect_server,
},
{},
};

Expand Down

0 comments on commit 97b9432

Please sign in to comment.