diff --git a/src/net.cpp b/src/net.cpp index 7166cd6fd655..255cb969c01d 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -423,13 +423,13 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce) } /** Get the bind address for a socket as CAddress */ -static CAddress GetBindAddress(SOCKET sock) +static CAddress GetBindAddress(const Sock& sock) { CAddress addr_bind; struct sockaddr_storage sockaddr_bind; socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); - if (sock != INVALID_SOCKET) { - if (!getsockname(sock, (struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { + if (sock.Get() != INVALID_SOCKET) { + if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind); } else { LogPrint(BCLog::NET, "Warning: getsockname failed\n"); @@ -572,9 +572,19 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); if (!addr_bind.IsValid()) { - addr_bind = GetBindAddress(sock->Get()); - } - CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type, /* inbound_onion */ false, std::move(i2p_transient_session)); + addr_bind = GetBindAddress(*sock); + } + CNode* pnode = new CNode(id, + nLocalServices, + std::move(sock), + addrConnect, + CalculateKeyedNetGroup(addrConnect), + nonce, + addr_bind, + pszDest ? pszDest : "", + conn_type, + /*inbound_onion=*/false, + std::move(i2p_transient_session)); pnode->AddRef(); statsClient.inc("peers.connect", 1.0f); @@ -589,15 +599,15 @@ void CNode::CloseSocketDisconnect(CConnman* connman) AssertLockHeld(connman->m_nodes_mutex); fDisconnect = true; - LOCK2(connman->cs_mapSocketToNode, cs_hSocket); - if (hSocket == INVALID_SOCKET) { + LOCK2(connman->cs_mapSocketToNode, m_sock_mutex); + if (!m_sock) { return; } fHasRecvData = false; fCanSendData = false; - connman->mapSocketToNode.erase(hSocket); + connman->mapSocketToNode.erase(m_sock->Get()); { LOCK(connman->cs_sendable_receivable_nodes); connman->mapReceivableNodes.erase(GetId()); @@ -611,12 +621,12 @@ void CNode::CloseSocketDisconnect(CConnman* connman) } } - if (connman->m_edge_trig_events && !connman->m_edge_trig_events->UnregisterEvents(hSocket)) { + if (connman->m_edge_trig_events && !connman->m_edge_trig_events->UnregisterEvents(m_sock->Get())) { LogPrint(BCLog::NET, "EdgeTriggeredEvents::UnregisterEvents() failed\n"); } LogPrint(BCLog::NET, "disconnecting peer=%d\n", id); - CloseSocket(hSocket); + m_sock.reset(); m_i2p_sam_session.reset(); statsClient.inc("peers.disconnect", 1.0f); @@ -909,10 +919,11 @@ size_t CConnman::SocketSendData(CNode& node) assert(data.size() > node.nSendOffset); int nBytes = 0; { - LOCK(node.cs_hSocket); - if (node.hSocket == INVALID_SOCKET) + LOCK(node.m_sock_mutex); + if (!node.m_sock) { break; - nBytes = send(node.hSocket, reinterpret_cast(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT); + } + nBytes = node.m_sock->Send(reinterpret_cast(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT); } if (nBytes > 0) { node.m_last_send = GetTime(); @@ -1220,9 +1231,10 @@ bool CConnman::AttemptToEvictConnection() void CConnman::AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSync& mn_sync) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len); + auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len); CAddress addr; - if (hSocket == INVALID_SOCKET) { + + if (!sock) { const int nErr = WSAGetLastError(); if (nErr != WSAEWOULDBLOCK) { LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); @@ -1236,15 +1248,15 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSy addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE}; } - const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(hSocket)), NODE_NONE}; + const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock)), NODE_NONE}; NetPermissionFlags permissionFlags = NetPermissionFlags::None; hListenSocket.AddSocketPermissionFlags(permissionFlags); - CreateNodeFromAcceptedSocket(hSocket, permissionFlags, addr_bind, addr, mn_sync); + CreateNodeFromAcceptedSocket(std::move(sock), permissionFlags, addr_bind, addr, mn_sync); } -void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, +void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, NetPermissionFlags permissionFlags, const CAddress& addr_bind, const CAddress& addr, @@ -1287,27 +1299,28 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, if (!fNetworkActive) { LogPrint(BCLog::NET_NETCONN, "%s: not accepting new connections\n", strDropped); - CloseSocket(hSocket); return; } - if (!IsSelectableSocket(hSocket)) + if (!IsSelectableSocket(sock->Get())) { LogPrintf("%s: non-selectable socket\n", strDropped); - CloseSocket(hSocket); return; } // According to the internet TCP_NODELAY is not carried into accepted sockets // on all platforms. Set it again here just to be sure. - SetSocketNoDelay(hSocket); + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogPrint(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n", + addr.ToString()); + } // Don't accept connections from banned peers. bool banned = m_banman && m_banman->IsBanned(addr); if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && banned) { LogPrint(BCLog::NET, "%s (banned)\n", strDropped); - CloseSocket(hSocket); return; } @@ -1316,7 +1329,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && nInbound + 1 >= nMaxInbound && discouraged) { LogPrint(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToString()); - CloseSocket(hSocket); return; } @@ -1330,7 +1342,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, if (!AttemptToEvictConnection()) { // No connection to evict, disconnect the new connection LogPrint(BCLog::NET, "failed to find an eviction candidate - connection dropped (full)\n"); - CloseSocket(hSocket); return; } nInbound--; @@ -1339,7 +1350,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, // don't accept incoming connections until blockchain is synced if (fMasternodeMode && !mn_sync.IsBlockchainSynced()) { LogPrint(BCLog::NET, "AcceptConnection -- blockchain is not synced yet, skipping inbound connection attempt\n"); - CloseSocket(hSocket); return; } @@ -1352,7 +1362,16 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, } const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end(); - CNode* pnode = new CNode(id, nodeServices, hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion); + CNode* pnode = new CNode(id, + nodeServices, + std::move(sock), + addr, + CalculateKeyedNetGroup(addr), + nonce, + addr_bind, + /*addrNameIn=*/"", + ConnectionType::INBOUND, + inbound_onion); pnode->AddRef(); pnode->m_permissionFlags = permissionFlags; // If this flag is present, the user probably expect that RPC and QT report it as whitelisted (backward compatibility) @@ -1360,19 +1379,24 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, pnode->m_prefer_evict = discouraged; m_msgproc->InitializeNode(pnode); - if (fLogIPs) { - LogPrint(BCLog::NET_NETCONN, "connection from %s accepted, sock=%d, peer=%d\n", addr.ToString(), hSocket, pnode->GetId()); - } else { - LogPrint(BCLog::NET_NETCONN, "connection accepted, sock=%d, peer=%d\n", hSocket, pnode->GetId()); + { + LOCK(pnode->m_sock_mutex); + if (fLogIPs) { + LogPrint(BCLog::NET_NETCONN, "connection from %s accepted, sock=%d, peer=%d\n", addr.ToString(), pnode->m_sock->Get(), pnode->GetId()); + } else { + LogPrint(BCLog::NET_NETCONN, "connection accepted, sock=%d, peer=%d\n", pnode->m_sock->Get(), pnode->GetId()); + } } { LOCK(m_nodes_mutex); m_nodes.push_back(pnode); - WITH_LOCK(cs_mapSocketToNode, mapSocketToNode.emplace(hSocket, pnode)); + } + { + LOCK2(cs_mapSocketToNode, pnode->m_sock_mutex); + mapSocketToNode.emplace(pnode->m_sock->Get(), pnode); if (m_edge_trig_events) { - LOCK(pnode->cs_hSocket); - if (!m_edge_trig_events->RegisterEvents(pnode->hSocket)) { + if (!m_edge_trig_events->RegisterEvents(pnode->m_sock->Get())) { LogPrint(BCLog::NET, "EdgeTriggeredEvents::RegisterEvents() failed\n"); } } @@ -1459,10 +1483,10 @@ void CConnman::DisconnectNodes() if (GetTimeMillis() < pnode->nDisconnectLingerTime) { // everything flushed to the kernel? if (!pnode->fSocketShutdown && pnode->nSendMsgSize == 0) { - LOCK(pnode->cs_hSocket); - if (pnode->hSocket != INVALID_SOCKET) { + LOCK(pnode->m_sock_mutex); + if (pnode->m_sock) { // Give the other side a chance to detect the disconnect as early as possible (recv() will return 0) - ::shutdown(pnode->hSocket, SD_SEND); + ::shutdown(pnode->m_sock->Get(), SD_SEND); } pnode->fSocketShutdown = true; } @@ -1656,7 +1680,7 @@ bool CConnman::GenerateSelectSet(const std::vector& nodes, std::set& error_set) { for (const ListenSocket& hListenSocket : vhListenSocket) { - recv_set.insert(hListenSocket.socket); + recv_set.insert(hListenSocket.sock->Get()); } for (CNode* pnode : nodes) @@ -1664,16 +1688,17 @@ bool CConnman::GenerateSelectSet(const std::vector& nodes, bool select_recv = !pnode->fHasRecvData; bool select_send = !pnode->fCanSendData; - LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) + LOCK(pnode->m_sock_mutex); + if (!pnode->m_sock) { continue; + } - error_set.insert(pnode->hSocket); + error_set.insert(pnode->m_sock->Get()); if (select_send) { - send_set.insert(pnode->hSocket); + send_set.insert(pnode->m_sock->Get()); } if (select_recv) { - recv_set.insert(pnode->hSocket); + recv_set.insert(pnode->m_sock->Get()); } } @@ -2128,7 +2153,7 @@ void CConnman::SocketHandlerListening(const std::set& recv_set, CMastern if (interruptNet) { return; } - if (recv_set.count(listen_socket.socket) > 0) { + if (recv_set.count(listen_socket.sock->Get()) > 0) { AcceptConnection(listen_socket, mn_sync); } } @@ -2140,10 +2165,10 @@ size_t CConnman::SocketRecvData(CNode *pnode) uint8_t pchBuf[0x10000]; int nBytes = 0; { - LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) + LOCK(pnode->m_sock_mutex); + if (!pnode->m_sock) return 0; - nBytes = recv(pnode->hSocket, (char*)pchBuf, sizeof(pchBuf), MSG_DONTWAIT); + nBytes = recv(pnode->m_sock->Get(), (char*)pchBuf, sizeof(pchBuf), MSG_DONTWAIT); if (nBytes < (int)sizeof(pchBuf)) { pnode->fHasRecvData = false; } @@ -3050,8 +3075,8 @@ void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFai } { - LOCK(pnode->cs_hSocket); - LogPrint(BCLog::NET_NETCONN, "CConnman::%s -- successfully connected to %s, sock=%d, peer=%d\n", __func__, getIpStr(), pnode->hSocket, pnode->GetId()); + LOCK(pnode->m_sock_mutex); + LogPrint(BCLog::NET_NETCONN, "CConnman::%s -- successfully connected to %s, sock=%d, peer=%d\n", __func__, getIpStr(), pnode->m_sock->Get(), pnode->GetId()); } if (grantOutbound) @@ -3063,17 +3088,19 @@ void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFai pnode->m_masternode_probe_connection = true; { - LOCK2(cs_mapSocketToNode, pnode->cs_hSocket); - mapSocketToNode.emplace(pnode->hSocket, pnode); + LOCK2(cs_mapSocketToNode, pnode->m_sock_mutex); + mapSocketToNode.emplace(pnode->m_sock->Get(), pnode); } m_msgproc->InitializeNode(pnode); { LOCK(m_nodes_mutex); m_nodes.push_back(pnode); + } + { if (m_edge_trig_events) { - LOCK(pnode->cs_hSocket); - if (!m_edge_trig_events->RegisterEvents(pnode->hSocket)) { + LOCK(pnode->m_sock_mutex); + if (!m_edge_trig_events->RegisterEvents(pnode->m_sock->Get())) { LogPrint(BCLog::NET, "EdgeTriggeredEvents::RegisterEvents() failed\n"); } } @@ -3168,7 +3195,7 @@ void CConnman::ThreadI2PAcceptIncoming(CMasternodeSync& mn_sync) continue; } - CreateNodeFromAcceptedSocket(conn.sock->Release(), NetPermissionFlags::None, + CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None, CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE}, mn_sync); } } @@ -3196,22 +3223,30 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, // Allow binding if the port is still in TIME_WAIT state after // the program was closed and restarted. - setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)); + if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { + strError = strprintf(Untranslated("Error setting SO_REUSEADDR on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError())); + LogPrintf("%s\n", strError.original); + } // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option // and enable it by default or not. Try to enable it, if possible. if (addrBind.IsIPv6()) { #ifdef IPV6_V6ONLY - setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)); + if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) { + strError = strprintf(Untranslated("Error setting IPV6_V6ONLY on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError())); + LogPrintf("%s\n", strError.original); + } #endif #ifdef WIN32 int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED; - setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)); + if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)) == SOCKET_ERROR) { + strError = strprintf(Untranslated("Error setting IPV6_PROTECTION_LEVEL on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError())); + LogPrintf("%s\n", strError.original); + } #endif } - if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR) - { + if (sock->Bind(reinterpret_cast(&sockaddr), len) == SOCKET_ERROR) { int nErr = WSAGetLastError(); if (nErr == WSAEADDRINUSE) strError = strprintf(_("Unable to bind to %s on this computer. %s is probably already running."), addrBind.ToString(), PACKAGE_NAME); @@ -3223,7 +3258,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, LogPrintf("Bound to %s\n", addrBind.ToString()); // Listen for incoming connections - if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR) + if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) { strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintf("%s\n", strError.original); @@ -3235,7 +3270,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - vhListenSocket.push_back(ListenSocket(sock->Release(), permissions)); + vhListenSocket.emplace_back(std::move(sock), permissions); return true; } @@ -3582,12 +3617,10 @@ void CConnman::StopNodes() pnode->CloseSocketDisconnect(this); } for (ListenSocket& hListenSocket : vhListenSocket) { - if (hListenSocket.socket != INVALID_SOCKET) { - if (m_edge_trig_events && !m_edge_trig_events->RemoveSocket(hListenSocket.socket)) { + if (hListenSocket.sock) { + if (m_edge_trig_events && !m_edge_trig_events->RemoveSocket(hListenSocket.sock->Get())) { LogPrintf("EdgeTriggeredEvents::RemoveSocket() failed\n"); } - if (!CloseSocket(hListenSocket.socket)) - LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError())); } } @@ -4051,8 +4084,9 @@ ServiceFlags CConnman::GetLocalServices() const unsigned int CConnman::GetReceiveFloodSize() const { return nReceiveFloodSize; } -CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion, std::unique_ptr&& i2p_sam_session) - : nTimeConnected{GetTimeSeconds()}, +CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, std::shared_ptr sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion, std::unique_ptr&& i2p_sam_session) + : m_sock{sock}, + nTimeConnected{GetTimeSeconds()}, addr{addrIn}, addrBind{addrBindIn}, m_addr_name{addrNameIn.empty() ? addr.ToStringIPPort() : addrNameIn}, @@ -4065,7 +4099,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const m_i2p_sam_session{std::move(i2p_sam_session)} { if (inbound_onion) assert(conn_type_in == ConnectionType::INBOUND); - hSocket = hSocketIn; for (const std::string &msg : getAllNetMessageTypes()) mapRecvBytesPerMsgCmd[msg] = 0; @@ -4081,11 +4114,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const m_serializer = std::make_unique(V1TransportSerializer()); } -CNode::~CNode() -{ - CloseSocket(hSocket); -} - bool CConnman::NodeFullyConnected(const CNode* pnode) { return pnode && pnode->fSuccessfullyConnected && !pnode->fDisconnect; diff --git a/src/net.h b/src/net.h index 38668592a616..22edb30c14c2 100644 --- a/src/net.h +++ b/src/net.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -439,7 +440,17 @@ class CNode NetPermissionFlags m_permissionFlags{ NetPermissionFlags::None }; std::atomic nServices{NODE_NONE}; - SOCKET hSocket GUARDED_BY(cs_hSocket); + + /** + * Socket used for communication with the node. + * May not own a Sock object (after `CloseSocketDisconnect()` or during tests). + * `shared_ptr` (instead of `unique_ptr`) is used to avoid premature close of + * the underlying file descriptor by one thread while another thread is + * poll(2)-ing it for activity. + * @see https://github.com/bitcoin/bitcoin/issues/21744 for details. + */ + std::shared_ptr m_sock GUARDED_BY(m_sock_mutex); + /** Total size of all vSendMsg entries */ size_t nSendSize GUARDED_BY(cs_vSend){0}; /** Offset inside the first vSendMsg already sent */ @@ -448,7 +459,7 @@ class CNode std::list> vSendMsg GUARDED_BY(cs_vSend); std::atomic nSendMsgSize{0}; Mutex cs_vSend; - Mutex cs_hSocket; + Mutex m_sock_mutex; Mutex cs_vRecv; RecursiveMutex cs_vProcessMsg; @@ -629,8 +640,7 @@ class CNode bool IsBlockRelayOnly() const; - CNode(NodeId id, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress &addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress &addrBindIn, const std::string &addrNameIn, ConnectionType conn_type_in, bool inbound_onion, std::unique_ptr&& i2p_sam_session = nullptr); - ~CNode(); + CNode(NodeId id, ServiceFlags nLocalServicesIn, std::shared_ptr sock, const CAddress &addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress &addrBindIn, const std::string &addrNameIn, ConnectionType conn_type_in, bool inbound_onion, std::unique_ptr&& i2p_sam_session = nullptr); CNode(const CNode&) = delete; CNode& operator=(const CNode&) = delete; @@ -684,7 +694,7 @@ class CNode nRefCount--; } - void CloseSocketDisconnect(CConnman* connman) EXCLUSIVE_LOCKS_REQUIRED(!cs_hSocket); + void CloseSocketDisconnect(CConnman* connman) EXCLUSIVE_LOCKS_REQUIRED(!m_sock_mutex); void CopyStats(CNodeStats& stats) EXCLUSIVE_LOCKS_REQUIRED(!m_subver_mutex, !m_addr_local_mutex, !cs_vSend, !cs_vRecv); @@ -794,7 +804,7 @@ class CNode * closed. * Otherwise this unique_ptr is empty. */ - std::unique_ptr m_i2p_sam_session GUARDED_BY(cs_hSocket); + std::unique_ptr m_i2p_sam_session GUARDED_BY(m_sock_mutex); }; /** @@ -1221,9 +1231,13 @@ friend class CNode; private: struct ListenSocket { public: - SOCKET socket; + std::shared_ptr sock; inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); } - ListenSocket(SOCKET socket_, NetPermissionFlags permissions_) : socket(socket_), m_permissions(permissions_) {} + ListenSocket(std::shared_ptr sock_, NetPermissionFlags permissions_) + : sock{sock_}, m_permissions{permissions_} + { + } + private: NetPermissionFlags m_permissions; }; @@ -1251,12 +1265,12 @@ friend class CNode; /** * Create a `CNode` object from a socket that has just been accepted and add the node to * the `m_nodes` member. - * @param[in] hSocket Connected socket to communicate with the peer. + * @param[in] sock Connected socket to communicate with the peer. * @param[in] permissionFlags The peer's permissions. * @param[in] addr_bind The address and port at our side of the connection. * @param[in] addr The address and port at the peer's side of the connection. */ - void CreateNodeFromAcceptedSocket(SOCKET hSocket, + void CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, NetPermissionFlags permissionFlags, const CAddress& addr_bind, const CAddress& addr, diff --git a/src/netbase.cpp b/src/netbase.cpp index 544f575da03d..97dbe905dcff 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -498,10 +498,11 @@ std::unique_ptr CreateSockTCP(const CService& address_family) return nullptr; } + auto sock = std::make_unique(hSocket); + // Ensure that waiting for I/O on this socket won't result in undefined // behavior. - if (!IsSelectableSocket(hSocket)) { - CloseSocket(hSocket); + if (!IsSelectableSocket(sock->Get())) { LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n"); return nullptr; } @@ -510,19 +511,24 @@ std::unique_ptr CreateSockTCP(const CService& address_family) int set = 1; // Set the no-sigpipe option on the socket for BSD systems, other UNIXes // should use the MSG_NOSIGNAL flag for every send. - setsockopt(hSocket, SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)); + if (sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)) == SOCKET_ERROR) { + LogPrintf("Error setting SO_NOSIGPIPE on socket: %s, continuing anyway\n", + NetworkErrorString(WSAGetLastError())); + } #endif // Set the no-delay option (disable Nagle's algorithm) on the TCP socket. - SetSocketNoDelay(hSocket); + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n"); + } // Set the non-blocking option on the socket. - if (!SetSocketNonBlocking(hSocket)) { - CloseSocket(hSocket); + if (!SetSocketNonBlocking(sock->Get())) { LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError())); return nullptr; } - return std::make_unique(hSocket); + return sock; } std::function(const CService&)> CreateSock = CreateSockTCP; @@ -729,13 +735,6 @@ bool SetSocketNonBlocking(const SOCKET& hSocket) return true; } -bool SetSocketNoDelay(const SOCKET& hSocket) -{ - int set = 1; - int rc = setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int)); - return rc == 0; -} - void InterruptSocks5(bool interrupt) { interruptSocks5Recv = interrupt; diff --git a/src/netbase.h b/src/netbase.h index 3d98f9cb7328..ef0eb85eae8e 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -227,8 +227,6 @@ bool ConnectThroughProxy(const Proxy& proxy, const std::string& strDest, uint16_ /** Enable non-blocking mode for a socket */ bool SetSocketNonBlocking(const SOCKET& hSocket); -/** Set the TCP_NODELAY flag on a socket */ -bool SetSocketNoDelay(const SOCKET& hSocket); void InterruptSocks5(bool interrupt); /** diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp index c45d20be377a..e3f9d7103951 100644 --- a/src/test/denialofservice_tests.cpp +++ b/src/test/denialofservice_tests.cpp @@ -74,7 +74,16 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) // Mock an outbound peer CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK), INVALID_SOCKET, addr1, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress(), /* pszDest */ "", ConnectionType::OUTBOUND_FULL_RELAY, /* inbound_onion */ false); + CNode dummyNode1{id++, + ServiceFlags(NODE_NETWORK), + /*sock=*/nullptr, + addr1, + /*nKeyedNetGroupIn=*/0, + /*nLocalHostNonceIn=*/0, + CAddress(), + /*addrNameIn=*/"", + ConnectionType::OUTBOUND_FULL_RELAY, + /*inbound_onion=*/false}; dummyNode1.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode1); @@ -124,7 +133,16 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) static void AddRandomOutboundPeer(std::vector& vNodes, PeerManager& peerLogic, ConnmanTestMsg& connman) { CAddress addr(ip(g_insecure_rand_ctx.randbits(32)), NODE_NONE); - vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK), INVALID_SOCKET, addr, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress(), /* pszDest */ "", ConnectionType::OUTBOUND_FULL_RELAY, /* inbound_onion */ false)); + vNodes.emplace_back(new CNode{id++, + ServiceFlags(NODE_NETWORK), + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/0, + /*nLocalHostNonceIn=*/0, + CAddress(), + /*addrNameIn=*/"", + ConnectionType::OUTBOUND_FULL_RELAY, + /*inbound_onion=*/false}); CNode &node = *vNodes.back(); node.SetCommonVersion(PROTOCOL_VERSION); @@ -220,7 +238,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) banman->ClearBanned(); CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, NODE_NETWORK, INVALID_SOCKET, addr1, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress(), /* pszDest */ "", ConnectionType::INBOUND, /* inbound_onion */ false); + CNode dummyNode1{id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr1, + /*nKeyedNetGroupIn=*/0, + /*nLocalHostNonceIn=*/0, + CAddress(), + /*addrNameIn=*/"", + ConnectionType::INBOUND, + /*inbound_onion=*/false}; dummyNode1.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode1); dummyNode1.fSuccessfullyConnected = true; @@ -233,7 +260,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) BOOST_CHECK(!banman->IsDiscouraged(ip(0xa0b0c001|0x0000ff00))); // Different IP, not discouraged CAddress addr2(ip(0xa0b0c002), NODE_NONE); - CNode dummyNode2(id++, NODE_NETWORK, INVALID_SOCKET, addr2, /* nKeyedNetGroupIn */ 1, /* nLocalHostNonceIn */ 1, CAddress(), /* pszDest */ "", ConnectionType::INBOUND, /* inbound_onion */ false); + CNode dummyNode2{id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr2, + /*nKeyedNetGroupIn=*/1, + /*nLocalHostNonceIn=*/1, + CAddress(), + /*pszDest=*/"", + ConnectionType::INBOUND, + /*inbound_onion=*/false}; dummyNode2.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode2); dummyNode2.fSuccessfullyConnected = true; @@ -271,7 +307,16 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) SetMockTime(nStartTime); // Overrides future calls to GetTime() CAddress addr(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode(id++, NODE_NETWORK, INVALID_SOCKET, addr, /* nKeyedNetGroupIn */ 4, /* nLocalHostNonceIn */ 4, CAddress(), /* pszDest */ "", ConnectionType::INBOUND, /* inbound_onion */ false); + CNode dummyNode{id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/4, + /*nLocalHostNonceIn=*/4, + CAddress(), + /*addrNameIn=*/"", + ConnectionType::INBOUND, + /*inbound_onion=*/false}; dummyNode.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode); dummyNode.fSuccessfullyConnected = true; diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp index 36b6233cab19..cb88167ee1c6 100644 --- a/src/test/fuzz/util.cpp +++ b/src/test/fuzz/util.cpp @@ -10,6 +10,8 @@ #include #include +#include + FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider) : m_fuzzed_data_provider{fuzzed_data_provider} { @@ -155,6 +157,59 @@ int FuzzedSock::Connect(const sockaddr*, socklen_t) const return 0; } +int FuzzedSock::Bind(const sockaddr*, socklen_t) const +{ + // Have a permanent error at bind_errnos[0] because when the fuzzed data is exhausted + // SetFuzzedErrNo() will always set the global errno to bind_errnos[0]. We want to + // avoid this method returning -1 and setting errno to a temporary error (like EAGAIN) + // repeatedly because proper code should retry on temporary errors, leading to an + // infinite loop. + constexpr std::array bind_errnos{ + EACCES, + EADDRINUSE, + EADDRNOTAVAIL, + EAGAIN, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, bind_errnos); + return -1; + } + return 0; +} + +int FuzzedSock::Listen(int) const +{ + // Have a permanent error at listen_errnos[0] because when the fuzzed data is exhausted + // SetFuzzedErrNo() will always set the global errno to listen_errnos[0]. We want to + // avoid this method returning -1 and setting errno to a temporary error (like EAGAIN) + // repeatedly because proper code should retry on temporary errors, leading to an + // infinite loop. + constexpr std::array listen_errnos{ + EADDRINUSE, + EINVAL, + EOPNOTSUPP, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, listen_errnos); + return -1; + } + return 0; +} + +std::unique_ptr FuzzedSock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ + constexpr std::array accept_errnos{ + ECONNABORTED, + EINTR, + ENOMEM, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, accept_errnos); + return std::unique_ptr(); + } + return std::make_unique(m_fuzzed_data_provider); +} + int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const { constexpr std::array getsockopt_errnos{ @@ -174,6 +229,33 @@ int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* op return 0; } +int FuzzedSock::SetSockOpt(int, int, const void*, socklen_t) const +{ + constexpr std::array setsockopt_errnos{ + ENOMEM, + ENOBUFS, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, setsockopt_errnos); + return -1; + } + return 0; +} + +int FuzzedSock::GetSockName(sockaddr* name, socklen_t* name_len) const +{ + constexpr std::array getsockname_errnos{ + ECONNRESET, + ENOBUFS, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, getsockname_errnos); + return -1; + } + *name_len = m_fuzzed_data_provider.ConsumeData(name, *name_len); + return 0; +} + bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { constexpr std::array wait_errnos{ diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 558ef976a9f8..f3a18a1efb03 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -40,6 +40,54 @@ class PeerManager; +class FuzzedSock : public Sock +{ + FuzzedDataProvider& m_fuzzed_data_provider; + + /** + * Data to return when `MSG_PEEK` is used as a `Recv()` flag. + * If `MSG_PEEK` is used, then our `Recv()` returns some random data as usual, but on the next + * `Recv()` call we must return the same data, thus we remember it here. + */ + mutable std::optional m_peek_data; + +public: + explicit FuzzedSock(FuzzedDataProvider& fuzzed_data_provider); + + ~FuzzedSock() override; + + FuzzedSock& operator=(Sock&& other) override; + + void Reset() override; + + ssize_t Send(const void* data, size_t len, int flags) const override; + + ssize_t Recv(void* buf, size_t len, int flags) const override; + + int Connect(const sockaddr*, socklen_t) const override; + + int Bind(const sockaddr*, socklen_t) const override; + + int Listen(int backlog) const override; + + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override; + + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override; + + int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const override; + + int GetSockName(sockaddr* name, socklen_t* name_len) const override; + + bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; + + bool IsConnected(std::string& errmsg) const override; +}; + +[[nodiscard]] inline FuzzedSock ConsumeSock(FuzzedDataProvider& fuzzed_data_provider) +{ + return FuzzedSock{fuzzed_data_provider}; +} + template void CallOneOf(FuzzedDataProvider& fuzzed_data_provider, Callables... callables) { @@ -313,7 +361,7 @@ auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional(0, std::numeric_limits::max())); const ServiceFlags local_services = ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS); - const SOCKET socket = INVALID_SOCKET; + const auto sock = std::make_shared(fuzzed_data_provider); const CAddress address = ConsumeAddress(fuzzed_data_provider); const uint64_t keyed_net_group = fuzzed_data_provider.ConsumeIntegral(); const uint64_t local_host_nonce = fuzzed_data_provider.ConsumeIntegral(); @@ -323,9 +371,27 @@ auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional(node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion); + return std::make_unique(node_id, + local_services, + sock, + address, + keyed_net_group, + local_host_nonce, + addr_bind, + addr_name, + conn_type, + inbound_onion); } else { - return CNode{node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion}; + return CNode{node_id, + local_services, + sock, + address, + keyed_net_group, + local_host_nonce, + addr_bind, + addr_name, + conn_type, + inbound_onion}; } } inline std::unique_ptr ConsumeNodeAsUniquePtr(FuzzedDataProvider& fdp, const std::optional& node_id_in = std::nullopt) { return ConsumeNode(fdp, node_id_in); } @@ -534,42 +600,4 @@ void ReadFromStream(FuzzedDataProvider& fuzzed_data_provider, Stream& stream) no } } -class FuzzedSock : public Sock -{ - FuzzedDataProvider& m_fuzzed_data_provider; - - /** - * Data to return when `MSG_PEEK` is used as a `Recv()` flag. - * If `MSG_PEEK` is used, then our `Recv()` returns some random data as usual, but on the next - * `Recv()` call we must return the same data, thus we remember it here. - */ - mutable std::optional m_peek_data; - -public: - explicit FuzzedSock(FuzzedDataProvider& fuzzed_data_provider); - - ~FuzzedSock() override; - - FuzzedSock& operator=(Sock&& other) override; - - void Reset() override; - - ssize_t Send(const void* data, size_t len, int flags) const override; - - ssize_t Recv(void* buf, size_t len, int flags) const override; - - int Connect(const sockaddr*, socklen_t) const override; - - int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override; - - bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; - - bool IsConnected(std::string& errmsg) const override; -}; - -[[nodiscard]] inline FuzzedSock ConsumeSock(FuzzedDataProvider& fuzzed_data_provider) -{ - return FuzzedSock{fuzzed_data_provider}; -} - #endif // BITCOIN_TEST_FUZZ_UTIL_H diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index e7fabd9e8209..06f37a146892 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -44,7 +44,6 @@ BOOST_AUTO_TEST_CASE(cnode_listen_port) BOOST_AUTO_TEST_CASE(cnode_simple_test) { - SOCKET hSocket = INVALID_SOCKET; NodeId id = 0; in_addr ipv4Addr; @@ -53,12 +52,16 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) CAddress addr = CAddress(CService(ipv4Addr, 7777), NODE_NETWORK); std::string pszDest; - std::unique_ptr pnode1 = std::make_unique( - id++, NODE_NETWORK, hSocket, addr, - /* nKeyedNetGroupIn = */ 0, - /* nLocalHostNonceIn = */ 0, - CAddress(), pszDest, ConnectionType::OUTBOUND_FULL_RELAY, - /* inbound_onion = */ false); + std::unique_ptr pnode1 = std::make_unique(id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/0, + /*nLocalHostNonceIn=*/0, + CAddress(), + pszDest, + ConnectionType::OUTBOUND_FULL_RELAY, + /*inbound_onion=*/false); BOOST_CHECK(pnode1->IsFullOutboundConn() == true); BOOST_CHECK(pnode1->IsManualConn() == false); BOOST_CHECK(pnode1->IsBlockOnlyConn() == false); @@ -68,12 +71,16 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK(pnode1->m_inbound_onion == false); BOOST_CHECK_EQUAL(pnode1->ConnectedThroughNetwork(), Network::NET_IPV4); - std::unique_ptr pnode2 = std::make_unique( - id++, NODE_NETWORK, hSocket, addr, - /* nKeyedNetGroupIn = */ 1, - /* nLocalHostNonceIn = */ 1, - CAddress(), pszDest, ConnectionType::INBOUND, - /* inbound_onion = */ false); + std::unique_ptr pnode2 = std::make_unique(id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/1, + /*nLocalHostNonceIn=*/1, + CAddress(), + pszDest, + ConnectionType::INBOUND, + /*inbound_onion=*/false); BOOST_CHECK(pnode2->IsFullOutboundConn() == false); BOOST_CHECK(pnode2->IsManualConn() == false); BOOST_CHECK(pnode2->IsBlockOnlyConn() == false); @@ -83,12 +90,16 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK(pnode2->m_inbound_onion == false); BOOST_CHECK_EQUAL(pnode2->ConnectedThroughNetwork(), Network::NET_IPV4); - std::unique_ptr pnode3 = std::make_unique( - id++, NODE_NETWORK, hSocket, addr, - /* nKeyedNetGroupIn = */ 0, - /* nLocalHostNonceIn = */ 0, - CAddress(), pszDest, ConnectionType::OUTBOUND_FULL_RELAY, - /* inbound_onion = */ false); + std::unique_ptr pnode3 = std::make_unique(id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/0, + /*nLocalHostNonceIn=*/0, + CAddress(), + pszDest, + ConnectionType::OUTBOUND_FULL_RELAY, + /*inbound_onion=*/false); BOOST_CHECK(pnode3->IsFullOutboundConn() == true); BOOST_CHECK(pnode3->IsManualConn() == false); BOOST_CHECK(pnode3->IsBlockOnlyConn() == false); @@ -98,12 +109,16 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK(pnode3->m_inbound_onion == false); BOOST_CHECK_EQUAL(pnode3->ConnectedThroughNetwork(), Network::NET_IPV4); - std::unique_ptr pnode4 = std::make_unique( - id++, NODE_NETWORK, hSocket, addr, - /* nKeyedNetGroupIn = */ 1, - /* nLocalHostNonceIn = */ 1, - CAddress(), pszDest, ConnectionType::INBOUND, - /* inbound_onion = */ true); + std::unique_ptr pnode4 = std::make_unique(id++, + NODE_NETWORK, + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/1, + /*nLocalHostNonceIn=*/1, + CAddress(), + pszDest, + ConnectionType::INBOUND, + /*inbound_onion=*/true); BOOST_CHECK(pnode4->IsFullOutboundConn() == false); BOOST_CHECK(pnode4->IsManualConn() == false); BOOST_CHECK(pnode4->IsBlockOnlyConn() == false); @@ -608,7 +623,16 @@ BOOST_AUTO_TEST_CASE(ipv4_peer_with_ipv6_addrMe_test) in_addr ipv4AddrPeer; ipv4AddrPeer.s_addr = 0xa0b0c001; CAddress addr = CAddress(CService(ipv4AddrPeer, 7777), NODE_NETWORK); - std::unique_ptr pnode = std::make_unique(0, NODE_NETWORK, INVALID_SOCKET, addr, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress{}, /* pszDest */ std::string{}, ConnectionType::OUTBOUND_FULL_RELAY, /* inbound_onion */ false); + std::unique_ptr pnode = std::make_unique(/*id=*/0, + NODE_NETWORK, + /*sock=*/nullptr, + addr, + /*nKeyedNetGroupIn=*/0, + /*nLocalHostNonceIn=*/0, + CAddress{}, + /*pszDest=*/std::string{}, + ConnectionType::OUTBOUND_FULL_RELAY, + /*inbound_onion=*/false); pnode->fSuccessfullyConnected.store(true); // the peer claims to be reaching us via IPv6 diff --git a/src/test/util/net.h b/src/test/util/net.h index 79635be91004..cfb465c5fd67 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -13,6 +13,7 @@ #include #include #include +#include #include struct ConnmanTestMsg : public CConnman { @@ -126,12 +127,41 @@ class StaticContentsSock : public Sock int Connect(const sockaddr*, socklen_t) const override { return 0; } + int Bind(const sockaddr*, socklen_t) const override { return 0; } + + int Listen(int) const override { return 0; } + + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override + { + if (addr != nullptr) { + // Pretend all connections come from 5.5.5.5:6789 + memset(addr, 0x00, *addr_len); + const socklen_t write_len = static_cast(sizeof(sockaddr_in)); + if (*addr_len >= write_len) { + *addr_len = write_len; + sockaddr_in* addr_in = reinterpret_cast(addr); + addr_in->sin_family = AF_INET; + memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr)); + addr_in->sin_port = htons(6789); + } + } + return std::make_unique(""); + }; + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override { std::memset(opt_val, 0x0, *opt_len); return 0; } + int SetSockOpt(int, int, const void*, socklen_t) const override { return 0; } + + int GetSockName(sockaddr* name, socklen_t* name_len) const override + { + std::memset(name, 0x0, *name_len); + return 0; + } + bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 1a4d67a65eee..1d12669652d6 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -73,11 +74,57 @@ int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const return connect(m_socket, addr, addr_len); } +int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const +{ + return bind(m_socket, addr, addr_len); +} + +int Sock::Listen(int backlog) const +{ + return listen(m_socket, backlog); +} + +std::unique_ptr Sock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ +#ifdef WIN32 + static constexpr auto ERR = INVALID_SOCKET; +#else + static constexpr auto ERR = SOCKET_ERROR; +#endif + + std::unique_ptr sock; + + const auto socket = accept(m_socket, addr, addr_len); + if (socket != ERR) { + try { + sock = std::make_unique(socket); + } catch (const std::exception&) { +#ifdef WIN32 + closesocket(socket); +#else + close(socket); +#endif + } + } + + return sock; +} + int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const { return getsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); } +int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const +{ + return setsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); +} + +int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const +{ + return getsockname(m_socket, name, name_len); +} + bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const { #ifdef USE_POLL diff --git a/src/util/sock.h b/src/util/sock.h index 324e0c763ed5..377face66b8f 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -10,6 +10,7 @@ #include #include +#include #include /** @@ -144,6 +145,26 @@ class Sock */ [[nodiscard]] virtual int Connect(const sockaddr* addr, socklen_t addr_len) const; + /** + * bind(2) wrapper. Equivalent to `bind(this->Get(), addr, addr_len)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + [[nodiscard]] virtual int Bind(const sockaddr* addr, socklen_t addr_len) const; + + /** + * listen(2) wrapper. Equivalent to `listen(this->Get(), backlog)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + [[nodiscard]] virtual int Listen(int backlog) const; + + /** + * accept(2) wrapper. Equivalent to `std::make_unique(accept(this->Get(), addr, addr_len))`. + * Code that uses this wrapper can be unit tested if this method is overridden by a mock Sock + * implementation. + * The returned unique_ptr is empty if `accept()` failed in which case errno will be set. + */ + [[nodiscard]] virtual std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const; + /** * getsockopt(2) wrapper. Equivalent to * `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this @@ -154,6 +175,23 @@ class Sock void* opt_val, socklen_t* opt_len) const; + /** + * setsockopt(2) wrapper. Equivalent to + * `setsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + [[nodiscard]] virtual int SetSockOpt(int level, + int opt_name, + const void* opt_val, + socklen_t opt_len) const; + + /** + * getsockname(2) wrapper. Equivalent to + * `getsockname(this->Get(), name, name_len)`. Code that uses this + * wrapper can be unit tested if this method is overridden by a mock Sock implementation. + */ + [[nodiscard]] virtual int GetSockName(sockaddr* name, socklen_t* name_len) const; + using Event = uint8_t; /**