diff --git a/crates/floresta-wire/src/p2p_wire/peer.rs b/crates/floresta-wire/src/p2p_wire/peer.rs index b2ca2c32..ed32b239 100644 --- a/crates/floresta-wire/src/p2p_wire/peer.rs +++ b/crates/floresta-wire/src/p2p_wire/peer.rs @@ -27,13 +27,19 @@ use bitcoin::{ BlockHash, BlockHeader, Network, Transaction, }; use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt}; -use log::{debug, error, warn}; +use log::{error, warn}; use std::{ fmt::Debug, sync::Arc, time::{Duration, Instant}, }; +/// If we send a ping, and our peer takes more than PING_TIMEOUT to +/// reply, disconnect. +const PING_TIMEOUT: u64 = 30; +/// If the last message we've got was more than XX, send out a ping +const SEND_PING_TIMEOUT: u64 = 60; + #[derive(Debug, PartialEq)] enum State { None, @@ -65,8 +71,9 @@ pub struct Peer { user_agent: String, messages: u64, start_time: Instant, + last_message: Instant, current_best_block: i32, - last_ping: Instant, + last_ping: Option, id: u32, node_tx: Sender, state: State, @@ -75,6 +82,7 @@ pub struct Peer { address_id: usize, feeler: bool, wants_addrv2: bool, + shutdown: bool, } #[derive(Debug, Error)] pub enum PeerError { @@ -92,6 +100,8 @@ pub enum PeerError { MagicBitsMismatch, #[error("Peer sent us too many message in a short period of time")] TooManyMessages, + #[error("Peer timed a ping out")] + Timeout, } impl Debug for Peer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -105,7 +115,7 @@ type Result = std::result::Result; impl Peer { pub async fn read_loop(mut self) -> Result<()> { let err = self.peer_loop_inner().await; - debug!("Peer connection loop closed: {err:?}"); + warn!("Peer connection loop closed: {err:?}"); self.send_to_node(PeerMessages::Disconnected(self.address_id)) .await; Ok(()) @@ -133,18 +143,40 @@ impl Peer { } } }; + if self.shutdown { + return Ok(()); + } + // If we send a ping and our peer doesn't respond in time, disconnect + if let Some(when) = self.last_ping { + if when.elapsed().as_secs() > PING_TIMEOUT { + return Err(PeerError::Timeout); + } + } + + // Send a ping to check if this peer is still good + let last_message = self.last_message.elapsed().as_secs(); + if last_message > SEND_PING_TIMEOUT { + let nonce = rand::random(); + self.last_ping = Some(Instant::now()); + self.write(NetworkMessage::Ping(nonce)).await?; + } + // divide the number of messages by the number of seconds we've been connected, // if it's more than 100 msg/sec, this peer is sending us too many messages, and we should // disconnect. - if self.messages > 0 - && self.messages / Instant::now().duration_since(self.start_time).as_secs() > 10 - { + let msg_sec = self + .messages + .checked_div(Instant::now().duration_since(self.start_time).as_secs()) + .unwrap_or(0); + + if msg_sec > 10 { error!( "Peer {} is sending us too many messages, disconnecting", self.id ); return Err(PeerError::TooManyMessages); } + if let State::SentVersion(when) = self.state { if Instant::now().duration_since(when) > Duration::from_secs(10) { return Err(PeerError::UnexpectedMessage); @@ -181,6 +213,7 @@ impl Peer { .await; } NodeRequest::Shutdown => { + self.shutdown = true; let _ = self.stream.shutdown(); } NodeRequest::GetAddresses => { @@ -201,6 +234,8 @@ impl Peer { Ok(()) } pub async fn handle_peer_message(&mut self, message: RawNetworkMessage) -> Result<()> { + self.last_message = Instant::now(); + match self.state { State::Connected => match message.payload { NetworkMessage::Inv(inv) => { @@ -262,10 +297,12 @@ impl Peer { self.wants_addrv2 = true; self.write(NetworkMessage::SendAddrV2).await?; } + NetworkMessage::Pong(_) => { + self.last_ping = None; + } NetworkMessage::Unknown { command, payload } => { warn!("Unknown message: {} {:?}", command, payload); } - // Explicitly ignore these messages, if something changes in the future // this would cause a compile error. NetworkMessage::Verack @@ -288,7 +325,6 @@ impl Peer { | NetworkMessage::GetCFilters(_) | NetworkMessage::MemPool | NetworkMessage::MerkleBlock(_) - | NetworkMessage::Pong(_) | NetworkMessage::SendCmpct(_) => {} }, State::None | State::SentVersion(_) => match message.payload { @@ -381,7 +417,8 @@ impl Peer { current_best_block: -1, id, mempool, - last_ping: Instant::now(), + last_ping: None, + last_message: Instant::now(), network, node_tx, services: ServiceFlags::NONE, @@ -394,6 +431,7 @@ impl Peer { node_requests, feeler, wants_addrv2: false, + shutdown: false, }; spawn(peer.read_loop()); } @@ -426,7 +464,8 @@ impl Peer { current_best_block: -1, id, mempool, - last_ping: Instant::now(), + last_ping: None, + last_message: Instant::now(), network, node_tx, services: ServiceFlags::NONE, @@ -439,11 +478,11 @@ impl Peer { node_requests, feeler, wants_addrv2: false, + shutdown: false, }; spawn(peer.read_loop()); } async fn handle_ping(&mut self, nonce: u64) -> Result<()> { - self.last_ping = Instant::now(); let pong = make_pong(nonce); self.write(pong).await }