diff --git a/src/activemasternode.cpp b/src/activemasternode.cpp index 1d953f26fb0c9..b1d5bfd260957 100644 --- a/src/activemasternode.cpp +++ b/src/activemasternode.cpp @@ -156,9 +156,9 @@ void CActiveMasternode::ManageStateInitial() if(!fFoundLocal) { bool empty = true; // If we have some peers, let's try to find our local address from one of them - g_connman->ForEachNodeContinueIf([&fFoundLocal, &empty, this](CNode* pnode) { + g_connman->ForEachNodeContinueIf(CConnman::AllNodes, [&fFoundLocal, &empty, this](CNode* pnode) { empty = false; - if (pnode->fSuccessfullyConnected && pnode->addr.IsIPv4()) + if (pnode->addr.IsIPv4()) fFoundLocal = GetLocal(service, &pnode->addr) && CMasternode::IsValidNetAddr(service); return !fFoundLocal; }); diff --git a/src/masternode-sync.cpp b/src/masternode-sync.cpp index 1b325e8345375..d5c3322864b61 100644 --- a/src/masternode-sync.cpp +++ b/src/masternode-sync.cpp @@ -86,7 +86,7 @@ void CMasternodeSync::SwitchToNextAsset() // TRY_LOCK(cs_vNodes, lockRecv); // if(lockRecv) { ... } - g_connman->ForEachNode([](CNode* pnode) { + g_connman->ForEachNode(CConnman::AllNodes, [](CNode* pnode) { netfulfilledman.AddFulfilledRequest(pnode->addr, "full-sync"); }); LogPrintf("CMasternodeSync::SwitchToNextAsset -- Sync has finished\n"); @@ -132,7 +132,7 @@ void CMasternodeSync::ClearFulfilledRequests() // TRY_LOCK(cs_vNodes, lockRecv); // if(!lockRecv) return; - g_connman->ForEachNode([](CNode* pnode) { + g_connman->ForEachNode(CConnman::AllNodes, [](CNode* pnode) { netfulfilledman.RemoveFulfilledRequest(pnode->addr, "spork-sync"); netfulfilledman.RemoveFulfilledRequest(pnode->addr, "masternode-list-sync"); netfulfilledman.RemoveFulfilledRequest(pnode->addr, "masternode-payment-sync"); diff --git a/src/masternodeman.cpp b/src/masternodeman.cpp index 26e03fe7af501..c4761e0f6861c 100644 --- a/src/masternodeman.cpp +++ b/src/masternodeman.cpp @@ -763,7 +763,7 @@ void CMasternodeMan::ProcessMasternodeConnections() //we don't care about this for regtest if(Params().NetworkIDString() == CBaseChainParams::REGTEST) return; - g_connman->ForEachNode([](CNode* pnode) { + g_connman->ForEachNode(CConnman::AllNodes, [](CNode* pnode) { if(pnode->fMasternode) { if(privateSendClient.infoMixingMasternode.fInfoValid && pnode->addr == privateSendClient.infoMixingMasternode.addr) return true; LogPrintf("Closing Masternode connection: peer=%d, addr=%s\n", pnode->id, pnode->addr.ToString()); diff --git a/src/net.cpp b/src/net.cpp index 14b80c8f01bcc..cc32747169813 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -66,6 +66,9 @@ const static std::string NET_MESSAGE_COMMAND_OTHER = "*other*"; +constexpr const CConnman::CFullyConnectedOnly CConnman::FullyConnectedOnly; +constexpr const CConnman::CAllNodes CConnman::AllNodes; + // // Global state variables // @@ -719,6 +722,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete return true; } +void CNode::SetSendVersion(int nVersionIn) +{ + // Send version may only be changed in the version message, and + // only one version message is allowed per session. We can therefore + // treat this value as const and even atomic as long as it's only used + // once a version message has been successfully processed. Any attempt to + // set this twice is an error. + if (nSendVersion != 0) { + error("Send version already set for node: %i. Refusing to change from %i to %i", id, nSendVersion, nVersionIn); + } else { + nSendVersion = nVersionIn; + } +} + +int CNode::GetSendVersion() const +{ + // The send version should always be explicitly set to + // INIT_PROTO_VERSION rather than using this value until SetSendVersion + // has been called. + if (nSendVersion == 0) { + error("Requesting unset send version for node: %i. Using %i", id, INIT_PROTO_VERSION); + return INIT_PROTO_VERSION; + } + return nSendVersion; +} + + int CNetMessage::readHeader(const char *pch, unsigned int nBytes) { // copy data to temporary parsing buffer @@ -2725,6 +2755,11 @@ void CNode::AskFor(const CInv& inv) mapAskFor.insert(std::make_pair(nRequestTime, inv)); } +bool CConnman::NodeFullyConnected(const CNode* pnode) +{ + return pnode && pnode->fSuccessfullyConnected && !pnode->fDisconnect; +} + std::vector CNode::CalculateKeyedNetGroup(CAddress& address) { if(vchSecretKey.size() == 0) { @@ -2792,7 +2827,7 @@ void CConnman::PushMessage(CNode* pnode, CDataStream& strm, const std::string& s RecordBytesSent(nBytesSent); } -bool CConnman::ForNode(const CService& addr, std::function func) +bool CConnman::ForNode(const CService& addr, std::function cond, std::function func) { CNode* found = nullptr; LOCK(cs_vNodes); @@ -2802,10 +2837,10 @@ bool CConnman::ForNode(const CService& addr, std::function f break; } } - return found != nullptr && func(found); + return found != nullptr && cond(found) && func(found); } -bool CConnman::ForNode(NodeId id, std::function func) +bool CConnman::ForNode(NodeId id, std::function cond, std::function func) { CNode* found = nullptr; LOCK(cs_vNodes); @@ -2815,7 +2850,7 @@ bool CConnman::ForNode(NodeId id, std::function func) break; } } - return found != nullptr && func(found); + return found != nullptr && cond(found) && func(found); } int64_t PoissonNextSend(int64_t nNow, int average_interval_seconds) { diff --git a/src/net.h b/src/net.h index 52840c264500a..15a71c7b3063f 100644 --- a/src/net.h +++ b/src/net.h @@ -143,8 +143,34 @@ class CConnman // because it's used in many Dash-specific places (masternode, privatesend). CNode* ConnectNode(CAddress addrConnect, const char *pszDest = NULL, bool fConnectToMasternode = false); - bool ForNode(NodeId id, std::function func); - bool ForNode(const CService& addr, std::function func); + struct CFullyConnectedOnly { + bool operator() (const CNode* pnode) const { + return NodeFullyConnected(pnode); + } + }; + + constexpr static const CFullyConnectedOnly FullyConnectedOnly{}; + + struct CAllNodes { + bool operator() (const CNode*) const {return true;} + }; + + constexpr static const CAllNodes AllNodes{}; + + bool ForNode(NodeId id, std::function cond, std::function func); + bool ForNode(const CService& addr, std::function cond, std::function func); + + template + bool ForNode(const CService& addr, Callable&& func) + { + return ForNode(addr, FullyConnectedOnly, func); + } + + template + bool ForNode(NodeId id, Callable&& func) + { + return ForNode(id, FullyConnectedOnly, func); + } template void PushMessageWithVersionAndFlag(CNode* pnode, int nVersion, int flag, const std::string& sCommand, Args&&... args) @@ -173,87 +199,105 @@ class CConnman PushMessageWithVersionAndFlag(pnode, 0, 0, sCommand, std::forward(args)...); } - template - bool ForEachNodeContinueIf(Callable&& func) + template + bool ForEachNodeContinueIf(const Condition& cond, Callable&& func) { LOCK(cs_vNodes); for (auto&& node : vNodes) - if(!func(node)) - return false; + if (cond(node)) + if(!func(node)) + return false; return true; }; template - bool ForEachNodeContinueIf(Callable&& func) const + bool ForEachNodeContinueIf(Callable&& func) + { + return ForEachNodeContinueIf(FullyConnectedOnly, func); + } + + template + bool ForEachNodeContinueIf(const Condition& cond, Callable&& func) const { LOCK(cs_vNodes); for (const auto& node : vNodes) - if(!func(node)) - return false; + if (cond(node)) + if(!func(node)) + return false; return true; }; - template - bool ForEachNodeContinueIfThen(Callable&& pre, CallableAfter&& post) + template + bool ForEachNodeContinueIf(Callable&& func) const { - bool ret = true; - LOCK(cs_vNodes); - for (auto&& node : vNodes) - if(!pre(node)) { - ret = false; - break; - } - post(); - return ret; - }; + return ForEachNodeContinueIf(FullyConnectedOnly, func); + } - template - bool ForEachNodeContinueIfThen(Callable&& pre, CallableAfter&& post) const + template + void ForEachNode(const Condition& cond, Callable&& func) { - bool ret = true; LOCK(cs_vNodes); - for (const auto& node : vNodes) - if(!pre(node)) { - ret = false; - break; - } - post(); - return ret; + for (auto&& node : vNodes) { + if (cond(node)) + func(node); + } }; template void ForEachNode(Callable&& func) + { + ForEachNode(FullyConnectedOnly, func); + } + + template + void ForEachNode(const Condition& cond, Callable&& func) const { LOCK(cs_vNodes); - for (auto&& node : vNodes) - func(node); + for (auto&& node : vNodes) { + if (cond(node)) + func(node); + } }; template void ForEachNode(Callable&& func) const + { + ForEachNode(FullyConnectedOnly, func); + } + + template + void ForEachNodeThen(const Condition& cond, Callable&& pre, CallableAfter&& post) { LOCK(cs_vNodes); - for (const auto& node : vNodes) - func(node); + for (auto&& node : vNodes) { + if (cond(node)) + pre(node); + } + post(); }; template void ForEachNodeThen(Callable&& pre, CallableAfter&& post) + { + ForEachNodeThen(FullyConnectedOnly, pre, post); + } + + template + void ForEachNodeThen(const Condition& cond, Callable&& pre, CallableAfter&& post) const { LOCK(cs_vNodes); - for (auto&& node : vNodes) - pre(node); + for (auto&& node : vNodes) { + if (cond(node)) + pre(node); + } post(); }; template void ForEachNodeThen(Callable&& pre, CallableAfter&& post) const { - LOCK(cs_vNodes); - for (const auto& node : vNodes) - pre(node); - post(); - }; + ForEachNodeThen(FullyConnectedOnly, pre, post); + } std::vector CopyNodeVector(); void ReleaseNodeVector(const std::vector& vecNodes); @@ -391,6 +435,9 @@ class CConnman void RecordBytesRecv(uint64_t bytes); void RecordBytesSent(uint64_t bytes); + // Whether the node should be passed out in ForEach* callbacks + static bool NodeFullyConnected(const CNode* pnode); + // Network usage totals CCriticalSection cs_totalBytesRecv; CCriticalSection cs_totalBytesSent; @@ -634,7 +681,7 @@ class CNode std::string addrName; CService addrLocal; int nNumWarningsSkipped; - int nVersion; + std::atomic nVersion; // strSubVer is whatever byte array we read from the wire. However, this field is intended // to be printed out, displayed to humans in various forms and so on. So we sanitize it and // store the sanitized version in cleanSubVer. The original should be used when dealing with @@ -646,7 +693,7 @@ class CNode bool fClient; bool fInbound; bool fNetworkNode; - bool fSuccessfullyConnected; + std::atomic_bool fSuccessfullyConnected; bool fDisconnect; // We use fRelayTxes for two purposes - // a) it allows us to not relay tx invs before receiving the peer's version message @@ -763,25 +810,8 @@ class CNode BOOST_FOREACH(CNetMessage &msg, vRecvMsg) msg.SetVersion(nVersionIn); } - void SetSendVersion(int nVersionIn) - { - // Send version may only be changed in the version message, and - // only one version message is allowed per session. We can therefore - // treat this value as const and even atomic as long as it's only used - // once the handshake is complete. Any attempt to set this twice is an - // error. - assert(nSendVersion == 0); - nSendVersion = nVersionIn; - } - - int GetSendVersion() const - { - // The send version should always be explicitly set to - // INIT_PROTO_VERSION rather than using this value until the handshake - // is complete. See PushMessageWithVersion(). - assert(nSendVersion != 0); - return nSendVersion; - } + void SetSendVersion(int nVersionIn); + int GetSendVersion() const; CNode* AddRef() { diff --git a/src/net_processing.cpp b/src/net_processing.cpp index f8ddb4042c091..edae321033ea3 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -1122,46 +1122,51 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, CAddress addrFrom; uint64_t nNonce = 1; uint64_t nServiceInt; - vRecv >> pfrom->nVersion >> nServiceInt >> nTime >> addrMe; - pfrom->nServices = ServiceFlags(nServiceInt); + ServiceFlags nServices; + int nVersion; + int nSendVersion; + std::string strSubVer; + int nStartingHeight = -1; + bool fRelay = true; + + vRecv >> nVersion >> nServiceInt >> nTime >> addrMe; + nSendVersion = std::min(nVersion, PROTOCOL_VERSION); + nServices = ServiceFlags(nServiceInt); if (!pfrom->fInbound) { - connman.SetServices(pfrom->addr, pfrom->nServices); + connman.SetServices(pfrom->addr, nServices); } - if (pfrom->nServicesExpected & ~pfrom->nServices) + if (pfrom->nServicesExpected & ~nServices) { - LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, pfrom->nServices, pfrom->nServicesExpected); + LogPrint("net", "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->id, nServices, pfrom->nServicesExpected); connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::REJECT, strCommand, REJECT_NONSTANDARD, strprintf("Expected to offer services %08x", pfrom->nServicesExpected)); pfrom->fDisconnect = true; return false; } - if (pfrom->nVersion < MIN_PEER_PROTO_VERSION) + if (nVersion < MIN_PEER_PROTO_VERSION) { // disconnect from peers older than this proto version - LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, pfrom->nVersion); + LogPrintf("peer=%d using obsolete version %i; disconnecting\n", pfrom->id, nVersion); connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::REJECT, strCommand, REJECT_OBSOLETE, strprintf("Version must be %d or greater", MIN_PEER_PROTO_VERSION)); pfrom->fDisconnect = true; return false; } - if (pfrom->nVersion == 10300) - pfrom->nVersion = 300; + if (nVersion == 10300) + nVersion = 300; if (!vRecv.empty()) vRecv >> addrFrom >> nNonce; if (!vRecv.empty()) { - vRecv >> LIMITED_STRING(pfrom->strSubVer, MAX_SUBVERSION_LENGTH); - pfrom->cleanSubVer = SanitizeString(pfrom->strSubVer); + vRecv >> LIMITED_STRING(strSubVer, MAX_SUBVERSION_LENGTH); + } + if (!vRecv.empty()) { + vRecv >> nStartingHeight; } if (!vRecv.empty()) - vRecv >> pfrom->nStartingHeight; - if (!vRecv.empty()) - vRecv >> pfrom->fRelayTxes; // set to true after we get the first filter* message - else - pfrom->fRelayTxes = true; - + vRecv >> fRelay; // Disconnect if we connected to ourself if (pfrom->fInbound && !connman.CheckIncomingNonce(nNonce)) { @@ -1170,7 +1175,6 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, return true; } - pfrom->addrLocal = addrMe; if (pfrom->fInbound && addrMe.IsRoutable()) { SeenLocal(addrMe); @@ -1180,7 +1184,22 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, if (pfrom->fInbound) PushNodeVersion(pfrom, connman, GetAdjustedTime()); - pfrom->fClient = !(pfrom->nServices & NODE_NETWORK); + connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::VERACK); + + pfrom->nServices = nServices; + pfrom->addrLocal = addrMe; + pfrom->strSubVer = strSubVer; + pfrom->cleanSubVer = SanitizeString(strSubVer); + pfrom->nStartingHeight = nStartingHeight; + pfrom->fClient = !(nServices & NODE_NETWORK); + { + LOCK(pfrom->cs_filter); + pfrom->fRelayTxes = fRelay; // set to true after we get the first filter* message + } + + // Change version + pfrom->SetSendVersion(nSendVersion); + pfrom->nVersion = nVersion; // Potentially mark this peer as a preferred download peer. { @@ -1188,10 +1207,6 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, UpdatePreferredDownload(pfrom, State(pfrom->GetId())); } - // Change version - connman.PushMessageWithVersion(pfrom, INIT_PROTO_VERSION, NetMsgType::VERACK); - pfrom->SetSendVersion(min(pfrom->nVersion, PROTOCOL_VERSION)); - if (!pfrom->fInbound) { // Advertise our address @@ -1231,8 +1246,6 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, item.second.RelayTo(pfrom, connman); } - pfrom->fSuccessfullyConnected = true; - string remoteAddr; if (fLogIPs) remoteAddr = ", peeraddr=" + pfrom->addr.ToString(); @@ -1259,7 +1272,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, else if (strCommand == NetMsgType::VERACK) { - pfrom->SetRecvVersion(min(pfrom->nVersion, PROTOCOL_VERSION)); + pfrom->SetRecvVersion(std::min(pfrom->nVersion.load(), PROTOCOL_VERSION)); // Mark this node as currently connected, so we update its timestamp later. if (pfrom->fNetworkNode) { @@ -1274,6 +1287,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv, // nodes) connman.PushMessage(pfrom, NetMsgType::SENDHEADERS); } + pfrom->fSuccessfullyConnected = true; } @@ -2283,8 +2297,8 @@ bool SendMessages(CNode* pto, CConnman& connman, std::atomic& interruptMsg { const Consensus::Params& consensusParams = Params().GetConsensus(); { - // Don't send anything until we get its version message - if (pto->nVersion == 0) + // Don't send anything until the version handshake is complete + if (!pto->fSuccessfullyConnected || pto->fDisconnect) return true; // diff --git a/src/privatesend-client.cpp b/src/privatesend-client.cpp index cde47427bf022..bb9ecc2f9544b 100644 --- a/src/privatesend-client.cpp +++ b/src/privatesend-client.cpp @@ -850,7 +850,7 @@ bool CPrivateSendClient::JoinExistingQueue(CAmount nBalanceNeedsAnonymized) CNode* pnodeFound = NULL; bool fDisconnect = false; - g_connman->ForNode(infoMn.addr, [&pnodeFound, &fDisconnect](CNode* pnode) { + g_connman->ForNode(infoMn.addr, CConnman::AllNodes, [&pnodeFound, &fDisconnect](CNode* pnode) { pnodeFound = pnode; if(pnodeFound->fDisconnect) { fDisconnect = true; @@ -925,7 +925,7 @@ bool CPrivateSendClient::StartNewQueue(CAmount nValueMin, CAmount nBalanceNeedsA CNode* pnodeFound = NULL; bool fDisconnect = false; - g_connman->ForNode(infoMn.addr, [&pnodeFound, &fDisconnect](CNode* pnode) { + g_connman->ForNode(infoMn.addr, CConnman::AllNodes, [&pnodeFound, &fDisconnect](CNode* pnode) { pnodeFound = pnode; if(pnodeFound->fDisconnect) { fDisconnect = true; diff --git a/src/test/DoS_tests.cpp b/src/test/DoS_tests.cpp index deb7733bdcb17..180ad667a7ef7 100644 --- a/src/test/DoS_tests.cpp +++ b/src/test/DoS_tests.cpp @@ -54,6 +54,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) dummyNode1.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; + dummyNode1.fSuccessfullyConnected = true; Misbehaving(dummyNode1.GetId(), 100); // Should get banned SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(connman->IsBanned(addr1)); @@ -64,6 +65,7 @@ BOOST_AUTO_TEST_CASE(DoS_banning) dummyNode2.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode2, *connman); dummyNode2.nVersion = 1; + dummyNode2.fSuccessfullyConnected = true; Misbehaving(dummyNode2.GetId(), 50); SendMessages(&dummyNode2, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr2)); // 2 not banned yet... @@ -84,6 +86,7 @@ BOOST_AUTO_TEST_CASE(DoS_banscore) dummyNode1.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode1, *connman); dummyNode1.nVersion = 1; + dummyNode1.fSuccessfullyConnected = true; Misbehaving(dummyNode1.GetId(), 100); SendMessages(&dummyNode1, *connman, interruptDummy); BOOST_CHECK(!connman->IsBanned(addr1)); @@ -109,6 +112,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) dummyNode.SetSendVersion(PROTOCOL_VERSION); GetNodeSignals().InitializeNode(&dummyNode, *connman); dummyNode.nVersion = 1; + dummyNode.fSuccessfullyConnected = true; Misbehaving(dummyNode.GetId(), 100); SendMessages(&dummyNode, *connman, interruptDummy);