From 52bbf262652c04319f78ed594fcbc41924b44d99 Mon Sep 17 00:00:00 2001 From: Rob N Date: Sat, 13 Jul 2024 10:58:25 -1000 Subject: [PATCH] peer: clean up reader and peer --- src/peers/counter.rs | 1 + src/peers/outbound_messages.rs | 8 ++---- src/peers/peer.rs | 8 ++++-- src/peers/reader.rs | 52 ++++++++++++++-------------------- src/peers/traits.rs | 8 ++---- 5 files changed, 32 insertions(+), 45 deletions(-) diff --git a/src/peers/counter.rs b/src/peers/counter.rs index ce78c30..ae3e025 100644 --- a/src/peers/counter.rs +++ b/src/peers/counter.rs @@ -133,6 +133,7 @@ mod tests { use super::MessageTimer; #[tokio::test] + #[ignore = "time wasting"] async fn test_timer_works() { let mut timer = MessageTimer::new(); assert!(!timer.unresponsive()); diff --git a/src/peers/outbound_messages.rs b/src/peers/outbound_messages.rs index 4549425..b7bf084 100644 --- a/src/peers/outbound_messages.rs +++ b/src/peers/outbound_messages.rs @@ -64,16 +64,12 @@ impl MessageGenerator for V1OutboundMessage { serialize(&data) } - fn get_addr(&mut self) -> Vec { + fn addr(&mut self) -> Vec { let data = RawNetworkMessage::new(self.network.magic(), NetworkMessage::GetAddr); serialize(&data) } - fn get_headers( - &mut self, - locator_hashes: Vec, - stop_hash: Option, - ) -> Vec { + fn headers(&mut self, locator_hashes: Vec, stop_hash: Option) -> Vec { let msg = GetHeadersMessage::new(locator_hashes, stop_hash.unwrap_or(BlockHash::all_zeros())); let data = diff --git a/src/peers/peer.rs b/src/peers/peer.rs index 4fec27d..5689873 100644 --- a/src/peers/peer.rs +++ b/src/peers/peer.rs @@ -83,6 +83,7 @@ impl Peer { .write_all(&version_message) .await .map_err(|_| PeerError::BufferWrite)?; + self.message_timer.track(); let (reader, mut writer) = stream.into_split(); let (tx, mut rx) = mpsc::channel(32); let mut peer_reader = Reader::new(reader, tx, self.network); @@ -104,7 +105,7 @@ impl Peer { } select! { // The peer sent us a message - peer_message = tokio::time::timeout(Duration::from_secs(CONNECTION_TIMEOUT), rx.recv())=> { + peer_message = tokio::time::timeout(Duration::from_secs(CONNECTION_TIMEOUT), rx.recv()) => { if let Ok(peer_message) = peer_message { match peer_message { Some(message) => { @@ -157,6 +158,7 @@ impl Peer { match message { PeerMessage::Version(version) => { self.message_counter.got_version(); + self.message_timer.untrack(); self.main_thread_sender .send(PeerThreadMessage { nonce: self.nonce, @@ -288,14 +290,14 @@ impl Peer { MainThreadMessage::GetAddr => { self.message_counter.sent_addrs(); writer - .write_all(&message_generator.get_addr()) + .write_all(&message_generator.addr()) .await .map_err(|_| PeerError::BufferWrite)?; } MainThreadMessage::GetHeaders(config) => { self.message_counter.sent_header(); self.message_timer.track(); - let message = message_generator.get_headers(config.locators, config.stop_hash); + let message = message_generator.headers(config.locators, config.stop_hash); writer .write_all(&message) .await diff --git a/src/peers/reader.rs b/src/peers/reader.rs index 7ab9fca..3c5ed38 100644 --- a/src/peers/reader.rs +++ b/src/peers/reader.rs @@ -1,6 +1,3 @@ -use std::time::SystemTime; -use std::time::UNIX_EPOCH; - use bitcoin::consensus::deserialize; use bitcoin::consensus::deserialize_partial; use bitcoin::consensus::Decodable; @@ -24,14 +21,13 @@ use crate::node::messages::RejectPayload; const ONE_MONTH: u64 = 2_500_000; const ONE_MINUTE: u64 = 60; -// The peer must have sent at least 10 messages to trigger DOS -const MINIMUM_DOS_THRESHOLD: u64 = 10; -// We allow up to 5000 messages per second -const RATE_LIMIT: u64 = 5000; +const MAX_MESSAGE_BYTES: u32 = 1024 * 1024 * 32; +// From Bitcoin Core PR #29575 +const MAX_ADDR: usize = 1_000; +const MAX_INV: usize = 50_000; +const MAX_HEADERS: usize = 2_000; pub(crate) struct Reader { - num_messages: u64, - start_time: u64, stream: OwnedReadHalf, tx: Sender, network: Network, @@ -39,13 +35,7 @@ pub(crate) struct Reader { impl Reader { pub fn new(stream: OwnedReadHalf, tx: Sender, network: Network) -> Self { - let start_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time went backwards") - .as_secs(); Self { - num_messages: 0, - start_time, stream, tx, network, @@ -54,7 +44,7 @@ impl Reader { pub(crate) async fn read_from_remote(&mut self) -> Result<(), PeerReadError> { loop { - // v1 headers are 24 bytes + // V1 headers are 24 bytes let mut message_buf = vec![0_u8; 24]; let _ = self .stream @@ -69,21 +59,9 @@ impl Reader { return Err(PeerReadError::Deserialization); } // Message is too long - if header.length > (1024 * 1024 * 32) as u32 { + if header.length > MAX_MESSAGE_BYTES { return Err(PeerReadError::Deserialization); } - // DOS protection - self.num_messages += 1; - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time went backwards") - .as_secs(); - let duration = now - self.start_time; - if self.num_messages > MINIMUM_DOS_THRESHOLD - && self.num_messages.checked_div(duration).unwrap_or(0) > RATE_LIMIT - { - return Err(PeerReadError::TooManyMessages); - } let mut contents_buf = vec![0_u8; header.length as usize]; let _ = self.stream.read_exact(&mut contents_buf).await.unwrap(); message_buf.extend_from_slice(&contents_buf); @@ -111,6 +89,9 @@ fn parse_message(message: &NetworkMessage) -> Option { })), NetworkMessage::Verack => Some(PeerMessage::Verack), NetworkMessage::Addr(addresses) => { + if addresses.len() > MAX_ADDR { + return Some(PeerMessage::Disconnect); + } let addresses: Vec
= addresses .iter() .filter(|f| f.1.services.has(ServiceFlags::COMPACT_FILTERS)) @@ -120,6 +101,9 @@ fn parse_message(message: &NetworkMessage) -> Option { Some(PeerMessage::Addr(addresses)) } NetworkMessage::Inv(inventory) => { + if inventory.len() > MAX_INV { + return Some(PeerMessage::Disconnect); + } let mut hashes = Vec::new(); for i in inventory { match i { @@ -142,7 +126,12 @@ fn parse_message(message: &NetworkMessage) -> Option { NetworkMessage::MemPool => None, NetworkMessage::Tx(_) => None, NetworkMessage::Block(block) => Some(PeerMessage::Block(block.clone())), - NetworkMessage::Headers(headers) => Some(PeerMessage::Headers(headers.clone())), + NetworkMessage::Headers(headers) => { + if headers.len() > MAX_HEADERS { + return Some(PeerMessage::Disconnect); + } + Some(PeerMessage::Headers(headers.clone())) + } NetworkMessage::SendHeaders => None, NetworkMessage::GetAddr => None, NetworkMessage::Ping(nonce) => Some(PeerMessage::Ping(*nonce)), @@ -174,6 +163,9 @@ fn parse_message(message: &NetworkMessage) -> Option { NetworkMessage::FeeFilter(_) => None, NetworkMessage::WtxidRelay => None, NetworkMessage::AddrV2(addresses) => { + if addresses.len() > MAX_ADDR { + return Some(PeerMessage::Disconnect); + } let addresses: Vec
= addresses .iter() .filter(|f| f.services.has(ServiceFlags::COMPACT_FILTERS)) diff --git a/src/peers/traits.rs b/src/peers/traits.rs index 8fa379b..9145c8c 100644 --- a/src/peers/traits.rs +++ b/src/peers/traits.rs @@ -10,13 +10,9 @@ pub(crate) trait MessageGenerator { fn verack(&mut self) -> Vec; - fn get_addr(&mut self) -> Vec; + fn addr(&mut self) -> Vec; - fn get_headers( - &mut self, - locator_hashes: Vec, - stop_hash: Option, - ) -> Vec; + fn headers(&mut self, locator_hashes: Vec, stop_hash: Option) -> Vec; fn cf_headers(&mut self, message: GetCFHeaders) -> Vec;