Skip to content

Commit

Permalink
Merge pull request #52 from rustaceanrob/gen-rw
Browse files Browse the repository at this point in the history
peers: make reader/writer generic
  • Loading branch information
rustaceanrob authored Aug 4, 2024
2 parents abb1569 + 9403ffe commit a06d6ca
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 45 deletions.
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,17 @@ impl From<TrustedPeer> for (IpAddr, Option<u16>) {
(value.ip(), value.port())
}
}

/// How to connect to peers on the peer-to-peer network
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ConnectionType {
/// Version one peer-to-peer connections
ClearNet,
}

impl Default for ConnectionType {
fn default() -> Self {
ConnectionType::ClearNet
}
}
73 changes: 34 additions & 39 deletions src/peers/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{net::IpAddr, time::Duration};

use bitcoin::Network;
use tokio::{
io::AsyncWriteExt,
net::{tcp::OwnedWriteHalf, TcpStream},
io::{AsyncWrite, AsyncWriteExt},
net::TcpStream,
select,
sync::mpsc::{self, Receiver, Sender},
};
Expand Down Expand Up @@ -71,6 +71,7 @@ impl Peer {
)
.await
.map_err(|_| PeerError::TcpConnectionFailed)?;
// Replace with generalization
let mut stream: TcpStream;
if let Ok(tcp) = timeout {
stream = tcp;
Expand All @@ -86,11 +87,8 @@ impl Peer {
return Err(PeerError::TcpConnectionFailed);
}
let mut outbound_messages = V1OutboundMessage::new(self.network);
let version_message = outbound_messages.version_message(None);
stream
.write_all(&version_message)
.await
.map_err(|_| PeerError::BufferWrite)?;
let message = outbound_messages.version_message(None);
self.write_bytes(&mut stream, message).await?;
self.message_timer.track();
let (reader, mut writer) = stream.into_split();
let (tx, mut rx) = mpsc::channel(32);
Expand Down Expand Up @@ -156,14 +154,15 @@ impl Peer {
}
}

async fn handle_peer_message<M>(
async fn handle_peer_message<M, W>(
&mut self,
message: PeerMessage,
writer: &mut OwnedWriteHalf,
writer: &mut W,
message_generator: &mut M,
) -> Result<(), PeerError>
where
M: MessageGenerator,
W: AsyncWrite + Send + Sync + Unpin,
{
match message {
PeerMessage::Version(version) => {
Expand All @@ -175,10 +174,12 @@ impl Peer {
})
.await
.map_err(|_| PeerError::ThreadChannel)?;
// FIXME: Write this after confirming the peer has CBF.
writer
.write_all(&message_generator.verack())
.await
.map_err(|_| PeerError::BufferWrite)?;
writer.flush().await.map_err(|_| PeerError::BufferWrite)?;
Ok(())
}
PeerMessage::Addr(addrs) => {
Expand Down Expand Up @@ -254,10 +255,8 @@ impl Peer {
Ok(())
}
PeerMessage::Ping(nonce) => {
writer
.write_all(&message_generator.pong(nonce))
.await
.map_err(|_| PeerError::BufferWrite)?;
let message = message_generator.pong(nonce);
self.write_bytes(writer, message).await?;
Ok(())
}
PeerMessage::Pong(_) => Ok(()),
Expand Down Expand Up @@ -285,67 +284,63 @@ impl Peer {
}
}

async fn main_thread_request<M>(
async fn main_thread_request<M, W>(
&mut self,
request: MainThreadMessage,
writer: &mut OwnedWriteHalf,
writer: &mut W,
message_generator: &mut M,
) -> Result<(), PeerError>
where
M: MessageGenerator,
W: AsyncWrite + Send + Sync + Unpin,
{
match request {
MainThreadMessage::GetAddr => {
self.message_counter.sent_addrs();
writer
.write_all(&message_generator.addr())
.await
.map_err(|_| PeerError::BufferWrite)?;
let message = message_generator.addr();
self.write_bytes(writer, message).await?;
}
MainThreadMessage::GetHeaders(config) => {
self.message_timer.track();
let message = message_generator.headers(config.locators, config.stop_hash);
writer
.write_all(&message)
.await
.map_err(|_| PeerError::BufferWrite)?;
self.write_bytes(writer, message).await?;
}
MainThreadMessage::GetFilterHeaders(config) => {
self.message_counter.sent_filter_header();
self.message_timer.track();
let message = message_generator.cf_headers(config);
writer
.write_all(&message)
.await
.map_err(|_| PeerError::BufferWrite)?;
self.write_bytes(writer, message).await?;
}
MainThreadMessage::GetFilters(config) => {
self.message_counter.sent_filters();
let message = message_generator.filters(config);
writer
.write_all(&message)
.await
.map_err(|_| PeerError::BufferWrite)?;
self.write_bytes(writer, message).await?;
}
MainThreadMessage::GetBlock(message) => {
self.message_counter.sent_block();
self.message_timer.track();
let message = message_generator.block(message);
writer
.write_all(&message)
.await
.map_err(|_| PeerError::BufferWrite)?;
self.write_bytes(writer, message).await?;
}
MainThreadMessage::BroadcastTx(transaction) => {
self.message_counter.sent_tx();
let message = message_generator.transaction(transaction);
writer
.write_all(&message)
.await
.map_err(|_| PeerError::BufferWrite)?;
self.write_bytes(writer, message).await?;
}
MainThreadMessage::Disconnect => return Err(PeerError::DisconnectCommand),
}
Ok(())
}

async fn write_bytes<W>(&mut self, writer: &mut W, message: Vec<u8>) -> Result<(), PeerError>
where
W: AsyncWrite + Send + Sync + Unpin,
{
writer
.write_all(&message)
.await
.map_err(|_| PeerError::BufferWrite)?;
writer.flush().await.map_err(|_| PeerError::BufferWrite)?;
Ok(())
}
}
20 changes: 14 additions & 6 deletions src/peers/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use bitcoin::p2p::{
Address, Magic, ServiceFlags,
};
use bitcoin::{Network, Txid};
use tokio::io::AsyncReadExt;
use tokio::net::tcp::OwnedReadHalf;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio::sync::mpsc::Sender;

use crate::node::channel_messages::{PeerMessage, RemoteVersion};
Expand All @@ -23,14 +22,23 @@ const MAX_ADDR: usize = 1_000;
const MAX_INV: usize = 50_000;
const MAX_HEADERS: usize = 2_000;

pub(crate) struct Reader {
stream: OwnedReadHalf,
pub(crate) struct Reader<R>
where
R: AsyncRead + Send + Sync + Unpin,
{
stream: R,
tx: Sender<PeerMessage>,
network: Network,
}

impl Reader {
pub fn new(stream: OwnedReadHalf, tx: Sender<PeerMessage>, network: Network) -> Self {
impl<R> Reader<R>
where
R: AsyncRead + Send + Sync + Unpin,
{
pub fn new(stream: R, tx: Sender<PeerMessage>, network: Network) -> Self
where
R: AsyncRead + Send + Sync + Unpin,
{
Self {
stream,
tx,
Expand Down

0 comments on commit a06d6ca

Please sign in to comment.