diff --git a/Cargo.lock b/Cargo.lock index 053f5cdcde..2b4d130e27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1750,6 +1750,7 @@ name = "dns_server" version = "0.1.2" dependencies = [ "async-trait", + "chainstate", "clap", "common", "crypto", diff --git a/dns_server/Cargo.toml b/dns_server/Cargo.toml index f8f2603e2e..bfc4f82e15 100644 --- a/dns_server/Cargo.toml +++ b/dns_server/Cargo.toml @@ -8,6 +8,7 @@ rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +chainstate = { path = "../chainstate" } common = { path = "../common" } crypto = { path = "../crypto" } logging = { path = "../logging" } diff --git a/dns_server/src/crawler_p2p/crawler/mod.rs b/dns_server/src/crawler_p2p/crawler/mod.rs index c8ee2ef79e..5aa3795380 100644 --- a/dns_server/src/crawler_p2p/crawler/mod.rs +++ b/dns_server/src/crawler_p2p/crawler/mod.rs @@ -34,14 +34,16 @@ use std::{ time::Duration, }; +use chainstate::ban_score::BanScore; use common::{chain::ChainConfig, primitives::time::Time}; use crypto::random::{seq::IteratorRandom, Rng}; use logging::log; use p2p::{ + config::{BanDuration, BanThreshold}, error::P2pError, net::types::PeerInfo, peer_manager::{ADDR_RATE_BUCKET_SIZE, ADDR_RATE_INITIAL_SIZE, MAX_ADDR_RATE_PER_SECOND}, - types::{peer_id::PeerId, socket_address::SocketAddress}, + types::{bannable_address::BannableAddress, peer_id::PeerId, socket_address::SocketAddress}, utils::rate_limiter::RateLimiter, }; @@ -52,22 +54,35 @@ use self::address_data::{AddressData, AddressState}; /// How many outbound connection attempts can be made per heartbeat const MAX_CONNECTS_PER_HEARTBEAT: usize = 25; +#[derive(Clone)] +pub struct CrawlerConfig { + pub ban_threshold: BanThreshold, + pub ban_duration: BanDuration, +} + /// The `Crawler` is the component that communicates with Mintlayer peers using p2p, /// and based on the results, commands the DNS server to add/remove ip addresses. /// The `Crawler` emits events that communicate whether addresses were reached or, /// are unreachable anymore. pub struct Crawler { + /// Current time. This value is advanced explicitly by the caller code. now: Time, /// Chain config chain_config: Arc, + /// Crawler config + config: CrawlerConfig, + /// Map of all known addresses (including currently unreachable); these addresses /// will be periodically tested, and reachable addresses will be handed /// to the DNS server to be returned to the user on DNS queries, /// and unreachable addresses will be removed from the DNS server addresses: BTreeMap, + /// Banned addresses. + banned_addresses: BTreeMap, + /// Map of all currently connected outbound peers that we successfully /// reached and are still connected to (generally speaking, /// we don't have to stay connected to those peers, but this is an implementation detail) @@ -77,6 +92,7 @@ pub struct Crawler { struct Peer { address: SocketAddress, address_rate_limiter: RateLimiter, + ban_score: u32, } pub enum CrawlerEvent { @@ -98,6 +114,10 @@ pub enum CrawlerEvent { address: SocketAddress, error: P2pError, }, + Misbehaved { + peer_id: PeerId, + error: P2pError, + }, } pub enum CrawlerCommand { @@ -112,16 +132,24 @@ pub enum CrawlerCommand { old_state: AddressState, new_state: AddressState, }, + MarkAsBanned { + address: BannableAddress, + ban_until: Time, + }, + RemoveBannedStatus { + address: BannableAddress, + }, } impl Crawler { pub fn new( + now: Time, chain_config: Arc, + config: CrawlerConfig, loaded_addresses: BTreeSet, + loaded_banned_addresses: BTreeMap, reserved_addresses: BTreeSet, ) -> Self { - let now = common::primitives::time::get_time(); - let addresses = loaded_addresses .union(&reserved_addresses) .map(|addr| { @@ -142,7 +170,9 @@ impl Crawler { Self { now, chain_config, + config, addresses, + banned_addresses: loaded_banned_addresses, outbound_peers: BTreeMap::new(), } } @@ -153,7 +183,7 @@ impl Crawler { peer_info: PeerInfo, callback: &mut impl FnMut(CrawlerCommand), ) { - log::info!("connected open, peer_id: {}", peer_info.peer_id); + log::info!("connection opened, peer_id: {}", peer_info.peer_id); self.create_outbound_peer(peer_info.peer_id, address, peer_info, callback); } @@ -177,6 +207,77 @@ impl Crawler { AddressStateTransitionTo::Disconnected, callback, ); + + self.handle_new_ban_score(&address, error.ban_score(), callback); + } + + fn handle_misbehaved_peer( + &mut self, + peer_id: PeerId, + error: P2pError, + callback: &mut impl FnMut(CrawlerCommand), + ) { + let ban_score = error.ban_score(); + + if ban_score > 0 { + log::debug!("handling misbehaved peer, peer_id: {peer_id}"); + + let peer = self + .outbound_peers + .get_mut(&peer_id) + .expect("peer must be known (handle_misbehaved_peer)"); + peer.ban_score = peer.ban_score.saturating_add(ban_score); + + log::info!( + "Adjusting peer ban score for peer {peer_id}, adjustment: {ban_score}, new score: {}", + peer.ban_score + ); + + let address = peer.address; + let new_score = peer.ban_score; + self.handle_new_ban_score(&address, new_score, callback); + } + } + + fn handle_new_ban_score( + &mut self, + address: &SocketAddress, + new_ban_score: u32, + callback: &mut impl FnMut(CrawlerCommand), + ) { + let ban_until = (self.now + *self.config.ban_duration).expect("Unexpected ban duration"); + + if new_ban_score >= *self.config.ban_threshold { + let address = address.as_bannable(); + + log::info!("Ban threshold for address {address} reached"); + + self.disconnect_all(&address, callback); + callback(CrawlerCommand::MarkAsBanned { address, ban_until }); + self.banned_addresses.insert(address, ban_until); + } + } + + fn disconnect_all( + &mut self, + address: &BannableAddress, + callback: &mut impl FnMut(CrawlerCommand), + ) { + let to_disconnect = self + .outbound_peers + .iter() + .filter_map(|(peer_id, peer)| { + if peer.address.as_bannable() == *address { + Some((*peer_id, peer.address)) + } else { + None + } + }) + .collect::>(); + + for (peer_id, peer_address) in to_disconnect { + self.disconnect_peer(peer_id, &peer_address, callback); + } } fn handle_disconnected(&mut self, peer_id: PeerId, callback: &mut impl FnMut(CrawlerCommand)) { @@ -252,6 +353,7 @@ impl Crawler { let peer = Peer { address, address_rate_limiter, + ban_score: 0, }; let old_peer = self.outbound_peers.insert(peer_id, peer); @@ -266,12 +368,12 @@ impl Crawler { is_compatible ); - let address_data = self - .addresses - .get_mut(&address) - .expect("address must be known (create_outbound_peer)"); - if is_compatible { + let address_data = self + .addresses + .get_mut(&address) + .expect("address must be known (create_outbound_peer)"); + Self::change_address_state( self.now, &address, @@ -280,18 +382,32 @@ impl Crawler { callback, ); } else { - callback(CrawlerCommand::Disconnect { peer_id }); - - Self::change_address_state( - self.now, - &address, - address_data, - AddressStateTransitionTo::Disconnecting, - callback, - ); + self.disconnect_peer(peer_id, &address, callback); } } + fn disconnect_peer( + &mut self, + peer_id: PeerId, + address: &SocketAddress, + callback: &mut impl FnMut(CrawlerCommand), + ) { + let address_data = self + .addresses + .get_mut(address) + .expect("address must be known (disconnect_peer)"); + + callback(CrawlerCommand::Disconnect { peer_id }); + + Self::change_address_state( + self.now, + address, + address_data, + AddressStateTransitionTo::Disconnecting, + callback, + ); + } + /// Remove existing outbound peer fn remove_outbound_peer(&mut self, peer_id: PeerId, callback: &mut impl FnMut(CrawlerCommand)) { log::debug!("outbound peer removed, peer_id: {}", peer_id); @@ -319,10 +435,23 @@ impl Crawler { /// /// Select random addresses to connect to, delete old addresses from memory and DB. fn heartbeat(&mut self, callback: &mut impl FnMut(CrawlerCommand), rng: &mut impl Rng) { + self.banned_addresses.retain(|address, banned_until| { + let banned = self.now < *banned_until; + + if !banned { + callback(CrawlerCommand::RemoveBannedStatus { address: *address }); + } + + banned + }); + let connecting_addresses = self .addresses .iter_mut() - .filter(|(_address, address_data)| address_data.connect_now(self.now)) + .filter(|(address, address_data)| { + address_data.connect_now(self.now) + && self.banned_addresses.get(&address.as_bannable()).is_none() + }) .choose_multiple(rng, MAX_CONNECTS_PER_HEARTBEAT); for (address, address_data) in connecting_addresses { @@ -370,6 +499,9 @@ impl Crawler { CrawlerEvent::ConnectionError { address, error } => { self.handle_connection_error(address, error, callback); } + CrawlerEvent::Misbehaved { peer_id, error } => { + self.handle_misbehaved_peer(peer_id, error, callback) + } } } } diff --git a/dns_server/src/crawler_p2p/crawler/tests/mock_crawler.rs b/dns_server/src/crawler_p2p/crawler/tests/mock_crawler.rs index b1fa55b712..b2f13135ff 100644 --- a/dns_server/src/crawler_p2p/crawler/tests/mock_crawler.rs +++ b/dns_server/src/crawler_p2p/crawler/tests/mock_crawler.rs @@ -19,12 +19,14 @@ use std::{ time::Duration, }; -use common::chain::ChainConfig; +use common::{chain::ChainConfig, primitives::time::Time}; use crypto::random::Rng; -use p2p::types::{peer_id::PeerId, socket_address::SocketAddress}; +use p2p::types::{ + bannable_address::BannableAddress, peer_id::PeerId, socket_address::SocketAddress, +}; use crate::crawler_p2p::crawler::{ - address_data::AddressState, Crawler, CrawlerCommand, CrawlerEvent, + address_data::AddressState, Crawler, CrawlerCommand, CrawlerConfig, CrawlerEvent, }; /// Mock crawler @@ -39,6 +41,7 @@ pub struct MockCrawler { pub pending_disconnects: BTreeSet, pub peers: BTreeMap, pub peer_addresses: BTreeMap, + pub banned_addresses: BTreeMap, } #[derive(Debug)] @@ -55,14 +58,19 @@ pub struct AddressUpdate { } pub fn test_crawler( + config: CrawlerConfig, loaded_addresses: BTreeSet, + loaded_banned_addresses: BTreeMap, added_addresses: BTreeSet, ) -> MockCrawler { let chain_config = Arc::new(common::chain::config::create_mainnet()); let crawler = Crawler::new( + Time::from_duration_since_epoch(Duration::ZERO), chain_config.clone(), + config, loaded_addresses.clone(), + loaded_banned_addresses.clone(), added_addresses, ); @@ -77,6 +85,7 @@ pub fn test_crawler( pending_disconnects: Default::default(), peers: Default::default(), peer_addresses: BTreeMap::new(), + banned_addresses: loaded_banned_addresses, } } @@ -118,6 +127,10 @@ impl MockCrawler { let removed = self.pending_connects.remove(address); assert!(removed); } + CrawlerEvent::Misbehaved { + peer_id: _, + error: _, + } => {} } let mut cmd_handler = |cmd| match cmd { @@ -165,6 +178,12 @@ impl MockCrawler { new_state, }); } + CrawlerCommand::MarkAsBanned { address, ban_until } => { + self.banned_addresses.insert(address, ban_until); + } + CrawlerCommand::RemoveBannedStatus { address } => { + self.banned_addresses.remove(&address); + } }; self.crawler.step(event, &mut cmd_handler, rng); @@ -178,10 +197,61 @@ impl MockCrawler { assert!(peer.is_compatible); } - // Verify that all compatible nodes are reachable + // Verify that all compatible nodes are reachable (unless they are being disconnected + // at the moment) and all incompatible ones are non-reachable. for peer in self.peers.values() { let valid_ip = peer.is_compatible; - assert_eq!(self.reachable.contains(&peer.address), valid_ip); + let is_reachable = self.reachable.contains(&peer.address); + let peer_id = self.peer_addresses.get(&peer.address).unwrap(); + let is_being_disconnected = self.pending_disconnects.contains(peer_id); + + if valid_ip { + assert!(is_reachable || is_being_disconnected); + } else { + assert!(!is_reachable); + } + } + } + + pub fn now(&self) -> Time { + self.crawler.now + } + + pub fn assert_banned_addresses(&self, expected: &[(BannableAddress, Time)]) { + let expected: BTreeMap<_, _> = expected.iter().copied().collect(); + assert_eq!(self.banned_addresses, expected); + assert_eq!(self.crawler.banned_addresses, expected); + } + + pub fn assert_ban_scores(&self, expected: &[(PeerId, u32)]) { + let expected: BTreeMap<_, _> = expected.iter().copied().collect(); + + for (peer_id, peer) in &self.crawler.outbound_peers { + assert_eq!( + // "Compare" peer_id too, so that it appears in the message if the assertion fails. + (*peer_id, peer.ban_score), + (*peer_id, *expected.get(peer_id).unwrap_or(&0)) + ); } } + + pub fn assert_pending_connects(&self, expected: &[SocketAddress]) { + let expected: BTreeSet<_> = expected.iter().copied().collect(); + assert_eq!(self.pending_connects, expected); + } + + pub fn assert_pending_disconnects(&self, expected: &[PeerId]) { + let expected: BTreeSet<_> = expected.iter().copied().collect(); + assert_eq!(self.pending_disconnects, expected); + } + + pub fn assert_connected_peers(&self, expected: &[PeerId]) { + let expected: BTreeSet<_> = expected.iter().copied().collect(); + + let actual1: BTreeSet<_> = self.peers.keys().copied().collect(); + assert_eq!(actual1, expected); + + let actual2: BTreeSet<_> = self.crawler.outbound_peers.keys().copied().collect(); + assert_eq!(actual2, expected); + } } diff --git a/dns_server/src/crawler_p2p/crawler/tests/mod.rs b/dns_server/src/crawler_p2p/crawler/tests/mod.rs index 69d4ff3aa7..51826b329d 100644 --- a/dns_server/src/crawler_p2p/crawler/tests/mod.rs +++ b/dns_server/src/crawler_p2p/crawler/tests/mod.rs @@ -16,15 +16,19 @@ mod mock_crawler; use std::{ - collections::BTreeSet, + collections::{BTreeMap, BTreeSet}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, time::Duration, }; -use common::primitives::user_agent::mintlayer_core_user_agent; +use chainstate::ban_score::BanScore; +use common::{ + chain::ChainConfig, + primitives::{time::Time, user_agent::mintlayer_core_user_agent}, +}; use p2p::{ - config::NodeType, - error::{DialError, P2pError}, + config::{BanDuration, BanThreshold, NodeType}, + error::{DialError, P2pError, ProtocolError}, net::types::PeerInfo, testing_utils::TEST_PROTOCOL_VERSION, types::{peer_id::PeerId, socket_address::SocketAddress}, @@ -41,6 +45,8 @@ use mock_crawler::test_crawler; use crate::crawler_p2p::crawler::CrawlerEvent; +use super::CrawlerConfig; + #[rstest] #[trace] #[case(Seed::from_entropy())] @@ -49,7 +55,12 @@ fn basic(#[case] seed: Seed) { let node1: SocketAddress = "1.2.3.4:3031".parse().unwrap(); let peer1 = PeerId::new(); let chain_config = common::chain::config::create_mainnet(); - let mut crawler = test_crawler(BTreeSet::new(), [node1].into_iter().collect()); + let mut crawler = test_crawler( + make_config(), + BTreeSet::new(), + BTreeMap::new(), + [node1].into_iter().collect(), + ); crawler.timer(Duration::from_secs(100), &mut rng); assert_eq!(crawler.pending_connects.len(), 1); @@ -58,14 +69,7 @@ fn basic(#[case] seed: Seed) { crawler.step( CrawlerEvent::Connected { address: node1, - peer_info: PeerInfo { - peer_id: peer1, - protocol_version: TEST_PROTOCOL_VERSION, - network: *chain_config.magic_bytes(), - software_version: *chain_config.software_version(), - user_agent: mintlayer_core_user_agent(), - common_services: NodeType::DnsServer.into(), - }, + peer_info: make_peer_info(peer1, &chain_config), }, &mut rng, ); @@ -90,14 +94,7 @@ fn basic(#[case] seed: Seed) { crawler.step( CrawlerEvent::Connected { address: node2, - peer_info: PeerInfo { - peer_id: peer2, - protocol_version: TEST_PROTOCOL_VERSION, - network: *chain_config.magic_bytes(), - software_version: *chain_config.software_version(), - user_agent: mintlayer_core_user_agent(), - common_services: NodeType::DnsServer.into(), - }, + peer_info: make_peer_info(peer2, &chain_config), }, &mut rng, ); @@ -147,7 +144,7 @@ fn randomized(#[case] seed: Seed) { let loaded_count = rng.gen_range(0..10); let loaded_nodes = nodes.choose_multiple(&mut rng, loaded_count).cloned().collect(); - let mut crawler = test_crawler(loaded_nodes, reserved_nodes); + let mut crawler = test_crawler(make_config(), loaded_nodes, BTreeMap::new(), reserved_nodes); for _ in 0..rng.gen_range(0..100000) { crawler.timer(Duration::from_secs(rng.gen_range(0..100)), &mut rng); @@ -170,14 +167,7 @@ fn randomized(#[case] seed: Seed) { crawler.step( CrawlerEvent::Connected { address, - peer_info: PeerInfo { - peer_id: PeerId::new(), - protocol_version: TEST_PROTOCOL_VERSION, - network: *chain_config.magic_bytes(), - software_version: *chain_config.software_version(), - user_agent: mintlayer_core_user_agent(), - common_services: NodeType::DnsServer.into(), - }, + peer_info: make_peer_info(PeerId::new(), &chain_config), }, &mut rng, ) @@ -189,14 +179,7 @@ fn randomized(#[case] seed: Seed) { crawler.step( CrawlerEvent::Connected { address, - peer_info: PeerInfo { - peer_id: PeerId::new(), - protocol_version: TEST_PROTOCOL_VERSION, - network: [255, 255, 255, 255], - software_version: *chain_config.software_version(), - user_agent: mintlayer_core_user_agent(), - common_services: NodeType::DnsServer.into(), - }, + peer_info: make_peer_info(PeerId::new(), &chain_config), }, &mut rng, ) @@ -230,13 +213,18 @@ fn incompatible_node(#[case] seed: Seed) { let node1: SocketAddress = "1.2.3.4:3031".parse().unwrap(); let peer1 = PeerId::new(); let chain_config = common::chain::config::create_mainnet(); - let mut crawler = test_crawler(BTreeSet::new(), [node1].into_iter().collect()); + let mut crawler = test_crawler( + make_config(), + BTreeSet::new(), + BTreeMap::new(), + [node1].into_iter().collect(), + ); - // // Crawler attempts to connect to the specified node + // Crawler attempts to connect to the specified node crawler.timer(Duration::from_secs(100), &mut rng); assert!(crawler.pending_connects.contains(&node1)); - // // Connection to the node is successful + // Connection to the node is successful crawler.step( CrawlerEvent::Connected { address: node1, @@ -264,7 +252,9 @@ fn long_offline(#[case] seed: Seed) { let loaded_node: SocketAddress = "1.0.0.0:3031".parse().unwrap(); let added_node: SocketAddress = "2.0.0.0:3031".parse().unwrap(); let mut crawler = test_crawler( + make_config(), [loaded_node].into_iter().collect(), + BTreeMap::new(), [added_node].into_iter().collect(), ); assert!(crawler.persistent.contains(&loaded_node)); @@ -299,3 +289,368 @@ fn long_offline(#[case] seed: Seed) { assert!(crawler.connect_requests.iter().any(|addr| *addr == added_node)); crawler.connect_requests.clear(); } + +// Connect to two peers and then send CrawlerEvent::Misbehaved for one of them several times, +// making sure that the ban score is updated accordingly and that eventually the peer is banned. +// Also check that we don't reconnect to the banned peer until the ban end is reached. +#[rstest] +#[trace] +#[case(Seed::from_entropy())] +fn ban_misbehaved_peer(#[case] seed: Seed) { + let mut rng = make_seedable_rng(seed); + + let node1: SocketAddress = "1.2.3.4:1234".parse().unwrap(); + let peer1 = PeerId::new(); + let node2: SocketAddress = "2.3.4.5:2345".parse().unwrap(); + let peer2 = PeerId::new(); + + let test_error = P2pError::ProtocolError(ProtocolError::UnexpectedMessage("".to_owned())); + let test_error_ban_score = test_error.ban_score(); + assert!(test_error_ban_score > 0); + let ban_threshold = test_error_ban_score * 2; + + let chain_config = common::chain::config::create_mainnet(); + let mut crawler = test_crawler( + CrawlerConfig { + ban_duration: BanDuration::new(BAN_DURATION), + ban_threshold: BanThreshold::new(ban_threshold), + }, + BTreeSet::new(), + BTreeMap::new(), + [node1, node2].into_iter().collect(), + ); + + let times_step = Duration::from_secs(100); + + crawler.timer(times_step, &mut rng); + + crawler.assert_pending_connects(&[node1, node2]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); + + crawler.step( + CrawlerEvent::Connected { + address: node1, + peer_info: make_peer_info(peer1, &chain_config), + }, + &mut rng, + ); + + crawler.step( + CrawlerEvent::Connected { + address: node2, + peer_info: make_peer_info(peer2, &chain_config), + }, + &mut rng, + ); + + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer1, peer2]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); + + crawler.step( + CrawlerEvent::Misbehaved { + peer_id: peer1, + error: test_error.clone(), + }, + &mut rng, + ); + + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer1, peer2]); + crawler.assert_ban_scores(&[(peer1, test_error_ban_score)]); + crawler.assert_banned_addresses(&[]); + + let ban_start_time = crawler.now(); + let ban_end_time = (ban_start_time + BAN_DURATION).unwrap(); + + crawler.step( + CrawlerEvent::Misbehaved { + peer_id: peer1, + error: test_error, + }, + &mut rng, + ); + + // The peer is banned. + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[peer1]); + crawler.assert_connected_peers(&[peer1, peer2]); + crawler.assert_ban_scores(&[(peer1, test_error_ban_score * 2)]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); + + crawler.step(CrawlerEvent::Disconnected { peer_id: peer1 }, &mut rng); + + // The peer has become disconnected and its ban score was lost. But it's still banned. + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer2]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); + + // Wait some (small) time, the peer should still be banned. + crawler.timer(times_step, &mut rng); + + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer2]); + crawler.assert_ban_scores(&[(peer1, test_error_ban_score * 10)]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); + + // Sanity check + assert!(crawler.now() < ban_end_time); + + // Wait for the remaining ban time. + crawler.timer((ban_end_time - crawler.now()).unwrap(), &mut rng); + + // The peer is no longer banned; instead, it is being connected to. + crawler.assert_pending_connects(&[node1]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer2]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); +} + +// Connect to three peers, where two of them share the same ip address, and then send +// CrawlerEvent::Misbehaved for one of those. Make sure that both peers get disconnected and +// no connect attempts are made until the ban end is reached. +#[rstest] +#[trace] +#[case(Seed::from_entropy())] +fn ban_misbehaved_peers_with_same_address(#[case] seed: Seed) { + let mut rng = make_seedable_rng(seed); + + let node1: SocketAddress = "1.2.3.4:1234".parse().unwrap(); + let peer1 = PeerId::new(); + let node2: SocketAddress = "2.3.4.5:2345".parse().unwrap(); + let peer2 = PeerId::new(); + let node3: SocketAddress = "1.2.3.4:4321".parse().unwrap(); + let peer3 = PeerId::new(); + + assert_eq!(node3.as_bannable(), node1.as_bannable()); + + let test_error = P2pError::ProtocolError(ProtocolError::UnexpectedMessage("".to_owned())); + let test_error_ban_score = test_error.ban_score(); + assert!(test_error_ban_score > 0); + let ban_threshold = test_error_ban_score; + + let chain_config = common::chain::config::create_mainnet(); + let mut crawler = test_crawler( + CrawlerConfig { + ban_duration: BanDuration::new(BAN_DURATION), + ban_threshold: BanThreshold::new(ban_threshold), + }, + BTreeSet::new(), + BTreeMap::new(), + [node1, node2, node3].into_iter().collect(), + ); + + let times_step = Duration::from_secs(100); + + crawler.timer(times_step, &mut rng); + + crawler.assert_pending_connects(&[node1, node2, node3]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); + + crawler.step( + CrawlerEvent::Connected { + address: node1, + peer_info: make_peer_info(peer1, &chain_config), + }, + &mut rng, + ); + + crawler.step( + CrawlerEvent::Connected { + address: node2, + peer_info: make_peer_info(peer2, &chain_config), + }, + &mut rng, + ); + + crawler.step( + CrawlerEvent::Connected { + address: node3, + peer_info: make_peer_info(peer3, &chain_config), + }, + &mut rng, + ); + + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer1, peer2, peer3]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); + + let ban_start_time = crawler.now(); + let ban_end_time = (ban_start_time + BAN_DURATION).unwrap(); + + crawler.step( + CrawlerEvent::Misbehaved { + peer_id: peer1, + error: test_error.clone(), + }, + &mut rng, + ); + + // The peer1's address is banned; peer1 and peer3 are being disconnected. + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[peer1, peer3]); + crawler.assert_connected_peers(&[peer1, peer2, peer3]); + crawler.assert_ban_scores(&[(peer1, test_error_ban_score)]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); + + crawler.step(CrawlerEvent::Disconnected { peer_id: peer1 }, &mut rng); + crawler.step(CrawlerEvent::Disconnected { peer_id: peer3 }, &mut rng); + + // peer1 and peer3 are now disconnected; the ban score was lost, but the address is still banned. + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer2]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); + + // Wait some (small) time, the address should still be banned. + crawler.timer(times_step, &mut rng); + + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer2]); + crawler.assert_ban_scores(&[(peer1, test_error_ban_score * 10)]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); + + // Sanity check + assert!(crawler.now() < ban_end_time); + + // Wait for the remaining ban time. + crawler.timer((ban_end_time - crawler.now()).unwrap(), &mut rng); + + // The address is no longer banned; instead, both peer1 and peer3 are being connected to. + crawler.assert_pending_connects(&[node1, node3]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer2]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); +} + +// Create a crawler with 2 addresses and mark one of them as banned. +// Make sure it doesn't try to connect to the banned address. +#[rstest] +#[trace] +#[case(Seed::from_entropy())] +fn dont_connect_to_initially_banned_peer(#[case] seed: Seed) { + let mut rng = make_seedable_rng(seed); + + let node1: SocketAddress = "1.2.3.4:1234".parse().unwrap(); + let node2: SocketAddress = "2.3.4.5:2345".parse().unwrap(); + + let ban_end_time = Time::from_duration_since_epoch(BAN_DURATION); + + let mut crawler = test_crawler( + make_config(), + BTreeSet::new(), + [(node1.as_bannable(), ban_end_time)].into_iter().collect(), + [node1, node2].into_iter().collect(), + ); + + crawler.timer(Duration::from_secs(100), &mut rng); + + crawler.assert_pending_connects(&[node2]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[(node1.as_bannable(), ban_end_time)]); +} + +// Check that a peer is banned on CrawlerEvent::ConnectionError. +#[rstest] +#[trace] +#[case(Seed::from_entropy())] +fn ban_on_connection_error(#[case] seed: Seed) { + let mut rng = make_seedable_rng(seed); + + let node1: SocketAddress = "1.2.3.4:1234".parse().unwrap(); + let peer1 = PeerId::new(); + let node2: SocketAddress = "2.3.4.5:2345".parse().unwrap(); + + let test_error = P2pError::ProtocolError(ProtocolError::HandshakeExpected); + let test_error_ban_score = test_error.ban_score(); + assert!(test_error_ban_score > 0); + let ban_threshold = test_error_ban_score; + + let chain_config = common::chain::config::create_mainnet(); + let mut crawler = test_crawler( + CrawlerConfig { + ban_duration: BanDuration::new(BAN_DURATION), + ban_threshold: BanThreshold::new(ban_threshold), + }, + BTreeSet::new(), + BTreeMap::new(), + [node1, node2].into_iter().collect(), + ); + + let times_step = Duration::from_secs(100); + + crawler.timer(times_step, &mut rng); + + crawler.assert_pending_connects(&[node1, node2]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[]); + + crawler.step( + CrawlerEvent::Connected { + address: node1, + peer_info: make_peer_info(peer1, &chain_config), + }, + &mut rng, + ); + + crawler.step( + CrawlerEvent::ConnectionError { + address: node2, + error: test_error, + }, + &mut rng, + ); + + let ban_start_time = crawler.now(); + let ban_end_time = (ban_start_time + BAN_DURATION).unwrap(); + + // The ban score is not recorded, but the peer is banned. + crawler.assert_pending_connects(&[]); + crawler.assert_pending_disconnects(&[]); + crawler.assert_connected_peers(&[peer1]); + crawler.assert_ban_scores(&[]); + crawler.assert_banned_addresses(&[(node2.as_bannable(), ban_end_time)]); +} + +const BAN_DURATION: Duration = Duration::from_secs(1000); +const BAN_THRESHOLD: u32 = 100; + +fn make_config() -> CrawlerConfig { + CrawlerConfig { + ban_duration: BanDuration::new(BAN_DURATION), + ban_threshold: BanThreshold::new(BAN_THRESHOLD), + } +} + +fn make_peer_info(peer_id: PeerId, chain_config: &ChainConfig) -> PeerInfo { + PeerInfo { + peer_id, + protocol_version: TEST_PROTOCOL_VERSION, + network: *chain_config.magic_bytes(), + software_version: *chain_config.software_version(), + user_agent: mintlayer_core_user_agent(), + common_services: NodeType::DnsServer.into(), + } +} diff --git a/dns_server/src/crawler_p2p/crawler_manager/mod.rs b/dns_server/src/crawler_p2p/crawler_manager/mod.rs index 413f7dd475..692757ea1a 100644 --- a/dns_server/src/crawler_p2p/crawler_manager/mod.rs +++ b/dns_server/src/crawler_p2p/crawler_manager/mod.rs @@ -17,7 +17,7 @@ pub mod storage; pub mod storage_impl; use std::{ - collections::BTreeSet, + collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr}, sync::Arc, time::Duration, @@ -38,8 +38,8 @@ use p2p::{ peerdb_common::{storage::update_db, TransactionRo, TransactionRw}, }, types::{ - ip_or_socket_address::IpOrSocketAddress, peer_address::PeerAddress, peer_id::PeerId, - socket_address::SocketAddress, IsGlobalIp, + bannable_address::BannableAddress, ip_or_socket_address::IpOrSocketAddress, + peer_address::PeerAddress, peer_id::PeerId, socket_address::SocketAddress, IsGlobalIp, }, }; use tokio::sync::mpsc; @@ -48,7 +48,7 @@ use crate::{dns_server::DnsServerCommand, error::DnsServerError}; use self::storage::{DnsServerStorage, DnsServerStorageRead, DnsServerStorageWrite}; -use super::crawler::{Crawler, CrawlerCommand, CrawlerEvent}; +use super::crawler::{Crawler, CrawlerCommand, CrawlerConfig, CrawlerEvent}; /// How often the server performs maintenance (tries to connect to new nodes) const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -91,14 +91,31 @@ pub struct CrawlerManager { dns_server_cmd_tx: mpsc::UnboundedSender, } +// Note: "pub" access is only needed because of the "load_storage_for_tests" function. +pub struct LoadedStorage { + pub known_addresses: BTreeSet, + pub banned_addresses: BTreeMap, +} + +impl LoadedStorage { + pub fn new() -> Self { + Self { + known_addresses: BTreeSet::new(), + banned_addresses: BTreeMap::new(), + } + } +} + impl CrawlerManager where N::SyncingEventReceiver: SyncingEventReceiver, N::ConnectivityHandle: ConnectivityService, { + #[allow(clippy::too_many_arguments)] pub fn new( time_getter: TimeGetter, config: CrawlerManagerConfig, + crawler_config: CrawlerConfig, chain_config: Arc, conn: N::ConnectivityHandle, sync: N::SyncingEventReceiver, @@ -108,7 +125,7 @@ where let last_crawler_timer = time_getter.get_time(); // Addresses that are stored in the DB as reachable - let loaded_addresses: BTreeSet = Self::load_storage(&storage)?; + let loaded_storage = Self::load_storage(&storage)?; // Addresses listed as reachable from the command line let reserved_addresses: BTreeSet = config @@ -119,7 +136,14 @@ where assert!(conn.local_addresses().is_empty()); - let crawler = Crawler::new(chain_config, loaded_addresses, reserved_addresses); + let crawler = Crawler::new( + last_crawler_timer, + chain_config, + crawler_config, + loaded_storage.known_addresses, + loaded_storage.banned_addresses, + reserved_addresses, + ); Ok(Self { time_getter, @@ -133,7 +157,7 @@ where }) } - fn load_storage(storage: &S) -> Result, DnsServerError> { + fn load_storage(storage: &S) -> Result { let tx = storage.transaction_ro()?; let version = tx.get_version()?; tx.close(); @@ -145,18 +169,28 @@ where } } - fn init_storage(storage: &S) -> Result, DnsServerError> { + fn init_storage(storage: &S) -> Result { let mut tx = storage.transaction_rw()?; tx.set_version(STORAGE_VERSION)?; tx.commit()?; - Ok(BTreeSet::new()) + Ok(LoadedStorage::new()) } - fn load_storage_v1(storage: &S) -> Result, DnsServerError> { + fn load_storage_v1(storage: &S) -> Result { let tx = storage.transaction_ro()?; - let addresses = + let known_addresses = tx.get_addresses()?.iter().filter_map(|address| address.parse().ok()).collect(); - Ok(addresses) + + let banned_addresses = tx + .get_banned_addresses()? + .iter() + .filter_map(|(address, ban_until)| address.parse().ok().map(|addr| (addr, *ban_until))) + .collect(); + + Ok(LoadedStorage { + known_addresses, + banned_addresses, + }) } fn handle_conn_message(&mut self, peer_id: PeerId, message: PeerManagerMessage) { @@ -219,12 +253,8 @@ where ConnectivityEvent::ConnectionClosed { peer_id } => { self.send_crawler_event(CrawlerEvent::Disconnected { peer_id }); } - ConnectivityEvent::Misbehaved { - peer_id: _, - error: _, - } => { - // Ignore all misbehave reports - // TODO: Should we ban peers when they send unexpected messages? + ConnectivityEvent::Misbehaved { peer_id, error } => { + self.send_crawler_event(CrawlerEvent::Misbehaved { peer_id, error }); } } } @@ -309,6 +339,16 @@ where _ => {} } } + CrawlerCommand::MarkAsBanned { address, ban_until } => { + update_db(storage, |tx| { + tx.add_banned_address(&address.to_string(), ban_until) + }) + .expect("update_db must succeed (add_banned_address)"); + } + CrawlerCommand::RemoveBannedStatus { address } => { + update_db(storage, |tx| tx.del_banned_address(&address.to_string())) + .expect("update_db must succeed (del_banned_address)"); + } } } @@ -345,6 +385,11 @@ where } } } + + #[cfg(test)] + pub fn load_storage_for_tests(&self) -> Result { + Self::load_storage(&self.storage) + } } #[cfg(test)] diff --git a/dns_server/src/crawler_p2p/crawler_manager/storage.rs b/dns_server/src/crawler_p2p/crawler_manager/storage.rs index 84550ed6b1..ea06d02ed2 100644 --- a/dns_server/src/crawler_p2p/crawler_manager/storage.rs +++ b/dns_server/src/crawler_p2p/crawler_manager/storage.rs @@ -13,12 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common::primitives::time::Time; use p2p::peer_manager::peerdb_common::{TransactionRo, TransactionRw, Transactional}; pub trait DnsServerStorageRead { fn get_version(&self) -> Result, storage::Error>; fn get_addresses(&self) -> Result, storage::Error>; + + fn get_banned_addresses(&self) -> Result, storage::Error>; } pub trait DnsServerStorageWrite { @@ -27,6 +30,10 @@ pub trait DnsServerStorageWrite { fn add_address(&mut self, address: &str) -> Result<(), storage::Error>; fn del_address(&mut self, address: &str) -> Result<(), storage::Error>; + + fn add_banned_address(&mut self, address: &str, time: Time) -> Result<(), storage::Error>; + + fn del_banned_address(&mut self, address: &str) -> Result<(), storage::Error>; } // Note: here we want to say something like: diff --git a/dns_server/src/crawler_p2p/crawler_manager/storage_impl.rs b/dns_server/src/crawler_p2p/crawler_manager/storage_impl.rs index f1f9e9400a..1fdf920f62 100644 --- a/dns_server/src/crawler_p2p/crawler_manager/storage_impl.rs +++ b/dns_server/src/crawler_p2p/crawler_manager/storage_impl.rs @@ -13,7 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use super::storage::{DnsServerStorage, DnsServerStorageRead, DnsServerStorageWrite}; +use common::primitives::time::Time; use p2p::peer_manager::peerdb_common::storage_impl::{StorageImpl, StorageTxRo, StorageTxRw}; use serialization::{encoded::Encoded, DecodeAll, Encode}; use storage::MakeMapRef; @@ -28,6 +31,9 @@ storage::decl_schema! { /// Table for all reachable addresses pub DBAddresses: Map, + + /// Table for banned addresses + pub DBBannedAddresses: Map, } } @@ -52,6 +58,16 @@ impl<'st, B: storage::Backend> DnsServerStorageWrite for DnsServerStoreTxRw<'st, fn del_address(&mut self, address: &str) -> Result<(), storage::Error> { self.storage().get_mut::().del(address) } + + fn add_banned_address(&mut self, address: &str, time: Time) -> Result<(), storage::Error> { + self.storage() + .get_mut::() + .put(address, time.as_duration_since_epoch()) + } + + fn del_banned_address(&mut self, address: &str) -> Result<(), storage::Error> { + self.storage().get_mut::().del(address) + } } impl<'st, B: storage::Backend> DnsServerStorageRead for DnsServerStoreTxRo<'st, B> { @@ -68,4 +84,12 @@ impl<'st, B: storage::Backend> DnsServerStorageRead for DnsServerStoreTxRo<'st, let iter = map.prefix_iter_decoded(&())?.map(|(addr, ())| addr); Ok(iter.collect::>()) } + + fn get_banned_addresses(&self) -> Result, storage::Error> { + let map = self.storage().get::(); + let iter = map + .prefix_iter_decoded(&())? + .map(|(addr, dur)| (addr, Time::from_duration_since_epoch(dur))); + Ok(iter.collect::>()) + } } diff --git a/dns_server/src/crawler_p2p/crawler_manager/tests/mock_manager.rs b/dns_server/src/crawler_p2p/crawler_manager/tests/mock_manager.rs index 85b6a5be06..f851856899 100644 --- a/dns_server/src/crawler_p2p/crawler_manager/tests/mock_manager.rs +++ b/dns_server/src/crawler_p2p/crawler_manager/tests/mock_manager.rs @@ -18,7 +18,7 @@ //! The mock simulates a network where peers go online and offline. use std::{ - collections::BTreeMap, + collections::{BTreeMap, BTreeSet}, sync::{Arc, Mutex}, time::Duration, }; @@ -31,12 +31,12 @@ use tokio::{ use common::{ chain::ChainConfig, - primitives::{semver::SemVer, user_agent::mintlayer_core_user_agent}, + primitives::{semver::SemVer, time::Time, user_agent::mintlayer_core_user_agent}, time_getter::TimeGetter, }; use p2p::{ - config::{NodeType, P2pConfig}, - error::{DialError, P2pError}, + config::{BanDuration, BanThreshold, NodeType, P2pConfig}, + error::{DialError, P2pError, ProtocolError}, message::{AnnounceAddrRequest, PeerManagerMessage}, net::{ types::{ConnectivityEvent, PeerInfo, SyncingEvent}, @@ -44,8 +44,8 @@ use p2p::{ }, testing_utils::TEST_PROTOCOL_VERSION, types::{ - ip_or_socket_address::IpOrSocketAddress, peer_id::PeerId, services::Services, - socket_address::SocketAddress, + bannable_address::BannableAddress, ip_or_socket_address::IpOrSocketAddress, + peer_id::PeerId, services::Services, socket_address::SocketAddress, }, P2pEventHandler, }; @@ -53,19 +53,31 @@ use p2p_test_utils::P2pBasicTestTimeGetter; use utils::atomics::SeqCstAtomicBool; use crate::{ - crawler_p2p::crawler_manager::{ - storage_impl::DnsServerStorageImpl, CrawlerManager, CrawlerManagerConfig, + crawler_p2p::{ + crawler::CrawlerConfig, + crawler_manager::{ + storage::DnsServerStorage, storage_impl::DnsServerStorageImpl, CrawlerManager, + CrawlerManagerConfig, + }, }, dns_server::DnsServerCommand, }; pub struct TestNode { pub chain_config: Arc, + /// If true, connecting to the node will produce a specific ConnectionError with a non-zero + /// ban score. + pub is_erratic: bool, +} + +/// The error part of ConnectionError that "erratic" nodes produce. +pub fn erratic_node_connection_error() -> P2pError { + P2pError::ProtocolError(ProtocolError::HandshakeExpected) } #[derive(Clone)] pub struct MockStateRef { - pub crawler_config: CrawlerManagerConfig, + pub crawler_mgr_config: CrawlerManagerConfig, pub online: Arc>>, pub connected: Arc>>, pub connection_attempts: Arc>>, @@ -74,10 +86,19 @@ pub struct MockStateRef { impl MockStateRef { pub fn node_online(&self, ip: SocketAddress) { + self.node_online_impl(ip, false) + } + + pub fn erratic_node_online(&self, ip: SocketAddress) { + self.node_online_impl(ip, true) + } + + fn node_online_impl(&self, ip: SocketAddress, is_erratic: bool) { let old = self.online.lock().unwrap().insert( ip, TestNode { chain_config: Arc::new(common::chain::config::create_mainnet()), + is_erratic, }, ); assert!(old.is_none()); @@ -102,6 +123,11 @@ impl MockStateRef { }) .unwrap(); } + + pub fn report_misbehavior(&self, ip: SocketAddress, error: P2pError) { + let peer_id = *self.connected.lock().unwrap().get(&ip).unwrap(); + self.conn_tx.send(ConnectivityEvent::Misbehaved { peer_id, error }).unwrap(); + } } #[derive(Debug, PartialEq, Eq)] @@ -144,35 +170,48 @@ impl NetworkingService for MockNetworkingService { impl ConnectivityService for MockConnectivityHandle { fn connect(&mut self, address: SocketAddress, _services: Option) -> p2p::Result<()> { self.state.connection_attempts.lock().unwrap().push(address); - if let Some(node) = self.state.online.lock().unwrap().get(&address) { - let peer_id = PeerId::new(); - let peer_info = PeerInfo { - peer_id, - protocol_version: TEST_PROTOCOL_VERSION, - network: *node.chain_config.magic_bytes(), - software_version: SemVer::new(1, 2, 3), - user_agent: mintlayer_core_user_agent(), - common_services: NodeType::DnsServer.into(), - }; - let old = self.state.connected.lock().unwrap().insert(address, peer_id); - assert!(old.is_none()); - self.state - .conn_tx - .send(ConnectivityEvent::OutboundAccepted { - address, - peer_info, - receiver_address: None, - }) - .unwrap(); - } else { - self.state - .conn_tx - .send(ConnectivityEvent::ConnectionError { - address, - error: P2pError::DialError(DialError::ConnectionRefusedOrTimedOut), - }) - .unwrap(); + match self.state.online.lock().unwrap().get(&address) { + None => { + self.state + .conn_tx + .send(ConnectivityEvent::ConnectionError { + address, + error: P2pError::DialError(DialError::ConnectionRefusedOrTimedOut), + }) + .unwrap(); + } + Some(node) if node.is_erratic => { + self.state + .conn_tx + .send(ConnectivityEvent::ConnectionError { + address, + error: erratic_node_connection_error(), + }) + .unwrap(); + } + Some(node) => { + let peer_id = PeerId::new(); + let peer_info = PeerInfo { + peer_id, + protocol_version: TEST_PROTOCOL_VERSION, + network: *node.chain_config.magic_bytes(), + software_version: SemVer::new(1, 2, 3), + user_agent: mintlayer_core_user_agent(), + common_services: NodeType::DnsServer.into(), + }; + let old = self.state.connected.lock().unwrap().insert(address, peer_id); + assert!(old.is_none()); + self.state + .conn_tx + .send(ConnectivityEvent::OutboundAccepted { + address, + peer_info, + receiver_address: None, + }) + .unwrap(); + } } + Ok(()) } @@ -191,6 +230,10 @@ impl ConnectivityService for MockConnectivityHandle { .unwrap() .0; self.state.connected.lock().unwrap().remove(&address).unwrap(); + self.state + .conn_tx + .send(ConnectivityEvent::ConnectionClosed { peer_id }) + .unwrap(); Ok(()) } @@ -227,13 +270,17 @@ pub fn test_crawler( .into_iter() .map(|addr| IpOrSocketAddress::new_socket_address(addr.socket_addr())) .collect(); - let crawler_config = CrawlerManagerConfig { + let crawler_mgr_config = CrawlerManagerConfig { reserved_nodes, default_p2p_port: 3031, }; + let crawler_config = CrawlerConfig { + ban_duration: BanDuration::default(), + ban_threshold: BanThreshold::default(), + }; let state = MockStateRef { - crawler_config: crawler_config.clone(), + crawler_mgr_config: crawler_mgr_config.clone(), online: Default::default(), connected: Default::default(), connection_attempts: Default::default(), @@ -256,6 +303,7 @@ pub fn test_crawler( let crawler = CrawlerManager::::new( time_getter.get_time_getter(), + crawler_mgr_config, crawler_config, chain_config, conn, @@ -288,3 +336,29 @@ pub async fn advance_time( .await .expect_err("run should not return"); } + +pub fn assert_known_addresses(crawler: &CrawlerManager, expected: &[SocketAddress]) +where + N: NetworkingService, + S: DnsServerStorage, + N::SyncingEventReceiver: SyncingEventReceiver, + N::ConnectivityHandle: ConnectivityService, +{ + let loaded_storage = crawler.load_storage_for_tests().unwrap(); + let expected: BTreeSet<_> = expected.iter().copied().collect(); + assert_eq!(loaded_storage.known_addresses, expected); +} + +pub fn assert_banned_addresses( + crawler: &CrawlerManager, + expected: &[(BannableAddress, Time)], +) where + N: NetworkingService, + S: DnsServerStorage, + N::SyncingEventReceiver: SyncingEventReceiver, + N::ConnectivityHandle: ConnectivityService, +{ + let loaded_storage = crawler.load_storage_for_tests().unwrap(); + let expected: BTreeMap<_, _> = expected.iter().copied().collect(); + assert_eq!(loaded_storage.banned_addresses, expected); +} diff --git a/dns_server/src/crawler_p2p/crawler_manager/tests/mod.rs b/dns_server/src/crawler_p2p/crawler_manager/tests/mod.rs index 7548db4c8a..0e0bcc2841 100644 --- a/dns_server/src/crawler_p2p/crawler_manager/tests/mod.rs +++ b/dns_server/src/crawler_p2p/crawler_manager/tests/mod.rs @@ -17,12 +17,17 @@ mod mock_manager; use std::time::Duration; -use p2p::{peer_manager::peerdb_common::Transactional, types::socket_address::SocketAddress}; +use chainstate::ban_score::BanScore; +use p2p::{ + config::{BanDuration, BanThreshold}, + types::socket_address::SocketAddress, +}; +use p2p_test_utils::{expect_no_recv, expect_recv}; use crate::{ - crawler_p2p::crawler_manager::{ - storage::DnsServerStorageRead, - tests::mock_manager::{advance_time, test_crawler}, + crawler_p2p::crawler_manager::tests::mock_manager::{ + advance_time, assert_banned_addresses, assert_known_addresses, + erratic_node_connection_error, test_crawler, }, dns_server::DnsServerCommand, }; @@ -36,7 +41,7 @@ async fn basic() { state.node_online(node1); advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node1.socket_addr().ip()) ); @@ -44,7 +49,7 @@ async fn basic() { state.node_offline(node1); advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::DelAddress(node1.socket_addr().ip()) ); } @@ -67,7 +72,7 @@ async fn long_offline() { state.node_online(node1); advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 24 * 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node1.socket_addr().ip()) ); } @@ -85,29 +90,25 @@ async fn announced_online() { advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node1.socket_addr().ip()) ); state.announce_address(node1, node2); advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node2.socket_addr().ip()) ); state.announce_address(node2, node3); advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node3.socket_addr().ip()) ); - let addresses = crawler.storage.transaction_ro().unwrap().get_addresses().unwrap(); - assert_eq!( - addresses, - vec![node1.to_string(), node2.to_string(), node3.to_string()] - ); + assert_known_addresses(&crawler, &[node1, node2, node3]); } #[tokio::test] @@ -120,7 +121,7 @@ async fn announced_offline() { advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node1.socket_addr().ip()) ); assert_eq!(state.connection_attempts.lock().unwrap().len(), 1); @@ -135,7 +136,7 @@ async fn announced_offline() { state.announce_address(node1, node2); advance_time(&mut crawler, &time_getter, Duration::from_secs(60), 24 * 60).await; assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node2.socket_addr().ip()) ); assert_eq!(state.connection_attempts.lock().unwrap().len(), 3); @@ -163,26 +164,117 @@ async fn private_ip() { // Check that only nodes with public addresses and on the default port are added to DNS assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node1.socket_addr().ip()) ); assert_eq!( - command_rx.recv().await.unwrap(), + expect_recv!(command_rx), DnsServerCommand::AddAddress(node2.socket_addr().ip()) ); - assert!(command_rx.try_recv().is_err()); + expect_no_recv!(command_rx); // Check that all reachable nodes are stored in the DB - let mut addresses = crawler.storage.transaction_ro().unwrap().get_addresses().unwrap(); - let mut addresses_expected = vec![ - node1.to_string(), - node2.to_string(), - node3.to_string(), - node4.to_string(), - node5.to_string(), - node6.to_string(), - ]; - addresses.sort(); - addresses_expected.sort(); - assert_eq!(addresses, addresses_expected); + assert_known_addresses(&crawler, &[node1, node2, node3, node4, node5, node6]); +} + +#[tokio::test] +async fn ban_unban() { + let node1: SocketAddress = "1.2.3.4:3031".parse().unwrap(); + let node2: SocketAddress = "2.3.4.5:3031".parse().unwrap(); + let node3: SocketAddress = "3.4.5.6:3031".parse().unwrap(); + + let (mut crawler, state, mut command_rx, time_getter) = test_crawler(vec![node1, node2, node3]); + + // Sanity check + assert!(erratic_node_connection_error().ban_score() >= *BanThreshold::default()); + + let ban_duration = *BanDuration::default(); + + state.node_online(node1); + state.erratic_node_online(node2); + state.node_online(node3); + + let time_step = Duration::from_secs(60); + + advance_time(&mut crawler, &time_getter, time_step, 1).await; + + let node2_ban_end_time = (time_getter.get_time_getter().get_time() + ban_duration).unwrap(); + + // Only normal nodes are added to DNS + assert_eq!( + expect_recv!(command_rx), + DnsServerCommand::AddAddress(node1.socket_addr().ip()) + ); + assert_eq!( + expect_recv!(command_rx), + DnsServerCommand::AddAddress(node3.socket_addr().ip()) + ); + expect_no_recv!(command_rx); + + // node2 is banned + assert_banned_addresses(&crawler, &[(node2.as_bannable(), node2_ban_end_time)]); + + advance_time(&mut crawler, &time_getter, time_step, 1).await; + + // Report misbehavior for node1; the passed error has big enough ban score, so the node should + // be banned immediately. + state.report_misbehavior(node1, erratic_node_connection_error()); + + advance_time(&mut crawler, &time_getter, time_step, 1).await; + + let node1_ban_end_time = (time_getter.get_time_getter().get_time() + ban_duration).unwrap(); + + // Check that it's been removed from DNS. + assert_eq!( + expect_recv!(command_rx), + DnsServerCommand::DelAddress(node1.socket_addr().ip()) + ); + + // Both bad nodes are now banned. + assert_banned_addresses( + &crawler, + &[ + (node1.as_bannable(), node1_ban_end_time), + (node2.as_bannable(), node2_ban_end_time), + ], + ); + + // Node 2 comes online again and now it'll behave correctly. This shouldn't have any immediate effect though. + state.node_offline(node2); + state.node_online(node2); + + // Wait some more time, the nodes should still be banned. + advance_time(&mut crawler, &time_getter, time_step, 1).await; + assert_banned_addresses( + &crawler, + &[ + (node1.as_bannable(), node1_ban_end_time), + (node2.as_bannable(), node2_ban_end_time), + ], + ); + expect_no_recv!(command_rx); + + // Wait enough time for node2 to be unbanned. + let time_until_node2_unban = + (node2_ban_end_time - time_getter.get_time_getter().get_time()).unwrap(); + advance_time(&mut crawler, &time_getter, time_until_node2_unban, 1).await; + + // node2 is no longer banned; its address has been added to DNS. + assert_banned_addresses(&crawler, &[(node1.as_bannable(), node1_ban_end_time)]); + assert_eq!( + expect_recv!(command_rx), + DnsServerCommand::AddAddress(node2.socket_addr().ip()) + ); + + // Wait enough time for node1 to be unbanned. + let time_until_node1_unban = + (node1_ban_end_time - time_getter.get_time_getter().get_time()).unwrap(); + advance_time(&mut crawler, &time_getter, time_until_node1_unban, 1).await; + + // node1 is no longer banned; its address has been added to DNS. + assert_banned_addresses(&crawler, &[]); + assert_eq!( + expect_recv!(command_rx), + DnsServerCommand::AddAddress(node1.socket_addr().ip()) + ); } diff --git a/dns_server/src/main.rs b/dns_server/src/main.rs index d35db75821..a7626c0f35 100644 --- a/dns_server/src/main.rs +++ b/dns_server/src/main.rs @@ -31,6 +31,8 @@ use p2p::{ use utils::atomics::SeqCstAtomicBool; use utils::default_data_dir::{default_data_dir_for_chain, prepare_data_dir}; +use crate::crawler_p2p::crawler::CrawlerConfig; + mod config; mod crawler_p2p; mod dns_server; @@ -107,13 +109,19 @@ async fn run(config: Arc) -> Result::new( time_getter, + crawler_mgr_config, crawler_config, chain_config, conn, diff --git a/p2p/p2p-test-utils/src/lib.rs b/p2p/p2p-test-utils/src/lib.rs index 4a5e97fba6..89e5f5a43b 100644 --- a/p2p/p2p-test-utils/src/lib.rs +++ b/p2p/p2p-test-utils/src/lib.rs @@ -115,3 +115,46 @@ impl P2pBasicTestTimeGetter { self.current_time_millis.fetch_add(duration.as_millis() as u64); } } + +/// A timeout for blocking calls. +pub const LONG_TIMEOUT: Duration = Duration::from_secs(60); +/// A short timeout for events that shouldn't occur. +pub const SHORT_TIMEOUT: Duration = Duration::from_millis(500); + +/// Await for the specified future for some reasonably big amount of time; panic if the timeout +/// is reached. +// Note: this is implemented as a macro until #[track_caller] works correctly with async functions +// (needed to print the caller location if 'unwrap' fails). Same for the other macros below. +#[macro_export] +macro_rules! expect_future_val { + ($fut:expr) => { + tokio::time::timeout($crate::LONG_TIMEOUT, $fut) + .await + .expect("Failed to receive value in time") + }; +} + +/// Await for the specified future for a short time, expecting a timeout. +#[macro_export] +macro_rules! expect_no_future_val { + ($fut:expr) => { + tokio::time::timeout($crate::SHORT_TIMEOUT, $fut).await.unwrap_err(); + }; +} + +/// Try receiving a message from the tokio channel; panic if the channel is closed or the timeout +/// is reached. +#[macro_export] +macro_rules! expect_recv { + ($rx:expr) => { + $crate::expect_future_val!($rx.recv()).unwrap() + }; +} + +/// Try receiving a message from the tokio channel; expect that a timeout is reached. +#[macro_export] +macro_rules! expect_no_recv { + ($rx:expr) => { + $crate::expect_no_future_val!($rx.recv()) + }; +} diff --git a/p2p/src/config.rs b/p2p/src/config.rs index 430fba8afd..552186192e 100644 --- a/p2p/src/config.rs +++ b/p2p/src/config.rs @@ -121,10 +121,8 @@ pub struct P2pConfig { pub max_message_size: MaxMessageSize, /// A maximum number of announcements (hashes) for which we haven't receive transactions. pub max_peer_tx_announcements: MaxPeerTxAnnouncements, - /// A maximum number of singular unconnected headers that a peer can send before + /// A maximum number of singular unconnected headers that a V1 peer can send before /// it will be considered malicious. - // TODO: this is a legacy behavior that should be removed in the protocol v2. - // See the issue #1110. pub max_singular_unconnected_headers: MaxUnconnectedHeaders, /// A timeout after which a peer is disconnected. pub sync_stalling_timeout: SyncStallingTimeout, diff --git a/p2p/src/peer_manager/tests/mod.rs b/p2p/src/peer_manager/tests/mod.rs index 5d9e5ffece..175a41c00b 100644 --- a/p2p/src/peer_manager/tests/mod.rs +++ b/p2p/src/peer_manager/tests/mod.rs @@ -22,6 +22,7 @@ mod utils; use std::{sync::Arc, time::Duration}; +use p2p_test_utils::expect_recv; use tokio::sync::{mpsc, oneshot}; use ::utils::atomics::SeqCstAtomicBool; @@ -35,7 +36,6 @@ use tokio::{ }; use crate::{ - expect_recv, interface::types::ConnectedPeer, message::{PeerManagerMessage, PingRequest, PingResponse}, net::{ diff --git a/p2p/src/peer_manager/tests/ping.rs b/p2p/src/peer_manager/tests/ping.rs index f9fd3cb33c..990c950d68 100644 --- a/p2p/src/peer_manager/tests/ping.rs +++ b/p2p/src/peer_manager/tests/ping.rs @@ -16,12 +16,11 @@ use std::{sync::Arc, time::Duration}; use common::{chain::config, primitives::user_agent::mintlayer_core_user_agent}; -use p2p_test_utils::P2pBasicTestTimeGetter; +use p2p_test_utils::{expect_recv, P2pBasicTestTimeGetter}; use test_utils::{assert_matches, assert_matches_return_val}; use crate::{ config::{NodeType, P2pConfig}, - expect_recv, message::{PeerManagerMessage, PingRequest, PingResponse}, net::{ default_backend::{ diff --git a/p2p/src/sync/tests/helpers/mod.rs b/p2p/src/sync/tests/helpers/mod.rs index ee4c2ee3d3..74a32bf277 100644 --- a/p2p/src/sync/tests/helpers/mod.rs +++ b/p2p/src/sync/tests/helpers/mod.rs @@ -13,10 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, panic, sync::Arc, time::Duration}; +use std::{collections::BTreeMap, panic, sync::Arc}; use async_trait::async_trait; use crypto::random::Rng; +use p2p_test_utils::{expect_future_val, expect_no_recv, expect_recv, SHORT_TIMEOUT}; use p2p_types::socket_address::SocketAddress; use test_utils::random::Seed; use tokio::{ @@ -64,11 +65,6 @@ use crate::{ pub mod test_node_group; -/// A timeout for blocking calls. -const LONG_TIMEOUT: Duration = Duration::from_secs(60); -/// A short timeout for events that shouldn't occur. -const SHORT_TIMEOUT: Duration = Duration::from_millis(500); - /// A wrapper over other ends of the sync manager channels that simulates a test node. /// /// Provides methods for manipulating and observing the sync manager state. @@ -216,10 +212,7 @@ impl TestNode { /// Receives a message from the sync manager. pub async fn message(&mut self) -> (PeerId, SyncMessage) { - time::timeout(LONG_TIMEOUT, self.sync_msg_receiver.recv()) - .await - .expect("Failed to receive event in time") - .unwrap() + expect_recv!(&mut self.sync_msg_receiver) } /// Try to receive a message from the sync manager. @@ -238,7 +231,7 @@ impl TestNode { /// Panics if the sync manager returns an error. pub async fn assert_no_error(&mut self) { - time::timeout(SHORT_TIMEOUT, self.error_receiver.recv()).await.unwrap_err(); + expect_no_recv!(self.error_receiver); } /// Receives the `AdjustPeerScore` event from the peer manager. @@ -507,10 +500,7 @@ struct SyncingEventReceiverMock { #[async_trait] impl SyncingEventReceiver for SyncingEventReceiverMock { async fn poll_next(&mut self) -> Result { - time::timeout(LONG_TIMEOUT, self.events_receiver.recv()) - .await - .expect("Failed to receive event in time") - .ok_or(P2pError::ChannelClosed) + expect_future_val!(self.events_receiver.recv()).ok_or(P2pError::ChannelClosed) } } diff --git a/p2p/src/sync/tests/helpers/test_node_group.rs b/p2p/src/sync/tests/helpers/test_node_group.rs index 21fb9833a2..689595c489 100644 --- a/p2p/src/sync/tests/helpers/test_node_group.rs +++ b/p2p/src/sync/tests/helpers/test_node_group.rs @@ -17,12 +17,12 @@ use common::{chain::GenBlock, primitives::Id}; use crypto::random::Rng; use futures::{future::select_all, FutureExt}; use logging::log; +use p2p_test_utils::LONG_TIMEOUT; use p2p_types::PeerId; use tokio::time; use crate::{ message::{SyncMessage, TransactionResponse}, - sync::tests::helpers::LONG_TIMEOUT, PeerManagerEvent, }; diff --git a/p2p/src/testing_utils.rs b/p2p/src/testing_utils.rs index 6f353ba8fa..7b8518aa75 100644 --- a/p2p/src/testing_utils.rs +++ b/p2p/src/testing_utils.rs @@ -225,17 +225,6 @@ pub fn peerdb_inmemory_store() -> PeerDbStorageImpl PeerDbStorageImpl::new(storage).unwrap() } -/// Receive a message from the tokio channel. -/// Panics if the channel is closed or no message received in 10 seconds. -#[macro_export] -macro_rules! expect_recv { - // Implemented as a macro until #[track_caller] works correctly with async functions - // (needed to print the caller location if unwraps fail) - ($x:expr) => { - tokio::time::timeout(Duration::from_secs(10), $x.recv()).await.unwrap().unwrap() - }; -} - pub fn test_p2p_config() -> P2pConfig { P2pConfig { bind_addresses: Default::default(),