diff --git a/subsys/net/lib/sockets/sockets.c b/subsys/net/lib/sockets/sockets.c index 3a0705266d5f8..8a9b8ce8dfb26 100644 --- a/subsys/net/lib/sockets/sockets.c +++ b/subsys/net/lib/sockets/sockets.c @@ -525,10 +525,22 @@ int zsock_connect_ctx(struct net_context *ctx, const struct sockaddr *addr, cb = zsock_connected_cb; } - SET_ERRNO(net_context_recv(ctx, zsock_received_cb, K_NO_WAIT, - ctx->user_data)); - SET_ERRNO(net_context_connect(ctx, addr, addrlen, cb, timeout, - ctx->user_data)); + if (net_context_get_type(ctx) == SOCK_STREAM) { + /* For STREAM sockets net_context_recv() only installs + * recv callback w/o side effects, and it has to be done + * first to avoid race condition, when TCP stream data + * arrives right after connect. + */ + SET_ERRNO(net_context_recv(ctx, zsock_received_cb, + K_NO_WAIT, ctx->user_data)); + SET_ERRNO(net_context_connect(ctx, addr, addrlen, cb, + timeout, ctx->user_data)); + } else { + SET_ERRNO(net_context_connect(ctx, addr, addrlen, cb, + timeout, ctx->user_data)); + SET_ERRNO(net_context_recv(ctx, zsock_received_cb, + K_NO_WAIT, ctx->user_data)); + } } return 0; diff --git a/tests/net/socket/udp/src/main.c b/tests/net/socket/udp/src/main.c index 11200157c2a42..310f130d5d85f 100644 --- a/tests/net/socket/udp/src/main.c +++ b/tests/net/socket/udp/src/main.c @@ -1293,6 +1293,96 @@ ZTEST(net_socket_udp, test_23_v6_dgram_overflow) BUF_AND_SIZE(test_str_all_tx_bufs)); } +static void test_dgram_connected(int sock_c, int sock_s1, int sock_s2, + struct sockaddr *addr_c, socklen_t addrlen_c, + struct sockaddr *addr_s1, socklen_t addrlen_s1, + struct sockaddr *addr_s2, socklen_t addrlen_s2) +{ + uint8_t tx_buf = 0xab; + uint8_t rx_buf; + int rv; + + rv = bind(sock_c, addr_c, addrlen_c); + zassert_equal(rv, 0, "client bind failed"); + + rv = bind(sock_s1, addr_s1, addrlen_s1); + zassert_equal(rv, 0, "server bind failed"); + + rv = bind(sock_s2, addr_s2, addrlen_s2); + zassert_equal(rv, 0, "server bind failed"); + + rv = connect(sock_c, addr_s1, addrlen_s1); + zassert_equal(rv, 0, "connect failed"); + + /* Verify that a datagram can be received from the connected address */ + rv = sendto(sock_s1, &tx_buf, sizeof(tx_buf), 0, addr_c, addrlen_c); + zassert_equal(rv, sizeof(tx_buf), "send failed %d", errno); + + /* Give the packet a chance to go through the net stack */ + k_msleep(10); + + rx_buf = 0; + rv = recv(sock_c, &rx_buf, sizeof(rx_buf), MSG_DONTWAIT); + zassert_equal(rv, sizeof(rx_buf), "recv failed"); + zassert_equal(rx_buf, tx_buf, "wrong data"); + + /* Verify that a datagram is not received from other address */ + rv = sendto(sock_s2, &tx_buf, sizeof(tx_buf), 0, addr_c, addrlen_c); + zassert_equal(rv, sizeof(tx_buf), "send failed"); + + /* Give the packet a chance to go through the net stack */ + k_msleep(10); + + rv = recv(sock_c, &rx_buf, sizeof(rx_buf), MSG_DONTWAIT); + zassert_equal(rv, -1, "recv should've failed"); + zassert_equal(errno, EAGAIN, "incorrect errno"); + + rv = close(sock_c); + zassert_equal(rv, 0, "close failed"); + rv = close(sock_s1); + zassert_equal(rv, 0, "close failed"); + rv = close(sock_s2); + zassert_equal(rv, 0, "close failed"); +} + +ZTEST(net_socket_udp, test_24_v4_dgram_connected) +{ + int client_sock; + int server_sock_1; + int server_sock_2; + struct sockaddr_in client_addr; + struct sockaddr_in server_addr_1; + struct sockaddr_in server_addr_2; + + prepare_sock_udp_v4(MY_IPV4_ADDR, CLIENT_PORT, &client_sock, &client_addr); + prepare_sock_udp_v4(MY_IPV4_ADDR, SERVER_PORT, &server_sock_1, &server_addr_1); + prepare_sock_udp_v4(MY_IPV4_ADDR, SERVER_PORT + 1, &server_sock_2, &server_addr_2); + + test_dgram_connected(client_sock, server_sock_1, server_sock_2, + (struct sockaddr *)&client_addr, sizeof(client_addr), + (struct sockaddr *)&server_addr_1, sizeof(server_addr_1), + (struct sockaddr *)&server_addr_2, sizeof(server_addr_2)); +} + +ZTEST(net_socket_udp, test_25_v6_dgram_connected) +{ + int client_sock; + int server_sock_1; + int server_sock_2; + struct sockaddr_in6 client_addr; + struct sockaddr_in6 server_addr_1; + struct sockaddr_in6 server_addr_2; + + prepare_sock_udp_v6(MY_IPV6_ADDR, CLIENT_PORT, &client_sock, &client_addr); + prepare_sock_udp_v6(MY_IPV6_ADDR, SERVER_PORT, &server_sock_1, &server_addr_1); + prepare_sock_udp_v6(MY_IPV6_ADDR, SERVER_PORT + 1, &server_sock_2, &server_addr_2); + + test_dgram_connected(client_sock, server_sock_1, server_sock_2, + (struct sockaddr *)&client_addr, sizeof(client_addr), + (struct sockaddr *)&server_addr_1, sizeof(server_addr_1), + (struct sockaddr *)&server_addr_2, sizeof(server_addr_2)); +} + static void after(void *arg) { ARG_UNUSED(arg);