diff --git a/p2p/src/peer_manager/mod.rs b/p2p/src/peer_manager/mod.rs index 08128a4625..342c6b5241 100644 --- a/p2p/src/peer_manager/mod.rs +++ b/p2p/src/peer_manager/mod.rs @@ -345,7 +345,29 @@ where ); if peer.score >= *self.p2p_config.ban_threshold { - self.peerdb.ban(peer.address.as_bannable()); + let address = peer.address.as_bannable(); + self.ban(address); + } + } + + fn ban(&mut self, address: T::BannableAddress) { + let to_disconnect = self + .peers + .values() + .filter_map(|peer| { + if peer.address.as_bannable() == address { + Some(peer.info.peer_id) + } else { + None + } + }) + .collect::<Vec<_>>(); + + log::info!("Ban {:?}, disconnect peers: {:?}", address, to_disconnect); + + self.peerdb.ban(address); + + for peer_id in to_disconnect { self.disconnect(peer_id, None); } } @@ -1032,7 +1054,7 @@ where response.send(self.peerdb.list_banned().cloned().collect()) } PeerManagerEvent::Ban(address, response) => { - self.peerdb.ban(address); + self.ban(address); response.send(Ok(())); } PeerManagerEvent::Unban(address, response) => { diff --git a/p2p/src/peer_manager/tests/ban.rs b/p2p/src/peer_manager/tests/ban.rs index 8290c85353..8842338e23 100644 --- a/p2p/src/peer_manager/tests/ban.rs +++ b/p2p/src/peer_manager/tests/ban.rs @@ -16,20 +16,28 @@ use std::sync::Arc; use crate::{ - net::types::{services::Service, Role}, + config::NodeType, + net::{ + default_backend::{types::Command, ConnectivityHandle}, + types::{services::Service, PeerInfo, Role}, + }, + peer_manager::PeerManager, protocol::{NETWORK_PROTOCOL_CURRENT, NETWORK_PROTOCOL_MIN}, testing_utils::{ - connect_and_accept_services, connect_services, get_connectivity_event, RandomAddressMaker, - TestChannelAddressMaker, TestTcpAddressMaker, TestTransportChannel, TestTransportMaker, - TestTransportNoise, TestTransportTcp, + connect_and_accept_services, connect_services, get_connectivity_event, + peerdb_inmemory_store, test_p2p_config, RandomAddressMaker, TestChannelAddressMaker, + TestTcpAddressMaker, TestTransportChannel, TestTransportMaker, TestTransportNoise, + TestTransportTcp, }, types::peer_id::PeerId, utils::oneshot_nofail, + PeerManagerEvent, }; use common::{ chain::config, primitives::{semver::SemVer, user_agent::mintlayer_core_user_agent}, }; +use p2p_test_utils::P2pBasicTestTimeGetter; use crate::{ error::{P2pError, PeerError}, @@ -387,3 +395,67 @@ async fn inbound_connection_invalid_magic_noise() { >() .await; } + +// Test that manually banned peers are also disconnected +#[test] +fn ban_and_disconnect() { + type TestNetworkingService = DefaultNetworkingService<TcpTransportSocket>; + + let chain_config = Arc::new(config::create_mainnet()); + let p2p_config = Arc::new(test_p2p_config()); + let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel(); + let (_conn_tx, conn_rx) = tokio::sync::mpsc::unbounded_channel(); + let (_peer_tx, peer_rx) = + tokio::sync::mpsc::unbounded_channel::<PeerManagerEvent<TestNetworkingService>>(); + let time_getter = P2pBasicTestTimeGetter::new(); + let connectivity_handle = ConnectivityHandle::<TestNetworkingService, TcpTransportSocket>::new( + vec![], + cmd_tx, + conn_rx, + ); + + let mut pm = PeerManager::new( + Arc::clone(&chain_config), + Arc::clone(&p2p_config), + connectivity_handle, + peer_rx, + time_getter.get_time_getter(), + peerdb_inmemory_store(), + ) + .unwrap(); + + let peer_id_1 = PeerId::new(); + let address_1 = TestTcpAddressMaker::new(); + let peer_info = PeerInfo { + peer_id: peer_id_1, + protocol: NETWORK_PROTOCOL_CURRENT, + network: *chain_config.magic_bytes(), + version: *chain_config.version(), + user_agent: mintlayer_core_user_agent(), + services: NodeType::Full.into(), + }; + pm.accept_connection(address_1, Role::Inbound, peer_info, None); + assert_eq!(pm.peers.len(), 1); + + // Peer is accepted by the peer manager + match cmd_rx.try_recv() { + Ok(Command::Accept { peer_id }) if peer_id == peer_id_1 => {} + v => panic!("unexpected command: {v:?}"), + } + + let (ban_tx, mut ban_rx) = oneshot_nofail::channel(); + pm.handle_control_event(PeerManagerEvent::Ban(address_1.as_bannable(), ban_tx)); + ban_rx.try_recv().unwrap().unwrap(); + + // Peer is disconnected by the peer manager + match cmd_rx.try_recv() { + Ok(Command::Disconnect { peer_id }) if peer_id == peer_id_1 => {} + v => panic!("unexpected command: {v:?}"), + } + + // No more messages + match cmd_rx.try_recv() { + Err(_) => {} + v => panic!("unexpected command: {v:?}"), + } +} diff --git a/p2p/src/utils/oneshot_nofail.rs b/p2p/src/utils/oneshot_nofail.rs index 6f201f6733..a03f5b7575 100644 --- a/p2p/src/utils/oneshot_nofail.rs +++ b/p2p/src/utils/oneshot_nofail.rs @@ -25,7 +25,7 @@ use std::{ task::{Context, Poll}, }; -use tokio::sync::oneshot; +use tokio::sync::oneshot::{self, error::TryRecvError}; #[derive(Debug)] pub struct Sender<T>(oneshot::Sender<T>); @@ -47,6 +47,12 @@ impl<T> Future for Receiver<T> { } } +impl<T> Receiver<T> { + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.0.try_recv() + } +} + pub fn channel<T>() -> (Sender<T>, Receiver<T>) { let (sender, receiver) = oneshot::channel(); (Sender(sender), Receiver(receiver))