diff --git a/network/src/service.rs b/network/src/service.rs index 3796572c13..e883c23a76 100644 --- a/network/src/service.rs +++ b/network/src/service.rs @@ -17,7 +17,7 @@ use network_api::messages::{ use network_api::peer_score::{BlockBroadcastEntry, HandleState, LinearScore, Score}; use network_api::{NetworkActor, PeerMessageHandler}; use network_p2p::{Event, NetworkWorker}; -use rand::RngCore; +use rand::prelude::SliceRandom; use starcoin_config::NodeConfig; use starcoin_crypto::HashValue; use starcoin_network_rpc::NetworkRpcService; @@ -572,15 +572,42 @@ fn select_random_peers( ) -> Vec { let (min_peers, max_peers) = peer_num_range.into_inner(); let peers_len = peers.len(); - // sqrt(x)/x scaled to max u32 - let fraction = ((peers_len as f64).powf(-0.5) * (u32::max_value() as f64).round()) as u32; - let small = peers_len < (min_peers as usize); + // take sqrt(x) peers + let mut count = (peers_len as f64).powf(0.5).round() as u32; + count = count.min(max_peers).max(min_peers); let mut random = rand::thread_rng(); - peers - .keys() - .cloned() - .filter(|_| small || random.next_u32() < fraction) - .take(max_peers as usize) - .collect() + let mut peer_ids: Vec<_> = peers.keys().cloned().collect(); + peer_ids.shuffle(&mut random); + peer_ids.truncate(count as usize); + peer_ids +} + +#[cfg(test)] +mod test { + use crate::service::{select_random_peers, Peer}; + use network_api::{PeerId, PeerInfo}; + use starcoin_types::startup_info::ChainInfo; + use std::collections::HashMap; + + fn create_peers(n: u32) -> HashMap { + (0..n) + .map(|_| { + let peer_id = PeerId::random(); + let peer = Peer::new(PeerInfo::new(peer_id.clone(), ChainInfo::random())); + (peer_id, peer) + }) + .collect() + } + + #[test] + fn test_select_peer() { + assert_eq!(select_random_peers(1..=3, &create_peers(2)).len(), 1); + assert_eq!(select_random_peers(2..=5, &create_peers(9)).len(), 3); + assert_eq!(select_random_peers(8..=128, &create_peers(3)).len(), 3); + assert_eq!(select_random_peers(8..=128, &create_peers(4)).len(), 4); + assert_eq!(select_random_peers(8..=128, &create_peers(10)).len(), 8); + assert_eq!(select_random_peers(8..=128, &create_peers(25)).len(), 8); + assert_eq!(select_random_peers(8..=128, &create_peers(64)).len(), 8); + } }