From dc70505d7ef9dc609d07afd946aa528e6e14454f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 6 Mar 2025 01:07:54 -0500 Subject: [PATCH 001/135] Add initial setup for noise wrapper --- src/crypto/handshake.rs | 5 +++-- src/lib.rs | 1 + src/protocol.rs | 7 ++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 64db407..ee2d941 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -69,6 +69,7 @@ impl HandshakeResult { } } +#[derive(Debug)] pub(crate) struct Handshake { result: HandshakeResult, state: HandshakeState, @@ -170,11 +171,11 @@ impl Handshake { Ok(tx_buf) } - pub(crate) fn into_result(self) -> Result { + pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { if !self.complete() { Err(Error::new(ErrorKind::Other, "Handshake is not complete")) } else { - Ok(self.result) + Ok(&self.result) } } } diff --git a/src/lib.rs b/src/lib.rs index 531a068..0e5f037 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,6 +123,7 @@ mod constants; mod crypto; mod duplex; mod message; +mod noise; mod protocol; mod reader; mod util; diff --git a/src/protocol.rs b/src/protocol.rs index 7b8d468..3e8a2b5 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -10,6 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; +use tracing::trace; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; @@ -286,8 +287,8 @@ where } fn init(&mut self) -> Result<()> { - tracing::debug!( - "protocol init, state {:?}, options {:?}", + trace!( + "protocol Init, state {:?}, options {:?}", self.state, self.options ); @@ -479,7 +480,7 @@ where self.state = State::Established; } // Store handshake result - self.handshake = Some(handshake_result); + self.handshake = Some(handshake_result.clone()); } Ok(()) } From 640efc9944c93dc5f7a6e0930c066ee20a6d7e74 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:01:28 -0500 Subject: [PATCH 002/135] use futures --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index d77679f..a5ac273 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ futures-lite = "1" sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" +futures = "0.3.13" [dependencies.hypercore] version = "0.14.0" From 11afd0be5df208a061426dbb7dab5ce048952ff9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:01:47 -0500 Subject: [PATCH 003/135] fix logger --- examples/replication.rs | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/examples/replication.rs b/examples/replication.rs index bf65b72..459df9f 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -3,23 +3,23 @@ use async_std::net::{TcpListener, TcpStream}; use async_std::prelude::*; use async_std::sync::{Arc, Mutex}; use async_std::task; -use env_logger::Env; use futures_lite::stream::StreamExt; use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use log::*; use std::collections::HashMap; use std::convert::TryInto; use std::env; use std::fmt::Debug; +use std::sync::OnceLock; +use tracing::{error, info}; use hypercore_protocol::schema::*; use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; fn main() { - init_logger(); + log(); if env::args().count() < 3 { usage(); } @@ -93,8 +93,8 @@ async fn onconnection( let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); info!("protocol created, polling for next()"); while let Some(event) = protocol.next().await { - let event = event?; info!("protocol event {:?}", event); + let event = event?; match event { Event::Handshake(_) => { if is_initiator { @@ -414,9 +414,21 @@ async fn onmessage( Ok(()) } -/// Init EnvLogger, logging info, warn and error messages to stdout. -pub fn init_logger() { - env_logger::from_env(Env::default().default_filter_or("info")).init(); +#[allow(unused)] +pub fn log() { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + static START_LOGS: OnceLock<()> = OnceLock::new(); + START_LOGS.get_or_init(|| { + tracing_subscriber::fmt() + .with_target(true) + .with_line_number(true) + // print when instrumented funtion enters + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) + .with_file(true) + .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable + .without_time() + .init(); + }); } /// Log a result if it's an error. From 5f7d10b9dc8908d952bb44081909cc84ae7912a7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:02:09 -0500 Subject: [PATCH 004/135] Add test_utils --- src/lib.rs | 2 ++ src/test_utils.rs | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 src/test_utils.rs diff --git a/src/lib.rs b/src/lib.rs index 0e5f037..f12bcdb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,6 +126,8 @@ mod message; mod noise; mod protocol; mod reader; +#[cfg(test)] +mod test_utils; mod util; mod writer; diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 0000000..9eb986c --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,91 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_channel::{unbounded, Receiver, SendError, Sender}; +use futures::{Sink, SinkExt, Stream, StreamExt}; + +#[derive(Debug)] +pub struct Io { + receiver: Receiver>, + sender: Sender>, +} + +impl Default for Io { + fn default() -> Self { + let (sender, receiver) = unbounded(); + Self { sender, receiver } + } +} + +impl Stream for Io { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.receiver).poll_next(cx) + } +} + +impl Sink> for Io { + type Error = SendError>; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + let _ = self.sender.try_send(item); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +#[derive(Default, Debug)] +pub struct TwoWay { + l_to_r: Io, + r_to_l: Io, +} + +impl TwoWay { + fn split_sides(self) -> (Io, Io) { + let left = Io { + sender: self.l_to_r.sender, + receiver: self.r_to_l.receiver, + }; + let right = Io { + sender: self.r_to_l.sender, + receiver: self.l_to_r.receiver, + }; + (left, right) + } +} + +pub fn create_connected() -> (Io, Io) { + TwoWay::default().split_sides() +} +#[tokio::test] +async fn way_one() { + let mut a = Io::default(); + let _ = a.send(b"hello".into()).await; + let Some(res) = a.next().await else { panic!() }; + assert_eq!(res, b"hello"); +} + +#[tokio::test] +async fn split() { + let (mut left, mut right) = (TwoWay::default()).split_sides(); + + left.send(b"hello".to_vec()).await; + let Some(res) = right.next().await else { + panic!(); + }; + assert_eq!(res, b"hello"); +} From 4ee231c741f64bc67f5b7837b84d799f620d383e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:02:34 -0500 Subject: [PATCH 005/135] Add standalone noise wrapper --- src/lib.rs | 1 + src/noise.rs | 186 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 src/noise.rs diff --git a/src/lib.rs b/src/lib.rs index f12bcdb..646a93e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -136,6 +136,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; +pub use noise::Encrypted; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/noise.rs b/src/noise.rs new file mode 100644 index 0000000..c0efe4f --- /dev/null +++ b/src/noise.rs @@ -0,0 +1,186 @@ +use futures::{Sink, Stream}; +use std::{collections::VecDeque, io::Result, mem::replace, pin::Pin, task::Poll}; +use tracing::{error, trace, warn}; + +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; + +#[derive(Debug)] +pub(crate) enum Step { + NotInitialized, + Handshake(Box), + SecretStream((EncryptCipher, HandshakeResult)), + Established((EncryptCipher, DecryptCipher, HandshakeResult)), +} + +/// Wrap a stream with encryption +#[derive(Debug)] +pub struct Encrypted { + io: IO, + step: Step, + is_initiator: bool, + encrypted_tx: VecDeque>, + encrypted_rx: VecDeque>, + plain_tx: VecDeque>, + plain_rx: VecDeque>, +} + +impl Encrypted +where + IO: Stream + Sink> + Send + Unpin + 'static, +{ + /// Create [`Self`] from a Stream/Sink + pub fn new(is_initiator: bool, io: IO) -> Self { + Self { + io, + is_initiator, + step: Step::NotInitialized, + encrypted_tx: Default::default(), + encrypted_rx: Default::default(), + plain_tx: Default::default(), + plain_rx: Default::default(), + } + } +} + +impl> + Send + Unpin + 'static> Stream for Encrypted { + type Item = Vec; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let Encrypted { + io, + step, + is_initiator, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + .. + } = self.get_mut(); + + if let Step::Established((encryptor, decryptor, ..)) = step { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + } else { + break; + } + } + + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => plain_rx.push_back(plain_msg), + Err(e) => error!("RX message failed to decrypt: {e:?}"), + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(mut plain_out) = plain_tx.pop_front() { + // it encrypts in-place?? + if let Err(_e) = encryptor.encrypt(&mut plain_out) { + todo!("We failed to encrypt our own message...?"); + } + encrypted_tx.push_back(plain_out); + } + + // emit any messages that are ready + if let Some(msg) = plain_rx.pop_front() { + Poll::Ready(Some(msg)) + } else { + Poll::Pending + } + } else { + // Still setting up + if let Ok(Some(msg)) = init(step, *is_initiator) { + // queue the init message to send first + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { + for msg in msgs { + encrypted_tx.push_back(msg); + } + } + } + Poll::Pending + } + } +} + +fn init(step: &mut Step, is_initiator: bool) -> Result>> { + if !matches!(step, Step::NotInitialized) { + return Ok(None); + } + trace!( + "protocol Init, state {:?}, is_initiator {:?}", + step, + is_initiator + ); + let mut handshake = Handshake::new(is_initiator)?; + let out = handshake.start()?; + // next up is handshaking + *step = Step::Handshake(Box::new(handshake)); + Ok(out) +} + +fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { + match &step { + Step::NotInitialized => { + warn!("Encrypted state was reset"); + let mut handshake = Handshake::new(is_initiator)?; + let start_msg = handshake.start()?; + *step = Step::Handshake(Box::new(handshake)); + + Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) + } + Step::Handshake(_) => { + let mut out = vec![]; + if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { + if let Some(response) = handshake.read(msg)? { + out.push(response); + } + + if handshake.complete() { + let handshake_result = handshake.into_result()?; + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; + out.push(init_msg); + *step = Step::SecretStream((cipher, handshake_result.clone())); + } else { + *step = Step::Handshake(handshake); + } + } + Ok(out) + } + Step::SecretStream(_) => { + if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) + { + let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + *step = Step::Established((enc_cipher, dec_cipher, hs_result)); + } + Ok(vec![]) + } + Step::Established((..)) => todo!(), + } +} + +#[cfg(test)] +mod tset { + use crate::test_utils::create_connected; + + use super::*; + use futures::{SinkExt, StreamExt}; + + #[tokio::test] + async fn test_encrypted() -> Result<()> { + let (left, right) = create_connected(); + let left = Encrypted::new(true, left); + let right = Encrypted::new(true, right); + //left.send(b"hello").await?; + todo!() + } +} From ca061fe7fab492b41b6641eb7e419c5c639a76ce Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:04:43 -0500 Subject: [PATCH 006/135] lint test_utils --- src/test_utils.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 9eb986c..273935f 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -7,7 +7,7 @@ use async_channel::{unbounded, Receiver, SendError, Sender}; use futures::{Sink, SinkExt, Stream, StreamExt}; #[derive(Debug)] -pub struct Io { +pub(crate) struct Io { receiver: Receiver>, sender: Sender>, } @@ -49,7 +49,7 @@ impl Sink> for Io { } #[derive(Default, Debug)] -pub struct TwoWay { +pub(crate) struct TwoWay { l_to_r: Io, r_to_l: Io, } @@ -68,7 +68,7 @@ impl TwoWay { } } -pub fn create_connected() -> (Io, Io) { +pub(crate) fn create_connected() -> (Io, Io) { TwoWay::default().split_sides() } #[tokio::test] @@ -83,7 +83,7 @@ async fn way_one() { async fn split() { let (mut left, mut right) = (TwoWay::default()).split_sides(); - left.send(b"hello".to_vec()).await; + left.send(b"hello".to_vec()).await.unwrap(); let Some(res) = right.next().await else { panic!(); }; From 3e82e0f5f9e787a2b5b780bb818af6fb5a16945a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 9 Mar 2025 00:21:12 -0500 Subject: [PATCH 007/135] wip noise --- src/noise.rs | 343 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 312 insertions(+), 31 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index c0efe4f..a28c2c8 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -1,6 +1,13 @@ use futures::{Sink, Stream}; -use std::{collections::VecDeque, io::Result, mem::replace, pin::Pin, task::Poll}; -use tracing::{error, trace, warn}; +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + mem::replace, + pin::Pin, + task::{Context, Poll, Waker}, +}; +use tracing::{error, info, instrument, trace, warn}; use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; @@ -22,14 +29,17 @@ pub struct Encrypted { encrypted_rx: VecDeque>, plain_tx: VecDeque>, plain_rx: VecDeque>, + flush: bool, + count: usize, + name: String, } impl Encrypted where - IO: Stream + Sink> + Send + Unpin + 'static, + IO: Stream> + Sink> + Send + Unpin + 'static, { /// Create [`Self`] from a Stream/Sink - pub fn new(is_initiator: bool, io: IO) -> Self { + pub fn new(is_initiator: bool, io: IO, name: &str) -> Self { Self { io, is_initiator, @@ -38,17 +48,35 @@ where encrypted_rx: Default::default(), plain_tx: Default::default(), plain_rx: Default::default(), + flush: false, + count: 0, + name: name.to_string(), } } } -impl> + Send + Unpin + 'static> Stream for Encrypted { - type Item = Vec; +impl> + Sink> + Send + Unpin + Debug + 'static> Sink> + for Encrypted +{ + type Error = (); + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { + trace!("{} add plain tx", self.name); + self.plain_tx.push_back(item); + Ok(()) + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let Encrypted { io, step, @@ -57,24 +85,233 @@ impl> + Send + Unpin + 'static> Stream for Encrypted 200 { + //panic!(); + } + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!( + "{name} enc tx send msg + {encrypted_out:?} +" + ); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!("{name} flushed good"); + } + Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; + } + } + } + + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => todo!(), + Poll::Ready(Some(encrypted_msg)) => { + trace!( + "{name} enc rx queue + {encrypted_msg:?} + ); +" + ); + encrypted_rx.push_back(encrypted_msg); + } + } + } + if let Step::Established((encryptor, decryptor, ..)) = step { - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - } else { - break; + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!("{name} plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(mut plain_out) = plain_tx.pop_front() { + // it encrypts in-place?? + if let Err(_e) = encryptor.encrypt(&mut plain_out) { + todo!("{name} We failed to encrypt our own message...?"); + } + trace!("{name} enc tx queue"); + encrypted_tx.push_back(plain_out); + } + + if *flush { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } else { + trace!("{name} doing setup"); + // Still setting up + if let Ok(Some(msg)) = init(step, *is_initiator) { + // queue the init message to send first + trace!("{name} queue initial msg"); + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + trace!("{name} recieved setup msg"); + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { + for msg in msgs { + trace!("{name} queue more setup msg"); + encrypted_tx.push_front(msg); + } } } + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + todo!() + } +} +impl> + Sink> + Send + Unpin + Debug + 'static> Stream + for Encrypted +{ + type Item = Vec; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Encrypted { + io, + step, + is_initiator, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + flush, + count, + name, + .. + } = self.get_mut(); + + *count += 1; + if *count > 200 { + //panic!(); + } + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!( + "{name} enc tx send msg + {encrypted_out:?} +" + ); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!("{name} flushed good"); + } + Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; + } + } + } + + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => todo!(), + Poll::Ready(Some(encrypted_msg)) => { + trace!( + "{name} enc rx queue + {encrypted_msg:?} + ); +" + ); + encrypted_rx.push_back(encrypted_msg); + } + } + } + + if let Step::Established((encryptor, decryptor, ..)) = step { // decrypt any incromming encrypted messages while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => plain_rx.push_back(plain_msg), - Err(e) => error!("RX message failed to decrypt: {e:?}"), + Ok((plain_msg, _tag)) => { + trace!("{name} plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), } } @@ -82,35 +319,49 @@ impl> + Send + Unpin + 'static> Stream for Encrypted> + Sink> + Send + Unpin + 'static>( + encrypted: &mut Encrypted, + cx: &mut Context<'_>, +) { + todo!() +} + fn init(step: &mut Step, is_initiator: bool) -> Result>> { if !matches!(step, Step::NotInitialized) { return Ok(None); @@ -122,15 +373,19 @@ fn init(step: &mut Step, is_initiator: bool) -> Result>> { ); let mut handshake = Handshake::new(is_initiator)?; let out = handshake.start()?; - // next up is handshaking *step = Step::Handshake(Box::new(handshake)); Ok(out) } -fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { +fn handle_setup_message( + step: &mut Step, + msg: &[u8], + is_initiator: bool, + name: &str, +) -> Result>> { match &step { Step::NotInitialized => { - warn!("Encrypted state was reset"); + warn!("{name} Encrypted state was reset"); let mut handshake = Handshake::new(is_initiator)?; let start_msg = handshake.start()?; *step = Step::Handshake(Box::new(handshake)); @@ -140,14 +395,33 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - if let Some(response) = handshake.read(msg)? { + if let Some(response) = match handshake.read(msg) { + Ok(x) => x, + Err(e) => { + error!("error in handshake.read(msg) {e:?}"); + return Err(e); + } + } { out.push(response); } if handshake.complete() { - let handshake_result = handshake.into_result()?; + let handshake_result = match handshake.into_result() { + Ok(x) => x, + Err(e) => { + error!("into-result error {e:?}"); + return Err(e); + } + }; // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; + let (cipher, init_msg) = + match EncryptCipher::from_handshake_tx(handshake_result) { + Ok(x) => x, + Err(e) => { + error!("from_handshake_tx error {e:?}"); + return Err(e); + } + }; out.push(init_msg); *step = Step::SecretStream((cipher, handshake_result.clone())); } else { @@ -170,17 +444,24 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu #[cfg(test)] mod tset { - use crate::test_utils::create_connected; + use crate::test_utils::{create_connected, log, Io}; use super::*; use futures::{SinkExt, StreamExt}; #[tokio::test] async fn test_encrypted() -> Result<()> { + log(); let (left, right) = create_connected(); - let left = Encrypted::new(true, left); - let right = Encrypted::new(true, right); - //left.send(b"hello").await?; + let mut left = Encrypted::new(true, left, "left"); + let mut right = Encrypted::new(true, right, "right"); + tokio::task::spawn(async move { + left.send(b"hello".into()).await.unwrap(); + }); + //tokio::task::spawn(async move { + // let x = left.next().await; + //}); + dbg!(right.next().await); todo!() } } From 241622f36c65e7043935e989ad41b7c7dc67c247 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 9 Mar 2025 23:10:52 -0400 Subject: [PATCH 008/135] wip poll impl --- src/noise.rs | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/src/noise.rs b/src/noise.rs index a28c2c8..e977315 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -359,7 +359,139 @@ fn poll> + Sink> + Send + Unpin + 'static>( encrypted: &mut Encrypted, cx: &mut Context<'_>, ) { - todo!() + let Encrypted { + io, + step, + is_initiator, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + flush, + count, + name, + .. + } = encrypted; + + *count += 1; + if *count > 200 { + //panic!(); + } + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!( + "{name} enc tx send msg + {encrypted_out:?} +" + ); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!("{name} flushed good"); + } + Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; + } + } + } + + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => todo!(), + Poll::Ready(Some(encrypted_msg)) => { + trace!( + "{name} enc rx queue + {encrypted_msg:?} + ); +" + ); + encrypted_rx.push_back(encrypted_msg); + } + } + } + + if let Step::Established((encryptor, decryptor, ..)) = step { + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!("{name} plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(mut plain_out) = plain_tx.pop_front() { + // it encrypts in-place?? + if let Err(_e) = encryptor.encrypt(&mut plain_out) { + todo!("{name} We failed to encrypt our own message...?"); + } + trace!("{name} enc tx queue"); + encrypted_tx.push_back(plain_out); + } + + // emit any messages that are ready + if let Some(msg) = plain_rx.pop_front() { + trace!("{name} plain rx emit"); + //Poll::Ready(Some(msg)) + todo!() + } else { + //Poll::Pending + todo!() + } + } else { + trace!("{name} doing setup"); + // Still setting up + if let Ok(Some(msg)) = init(step, *is_initiator) { + // queue the init message to send first + trace!("{name} queue initial msg"); + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + trace!("{name} recieved setup msg"); + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { + for msg in msgs { + trace!("{name} queue more setup msg"); + encrypted_tx.push_front(msg); + } + } + } + cx.waker().wake_by_ref(); + todo!() + } } fn init(step: &mut Step, is_initiator: bool) -> Result>> { From 2591d2889eab4dfb358bde75a7ef61337d271146 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 10 Mar 2025 16:51:07 -0400 Subject: [PATCH 009/135] Add start_raw and read_raw --- src/crypto/handshake.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index ee2d941..21c5442 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -101,15 +101,17 @@ impl Handshake { }) } - pub(crate) fn start(&mut self) -> Result>> { + pub(crate) fn start_raw(&mut self) -> Result>> { if self.is_initiator() { let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); - Ok(Some(wrapped)) + Ok(Some(self.tx_buf[..tx_len].to_vec())) } else { Ok(None) } } + pub(crate) fn start(&mut self) -> Result>> { + Ok(self.start_raw()?.map(|x| wrap_uint24_le(&x))) + } pub(crate) fn complete(&self) -> bool { self.complete @@ -124,13 +126,13 @@ impl Handshake { .read_message(msg, &mut self.rx_buf) .map_err(map_err) } - fn send(&mut self) -> Result { + pub(crate) fn send(&mut self) -> Result { self.state .write_message(&self.payload, &mut self.tx_buf) .map_err(map_err) } - pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { + pub(crate) fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); @@ -138,16 +140,17 @@ impl Handshake { let _rx_len = self.recv(msg)?; + // first non-init if !self.is_initiator() && !self.did_receive { self.did_receive = true; let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + let wrapped = self.tx_buf[..tx_len].to_vec(); return Ok(Some(wrapped)); } let tx_buf = if self.is_initiator() { let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + let wrapped = self.tx_buf[..tx_len].to_vec(); Some(wrapped) } else { None @@ -170,6 +173,10 @@ impl Handshake { self.complete = true; Ok(tx_buf) } + // reads in `msg` without framing bytes, but emits msg WITH framing bytes + pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { + Ok(self.read_raw(msg)?.map(|x| wrap_uint24_le(&x))) + } pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { if !self.complete() { From 400d57d354e02e0bd0b3efdc3fe5f958ff1bd73a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 01:30:52 -0400 Subject: [PATCH 010/135] lint --- src/noise.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index e977315..86de5b3 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -5,9 +5,9 @@ use std::{ io::Result, mem::replace, pin::Pin, - task::{Context, Poll, Waker}, + task::{Context, Poll}, }; -use tracing::{error, info, instrument, trace, warn}; +use tracing::{error, info, trace, warn}; use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; @@ -210,7 +210,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static fn poll_close( self: Pin<&mut Self>, - cx: &mut Context<'_>, + _cx: &mut Context<'_>, ) -> Poll> { todo!() } @@ -355,7 +355,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } -fn poll> + Sink> + Send + Unpin + 'static>( +fn _poll> + Sink> + Send + Unpin + 'static>( encrypted: &mut Encrypted, cx: &mut Context<'_>, ) { @@ -464,7 +464,7 @@ fn poll> + Sink> + Send + Unpin + 'static>( } // emit any messages that are ready - if let Some(msg) = plain_rx.pop_front() { + if let Some(_msg) = plain_rx.pop_front() { trace!("{name} plain rx emit"); //Poll::Ready(Some(msg)) todo!() From 7c9e2f53ba1aa026bacf00e9dc2584b363b65820 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 01:31:37 -0400 Subject: [PATCH 011/135] use read_raw & start_raw --- src/noise.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 86de5b3..467008e 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -504,7 +504,7 @@ fn init(step: &mut Step, is_initiator: bool) -> Result>> { is_initiator ); let mut handshake = Handshake::new(is_initiator)?; - let out = handshake.start()?; + let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(out) } @@ -519,7 +519,7 @@ fn handle_setup_message( Step::NotInitialized => { warn!("{name} Encrypted state was reset"); let mut handshake = Handshake::new(is_initiator)?; - let start_msg = handshake.start()?; + let start_msg = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) @@ -527,10 +527,10 @@ fn handle_setup_message( Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - if let Some(response) = match handshake.read(msg) { + if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { - error!("error in handshake.read(msg) {e:?}"); + error!("error in handshake.read_raw(msg) {e:?}"); return Err(e); } } { From 20302eb4a01a8c238e4db89345f4768c749379fb Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 17:23:52 -0400 Subject: [PATCH 012/135] Encrypted stream is now working! --- src/noise.rs | 354 +++++++++++++++++++-------------------------------- 1 file changed, 128 insertions(+), 226 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 467008e..c22c1fd 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -7,16 +7,26 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{error, info, trace, warn}; +use tracing::{error, info, instrument, trace, warn}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; + +macro_rules! name { + ($name:tt) => {{ + if $name { + "initiator" + } else { + "other" + } + }}; +} #[derive(Debug)] pub(crate) enum Step { NotInitialized, Handshake(Box), - SecretStream((EncryptCipher, HandshakeResult)), - Established((EncryptCipher, DecryptCipher, HandshakeResult)), + SecretStream((RawEncryptCipher, HandshakeResult)), + Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } /// Wrap a stream with encryption @@ -31,15 +41,22 @@ pub struct Encrypted { plain_rx: VecDeque>, flush: bool, count: usize, - name: String, } +fn ename(is_initiator: bool) -> String { + if is_initiator { + "initiator".to_string() + } else { + "other".to_string() + } +} impl Encrypted where - IO: Stream> + Sink> + Send + Unpin + 'static, + IO: Stream> + Sink> + Send + Unpin + Debug + 'static, { /// Create [`Self`] from a Stream/Sink - pub fn new(is_initiator: bool, io: IO, name: &str) -> Self { + #[instrument(skip_all, fields(name = %ename(is_initiator)))] + pub fn new(is_initiator: bool, io: IO) -> Self { Self { io, is_initiator, @@ -50,7 +67,6 @@ where plain_rx: Default::default(), flush: false, count: 0, - name: name.to_string(), } } } @@ -67,12 +83,14 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Ready(Ok(())) } + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { - trace!("{} add plain tx", self.name); + trace!("add plain tx"); self.plain_tx.push_back(item); Ok(()) } + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -87,7 +105,6 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_rx, flush, count, - name, .. } = self.get_mut(); @@ -99,9 +116,8 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { trace!( - "{name} enc tx send msg - {encrypted_out:?} -" + name = %ename(*is_initiator), + "enc tx send msg\n{encrypted_out:?}" ); let _todo = Sink::start_send(Pin::new(io), encrypted_out); *flush = true; @@ -124,9 +140,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!("{name} flushed good"); + trace!(name = %ename(*is_initiator), "flushed good"); } - Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Ready(Err(_e)) => error!( + name = %ename(*is_initiator), + "Error sending encrypted msg"), Poll::Pending => { // More confusing docs // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush @@ -148,12 +166,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Pending => break, Poll::Ready(None) => todo!(), Poll::Ready(Some(encrypted_msg)) => { - trace!( - "{name} enc rx queue - {encrypted_msg:?} - ); -" - ); + trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -164,41 +177,45 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!("{name} plain rx queue"); + trace!(name = %ename(*is_initiator), "plain rx queue"); plain_rx.push_back(plain_msg); } - Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + Err(e) => { + error!(name = %ename(*is_initiator), "RX message failed to decrypt: {e:?}") + } } } // encrypt any pending plaintext outgoinng messages while let Some(mut plain_out) = plain_tx.pop_front() { // it encrypts in-place?? - if let Err(_e) = encryptor.encrypt(&mut plain_out) { - todo!("{name} We failed to encrypt our own message...?"); - } - trace!("{name} enc tx queue"); - encrypted_tx.push_back(plain_out); + let enc_out = match encryptor.encrypt(&mut plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + encrypted_tx.push_back(enc_out); + *flush = true; } if *flush { + cx.waker().wake_by_ref(); Poll::Pending } else { Poll::Ready(Ok(())) } } else { - trace!("{name} doing setup"); // Still setting up - if let Ok(Some(msg)) = init(step, *is_initiator) { + if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!("{name} queue initial msg"); + trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!("{name} recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { - for msg in msgs { - trace!("{name} queue more setup msg"); + trace!(name = %ename(*is_initiator),"recieved setup msg"); + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { + for msg in msgs.into_iter().rev() { + trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -208,6 +225,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn poll_close( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -220,6 +238,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static { type Item = Vec; + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -231,7 +250,6 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_rx, flush, count, - name, .. } = self.get_mut(); @@ -242,8 +260,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!( - "{name} enc tx send msg + trace!(name = %ename(*is_initiator), "enc tx send msg {encrypted_out:?} " ); @@ -268,9 +285,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!("{name} flushed good"); + trace!(name = %ename(*is_initiator), "flushed good"); + } + Poll::Ready(Err(_e)) => { + error!(name = %ename(*is_initiator), "Error sending encrypted msg") } - Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), Poll::Pending => { // More confusing docs // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush @@ -290,14 +309,9 @@ impl> + Sink> + Send + Unpin + Debug + 'static loop { match Stream::poll_next(Pin::new(io), cx) { Poll::Pending => break, - Poll::Ready(None) => todo!(), + Poll::Ready(None) => break, Poll::Ready(Some(encrypted_msg)) => { - trace!( - "{name} enc rx queue - {encrypted_msg:?} - ); -" - ); + trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -308,43 +322,50 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!("{name} plain rx queue"); + trace!(name = %ename(*is_initiator), "plain rx queue"); plain_rx.push_back(plain_msg); } - Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + Err(e) => { + error!(name = %ename(*is_initiator),"RX message failed to decrypt: {e:?}") + } } } // encrypt any pending plaintext outgoinng messages while let Some(mut plain_out) = plain_tx.pop_front() { - // it encrypts in-place?? - if let Err(_e) = encryptor.encrypt(&mut plain_out) { - todo!("{name} We failed to encrypt our own message...?"); - } - trace!("{name} enc tx queue"); - encrypted_tx.push_back(plain_out); + let enc_out = match encryptor.encrypt(&mut plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + encrypted_tx.push_back(enc_out); } // emit any messages that are ready if let Some(msg) = plain_rx.pop_front() { - trace!("{name} plain rx emit"); + trace!(name = %ename(*is_initiator), "plain rx emit"); Poll::Ready(Some(msg)) } else { Poll::Pending } } else { - trace!("{name} doing setup"); // Still setting up - if let Ok(Some(msg)) = init(step, *is_initiator) { + if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!("{name} queue initial msg"); + trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!("{name} recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { - for msg in msgs { - trace!("{name} queue more setup msg"); + trace!(name = %ename(*is_initiator), "recieved setup msg"); + if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, *is_initiator) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -355,169 +376,22 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } -fn _poll> + Sink> + Send + Unpin + 'static>( - encrypted: &mut Encrypted, - cx: &mut Context<'_>, -) { - let Encrypted { - io, - step, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - flush, - count, - name, - .. - } = encrypted; - - *count += 1; - if *count > 200 { - //panic!(); - } - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!( - "{name} enc tx send msg - {encrypted_out:?} -" - ); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - *flush = true; - } else { - break; - } - } - if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!("{name} flushed good"); - } - Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), - Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - *flush = true; - } - } - } - - // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => todo!(), - Poll::Ready(Some(encrypted_msg)) => { - trace!( - "{name} enc rx queue - {encrypted_msg:?} - ); -" - ); - encrypted_rx.push_back(encrypted_msg); - } - } - } - - if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - trace!("{name} plain rx queue"); - plain_rx.push_back(plain_msg); - } - Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), - } - } - - // encrypt any pending plaintext outgoinng messages - while let Some(mut plain_out) = plain_tx.pop_front() { - // it encrypts in-place?? - if let Err(_e) = encryptor.encrypt(&mut plain_out) { - todo!("{name} We failed to encrypt our own message...?"); - } - trace!("{name} enc tx queue"); - encrypted_tx.push_back(plain_out); - } - - // emit any messages that are ready - if let Some(_msg) = plain_rx.pop_front() { - trace!("{name} plain rx emit"); - //Poll::Ready(Some(msg)) - todo!() - } else { - //Poll::Pending - todo!() - } - } else { - trace!("{name} doing setup"); - // Still setting up - if let Ok(Some(msg)) = init(step, *is_initiator) { - // queue the init message to send first - trace!("{name} queue initial msg"); - encrypted_tx.push_front(msg); - } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!("{name} recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { - for msg in msgs { - trace!("{name} queue more setup msg"); - encrypted_tx.push_front(msg); - } - } - } - cx.waker().wake_by_ref(); - todo!() - } -} - -fn init(step: &mut Step, is_initiator: bool) -> Result>> { +fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { if !matches!(step, Step::NotInitialized) { return Ok(None); } - trace!( - "protocol Init, state {:?}, is_initiator {:?}", - step, - is_initiator - ); + trace!(name = %ename(is_initiator), "Init, state {step:?}"); let mut handshake = Handshake::new(is_initiator)?; let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(out) } -fn handle_setup_message( - step: &mut Step, - msg: &[u8], - is_initiator: bool, - name: &str, -) -> Result>> { +#[instrument(skip_all, fields(name = %ename(is_initiator)))] +fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { match &step { Step::NotInitialized => { - warn!("{name} Encrypted state was reset"); + warn!("{} Encrypted state was reset", name!(is_initiator)); let mut handshake = Handshake::new(is_initiator)?; let start_msg = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); @@ -527,17 +401,23 @@ fn handle_setup_message( Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { + trace!("Read in handshake msg\n{msg:?}"); if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { - error!("error in handshake.read_raw(msg) {e:?}"); + panic!("error in handshake.read_raw(msg) {e:?}"); return Err(e); } } { + info!( + "{} read message and emitting response {response:?}", + name!(is_initiator) + ); out.push(response); } if handshake.complete() { + info!("{} HS complete. Making result", name!(is_initiator)); let handshake_result = match handshake.into_result() { Ok(x) => x, Err(e) => { @@ -547,13 +427,14 @@ fn handle_setup_message( }; // The cipher will be put to use to the writer only after the peer's answer has come let (cipher, init_msg) = - match EncryptCipher::from_handshake_tx(handshake_result) { + match RawEncryptCipher::from_handshake_tx(handshake_result) { Ok(x) => x, Err(e) => { error!("from_handshake_tx error {e:?}"); return Err(e); } }; + info!("{} made enc cipher", name!(is_initiator)); out.push(init_msg); *step = Step::SecretStream((cipher, handshake_result.clone())); } else { @@ -563,6 +444,7 @@ fn handle_setup_message( Ok(out) } Step::SecretStream(_) => { + info!("E're a secret stream now!!!!!"); if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; @@ -576,24 +458,44 @@ fn handle_setup_message( #[cfg(test)] mod tset { - use crate::test_utils::{create_connected, log, Io}; + use crate::test_utils::{create_connected, log}; use super::*; use futures::{SinkExt, StreamExt}; + #[tokio::test] + async fn steps() -> Result<()> { + // figure out handshake problem + let mut left_hs = Handshake::new(true)?; + let s1 = left_hs.start_raw()?.unwrap(); + + println!("s1 {s1:?}"); + let mut right_hs = Handshake::new(false)?; + + let s2 = right_hs.read_raw(&s1)?.unwrap(); + println!("s2 {s2:?}"); + + let s3 = left_hs.read_raw(&s2)?.unwrap(); + println!("s3 {s3:?}"); + + let s4 = right_hs.read_raw(&s3)?; + + println!("s4 {s4:?}"); + Ok(()) + } + #[tokio::test] async fn test_encrypted() -> Result<()> { log(); + let expected = b"hello"; let (left, right) = create_connected(); - let mut left = Encrypted::new(true, left, "left"); - let mut right = Encrypted::new(true, right, "right"); + let mut left = Encrypted::new(true, left); + let mut right = Encrypted::new(false, right); tokio::task::spawn(async move { - left.send(b"hello".into()).await.unwrap(); + left.send(expected.into()).await.unwrap(); }); - //tokio::task::spawn(async move { - // let x = left.next().await; - //}); - dbg!(right.next().await); - todo!() + let result = right.next().await.unwrap(); + assert_eq!(result, expected); + Ok(()) } } From 792358dfa1eab32e9aed497bd7e96a289a6c8d5e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 17:29:41 -0400 Subject: [PATCH 013/135] Add RawEncrytpCipher --- src/crypto/cipher.rs | 78 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index c0e54a9..28ff1f4 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -184,3 +184,81 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { let result = result.as_slice(); out.copy_from_slice(result); } + +//NB "raw" here means UN-framed. No frame header. +const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; + +pub(crate) struct RawDecryptCipher { + pull_stream: PullStream, +} + +pub(crate) struct RawEncryptCipher { + push_stream: PushStream, +} +impl std::fmt::Debug for RawDecryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DecryptCipher(crypto_secretstream)") + } +} + +impl std::fmt::Debug for RawEncryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EncryptCipher(crypto_secretstream)") + } +} + +impl RawEncryptCipher { + pub(crate) fn from_handshake_tx( + handshake_result: &HandshakeResult, + ) -> std::io::Result<(Self, Vec)> { + let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + + let mut header_message: [u8; RAW_HEADER_MSG_LEN] = [0; RAW_HEADER_MSG_LEN]; + + write_stream_id( + &handshake_result.handshake_hash, + handshake_result.is_initiator, + &mut header_message[..STREAM_ID_LENGTH], + ); + + let (header, push_stream) = PushStream::init(OsRng, &key); + let header = header.as_ref(); + header_message[STREAM_ID_LENGTH..].copy_from_slice(header); + let msg = header_message.to_vec(); + Ok((Self { push_stream }, msg)) + } + + // Possible API's: + // encrypted message is (tag + encrypted + mac ) + // to have *zero* alocations we could + // * take a buffer of the expected final length, plantext starts at 1 to 1 + planetext.len() + // * final length is 1 + plaintext.len() + mac.len() + // * we write tag to 0 + // * encrypt plain text part in place + // * write mac to end + // + // it would be akward to take an array like this. We could infer the plaintext via the buffer + // it's range would be (1..(buf.len() - mac.len())) + // encypt-in-place the palintext, + // For now... let's just return the encrypted buffer + /// Encrypts message in the given buffer to the same buffer, returns number of byte + pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result> { + let mut out = buf.to_vec(); + self.push_stream + .push(&mut out, &[], Tag::Message) + .map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + })?; + Ok(out) + } + /// Get the length needed for encryption, that includes padding. + pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { + // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe + // extra room. + // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ + plaintext_len + 2 * 15 + } +} From 1f6fb96dd9a01e15e41db3f0f520034d3188fd91 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 12 Mar 2025 14:31:46 -0400 Subject: [PATCH 014/135] More tests, logging, lints --- src/crypto/cipher.rs | 41 ++++++++++++++------- src/crypto/handshake.rs | 3 ++ src/crypto/mod.rs | 2 +- src/noise.rs | 80 +++++++++++++++++++++++++---------------- src/protocol.rs | 4 +-- src/test_utils.rs | 19 ++++++++++ src/util.rs | 2 +- 7 files changed, 104 insertions(+), 47 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 28ff1f4..5c8d8a2 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -188,17 +188,9 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { //NB "raw" here means UN-framed. No frame header. const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; -pub(crate) struct RawDecryptCipher { - pull_stream: PullStream, -} - pub(crate) struct RawEncryptCipher { push_stream: PushStream, -} -impl std::fmt::Debug for RawDecryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DecryptCipher(crypto_secretstream)") - } + buf: Vec, } impl std::fmt::Debug for RawEncryptCipher { @@ -228,7 +220,13 @@ impl RawEncryptCipher { let header = header.as_ref(); header_message[STREAM_ID_LENGTH..].copy_from_slice(header); let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) + Ok(( + Self { + push_stream, + buf: Default::default(), + }, + msg, + )) } // Possible API's: @@ -244,9 +242,12 @@ impl RawEncryptCipher { // it's range would be (1..(buf.len() - mac.len())) // encypt-in-place the palintext, // For now... let's just return the encrypted buffer - /// Encrypts message in the given buffer to the same buffer, returns number of byte - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result> { - let mut out = buf.to_vec(); + // + /// Encrypts `msg` and returns the encrypted bytes + pub(crate) fn encrypt(&mut self, msg: &[u8]) -> io::Result> { + // NB: the result is written in place to the provided, however the buffer must be able to + // grow, since the encrypted message is bigger. So here we convert the slice to a vec. + let mut out = msg.to_vec(); self.push_stream .push(&mut out, &[], Tag::Message) .map_err(|err| { @@ -254,6 +255,20 @@ impl RawEncryptCipher { })?; Ok(out) } + + pub(crate) fn encrypt_in_place<'a>(&'a mut self, msg: &[u8]) -> io::Result<&'a [u8]> { + let min_safe_length = self.safe_encrypted_len(msg.len()); + if self.buf.len() < min_safe_length { + self.buf.resize(min_safe_length, 0); + } + // write message starting at index 1. we write the tag to index zero + self.buf[1..].copy_from_slice(msg); + // insert tag + // let enc_len = self.encrypt_no_alloc(&mut self.buff, 1..(1 + msg.len()))?; + // self.buf[..enc_len] + todo!() + } + /// Get the length needed for encryption, that includes padding. pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 21c5442..74a1ada 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -7,6 +7,7 @@ use blake2::{ use snow::resolvers::{DefaultResolver, FallbackResolver}; use snow::{Builder, Error as SnowError, HandshakeState}; use std::io::{Error, ErrorKind, Result}; +use tracing::instrument; const CIPHERKEYLEN: usize = 32; const HANDSHAKE_PATTERN: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; @@ -81,6 +82,7 @@ pub(crate) struct Handshake { } impl Handshake { + #[instrument] pub(crate) fn new(is_initiator: bool) -> Result { let (state, local_pubkey) = build_handshake_state(is_initiator).map_err(map_err)?; @@ -132,6 +134,7 @@ impl Handshake { .map_err(map_err) } + #[instrument(skip_all, fields(is_initiator = %self.result.is_initiator))] pub(crate) fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 66bb62d..27f12b4 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,5 @@ mod cipher; mod curve; mod handshake; -pub(crate) use cipher::{DecryptCipher, EncryptCipher}; +pub(crate) use cipher::{DecryptCipher, EncryptCipher, RawEncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/noise.rs b/src/noise.rs index c22c1fd..beff242 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -7,19 +7,9 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{error, info, instrument, trace, warn}; +use tracing::{debug, error, info, instrument, trace, warn}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; - -macro_rules! name { - ($name:tt) => {{ - if $name { - "initiator" - } else { - "other" - } - }}; -} +use crate::crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; #[derive(Debug)] pub(crate) enum Step { @@ -28,6 +18,20 @@ pub(crate) enum Step { SecretStream((RawEncryptCipher, HandshakeResult)), Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + } + ) + } +} /// Wrap a stream with encryption #[derive(Debug)] @@ -55,7 +59,7 @@ where IO: Stream> + Sink> + Send + Unpin + Debug + 'static, { /// Create [`Self`] from a Stream/Sink - #[instrument(skip_all, fields(name = %ename(is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %is_initiator))] pub fn new(is_initiator: bool, io: IO) -> Self { Self { io, @@ -83,14 +87,14 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Ready(Ok(())) } - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { trace!("add plain tx"); self.plain_tx.push_back(item); Ok(()) } - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -225,7 +229,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn poll_close( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -238,7 +242,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static { type Item = Vec; - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -387,14 +391,15 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } -#[instrument(skip_all, fields(name = %ename(is_initiator)))] +#[instrument(skip_all, fields(is_initiator = %is_initiator))] fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { match &step { Step::NotInitialized => { - warn!("{} Encrypted state was reset", name!(is_initiator)); + warn!(initiator = %is_initiator, "Encrypted state was reset"); let mut handshake = Handshake::new(is_initiator)?; let start_msg = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); + debug!(initiator = %is_initiator, "Step changed to {step}"); Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) } @@ -405,19 +410,18 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { - panic!("error in handshake.read_raw(msg) {e:?}"); return Err(e); } } { info!( - "{} read message and emitting response {response:?}", - name!(is_initiator) + initiator = %is_initiator, + "read message and emitting response {response:?}", ); out.push(response); } if handshake.complete() { - info!("{} HS complete. Making result", name!(is_initiator)); + debug!(initiator = %is_initiator, "HS complete. Making result"); let handshake_result = match handshake.into_result() { Ok(x) => x, Err(e) => { @@ -434,9 +438,9 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu return Err(e); } }; - info!("{} made enc cipher", name!(is_initiator)); out.push(init_msg); *step = Step::SecretStream((cipher, handshake_result.clone())); + debug!(initiator = %is_initiator, "Step changed to {step}"); } else { *step = Step::Handshake(handshake); } @@ -449,6 +453,7 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; *step = Step::Established((enc_cipher, dec_cipher, hs_result)); + debug!(initiator = %is_initiator, "Step changed to {step}"); } Ok(vec![]) } @@ -465,7 +470,6 @@ mod tset { #[tokio::test] async fn steps() -> Result<()> { - // figure out handshake problem let mut left_hs = Handshake::new(true)?; let s1 = left_hs.start_raw()?.unwrap(); @@ -481,21 +485,37 @@ mod tset { let s4 = right_hs.read_raw(&s3)?; println!("s4 {s4:?}"); + // both sides now ready Ok(()) } #[tokio::test] async fn test_encrypted() -> Result<()> { log(); - let expected = b"hello"; + let hello = b"hello"; + let world = b"world"; let (left, right) = create_connected(); let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); - tokio::task::spawn(async move { - left.send(expected.into()).await.unwrap(); + + // NB: we cannot totally finish 'left.send' until the other side becomes active + // this is because the handshake with the other side ('right') must complete + // before the message is sent. So we must spawn here, so we can proceed to run 'right' + let left_handle = tokio::task::spawn(async move { + left.send(hello.into()).await.unwrap(); + left }); - let result = right.next().await.unwrap(); - assert_eq!(result, expected); + + // right recieves left's message + assert_eq!(right.next().await.unwrap(), hello); + + let mut left = left_handle.await?; + + // now that the encrypted channel is established, we don't need to spawn. + right.send(world.into()).await.unwrap(); + + // left recieves right's message + assert_eq!(left.next().await.unwrap(), world); Ok(()) } } diff --git a/src/protocol.rs b/src/protocol.rs index 3e8a2b5..9d1ebe9 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -10,7 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::trace; +use tracing::{info, trace}; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; @@ -466,7 +466,7 @@ where if self.options.encrypted { // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?; + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; self.state = State::SecretStream(Some(cipher)); // Send the secret stream init message header to the other side diff --git a/src/test_utils.rs b/src/test_utils.rs index 273935f..7d8c3a7 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,5 +1,6 @@ use std::{ pin::Pin, + sync::OnceLock, task::{Context, Poll}, }; @@ -71,6 +72,24 @@ impl TwoWay { pub(crate) fn create_connected() -> (Io, Io) { TwoWay::default().split_sides() } + +#[allow(dead_code)] +pub(crate) fn log() { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + static START_LOGS: OnceLock<()> = OnceLock::new(); + START_LOGS.get_or_init(|| { + tracing_subscriber::fmt() + .with_target(true) + .with_line_number(true) + // print when instrumented funtion enters + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) + .with_file(true) + .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable + .without_time() + .init(); + }); +} + #[tokio::test] async fn way_one() { let mut a = Io::default(); diff --git a/src/util.rs b/src/util.rs index c99ff9c..1350728 100644 --- a/src/util.rs +++ b/src/util.rs @@ -31,7 +31,7 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { pub(crate) const UINT_24_LENGTH: usize = 3; #[inline] -pub(crate) fn wrap_uint24_le(data: &Vec) -> Vec { +pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; let n = data.len(); write_uint24_le(n, &mut buf); From 730baac9c90a876b1b322c594b6f1dfa510b0050 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 12 Mar 2025 16:10:59 -0400 Subject: [PATCH 015/135] rm encrypt_in_place --- src/crypto/cipher.rs | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 5c8d8a2..cbf84bc 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -190,7 +190,6 @@ const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; pub(crate) struct RawEncryptCipher { push_stream: PushStream, - buf: Vec, } impl std::fmt::Debug for RawEncryptCipher { @@ -220,13 +219,7 @@ impl RawEncryptCipher { let header = header.as_ref(); header_message[STREAM_ID_LENGTH..].copy_from_slice(header); let msg = header_message.to_vec(); - Ok(( - Self { - push_stream, - buf: Default::default(), - }, - msg, - )) + Ok((Self { push_stream }, msg)) } // Possible API's: @@ -255,25 +248,4 @@ impl RawEncryptCipher { })?; Ok(out) } - - pub(crate) fn encrypt_in_place<'a>(&'a mut self, msg: &[u8]) -> io::Result<&'a [u8]> { - let min_safe_length = self.safe_encrypted_len(msg.len()); - if self.buf.len() < min_safe_length { - self.buf.resize(min_safe_length, 0); - } - // write message starting at index 1. we write the tag to index zero - self.buf[1..].copy_from_slice(msg); - // insert tag - // let enc_len = self.encrypt_no_alloc(&mut self.buff, 1..(1 + msg.len()))?; - // self.buf[..enc_len] - todo!() - } - - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 - } } From bb7354b97cefbbffd18571274a8bf0f20959618f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 12 Mar 2025 16:11:46 -0400 Subject: [PATCH 016/135] rm stuff used for development --- src/noise.rs | 72 ++++++++++++++++++---------------------------------- 1 file changed, 24 insertions(+), 48 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index beff242..4851c5b 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -44,16 +44,8 @@ pub struct Encrypted { plain_tx: VecDeque>, plain_rx: VecDeque>, flush: bool, - count: usize, } -fn ename(is_initiator: bool) -> String { - if is_initiator { - "initiator".to_string() - } else { - "other".to_string() - } -} impl Encrypted where IO: Stream> + Sink> + Send + Unpin + Debug + 'static, @@ -70,7 +62,6 @@ where plain_tx: Default::default(), plain_rx: Default::default(), flush: false, - count: 0, } } } @@ -108,21 +99,13 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_tx, plain_rx, flush, - count, .. } = self.get_mut(); - *count += 1; - if *count > 200 { - //panic!(); - } // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!( - name = %ename(*is_initiator), - "enc tx send msg\n{encrypted_out:?}" - ); + trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); let _todo = Sink::start_send(Pin::new(io), encrypted_out); *flush = true; } else { @@ -144,10 +127,10 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!(name = %ename(*is_initiator), "flushed good"); + trace!(initiator = %is_initiator, "flushed good"); } Poll::Ready(Err(_e)) => error!( - name = %ename(*is_initiator), + initiator = %is_initiator, "Error sending encrypted msg"), Poll::Pending => { // More confusing docs @@ -168,9 +151,9 @@ impl> + Sink> + Send + Unpin + Debug + 'static loop { match Stream::poll_next(Pin::new(io), cx) { Poll::Pending => break, - Poll::Ready(None) => todo!(), + Poll::Ready(None) => break, Poll::Ready(Some(encrypted_msg)) => { - trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); + trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -181,11 +164,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(name = %ename(*is_initiator), "plain rx queue"); + trace!(initiator = %is_initiator, "plain rx queue"); plain_rx.push_back(plain_msg); } Err(e) => { - error!(name = %ename(*is_initiator), "RX message failed to decrypt: {e:?}") + error!(initiator = %is_initiator, "RX message failed to decrypt: {e:?}") } } } @@ -197,7 +180,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); encrypted_tx.push_back(enc_out); *flush = true; } @@ -212,14 +195,14 @@ impl> + Sink> + Send + Unpin + Debug + 'static // Still setting up if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(name = %ename(*is_initiator),"recieved setup msg"); + trace!(initiator = %is_initiator,"recieved setup msg"); if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { for msg in msgs.into_iter().rev() { - trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -253,21 +236,13 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_tx, plain_rx, flush, - count, .. } = self.get_mut(); - *count += 1; - if *count > 200 { - //panic!(); - } // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(name = %ename(*is_initiator), "enc tx send msg - {encrypted_out:?} -" - ); + trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); let _todo = Sink::start_send(Pin::new(io), encrypted_out); *flush = true; } else { @@ -289,10 +264,10 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!(name = %ename(*is_initiator), "flushed good"); + trace!(initiator = %is_initiator, "flushed good"); } Poll::Ready(Err(_e)) => { - error!(name = %ename(*is_initiator), "Error sending encrypted msg") + error!(initiator = %is_initiator, "Error sending encrypted msg") } Poll::Pending => { // More confusing docs @@ -315,7 +290,8 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Pending => break, Poll::Ready(None) => break, Poll::Ready(Some(encrypted_msg)) => { - trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); + trace!( + initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -326,11 +302,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(name = %ename(*is_initiator), "plain rx queue"); + trace!(initiator = %is_initiator, "plain rx queue"); plain_rx.push_back(plain_msg); } Err(e) => { - error!(name = %ename(*is_initiator),"RX message failed to decrypt: {e:?}") + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") } } } @@ -341,13 +317,13 @@ impl> + Sink> + Send + Unpin + Debug + 'static Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); encrypted_tx.push_back(enc_out); } // emit any messages that are ready if let Some(msg) = plain_rx.pop_front() { - trace!(name = %ename(*is_initiator), "plain rx emit"); + trace!(initiator = %is_initiator, "plain rx emit"); Poll::Ready(Some(msg)) } else { Poll::Pending @@ -356,11 +332,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static // Still setting up if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(name = %ename(*is_initiator), "recieved setup msg"); + trace!(initiator = %is_initiator, "recieved setup msg"); if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, *is_initiator) { Ok(x) => Ok(x), Err(e) => { @@ -369,7 +345,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } { for msg in msgs.into_iter().rev() { - trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -384,7 +360,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { if !matches!(step, Step::NotInitialized) { return Ok(None); } - trace!(name = %ename(is_initiator), "Init, state {step:?}"); + trace!(initiator = %is_initiator, "Init, state {step:?}"); let mut handshake = Handshake::new(is_initiator)?; let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); From 4e6217a0122d1770c2324e6be297d07e172db620 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:23:39 -0400 Subject: [PATCH 017/135] Add LengthPrefixed framing --- src/framing.rs | 249 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 2 files changed, 251 insertions(+) create mode 100644 src/framing.rs diff --git a/src/framing.rs b/src/framing.rs new file mode 100644 index 0000000..3504895 --- /dev/null +++ b/src/framing.rs @@ -0,0 +1,249 @@ +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{Sink, Stream}; +use futures_lite::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, instrument, trace}; + +use crate::util::{stat_uint24_le, wrap_uint24_le}; + +const BUF_SIZE: usize = 1024 * 8; +const HEADER_LEN: usize = 3; + +/// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream +pub struct LengthPrefixed { + io: IO, + to_stream: Vec, + from_sink: VecDeque>, + /// The index in [`Self::buf`] of the last byte that was to the [`Stream`]. + last_out_idx: usize, + /// The index in [`Self::buf`] of the last byte that was read from [`Self::io`] via + /// [`AsyncRead`] + last_data_idx: usize, + step: Step, +} +impl Debug for LengthPrefixed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Format()") + } +} +impl LengthPrefixed +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + /// Build [`LengthPrefixed`] around an [`AsyncWrite`]/[`AsyncRead`] thing. + pub fn new(io: IO) -> Self { + Self { + io, + to_stream: vec![0u8; BUF_SIZE], + from_sink: VecDeque::new(), + last_out_idx: 0, + last_data_idx: 0, + step: Step::Header, + } + } +} + +#[derive(Debug)] +enum Step { + Header, + Body { start: usize, end: u64 }, +} + +impl Stream for LengthPrefixed +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + type Item = Result>; + + #[instrument(skip_all)] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + trace!("from poll next!!"); + let Self { + io, + to_stream, + last_out_idx, + last_data_idx, + step, + .. + } = self.get_mut(); + let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { + Poll::Ready(Ok(n)) => n, + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => 0, + }; + // TODO handle if to_stream is full + trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + *last_data_idx += n_bytes_read; + // grow buffer if it's full + if *last_data_idx == to_stream.len() - 1 { + to_stream.extend(vec![0; to_stream.len() * 2]); + } + + if let Step::Header = step { + trace!(step = ?*step, "enter"); + if *last_data_idx - *last_out_idx < HEADER_LEN { + trace!("not enough bytes to read header"); + return Poll::Pending; + } + let Some((header_len, body_len)) = + stat_uint24_le(&to_stream[*last_out_idx..(*last_out_idx + HEADER_LEN)]) + else { + // we check above the there is room for header so this should never happen + todo!() + }; + + let cur_frame_start = *last_out_idx + header_len; + let cur_frame_end = (cur_frame_start as u64) + body_len; + *step = Step::Body { + start: cur_frame_start, + end: cur_frame_end, + }; + } + + if let Step::Body { start, end } = step { + let end = *end as usize; + if end <= *last_data_idx { + debug!(frame_size = end - *start, "Frame ready"); + let out = to_stream[*start..end].to_vec(); + *step = Step::Header; + *last_out_idx = end; + + return Poll::Ready(Some(Ok(out))); + } + } + Poll::Pending + } +} +impl Sink> for LengthPrefixed +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + type Error = std::io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { + self.from_sink.push_back(wrap_uint24_le(&item)); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let Self { from_sink, io, .. } = self.get_mut(); + if let Some(msg) = from_sink.pop_front() { + match Pin::new(io).poll_write(cx, &msg) { + Poll::Pending => { + from_sink.push_front(msg); + return Poll::Pending; + } + Poll::Ready(Ok(n)) => { + if n != msg.len() { + from_sink.push_front(msg[n..].to_vec()); + return Poll::Ready(Ok(())); + } + } + Poll::Ready(Err(_e)) => todo!(), + } + } + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + todo!() + } +} +#[cfg(test)] +mod test { + use crate::test_utils::log; + + use super::*; + use futures::{ + io::{AsyncReadExt, AsyncWriteExt}, + AsyncRead, AsyncWrite, SinkExt, StreamExt, + }; + use tokio_util::compat::TokioAsyncReadCompatExt; + + fn duplex(channel_size: usize) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { + let (left, right) = tokio::io::duplex(channel_size); + (left.compat(), right.compat()) + } + + #[tokio::test] + async fn t_duplex() -> Result<()> { + let (mut left, mut right) = duplex(64); + left.write_all(b"hello").await?; + let mut b = vec![0; 5]; + right.read_exact(&mut b).await?; + assert_eq!(b, b"hello"); + Ok(()) + } + + #[tokio::test] + async fn t_input() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = LengthPrefixed::new(left); + let input = b"yelp"; + let msg = wrap_uint24_le(input); + dbg!(&msg); + right.write_all(&msg).await?; + let Some(Ok(rx)) = lp.next().await else { + panic!() + }; + assert_eq!(rx, input); + Ok(()) + } + #[tokio::test] + async fn t_stream_many() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = LengthPrefixed::new(left); + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + let msg = wrap_uint24_le(d); + right.write_all(&msg).await?; + } + for d in data { + dbg!(); + let Some(Ok(res)) = lp.next().await else { + panic!(); + }; + dbg!(&res); + assert_eq!(&res, d); + } + Ok(()) + } + #[tokio::test] + async fn t_sink_many() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = LengthPrefixed::new(left); + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + lp.send(d.to_vec()).await.unwrap(); + } + + let mut expected = vec![]; + data.iter().for_each(|d| expected.extend(wrap_uint24_le(d))); + let mut result = vec![0; expected.len()]; + right.read_exact(&mut result).await?; + assert_eq!(result, expected); + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 646a93e..b1a043a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,6 +122,7 @@ mod channels; mod constants; mod crypto; mod duplex; +mod framing; mod message; mod noise; mod protocol; @@ -136,6 +137,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; +pub use framing::LengthPrefixed; pub use noise::Encrypted; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ From b576c9f4d2a7fc16885c9bc53a07d1f4b011306b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:49:41 -0400 Subject: [PATCH 018/135] use var for header len --- src/crypto/cipher.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index cbf84bc..8278692 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -148,7 +148,7 @@ impl EncryptCipher { let encrypted_len = to_encrypt.len(); write_uint24_le(encrypted_len, buf); buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(3 + encrypted_len) + Ok(header_len + encrypted_len) } else { Err(io::Error::new( io::ErrorKind::InvalidData, From a3a50d071da37309926104e1995cd87e703a35ef Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:50:19 -0400 Subject: [PATCH 019/135] refactor encrypted poll functions --- src/noise.rs | 357 +++++++++++++++++++++++---------------------------- 1 file changed, 159 insertions(+), 198 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 4851c5b..84fc381 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -69,7 +69,7 @@ where impl> + Sink> + Send + Unpin + Debug + 'static> Sink> for Encrypted { - type Error = (); + type Error = std::io::Error; fn poll_ready( self: Pin<&mut Self>, @@ -102,88 +102,19 @@ impl> + Sink> + Send + Unpin + Debug + 'static .. } = self.get_mut(); - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - *flush = true; - } else { - break; - } - } - if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!(initiator = %is_initiator, "flushed good"); - } - Poll::Ready(Err(_e)) => error!( - initiator = %is_initiator, - "Error sending encrypted msg"), - Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - *flush = true; - } - } - } - - // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => break, - Poll::Ready(Some(encrypted_msg)) => { - trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); - encrypted_rx.push_back(encrypted_msg); - } - } - } + poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(plain_msg); - } - Err(e) => { - error!(initiator = %is_initiator, "RX message failed to decrypt: {e:?}") - } - } - } - - // encrypt any pending plaintext outgoinng messages - while let Some(mut plain_out) = plain_tx.pop_front() { - // it encrypts in-place?? - let enc_out = match encryptor.encrypt(&mut plain_out) { - Ok(x) => x, - Err(_e) => todo!("We failed to encrypt our own message...?"), - }; - trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); - encrypted_tx.push_back(enc_out); - *flush = true; - } + poll_do_encrypt_and_decrypt( + encryptor, + decryptor, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + *is_initiator, + flush, + ); if *flush { cx.waker().wake_by_ref(); @@ -192,21 +123,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Ready(Ok(())) } } else { - // Still setting up - if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { - // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); - encrypted_tx.push_front(msg); - } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(initiator = %is_initiator,"recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); - encrypted_tx.push_front(msg); - } - } - } + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); cx.waker().wake_by_ref(); Poll::Pending } @@ -220,6 +137,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static todo!() } } + impl> + Sink> + Send + Unpin + Debug + 'static> Stream for Encrypted { @@ -239,88 +157,19 @@ impl> + Sink> + Send + Unpin + Debug + 'static .. } = self.get_mut(); - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - *flush = true; - } else { - break; - } - } - if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!(initiator = %is_initiator, "flushed good"); - } - Poll::Ready(Err(_e)) => { - error!(initiator = %is_initiator, "Error sending encrypted msg") - } - Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - *flush = true; - } - } - } - - // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => break, - Poll::Ready(Some(encrypted_msg)) => { - trace!( - initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); - encrypted_rx.push_back(encrypted_msg); - } - } - } + poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(plain_msg); - } - Err(e) => { - error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") - } - } - } - - // encrypt any pending plaintext outgoinng messages - while let Some(mut plain_out) = plain_tx.pop_front() { - let enc_out = match encryptor.encrypt(&mut plain_out) { - Ok(x) => x, - Err(_e) => todo!("We failed to encrypt our own message...?"), - }; - trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); - encrypted_tx.push_back(enc_out); - } - + poll_do_encrypt_and_decrypt( + encryptor, + decryptor, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + *is_initiator, + flush, + ); // emit any messages that are ready if let Some(msg) = plain_rx.pop_front() { trace!(initiator = %is_initiator, "plain rx emit"); @@ -329,31 +178,144 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Pending } } else { - // Still setting up - if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { - // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +fn poll_setup( + step: &mut Step, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>, + is_initiator: bool, +) { + // Still setting up + if let Ok(Some(msg)) = maybe_init(step, is_initiator) { + // queue the init message to send first + trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + trace!(initiator = %is_initiator, "recieved setup msg"); + if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(initiator = %is_initiator, "recieved setup msg"); - if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, *is_initiator) { - Ok(x) => Ok(x), - Err(e) => { - error!("handle_setup_message error: {e:?}"); - Err(e) - } - } { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); - encrypted_tx.push_front(msg); - } - } + } + } +} + +fn poll_encrypted_side_io< + IO: Stream> + Sink> + Send + Unpin + Debug + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!(initiator = %is_initiator, "flushed good"); + } + Poll::Ready(Err(_e)) => error!( + initiator = %is_initiator, + "Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; } - cx.waker().wake_by_ref(); - Poll::Pending } } + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => break, + Poll::Ready(Some(encrypted_msg)) => { + trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); + encrypted_rx.push_back(encrypted_msg); + } + } + } +} + +/// Process messages waiting to be encrypted or decrypted +// TODO sholud this return a Result +fn poll_do_encrypt_and_decrypt( + encryptor: &mut RawEncryptCipher, + decryptor: &mut DecryptCipher, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>, + plain_tx: &mut VecDeque>, + plain_rx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!(initiator = %is_initiator, "plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => { + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") + } + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(plain_out) = plain_tx.pop_front() { + let enc_out = match encryptor.encrypt(&plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); + encrypted_tx.push_back(enc_out); + *flush = true; + } } fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { @@ -439,7 +401,7 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu #[cfg(test)] mod tset { - use crate::test_utils::{create_connected, log}; + use crate::test_utils::create_connected; use super::*; use futures::{SinkExt, StreamExt}; @@ -467,7 +429,6 @@ mod tset { #[tokio::test] async fn test_encrypted() -> Result<()> { - log(); let hello = b"hello"; let world = b"world"; let (left, right) = create_connected(); From 94785a52b0074356f6b7b78380b6d9e99a6c50ed Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:51:50 -0400 Subject: [PATCH 020/135] rename writer fields --- src/protocol.rs | 2 +- src/writer.rs | 59 +++++++++++++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 9d1ebe9..1f24b1a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -593,7 +593,7 @@ where fn queue_frame_direct(&mut self, body: Vec) -> Result { let mut frame = Frame::RawBatch(vec![body]); - self.write_state.try_queue_direct(&mut frame) + self.write_state.try_encode_frame_for_tx(&mut frame) } fn accept_channel(&mut self, local_id: usize) -> Result<()> { diff --git a/src/writer.rs b/src/writer.rs index e3cc5da..38d6dcf 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -19,11 +19,11 @@ pub(crate) enum Step { pub(crate) struct WriteState { queue: VecDeque, - buf: Vec, current_frame: Option, - start: usize, - end: usize, cipher: Option, + buf: Vec, + written_up_to_idx: usize, + should_write_up_to_idx: usize, step: Step, } @@ -31,12 +31,12 @@ impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WriteState") .field("queue (len)", &self.queue.len()) - .field("step", &self.step) - .field("buf (len)", &self.buf.len()) .field("current_frame", &self.current_frame) - .field("start", &self.start) - .field("end", &self.end) .field("cipher", &self.cipher.is_some()) + .field("buf (len)", &self.buf.len()) + .field("start", &self.written_up_to_idx) + .field("end", &self.should_write_up_to_idx) + .field("step", &self.step) .finish() } } @@ -47,8 +47,8 @@ impl WriteState { queue: VecDeque::new(), buf: vec![0u8; BUF_SIZE], current_frame: None, - start: 0, - end: 0, + written_up_to_idx: 0, + should_write_up_to_idx: 0, cipher: None, step: Step::Processing, } @@ -61,7 +61,7 @@ impl WriteState { self.queue.push_back(frame.into()) } - pub(crate) fn try_queue_direct(&mut self, frame: &mut T) -> Result { + pub(crate) fn try_encode_frame_for_tx(&mut self, frame: &mut T) -> Result { let promised_len = frame.encoded_len()?; let padded_promised_len = self.safe_encrypted_len(promised_len); if self.buf.len() < padded_promised_len { @@ -70,13 +70,15 @@ impl WriteState { if padded_promised_len > self.remaining() { return Ok(false); } - let actual_len = frame.encode(&mut self.buf[self.end..])?; + + // write frame starting at end. fram is from end to end + actual_end + let actual_len = frame.encode(&mut self.buf[self.should_write_up_to_idx..])?; if actual_len != promised_len { panic!( "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" ); } - self.advance(padded_promised_len)?; + self.encrypt_frame_contents(padded_promised_len)?; Ok(true) } @@ -93,16 +95,18 @@ impl WriteState { } } - fn advance(&mut self, n: usize) -> Result<()> { - let end = self.end + n; + fn encrypt_frame_contents(&mut self, max_message_size: usize) -> Result<()> { + let end_of_message_index = self.should_write_up_to_idx + max_message_size; let encrypted_end = if let Some(ref mut cipher) = self.cipher { - self.end + cipher.encrypt(&mut self.buf[self.end..end])? + self.should_write_up_to_idx + + cipher + .encrypt(&mut self.buf[self.should_write_up_to_idx..end_of_message_index])? } else { - end + end_of_message_index }; - self.end = encrypted_end; + self.should_write_up_to_idx = encrypted_end; Ok(()) } @@ -111,11 +115,11 @@ impl WriteState { } fn remaining(&self) -> usize { - self.buf.len() - self.end + self.buf.len() - self.should_write_up_to_idx } fn pending(&self) -> usize { - self.end - self.start + self.should_write_up_to_idx - self.written_up_to_idx } pub(crate) fn poll_send( @@ -134,7 +138,7 @@ impl WriteState { } if let Some(mut frame) = self.current_frame.take() { - if !self.try_queue_direct(&mut frame)? { + if !self.try_encode_frame_for_tx(&mut frame)? { self.current_frame = Some(frame); } } @@ -145,13 +149,14 @@ impl WriteState { Step::Writing } Step::Writing => { - let n = ready!( - Pin::new(&mut writer).poll_write(cx, &self.buf[self.start..self.end]) - )?; - self.start += n; - if self.start == self.end { - self.start = 0; - self.end = 0; + let n = ready!(Pin::new(&mut writer).poll_write( + cx, + &self.buf[self.written_up_to_idx..self.should_write_up_to_idx] + ))?; + self.written_up_to_idx += n; + if self.written_up_to_idx == self.should_write_up_to_idx { + self.written_up_to_idx = 0; + self.should_write_up_to_idx = 0; } Step::Flushing } From fb88b8912d100de85c6a3986d758f82a78d78cd9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 14 Mar 2025 15:41:11 -0400 Subject: [PATCH 021/135] s/3/header_len/g --- src/crypto/cipher.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 8278692..4dfaf19 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -85,8 +85,8 @@ impl DecryptCipher { let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; let decrypted_len = to_decrypt.len(); write_uint24_le(decrypted_len, buf); - let decrypted_end = 3 + to_decrypt.len(); - buf[3..decrypted_end].copy_from_slice(to_decrypt.as_slice()); + let decrypted_end = header_len + to_decrypt.len(); + buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); // Set extra bytes in the buffer to 0 let encrypted_end = header_len + body_len; buf[decrypted_end..encrypted_end].fill(0x00); From 1bb619d436b4098e2aa4774c898599f3b5839156 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 15:50:31 -0400 Subject: [PATCH 022/135] add tokio-util for tests --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index a5ac273..ff89935 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ futures = "0.3.13" log = "0.4" test-log = { version = "0.2.11", default-features = false, features = ["trace"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } +tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] From b6db23333447d6b476806edbac0e4f91d072b418 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 15:51:14 -0400 Subject: [PATCH 023/135] Add result channel to test utils. refactor to use futures channels bc they implement Sender --- src/test_utils.rs | 107 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 7 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 7d8c3a7..2e9b6cd 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -4,8 +4,13 @@ use std::{ task::{Context, Poll}, }; -use async_channel::{unbounded, Receiver, SendError, Sender}; -use futures::{Sink, SinkExt, Stream, StreamExt}; +//use async_channel::{unbounded, Receiver, SendError, Sender}; +use futures::{ + channel::mpsc::{ + unbounded, SendError, UnboundedReceiver as Receiver, UnboundedSender as Sender, + }, + Sink, SinkExt, Stream, StreamExt, +}; #[derive(Debug)] pub(crate) struct Io { @@ -29,15 +34,14 @@ impl Stream for Io { } impl Sink> for Io { - type Error = SendError>; + type Error = SendError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - let _ = self.sender.try_send(item); - Ok(()) + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + Pin::new(&mut self.sender).start_send(item) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -73,7 +77,6 @@ pub(crate) fn create_connected() -> (Io, Io) { TwoWay::default().split_sides() } -#[allow(dead_code)] pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: OnceLock<()> = OnceLock::new(); @@ -108,3 +111,93 @@ async fn split() { }; assert_eq!(res, b"hello"); } + +#[derive(Debug)] +pub(crate) struct Moo { + receiver: Rx, + sender: Tx, +} + +impl + Unpin, Tx: Unpin> Stream for Moo { + type Item = RxItem; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(&mut this.receiver).poll_next(cx) + } +} + +impl + Unpin> Sink + for Moo +{ + type Error = SendError; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: TxItem) -> Result<(), Self::Error> { + let this = self.get_mut(); + Pin::new(&mut this.sender).start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +/// Creaee [`Moo`] from return value of [`unbounded`] +impl From<(Tx, Rx)> for Moo { + fn from(value: (Tx, Rx)) -> Self { + Moo { + receiver: value.1, + sender: value.0, + } + } +} + +impl Moo { + /// connect two [`Moo`]s + fn connect( + self, + other: Moo, + ) -> (Moo, Moo) { + let left = Moo { + receiver: self.receiver, + sender: other.sender, + }; + let right = Moo { + receiver: other.receiver, + sender: self.sender, + }; + (left, right) + } +} + +fn result_channel() -> (Sender>, impl Stream, String>>) { + let (tx, rx) = unbounded::>(); + (tx, rx.map(|x| Ok(x))) +} + +pub(crate) fn create_result_connected() -> ( + Moo, String>>, impl Sink>>, + Moo, String>>, impl Sink>>, +) { + let a = Moo::from(result_channel()); + let b = Moo::from(result_channel()); + a.connect(b) +} + +#[tokio::test] +async fn foo() -> Result<(), Box> { + let a = Moo::from(result_channel()); + let b = Moo::from(result_channel()); + let (mut left, mut right) = a.connect(b); + left.send(b"hello".to_vec()).await?; + assert_eq!(right.next().await.unwrap(), Ok(b"hello".into())); + Ok(()) +} From 7780cb2b189f351f7547605f3a2e266e2361394f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 15:52:54 -0400 Subject: [PATCH 024/135] refactor framing tests --- src/framing.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 3504895..64576f3 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -63,7 +63,6 @@ where #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - trace!("from poll next!!"); let Self { io, to_stream, @@ -169,23 +168,23 @@ where } } #[cfg(test)] -mod test { +pub(crate) mod test { use crate::test_utils::log; use super::*; - use futures::{ - io::{AsyncReadExt, AsyncWriteExt}, - AsyncRead, AsyncWrite, SinkExt, StreamExt, - }; + use futures::{SinkExt, StreamExt}; + use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio_util::compat::TokioAsyncReadCompatExt; - fn duplex(channel_size: usize) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { + pub(crate) fn duplex( + channel_size: usize, + ) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { let (left, right) = tokio::io::duplex(channel_size); (left.compat(), right.compat()) } #[tokio::test] - async fn t_duplex() -> Result<()> { + async fn duplex_works() -> Result<()> { let (mut left, mut right) = duplex(64); left.write_all(b"hello").await?; let mut b = vec![0; 5]; @@ -195,7 +194,7 @@ mod test { } #[tokio::test] - async fn t_input() -> Result<()> { + async fn input() -> Result<()> { log(); let (left, mut right) = duplex(64); let mut lp = LengthPrefixed::new(left); @@ -210,7 +209,7 @@ mod test { Ok(()) } #[tokio::test] - async fn t_stream_many() -> Result<()> { + async fn stream_many() -> Result<()> { log(); let (left, mut right) = duplex(64); let mut lp = LengthPrefixed::new(left); @@ -230,7 +229,7 @@ mod test { Ok(()) } #[tokio::test] - async fn t_sink_many() -> Result<()> { + async fn sink_many() -> Result<()> { log(); let (left, mut right) = duplex(64); let mut lp = LengthPrefixed::new(left); From ec9edc2ea626ad079c68ca1e40b153dd4fe943bd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 16:43:48 -0400 Subject: [PATCH 025/135] Get Encrypted working with Result> from io --- src/test_utils.rs | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 2e9b6cd..d35af0e 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,10 +1,11 @@ use std::{ + io::{self, ErrorKind}, pin::Pin, sync::OnceLock, task::{Context, Poll}, }; -//use async_channel::{unbounded, Receiver, SendError, Sender}; +//use async_channel::{unbounded, Receiver, io::Error, Sender}; use futures::{ channel::mpsc::{ unbounded, SendError, UnboundedReceiver as Receiver, UnboundedSender as Sender, @@ -34,14 +35,16 @@ impl Stream for Io { } impl Sink> for Io { - type Error = SendError; + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - Pin::new(&mut self.sender).start_send(item) + Pin::new(&mut self.sender) + .start_send(item) + .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -112,7 +115,6 @@ async fn split() { assert_eq!(res, b"hello"); } -#[derive(Debug)] pub(crate) struct Moo { receiver: Rx, sender: Tx, @@ -127,10 +129,10 @@ impl + Unpin, Tx: Unpin> Stream for Moo } } -impl + Unpin> Sink +impl + Unpin> Sink for Moo { - type Error = SendError; + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -138,7 +140,9 @@ impl + Unpi fn start_send(self: Pin<&mut Self>, item: TxItem) -> Result<(), Self::Error> { let this = self.get_mut(); - Pin::new(&mut this.sender).start_send(item) + Pin::new(&mut this.sender) + .start_send(item) + .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -178,14 +182,14 @@ impl Moo { } } -fn result_channel() -> (Sender>, impl Stream, String>>) { +fn result_channel() -> (Sender>, impl Stream>>) { let (tx, rx) = unbounded::>(); (tx, rx.map(|x| Ok(x))) } pub(crate) fn create_result_connected() -> ( - Moo, String>>, impl Sink>>, - Moo, String>>, impl Sink>>, + Moo>>, impl Sink>>, + Moo>>, impl Sink>>, ) { let a = Moo::from(result_channel()); let b = Moo::from(result_channel()); @@ -198,6 +202,6 @@ async fn foo() -> Result<(), Box> { let b = Moo::from(result_channel()); let (mut left, mut right) = a.connect(b); left.send(b"hello".to_vec()).await?; - assert_eq!(right.next().await.unwrap(), Ok(b"hello".into())); + assert_eq!(right.next().await.unwrap()?, b"hello".to_vec()); Ok(()) } From c9a7709b5ba09b4de44a71c121126622331c850b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 16 Mar 2025 15:42:48 -0400 Subject: [PATCH 026/135] Make Encrypted receive a Result --- src/noise.rs | 118 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 82 insertions(+), 36 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 84fc381..dd687ec 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -34,24 +34,38 @@ impl std::fmt::Display for Step { } /// Wrap a stream with encryption -#[derive(Debug)] pub struct Encrypted { io: IO, step: Step, is_initiator: bool, encrypted_tx: VecDeque>, - encrypted_rx: VecDeque>, + encrypted_rx: VecDeque>>, plain_tx: VecDeque>, - plain_rx: VecDeque>, + plain_rx: VecDeque>>, flush: bool, } +impl std::fmt::Debug for Encrypted { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Encrypted") + //.field("io", &self.io) + .field("step", &self.step) + .field("is_initiator", &self.is_initiator) + //.field("encrypted_tx", &self.encrypted_tx) + .field("encrypted_rx", &self.encrypted_rx) + .field("plain_tx", &self.plain_tx) + .field("plain_rx", &self.plain_rx) + .field("flush", &self.flush) + .finish() + } +} + impl Encrypted where - IO: Stream> + Sink> + Send + Unpin + Debug + 'static, + IO: Stream>> + Sink> + Send + Unpin + 'static, { /// Create [`Self`] from a Stream/Sink - #[instrument(skip_all, fields(is_initiator = %is_initiator))] + #[instrument(skip_all, fields(initiator = %is_initiator))] pub fn new(is_initiator: bool, io: IO) -> Self { Self { io, @@ -66,26 +80,31 @@ where } } -impl> + Sink> + Send + Unpin + Debug + 'static> Sink> - for Encrypted +impl< + IO: Stream>> + + Sink, Error = std::io::Error> + + Send + + Unpin + + 'static, + > Sink> for Encrypted { type Error = std::io::Error; fn poll_ready( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, ) -> Poll> { - Poll::Ready(Ok(())) + Sink::poll_ready(Pin::new(&mut self.io), cx) } - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { - trace!("add plain tx"); + info!(initiator = %self.is_initiator, "enqueue plain_tx\n{item:?}"); self.plain_tx.push_back(item); Ok(()) } - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -129,7 +148,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_close( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -138,12 +157,12 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } -impl> + Sink> + Send + Unpin + Debug + 'static> Stream +impl>> + Sink> + Send + Unpin + 'static> Stream for Encrypted { - type Item = Vec; + type Item = Result>; - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -185,10 +204,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } +#[instrument(skip_all, fields(initiator = %is_initiator))] fn poll_setup( step: &mut Step, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, is_initiator: bool, ) { // Still setting up @@ -197,30 +217,54 @@ fn poll_setup( trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(initiator = %is_initiator, "recieved setup msg"); - if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { - Ok(x) => Ok(x), - Err(e) => { - error!("handle_setup_message error: {e:?}"); - Err(e) + // TODO handle error + loop { + match encrypted_rx.pop_front() { + None => { + debug!( + " + num_encrypted_rx = {} + num_encrypted_tx = {} +no more encrp incoming", + encrypted_rx.len(), + encrypted_tx.len(), + ); + break; } - } { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); - encrypted_tx.push_front(msg); + Some(Err(e)) => { + error!( + num_encrypted_rx = 0, + num_encrypted_tx = encrypted_tx.len(), + "{e:?}" + ); + break; + } + Some(Ok(incoming_msg)) => { + info!(initiator = %is_initiator, "recieved setup msg"); + if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + info!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); + encrypted_tx.push_front(msg); + } + } } } } } fn poll_encrypted_side_io< - IO: Stream> + Sink> + Send + Unpin + Debug + 'static, + IO: Stream>> + Sink> + Send + Unpin + 'static, >( io: &mut IO, cx: &mut Context<'_>, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, is_initiator: bool, flush: &mut bool, ) { @@ -287,18 +331,20 @@ fn poll_do_encrypt_and_decrypt( encryptor: &mut RawEncryptCipher, decryptor: &mut DecryptCipher, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, plain_tx: &mut VecDeque>, - plain_rx: &mut VecDeque>, + plain_rx: &mut VecDeque>>, is_initiator: bool, flush: &mut bool, ) { // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { + // TODO handle error + while let Some(Ok(incoming_msg)) = encrypted_rx.pop_front() { + info!(initiator = %is_initiator, "enc rx decrypting\n{incoming_msg:?}"); match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(plain_msg); + info!(initiator = %is_initiator, "plain rx queue"); + plain_rx.push_back(Ok(plain_msg)); } Err(e) => { error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") From f2cd806b0a849c0b9ade8347cb0159db7c7d6aec Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 16 Mar 2025 16:13:08 -0400 Subject: [PATCH 027/135] Fix impl of Sink fro Framing poll_flush fixes the issue of the messages not being sent --- src/framing.rs | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 64576f3..32ac6d8 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -132,32 +132,40 @@ where Poll::Ready(Ok(())) } + #[instrument(skip_all)] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { self.from_sink.push_back(wrap_uint24_le(&item)); Ok(()) } + #[instrument(skip_all)] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { let Self { from_sink, io, .. } = self.get_mut(); - if let Some(msg) = from_sink.pop_front() { - match Pin::new(io).poll_write(cx, &msg) { - Poll::Pending => { - from_sink.push_front(msg); - return Poll::Pending; - } - Poll::Ready(Ok(n)) => { - if n != msg.len() { - from_sink.push_front(msg[n..].to_vec()); - return Poll::Ready(Ok(())); + loop { + if let Some(msg) = from_sink.pop_front() { + match Pin::new(&mut *io).poll_write(cx, &msg) { + Poll::Pending => { + from_sink.push_front(msg); + debug!("AsyncWrite busy, could not flush"); + return Poll::Pending; } + Poll::Ready(Ok(n)) => { + if n != msg.len() { + from_sink.push_front(msg[n..].to_vec()); + warn!("only wrote [{n} / {}]", msg.len()); + } + debug!("flushed whole message of N=[{n}] bytes"); + } + Poll::Ready(Err(_e)) => todo!(), } - Poll::Ready(Err(_e)) => todo!(), + } else { + debug!("No messages in self.from_sink. Flush done"); + return Poll::Ready(Ok(())); } } - Poll::Ready(Ok(())) } fn poll_close( From 08056577e27f6a9727fe24e7185b99b85b9fcca3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 16 Mar 2025 23:46:00 -0400 Subject: [PATCH 028/135] Add docs handle todos --- src/framing.rs | 96 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 32ac6d8..66ce84f 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -7,24 +7,28 @@ use std::{ }; use futures::{Sink, Stream}; + use futures_lite::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, instrument, trace}; +use tracing::{debug, error, info, instrument, trace, warn}; use crate::util::{stat_uint24_le, wrap_uint24_le}; -const BUF_SIZE: usize = 1024 * 8; +const BUF_SIZE: usize = 1024 * 64; const HEADER_LEN: usize = 3; /// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream pub struct LengthPrefixed { io: IO, + /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. to_stream: Vec, + /// Data from the `Sink` interface to be written out to [`Self::io`]'s [`AsyncWrite`] interface. from_sink: VecDeque>, - /// The index in [`Self::buf`] of the last byte that was to the [`Stream`]. + /// The index in [`Self::to_stream`] of the last byte that was to the [`Stream`]. last_out_idx: usize, - /// The index in [`Self::buf`] of the last byte that was read from [`Self::io`] via + /// The index in [`Self::to_stream`] of the last byte that was read from [`Self::io`]'s /// [`AsyncRead`] last_data_idx: usize, + /// Current step of a message being parsed step: Step, } impl Debug for LengthPrefixed { @@ -71,30 +75,36 @@ where step, .. } = self.get_mut(); + debug!( + "Try to AsyncRead up to (buff_size[{}] - last_data_idx[{}]) = [{}]", + to_stream.len(), + *last_data_idx, + to_stream.len() - *last_data_idx + ); let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { Poll::Ready(Ok(n)) => n, Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - Poll::Pending => 0, + Poll::Pending => { + cx.waker().wake_by_ref(); + 0 + } }; // TODO handle if to_stream is full - trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + debug!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); *last_data_idx += n_bytes_read; // grow buffer if it's full if *last_data_idx == to_stream.len() - 1 { + warn!("We filled our buffer!"); to_stream.extend(vec![0; to_stream.len() * 2]); } if let Step::Header = step { trace!(step = ?*step, "enter"); - if *last_data_idx - *last_out_idx < HEADER_LEN { + let cur_data = &to_stream[*last_out_idx..*last_data_idx]; + + let Some((header_len, body_len)) = stat_uint24_le(cur_data) else { trace!("not enough bytes to read header"); return Poll::Pending; - } - let Some((header_len, body_len)) = - stat_uint24_le(&to_stream[*last_out_idx..(*last_out_idx + HEADER_LEN)]) - else { - // we check above the there is room for header so this should never happen - todo!() }; let cur_frame_start = *last_out_idx + header_len; @@ -105,6 +115,7 @@ where }; } + info!(step = ?*step, "enter"); if let Step::Body { start, end } = step { let end = *end as usize; if end <= *last_data_idx { @@ -149,20 +160,22 @@ where match Pin::new(&mut *io).poll_write(cx, &msg) { Poll::Pending => { from_sink.push_front(msg); - debug!("AsyncWrite busy, could not flush"); return Poll::Pending; } Poll::Ready(Ok(n)) => { if n != msg.len() { from_sink.push_front(msg[n..].to_vec()); - warn!("only wrote [{n} / {}]", msg.len()); + warn!("only wrote [{n} / {}] bytes of message", msg.len()); } debug!("flushed whole message of N=[{n}] bytes"); } - Poll::Ready(Err(_e)) => todo!(), + Poll::Ready(Err(e)) => { + error!("Error flushing data"); + return Poll::Ready(Err(e)); + } } } else { - debug!("No messages in self.from_sink. Flush done"); + debug!("No more messages to flush"); return Poll::Ready(Ok(())); } } @@ -170,9 +183,10 @@ where fn poll_close( self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { - todo!() + let Self { io, .. } = self.get_mut(); + Pin::new(&mut *io).poll_close(cx) } } #[cfg(test)] @@ -253,4 +267,48 @@ pub(crate) mod test { assert_eq!(result, expected); Ok(()) } + + #[tokio::test] + async fn left_and_right() -> Result<()> { + let (left, right) = duplex(64); + + let mut leftlp = LengthPrefixed::new(left); + let mut rightlp = LengthPrefixed::new(right); + + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + } + + let mut result1 = vec![]; + for _ in data { + result1.push(leftlp.next().await.unwrap().unwrap()); + } + assert_eq!(result1, data); + + for d in data { + leftlp.send(d.to_vec()).await.unwrap(); + } + let mut result2 = vec![]; + for _ in data { + result2.push(rightlp.next().await.unwrap().unwrap()); + } + assert_eq!(result2, data); + + let mut r3 = vec![]; + let mut r4 = vec![]; + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + leftlp.send(d.to_vec()).await.unwrap(); + } + + for _ in data { + r3.push(rightlp.next().await.unwrap().unwrap()); + r4.push(leftlp.next().await.unwrap().unwrap()); + } + assert_eq!(r3, data); + assert_eq!(r4, data); + + Ok(()) + } } From b9c5482f69056a7ce0d2f17ed8f404457375d295 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 00:48:48 -0400 Subject: [PATCH 029/135] Add encryption_established more tests better logs --- src/noise.rs | 184 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 119 insertions(+), 65 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index dd687ec..15525de 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -78,6 +78,11 @@ where flush: false, } } + + /// Wether an encrypted connection has been established. + pub fn encryption_established(&self) -> bool { + matches!(self.step, Step::Established(_)) + } } impl< @@ -214,29 +219,17 @@ fn poll_setup( // Still setting up if let Ok(Some(msg)) = maybe_init(step, is_initiator) { // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + info!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } // TODO handle error loop { match encrypted_rx.pop_front() { None => { - debug!( - " - num_encrypted_rx = {} - num_encrypted_tx = {} -no more encrp incoming", - encrypted_rx.len(), - encrypted_tx.len(), - ); break; } Some(Err(e)) => { - error!( - num_encrypted_rx = 0, - num_encrypted_tx = encrypted_tx.len(), - "{e:?}" - ); + error!("Recieved an error during setup encryption setup: {e:?}"); break; } Some(Ok(incoming_msg)) => { @@ -258,6 +251,7 @@ no more encrp incoming", } } +#[instrument(skip_all, fields(initiator = %is_initiator))] fn poll_encrypted_side_io< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -271,33 +265,25 @@ fn poll_encrypted_side_io< // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); + info!(initiator = %is_initiator, msg_len = encrypted_out.len(), "enc tx send msg\n{encrypted_out:?}"); + if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { + error!("Error polling encyrpted side io") + } + *flush = true; } else { break; } } if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!(initiator = %is_initiator, "flushed good"); + info!(initiator = %is_initiator, "flushed good"); + } + Poll::Ready(Err(_e)) => { + error!(initiator = %is_initiator, "Error sending encrypted msg") } - Poll::Ready(Err(_e)) => error!( - initiator = %is_initiator, - "Error sending encrypted msg"), Poll::Pending => { // More confusing docs // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush @@ -308,6 +294,7 @@ fn poll_encrypted_side_io< // Does this mean, each time this task wakes up again from this code path that // I must trigger another poll_flush? But how would I know i need more // flushing? + debug!("flush not completed"); *flush = true; } } @@ -327,6 +314,7 @@ fn poll_encrypted_side_io< /// Process messages waiting to be encrypted or decrypted // TODO sholud this return a Result +#[instrument(skip_all)] fn poll_do_encrypt_and_decrypt( encryptor: &mut RawEncryptCipher, decryptor: &mut DecryptCipher, @@ -375,7 +363,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } -#[instrument(skip_all, fields(is_initiator = %is_initiator))] +#[instrument(skip_all, fields(initiator = %is_initiator))] fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { match &step { Step::NotInitialized => { @@ -432,7 +420,6 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu Ok(out) } Step::SecretStream(_) => { - info!("E're a secret stream now!!!!!"); if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; @@ -447,58 +434,125 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu #[cfg(test)] mod tset { - use crate::test_utils::create_connected; + + use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; use super::*; - use futures::{SinkExt, StreamExt}; + use futures::{future::join, SinkExt, StreamExt}; #[tokio::test] - async fn steps() -> Result<()> { - let mut left_hs = Handshake::new(true)?; - let s1 = left_hs.start_raw()?.unwrap(); + async fn encrypted() -> Result<()> { + let hello = b"hello".to_vec(); + let world = b"world".to_vec(); + let (lc, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); - println!("s1 {s1:?}"); - let mut right_hs = Handshake::new(false)?; + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(receieved.unwrap()?, hello); - let s2 = right_hs.read_raw(&s1)?.unwrap(); - println!("s2 {s2:?}"); + assert!(left.encryption_established()); + assert!(right.encryption_established()); - let s3 = left_hs.read_raw(&s2)?.unwrap(); - println!("s3 {s3:?}"); + // NB: we cannot totally finish 'left.send' until the other side becomes active + // because the handshake with the other side ('right') must complete + // before the 'hello' message is sent. So we poll both the send and receive concurrently. + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + // right recieves left's message + assert_eq!(receieved.unwrap()?, hello); - let s4 = right_hs.read_raw(&s3)?; + // now that the encrypted channel is established, we don't need to spawn. + right.send(world.clone()).await.unwrap(); - println!("s4 {s4:?}"); - // both sides now ready + // left recieves right's message + assert_eq!(left.next().await.unwrap()?, world); + Ok(()) + } + #[tokio::test] + async fn encrypted_many() -> Result<()> { + let hello = b"hello".to_vec(); + let data = vec![ + b"yolo".to_vec(), + b"squalor".to_vec(), + b"idleness".to_vec(), + b"hello".to_vec(), + b"stuff".to_vec(), + ]; + let (lc, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(receieved.unwrap()?, hello); + + for d in &data { + right.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(left.next().await.unwrap()?); + } + assert_eq!(result, data); Ok(()) } #[tokio::test] - async fn test_encrypted() -> Result<()> { - let hello = b"hello"; - let world = b"world"; - let (left, right) = create_connected(); + async fn with_framing() -> Result<()> { + crate::test_utils::log(); + let hello = b"hello".to_vec(); + + let (left, right) = duplex(1024 * 64); + let left = LengthPrefixed::new(left); + let right = LengthPrefixed::new(right); + let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); - // NB: we cannot totally finish 'left.send' until the other side becomes active - // this is because the handshake with the other side ('right') must complete - // before the message is sent. So we must spawn here, so we can proceed to run 'right' - let left_handle = tokio::task::spawn(async move { - left.send(hello.into()).await.unwrap(); - left - }); + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(receieved.unwrap()?, hello); - // right recieves left's message - assert_eq!(right.next().await.unwrap(), hello); + let data = vec![ + b"yolo".to_vec(), + b"squalor".to_vec(), + b"idleness".to_vec(), + b"hello".to_vec(), + b"stuff".to_vec(), + ]; - let mut left = left_handle.await?; + // send right to left + for d in &data { + right.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(left.next().await.unwrap()?); + } + assert_eq!(result, data); - // now that the encrypted channel is established, we don't need to spawn. - right.send(world.into()).await.unwrap(); + // send left to right + for d in &data { + left.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(right.next().await.unwrap()?); + } + assert_eq!(result, data); + + // send both ways + for d in &data { + left.send(d.to_vec()).await?; + right.send(d.to_vec()).await?; + } + let mut left_result = vec![]; + let mut right_result = vec![]; + for _ in &data { + right_result.push(right.next().await.unwrap()?); + left_result.push(left.next().await.unwrap()?); + } + assert_eq!(right_result, data); + assert_eq!(left_result, data); - // left recieves right's message - assert_eq!(left.next().await.unwrap(), world); Ok(()) } } From e547a1f7de2437546231fc0a673308d5ff7e6e05 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 12:18:28 -0400 Subject: [PATCH 030/135] logs --- src/noise.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/noise.rs b/src/noise.rs index 15525de..a846b28 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -346,7 +346,7 @@ fn poll_do_encrypt_and_decrypt( Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue\n{enc_out:?}"); encrypted_tx.push_back(enc_out); *flush = true; } From 83dbe3a77cb0cd2d5b4fe2cee8a7beb74df4222d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 14:24:05 -0400 Subject: [PATCH 031/135] Add framing buffer rotation --- src/framing.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 66ce84f..38b0f77 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -84,18 +84,15 @@ where let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { Poll::Ready(Ok(n)) => n, Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - Poll::Pending => { - cx.waker().wake_by_ref(); - 0 - } + Poll::Pending => 0, }; // TODO handle if to_stream is full debug!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); *last_data_idx += n_bytes_read; // grow buffer if it's full if *last_data_idx == to_stream.len() - 1 { - warn!("We filled our buffer!"); - to_stream.extend(vec![0; to_stream.len() * 2]); + warn!("Buffer full, double it's size"); + to_stream.extend(vec![0; to_stream.len()]); } if let Step::Header = step { @@ -122,14 +119,18 @@ where debug!(frame_size = end - *start, "Frame ready"); let out = to_stream[*start..end].to_vec(); *step = Step::Header; - *last_out_idx = end; + // remove bytes we're done with + to_stream.rotate_left(end); + *last_data_idx -= end; + *last_out_idx = 0; return Poll::Ready(Some(Ok(out))); } } Poll::Pending } } + impl Sink> for LengthPrefixed where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, From a585aefad5705d79794818fbd94689d432083efd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 15:34:42 -0400 Subject: [PATCH 032/135] bump futures to non-yanked version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ff89935..170c32f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ futures-lite = "1" sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" -futures = "0.3.13" +futures = "0.3.31" [dependencies.hypercore] version = "0.14.0" From 31e628134b44433c0855a6296b5562c5998c8486 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 17:55:12 -0400 Subject: [PATCH 033/135] handle setup errors and add test --- src/noise.rs | 202 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 159 insertions(+), 43 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index a846b28..5126576 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -18,21 +18,6 @@ pub(crate) enum Step { SecretStream((RawEncryptCipher, HandshakeResult)), Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } -impl std::fmt::Display for Step { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Step::NotInitialized => "NotInitialized", - Step::Handshake(_) => "Handshake", - Step::SecretStream(_) => "SecretStream", - Step::Established(_) => "Established", - } - ) - } -} - /// Wrap a stream with encryption pub struct Encrypted { io: IO, @@ -45,21 +30,6 @@ pub struct Encrypted { flush: bool, } -impl std::fmt::Debug for Encrypted { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Encrypted") - //.field("io", &self.io) - .field("step", &self.step) - .field("is_initiator", &self.is_initiator) - //.field("encrypted_tx", &self.encrypted_tx) - .field("encrypted_rx", &self.encrypted_rx) - .field("plain_tx", &self.plain_tx) - .field("plain_rx", &self.plain_rx) - .field("flush", &self.flush) - .finish() - } -} - impl Encrypted where IO: Stream>> + Sink> + Send + Unpin + 'static, @@ -78,7 +48,6 @@ where flush: false, } } - /// Wether an encrypted connection has been established. pub fn encryption_established(&self) -> bool { matches!(self.step, Step::Established(_)) @@ -147,7 +116,7 @@ impl< Poll::Ready(Ok(())) } } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); cx.waker().wake_by_ref(); Poll::Pending } @@ -202,7 +171,7 @@ impl>> + Sink> + Send + Unpin + 'static Poll::Pending } } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); cx.waker().wake_by_ref(); Poll::Pending } @@ -215,7 +184,12 @@ fn poll_setup( encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, is_initiator: bool, + flush: &mut bool, ) { + // if we get an error, it could be because the other side reset, and is sending a new + // initialization message. + // If this is the case, we should retry this message after the error. + // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. // Still setting up if let Ok(Some(msg)) = maybe_init(step, is_initiator) { // queue the init message to send first @@ -234,7 +208,14 @@ fn poll_setup( } Some(Ok(incoming_msg)) => { info!(initiator = %is_initiator, "recieved setup msg"); - if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { + if let Ok(msgs) = match handle_setup_message( + step, + &incoming_msg, + is_initiator, + encrypted_tx, + encrypted_rx, + flush, + ) { Ok(x) => Ok(x), Err(e) => { error!("handle_setup_message error: {e:?}"); @@ -252,6 +233,7 @@ fn poll_setup( } #[instrument(skip_all, fields(initiator = %is_initiator))] +/// Fills `encrypted_rx` and drains `encrypted_tx`. fn poll_encrypted_side_io< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -363,25 +345,64 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } +fn reset_encrypted( + step: &mut Step, + maybe_init_message: Option>, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + flush: &mut bool, +) { + *step = Step::NotInitialized; + encrypted_tx.clear(); + encrypted_rx.clear(); + if let Some(msg) = maybe_init_message { + encrypted_rx.push_front(Ok(msg)); + } + *flush = false; +} + +/// handle setup messages: if any are incorrect (cause an error) the state is reset #[instrument(skip_all, fields(initiator = %is_initiator))] -fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { +fn handle_setup_message( + step: &mut Step, + msg: &[u8], + is_initiator: bool, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + flush: &mut bool, +) -> Result>> { + // this would only happen after reset with a bad message. + let mut first_message = false; + if let Step::NotInitialized = step { + first_message = true; + assert!(!is_initiator); + warn!(initiator = %is_initiator, "Encrypted state was reset"); + let mut handshake = Handshake::new(is_initiator)?; + let _ = handshake.start_raw()?; + *step = Step::Handshake(Box::new(handshake)); + } match &step { Step::NotInitialized => { - warn!(initiator = %is_initiator, "Encrypted state was reset"); - let mut handshake = Handshake::new(is_initiator)?; - let start_msg = handshake.start_raw()?; - *step = Step::Handshake(Box::new(handshake)); - debug!(initiator = %is_initiator, "Step changed to {step}"); - - Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) + unreachable!("should not happen") } Step::Handshake(_) => { + dbg!(); let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { trace!("Read in handshake msg\n{msg:?}"); if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { + let maybe_init_message = + (!first_message && !is_initiator).then_some(msg.to_vec()); + + reset_encrypted( + step, + maybe_init_message, + encrypted_tx, + encrypted_rx, + flush, + ); return Err(e); } } { @@ -432,13 +453,46 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu } } +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + } + ) + } +} + +impl std::fmt::Debug for Encrypted { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Encrypted") + //.field("io", &self.io) + //.field("step", &self.step) + .field("is_initiator", &self.is_initiator) + .field("encrypted_tx", &self.encrypted_tx) + .field("encrypted_rx", &self.encrypted_rx) + .field("plain_tx", &self.plain_tx) + .field("plain_rx", &self.plain_rx) + .field("flush", &self.flush) + .finish() + } +} + #[cfg(test)] mod tset { use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; use super::*; - use futures::{future::join, SinkExt, StreamExt}; + use futures::{ + future::{join, select, Either}, + SinkExt, StreamExt, + }; #[tokio::test] async fn encrypted() -> Result<()> { @@ -555,4 +609,66 @@ mod tset { Ok(()) } + + #[tokio::test] + async fn test_setup_error_causes_re_init() -> Result<()> { + let (lc, mut init_side_messages) = create_result_connected(); + let (mut other_side_messages, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + let hello = b"hello".to_vec(); + + let send_fut = tokio::task::spawn(async move { + left.send(hello).await.unwrap(); + left + }); + + let init_msg = init_side_messages.next().await.unwrap()?; + + other_side_messages.send(init_msg).await?; + // other side encrypted needs to be polled to do work and send a response + let other_send_fut = tokio::task::spawn(async move { + right.send(b"other hello".to_vec()).await.unwrap(); + right + }); + + let _first_response = other_side_messages.next().await.unwrap()?; + // both sides now have a handshake in progress + + // send a bad message to init side. It should reset, and emit new init msg + init_side_messages.send(b"bad msg".to_vec()).await?; + let new_init_msg = init_side_messages.next().await.unwrap()?; + + other_side_messages.send(new_init_msg).await?; + let new_response = other_side_messages.next().await.unwrap()?; + init_side_messages.send(new_response).await?; + let final_setup_message = init_side_messages.next().await.unwrap()?; + other_side_messages.send(final_setup_message).await?; + + // exchange one more message then we're set up + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + // now our spawned sends can complete + let mut left = send_fut.await?; + let mut right = other_send_fut.await?; + + // exchange hellos + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + + assert!(left.encryption_established()); + assert!(right.encryption_established()); + assert_eq!(right.next().await.unwrap()?, b"hello"); + assert_eq!(left.next().await.unwrap()?, b"other hello"); + + Ok(()) + } } From 8a2370bae2709cb35732344117f9ec410b8e2908 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:35:18 -0400 Subject: [PATCH 034/135] RMME show logs in example --- examples-nodejs/run.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples-nodejs/run.js b/examples-nodejs/run.js index c96541f..ac77bba 100644 --- a/examples-nodejs/run.js +++ b/examples-nodejs/run.js @@ -37,7 +37,8 @@ function startRust (mode, key, color, name) { color: color || 'blue', env: { ...process.env, - RUST_LOG_STYLE: 'always' + RUST_LOG_STYLE: 'always', + RUST_LOG: 'trace' } }) return rust From 918a307c75d6e773bb22c260c990f0866b9e965f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:37:45 -0400 Subject: [PATCH 035/135] RMME extra docs --- src/crypto/cipher.rs | 3 +++ src/noise.rs | 2 +- src/protocol.rs | 9 +++++++++ src/reader.rs | 5 +++++ src/writer.rs | 19 +++++++++++++++++++ 5 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 4dfaf19..f2dc9b9 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -88,6 +88,7 @@ impl DecryptCipher { let decrypted_end = header_len + to_decrypt.len(); buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); // Set extra bytes in the buffer to 0 + // Why? let encrypted_end = header_len + body_len; buf[decrypted_end..encrypted_end].fill(0x00); Ok(decrypted_end) @@ -136,6 +137,8 @@ impl EncryptCipher { /// Encrypts message in the given buffer to the same buffer, returns number of bytes /// of total message. + /// NB: we expect the first 3 bytes of the buffer to a size header. + /// The encrypted buffer will also be written prepended with a size header, with it's new size. pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { let stat = stat_uint24_le(buf); if let Some((header_len, body_len)) = stat { diff --git a/src/noise.rs b/src/noise.rs index 5126576..a7bd306 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -295,7 +295,7 @@ fn poll_encrypted_side_io< } /// Process messages waiting to be encrypted or decrypted -// TODO sholud this return a Result +// TODO sholud this return a Result? #[instrument(skip_all)] fn poll_do_encrypt_and_decrypt( encryptor: &mut RawEncryptCipher, diff --git a/src/protocol.rs b/src/protocol.rs index 1f24b1a..673b307 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -301,6 +301,7 @@ where let mut handshake = Handshake::new(self.options.is_initiator)?; // If the handshake start returns a buffer, send it now. if let Some(buf) = handshake.start()? { + // TODO what if this fails? or returns false self.queue_frame_direct(buf.to_vec()).unwrap(); } self.read_state.set_frame_type(FrameType::Raw); @@ -375,6 +376,7 @@ where if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { return Err(e); } + // if no parking or setup in progress if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { return Ok(()); } @@ -406,11 +408,17 @@ where State::SecretStream(_) => self.on_secret_stream_message(buf)?, State::Established => { if let Some(processed_state) = processed_state.as_ref() { + // last state before established let previous_state = if self.options.encrypted { + // was SecretStream if we're encrypted State::SecretStream(None) } else { + // or wa hasdshake if we're not encrypted State::Handshake(None) }; + + // if htis raw_batch included regular messages (not handshake) + // after handshake stuff if processed_state == &format!("{previous_state:?}") { // This is the unlucky case where the batch had two or more messages where // the first one was correctly identified as Raw but everything @@ -591,6 +599,7 @@ where self.queued_events.push_back(event); } + /// enequeu a buf to be sent fn queue_frame_direct(&mut self, body: Vec) -> Result { let mut frame = Frame::RawBatch(vec![body]); self.write_state.try_encode_frame_for_tx(&mut frame) diff --git a/src/reader.rs b/src/reader.rs index 51b370b..5664d56 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -113,6 +113,9 @@ impl ReadState { if success { if let Some(ref mut cipher) = self.cipher { let mut dec_end = self.start; + // What happens if decrypt fails here? + // next call to this func would have same start, corret? + // so it'd fail repeatedly? for (index, header_len, body_len) in segments { let de = cipher.decrypt( &mut self.buf[self.start + index..end], @@ -137,6 +140,7 @@ impl ReadState { } } + /// Moves start of unprocessed data to the start of the buffer. And resize if necessary. fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { let (last_index, last_header_len, last_body_len) = last_segment; let total_incoming_length = last_index + last_header_len + last_body_len; @@ -207,6 +211,7 @@ impl ReadState { } #[allow(clippy::type_complexity)] +// get segments from buff fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { let mut index: usize = 0; let len = buf.len(); diff --git a/src/writer.rs b/src/writer.rs index 38d6dcf..56bbaf6 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,6 +9,10 @@ use std::pin::Pin; use std::task::{Context, Poll}; const BUF_SIZE: usize = 1024 * 64; +// This is the largest size that will fit in u24. +// a message is larger than this we should error. +// also check message is smaller than this when we are encrypting. +const _MAX_MSG_SIZE: usize = 2usize.pow(24) - 1; #[derive(Debug)] pub(crate) enum Step { @@ -64,9 +68,12 @@ impl WriteState { pub(crate) fn try_encode_frame_for_tx(&mut self, frame: &mut T) -> Result { let promised_len = frame.encoded_len()?; let padded_promised_len = self.safe_encrypted_len(promised_len); + // this handles when a message would be longer than the entire buffer if self.buf.len() < padded_promised_len { self.buf.resize(padded_promised_len, 0u8); } + + // check we have enough room if padded_promised_len > self.remaining() { return Ok(false); } @@ -78,6 +85,14 @@ impl WriteState { "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" ); } + // Instead of the above, write the buffer to a new vec `foo` of length `promised_length` + // encode frame.to this buff + // slice `foo[(header_len /* 3*/)..actual_len]` this is the fram data + // encrypt this in place + // replace header at start of foo + // write its len to self.buf and then write it to self.buf + // slice from + self.encrypt_frame_contents(padded_promised_len)?; Ok(true) } @@ -95,6 +110,10 @@ impl WriteState { } } + /// The frame should be written to `self.buf` before calling this. And + /// `self.should_write_up_to_idx` should mark the start of the message. + /// `max_message_size` is the maximum size the message could be when it is encrypted + /// We encrypt the message in-place on `self.buf`. fn encrypt_frame_contents(&mut self, max_message_size: usize) -> Result<()> { let end_of_message_index = self.should_write_up_to_idx + max_message_size; From ecc499a14ca979204748d04779f47542567f422a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:38:03 -0400 Subject: [PATCH 036/135] add const header len --- src/framing.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/framing.rs b/src/framing.rs index 38b0f77..2e8cf32 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -14,7 +14,7 @@ use tracing::{debug, error, info, instrument, trace, warn}; use crate::util::{stat_uint24_le, wrap_uint24_le}; const BUF_SIZE: usize = 1024 * 64; -const HEADER_LEN: usize = 3; +const _HEADER_LEN: usize = 3; /// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream pub struct LengthPrefixed { From 8a96462b73174b5578a7b0dc766c809c72f5565e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:38:29 -0400 Subject: [PATCH 037/135] lint --- src/noise.rs | 5 +---- src/protocol.rs | 2 +- src/test_utils.rs | 10 ++-------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index a7bd306..cf800f3 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -489,10 +489,7 @@ mod tset { use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; use super::*; - use futures::{ - future::{join, select, Either}, - SinkExt, StreamExt, - }; + use futures::{future::join, SinkExt, StreamExt}; #[tokio::test] async fn encrypted() -> Result<()> { diff --git a/src/protocol.rs b/src/protocol.rs index 673b307..89f3df1 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -10,7 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::{info, trace}; +use tracing::trace; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; diff --git a/src/test_utils.rs b/src/test_utils.rs index d35af0e..ff1a3c2 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -7,9 +7,7 @@ use std::{ //use async_channel::{unbounded, Receiver, io::Error, Sender}; use futures::{ - channel::mpsc::{ - unbounded, SendError, UnboundedReceiver as Receiver, UnboundedSender as Sender, - }, + channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender}, Sink, SinkExt, Stream, StreamExt, }; @@ -76,10 +74,6 @@ impl TwoWay { } } -pub(crate) fn create_connected() -> (Io, Io) { - TwoWay::default().split_sides() -} - pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: OnceLock<()> = OnceLock::new(); @@ -123,7 +117,7 @@ pub(crate) struct Moo { impl + Unpin, Tx: Unpin> Stream for Moo { type Item = RxItem; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Pin::new(&mut this.receiver).poll_next(cx) } From a85e42ea35a2a88073dff9dcf87b6b42e0912261 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:49:30 -0400 Subject: [PATCH 038/135] helpful names --- src/protocol.rs | 3 ++- src/writer.rs | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 89f3df1..930f9bd 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -602,7 +602,8 @@ where /// enequeu a buf to be sent fn queue_frame_direct(&mut self, body: Vec) -> Result { let mut frame = Frame::RawBatch(vec![body]); - self.write_state.try_encode_frame_for_tx(&mut frame) + self.write_state + .try_encode_and_enqueue_frame_for_tx(&mut frame) } fn accept_channel(&mut self, local_id: usize) -> Result<()> { diff --git a/src/writer.rs b/src/writer.rs index 56bbaf6..d91adfb 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -65,7 +65,10 @@ impl WriteState { self.queue.push_back(frame.into()) } - pub(crate) fn try_encode_frame_for_tx(&mut self, frame: &mut T) -> Result { + pub(crate) fn try_encode_and_enqueue_frame_for_tx( + &mut self, + frame: &mut T, + ) -> Result { let promised_len = frame.encoded_len()?; let padded_promised_len = self.safe_encrypted_len(promised_len); // this handles when a message would be longer than the entire buffer @@ -93,7 +96,7 @@ impl WriteState { // write its len to self.buf and then write it to self.buf // slice from - self.encrypt_frame_contents(padded_promised_len)?; + self.encrypt_frame_contents_onto_buf(padded_promised_len)?; Ok(true) } @@ -114,7 +117,7 @@ impl WriteState { /// `self.should_write_up_to_idx` should mark the start of the message. /// `max_message_size` is the maximum size the message could be when it is encrypted /// We encrypt the message in-place on `self.buf`. - fn encrypt_frame_contents(&mut self, max_message_size: usize) -> Result<()> { + fn encrypt_frame_contents_onto_buf(&mut self, max_message_size: usize) -> Result<()> { let end_of_message_index = self.should_write_up_to_idx + max_message_size; let encrypted_end = if let Some(ref mut cipher) = self.cipher { @@ -157,7 +160,7 @@ impl WriteState { } if let Some(mut frame) = self.current_frame.take() { - if !self.try_encode_frame_for_tx(&mut frame)? { + if !self.try_encode_and_enqueue_frame_for_tx(&mut frame)? { self.current_frame = Some(frame); } } From 3e1483a21f40cd70030cebf1743293f9c8af4512 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:53:02 -0400 Subject: [PATCH 039/135] rename framing struct --- src/framing.rs | 20 ++++++++++---------- src/lib.rs | 2 +- src/noise.rs | 8 +++++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 2e8cf32..e7c9c5c 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -17,7 +17,7 @@ const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; /// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream -pub struct LengthPrefixed { +pub struct Uint24LELengthPrefixedFraming { io: IO, /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. to_stream: Vec, @@ -31,12 +31,12 @@ pub struct LengthPrefixed { /// Current step of a message being parsed step: Step, } -impl Debug for LengthPrefixed { +impl Debug for Uint24LELengthPrefixedFraming { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Format()") } } -impl LengthPrefixed +impl Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { @@ -59,7 +59,7 @@ enum Step { Body { start: usize, end: u64 }, } -impl Stream for LengthPrefixed +impl Stream for Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { @@ -131,7 +131,7 @@ where } } -impl Sink> for LengthPrefixed +impl Sink> for Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { @@ -220,7 +220,7 @@ pub(crate) mod test { async fn input() -> Result<()> { log(); let (left, mut right) = duplex(64); - let mut lp = LengthPrefixed::new(left); + let mut lp = Uint24LELengthPrefixedFraming::new(left); let input = b"yelp"; let msg = wrap_uint24_le(input); dbg!(&msg); @@ -235,7 +235,7 @@ pub(crate) mod test { async fn stream_many() -> Result<()> { log(); let (left, mut right) = duplex(64); - let mut lp = LengthPrefixed::new(left); + let mut lp = Uint24LELengthPrefixedFraming::new(left); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; for d in data { let msg = wrap_uint24_le(d); @@ -255,7 +255,7 @@ pub(crate) mod test { async fn sink_many() -> Result<()> { log(); let (left, mut right) = duplex(64); - let mut lp = LengthPrefixed::new(left); + let mut lp = Uint24LELengthPrefixedFraming::new(left); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; for d in data { lp.send(d.to_vec()).await.unwrap(); @@ -273,8 +273,8 @@ pub(crate) mod test { async fn left_and_right() -> Result<()> { let (left, right) = duplex(64); - let mut leftlp = LengthPrefixed::new(left); - let mut rightlp = LengthPrefixed::new(right); + let mut leftlp = Uint24LELengthPrefixedFraming::new(left); + let mut rightlp = Uint24LELengthPrefixedFraming::new(right); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; for d in data { diff --git a/src/lib.rs b/src/lib.rs index b1a043a..07a677b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,7 +137,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; -pub use framing::LengthPrefixed; +pub use framing::Uint24LELengthPrefixedFraming; pub use noise::Encrypted; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ diff --git a/src/noise.rs b/src/noise.rs index cf800f3..cf651ff 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -486,7 +486,9 @@ impl std::fmt::Debug for Encrypted { #[cfg(test)] mod tset { - use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; + use crate::{ + framing::test::duplex, test_utils::create_result_connected, Uint24LELengthPrefixedFraming, + }; use super::*; use futures::{future::join, SinkExt, StreamExt}; @@ -553,8 +555,8 @@ mod tset { let hello = b"hello".to_vec(); let (left, right) = duplex(1024 * 64); - let left = LengthPrefixed::new(left); - let right = LengthPrefixed::new(right); + let left = Uint24LELengthPrefixedFraming::new(left); + let right = Uint24LELengthPrefixedFraming::new(right); let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); From 2a602911489e65e1a316950328a43359ffab0ac3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:03:13 -0400 Subject: [PATCH 040/135] Add func for building encrypted framed channel --- src/lib.rs | 2 +- src/noise.rs | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 07a677b..e4c0744 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,7 +138,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::Encrypted; +pub use noise::{encyrpted_framed_message_channel, Encrypted}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/noise.rs b/src/noise.rs index cf651ff..c2269ea 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -1,4 +1,4 @@ -use futures::{Sink, Stream}; +use futures::{AsyncRead, AsyncWrite, Sink, Stream}; use std::{ collections::VecDeque, fmt::Debug, @@ -9,8 +9,19 @@ use std::{ }; use tracing::{debug, error, info, instrument, trace, warn}; -use crate::crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; +use crate::{ + crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}, + Uint24LELengthPrefixedFraming, +}; +/// Create a framed and encrypted Stream/Sink that reads/writes to an AsyncRead/AsyncWrite. +pub fn encyrpted_framed_message_channel( + is_initiator: bool, + io: IO, +) -> Encrypted> { + let framed = Uint24LELengthPrefixedFraming::new(io); + Encrypted::new(is_initiator, framed) +} #[derive(Debug)] pub(crate) enum Step { NotInitialized, From 41fd39e76d54a53e340821650fa0f0f2b80eb62c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:24:52 -0400 Subject: [PATCH 041/135] Move Protocol in prep for feature flagging --- src/protocol/mod.rs | 4 ++++ src/{protocol.rs => protocol/old.rs} | 0 2 files changed, 4 insertions(+) create mode 100644 src/protocol/mod.rs rename src/{protocol.rs => protocol/old.rs} (100%) diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..18be509 --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,4 @@ +mod old; + +pub(crate) use old::Options; +pub use old::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; diff --git a/src/protocol.rs b/src/protocol/old.rs similarity index 100% rename from src/protocol.rs rename to src/protocol/old.rs From 8fca159fd9b5983dd678f3eb77a7a490e6b6ae9f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:47:10 -0400 Subject: [PATCH 042/135] add second protocol impl behind feature flag --- Cargo.toml | 1 + src/protocol/mod.rs | 11 +- src/protocol/modern.rs | 697 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 708 insertions(+), 1 deletion(-) create mode 100644 src/protocol/modern.rs diff --git a/Cargo.toml b/Cargo.toml index 170c32f..82790eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] +protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 18be509..7382df8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,4 +1,13 @@ -mod old; +#[cfg(feature = "protocol")] +mod modern; +#[cfg(feature = "protocol")] +pub(crate) use modern::Options; +#[cfg(feature = "protocol")] +pub use modern::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; +#[cfg(not(feature = "protocol"))] +mod old; +#[cfg(not(feature = "protocol"))] pub(crate) use old::Options; +#[cfg(not(feature = "protocol"))] pub use old::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs new file mode 100644 index 0000000..930f9bd --- /dev/null +++ b/src/protocol/modern.rs @@ -0,0 +1,697 @@ +use async_channel::{Receiver, Sender}; +use futures_lite::io::{AsyncRead, AsyncWrite}; +use futures_lite::stream::Stream; +use futures_timer::Delay; +use std::collections::VecDeque; +use std::convert::TryInto; +use std::fmt; +use std::future::Future; +use std::io::{self, Error, ErrorKind, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; +use tracing::trace; + +use crate::channels::{Channel, ChannelMap}; +use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; +use crate::message::{ChannelMessage, Frame, FrameType, Message}; +use crate::reader::ReadState; +use crate::schema::*; +use crate::util::{map_channel_err, pretty_hash}; +use crate::writer::WriteState; + +macro_rules! return_error { + ($msg:expr) => { + if let Err(e) = $msg { + return Poll::Ready(Err(e)); + } + }; +} + +const CHANNEL_CAP: usize = 1000; +const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); + +/// Options for a Protocol instance. +#[derive(Debug)] +pub(crate) struct Options { + /// Whether this peer initiated the IO connection for this protoccol + pub(crate) is_initiator: bool, + /// Enable or disable the handshake. + /// Disabling the handshake will also disable capabilitity verification. + /// Don't disable this if you're not 100% sure you want this. + pub(crate) noise: bool, + /// Enable or disable transport encryption. + pub(crate) encrypted: bool, +} + +impl Options { + /// Create with default options. + pub(crate) fn new(is_initiator: bool) -> Self { + Self { + is_initiator, + noise: true, + encrypted: true, + } + } +} + +/// Remote public key (32 bytes). +pub(crate) type RemotePublicKey = [u8; 32]; +/// Discovery key (32 bytes). +pub type DiscoveryKey = [u8; 32]; +/// Key (32 bytes). +pub type Key = [u8; 32]; + +/// A protocol event. +#[non_exhaustive] +#[derive(PartialEq)] +pub enum Event { + /// Emitted after the handshake with the remote peer is complete. + /// This is the first event (if the handshake is not disabled). + Handshake(RemotePublicKey), + /// Emitted when the remote peer opens a channel that we did not yet open. + DiscoveryKey(DiscoveryKey), + /// Emitted when a channel is established. + Channel(Channel), + /// Emitted when a channel is closed. + Close(DiscoveryKey), + /// Convenience event to make it possible to signal the protocol from a channel. + /// See channel.signal_local() and protocol.commands().signal_local(). + LocalSignal((String, Vec)), +} + +/// A protocol command. +#[derive(Debug)] +pub enum Command { + /// Open a channel + Open(Key), + /// Close a channel by discovery key + Close(DiscoveryKey), + /// Signal locally to protocol + SignalLocal((String, Vec)), +} + +impl fmt::Debug for Event { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Event::Handshake(remote_key) => { + write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key)) + } + Event::DiscoveryKey(discovery_key) => { + write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key)) + } + Event::Channel(channel) => { + write!(f, "Channel({})", &pretty_hash(channel.discovery_key())) + } + Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)), + Event::LocalSignal((name, data)) => { + write!(f, "LocalSignal(name={},len={})", name, data.len()) + } + } + } +} + +/// Protocol state +#[allow(clippy::large_enum_variant)] +pub(crate) enum State { + NotInitialized, + // The Handshake struct sits behind an option only so that we can .take() + // it out, it's never actually empty when in State::Handshake. + Handshake(Option), + SecretStream(Option), + Established, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::NotInitialized => write!(f, "NotInitialized"), + State::Handshake(_) => write!(f, "Handshaking"), + State::SecretStream(_) => write!(f, "SecretStream"), + State::Established => write!(f, "Established"), + } + } +} + +/// A Protocol stream. +pub struct Protocol { + write_state: WriteState, + read_state: ReadState, + io: IO, + state: State, + options: Options, + handshake: Option, + channels: ChannelMap, + command_rx: Receiver, + command_tx: CommandTx, + outbound_rx: Receiver>, + outbound_tx: Sender>, + keepalive: Delay, + queued_events: VecDeque, +} + +impl std::fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Protocol") + .field("write_state", &self.write_state) + .field("read_state", &self.read_state) + //.field("io", &self.io) + .field("state", &self.state) + .field("options", &self.options) + .field("handshake", &self.handshake) + .field("channels", &self.channels) + .field("command_rx", &self.command_rx) + .field("command_tx", &self.command_tx) + .field("outbound_rx", &self.outbound_rx) + .field("outbound_tx", &self.outbound_tx) + .field("keepalive", &self.keepalive) + .field("queued_events", &self.queued_events) + .finish() + } +} + +impl Protocol +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + /// Create a new protocol instance. + pub(crate) fn new(io: IO, options: Options) -> Self { + let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP); + let (outbound_tx, outbound_rx): ( + Sender>, + Receiver>, + ) = async_channel::bounded(1); + Protocol { + io, + read_state: ReadState::new(), + write_state: WriteState::new(), + options, + state: State::NotInitialized, + channels: ChannelMap::new(), + handshake: None, + command_rx, + command_tx: CommandTx(command_tx), + outbound_tx, + outbound_rx, + keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)), + queued_events: VecDeque::new(), + } + } + + /// Whether this protocol stream initiated the underlying IO connection. + pub fn is_initiator(&self) -> bool { + self.options.is_initiator + } + + /// Get your own Noise public key. + /// + /// Empty before the handshake completed. + pub fn public_key(&self) -> Option<&[u8]> { + match &self.handshake { + None => None, + Some(handshake) => Some(handshake.local_pubkey.as_slice()), + } + } + + /// Get the remote's Noise public key. + /// + /// Empty before the handshake completed. + pub fn remote_public_key(&self) -> Option<&[u8]> { + match &self.handshake { + None => None, + Some(handshake) => Some(handshake.remote_pubkey.as_slice()), + } + } + + /// Get a sender to send commands. + pub fn commands(&self) -> CommandTx { + self.command_tx.clone() + } + + /// Give a command to the protocol. + pub async fn command(&mut self, command: Command) -> Result<()> { + self.command_tx.send(command).await + } + + /// Open a new protocol channel. + /// + /// Once the other side proofed that it also knows the `key`, the channel is emitted as + /// `Event::Channel` on the protocol event stream. + pub async fn open(&mut self, key: Key) -> Result<()> { + self.command_tx.open(key).await + } + + /// Iterator of all currently opened channels. + pub fn channels(&self) -> impl Iterator { + self.channels.iter().map(|c| c.discovery_key()) + } + + /// Stop the protocol and return the inner reader and writer. + pub fn release(self) -> IO { + self.io + } + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let State::NotInitialized = this.state { + return_error!(this.init()); + } + + // Drain queued events first. + if let Some(event) = this.queued_events.pop_front() { + return Poll::Ready(Ok(event)); + } + + // Read and process incoming messages. + return_error!(this.poll_inbound_read(cx)); + + if let State::Established = this.state { + // Check for commands, but only once the connection is established. + return_error!(this.poll_commands(cx)); + } + + // Poll the keepalive timer. + this.poll_keepalive(cx); + + // Write everything we can write. + return_error!(this.poll_outbound_write(cx)); + + // Check if any events are enqueued. + if let Some(event) = this.queued_events.pop_front() { + Poll::Ready(Ok(event)) + } else { + Poll::Pending + } + } + + fn init(&mut self) -> Result<()> { + trace!( + "protocol Init, state {:?}, options {:?}", + self.state, + self.options + ); + match self.state { + State::NotInitialized => {} + _ => return Ok(()), + }; + + self.state = if self.options.noise { + let mut handshake = Handshake::new(self.options.is_initiator)?; + // If the handshake start returns a buffer, send it now. + if let Some(buf) = handshake.start()? { + // TODO what if this fails? or returns false + self.queue_frame_direct(buf.to_vec()).unwrap(); + } + self.read_state.set_frame_type(FrameType::Raw); + State::Handshake(Some(handshake)) + } else { + self.read_state.set_frame_type(FrameType::Message); + State::Established + }; + + Ok(()) + } + + /// Poll commands. + fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { + while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { + self.on_command(command)?; + } + Ok(()) + } + + /// Poll the keepalive timer and queue a ping message if needed. + fn poll_keepalive(&mut self, cx: &mut Context<'_>) { + if Pin::new(&mut self.keepalive).poll(cx).is_ready() { + if let State::Established = self.state { + // 24 bit header for the empty message, hence the 3 + self.write_state + .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]])); + } + self.keepalive.reset(KEEPALIVE_DURATION); + } + } + + fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { + // If message is close, close the local channel. + if let ChannelMessage { + channel, + message: Message::Close(_), + .. + } = message + { + self.close_local(*channel); + // If message is a LocalSignal, emit an event and return false to indicate + // this message should be filtered out. + } else if let ChannelMessage { + message: Message::LocalSignal((name, data)), + .. + } = message + { + self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec()))); + return false; + } + true + } + + /// Poll for inbound messages and processs them. + fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { + loop { + let msg = self.read_state.poll_reader(cx, &mut self.io); + match msg { + Poll::Ready(Ok(message)) => { + self.on_inbound_frame(message)?; + } + Poll::Ready(Err(e)) => return Err(e), + Poll::Pending => return Ok(()), + } + } + } + + /// Poll for outbound messages and write them. + fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { + loop { + if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { + return Err(e); + } + // if no parking or setup in progress + if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { + return Ok(()); + } + + match Pin::new(&mut self.outbound_rx).poll_next(cx) { + Poll::Ready(Some(mut messages)) => { + if !messages.is_empty() { + messages.retain(|message| self.on_outbound_message(message)); + if !messages.is_empty() { + let frame = Frame::MessageBatch(messages); + self.write_state.park_frame(frame); + } + } + } + Poll::Ready(None) => unreachable!("Channel closed before end"), + Poll::Pending => return Ok(()), + } + } + } + + fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { + match frame { + Frame::RawBatch(raw_batch) => { + let mut processed_state: Option = None; + for buf in raw_batch { + let state_name: String = format!("{:?}", self.state); + match self.state { + State::Handshake(_) => self.on_handshake_message(buf)?, + State::SecretStream(_) => self.on_secret_stream_message(buf)?, + State::Established => { + if let Some(processed_state) = processed_state.as_ref() { + // last state before established + let previous_state = if self.options.encrypted { + // was SecretStream if we're encrypted + State::SecretStream(None) + } else { + // or wa hasdshake if we're not encrypted + State::Handshake(None) + }; + + // if htis raw_batch included regular messages (not handshake) + // after handshake stuff + if processed_state == &format!("{previous_state:?}") { + // This is the unlucky case where the batch had two or more messages where + // the first one was correctly identified as Raw but everything + // after that should have been (decrypted and) a MessageBatch. Correct the mistake + // here post-hoc. + let buf = self.read_state.decrypt_buf(&buf)?; + let frame = Frame::decode(&buf, &FrameType::Message)?; + self.on_inbound_frame(frame)?; + continue; + } + } + unreachable!( + "May not receive raw frames in Established state" + ) + } + _ => unreachable!( + "May not receive raw frames outside of handshake or secretstream state, was {:?}", + self.state + ), + }; + if processed_state.is_none() { + processed_state = Some(state_name) + } + } + Ok(()) + } + Frame::MessageBatch(channel_messages) => match self.state { + State::Established => { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? + } + Ok(()) + } + _ => unreachable!("May not receive message batch frames when not established"), + }, + } + } + + fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { + let mut handshake = match &mut self.state { + State::Handshake(handshake) => handshake.take().unwrap(), + _ => unreachable!("May not call on_handshake_message when not in Handshake state"), + }; + + if let Some(response_buf) = handshake.read(&buf)? { + self.queue_frame_direct(response_buf.to_vec()).unwrap(); + } + + if !handshake.complete() { + self.state = State::Handshake(Some(handshake)); + } else { + let handshake_result = handshake.into_result()?; + + if self.options.encrypted { + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; + self.state = State::SecretStream(Some(cipher)); + + // Send the secret stream init message header to the other side + self.queue_frame_direct(init_msg).unwrap(); + } else { + // Skip secret stream and go straight to Established, then notify about + // handshake + self.read_state.set_frame_type(FrameType::Message); + let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; + self.queue_event(Event::Handshake(remote_public_key)); + self.state = State::Established; + } + // Store handshake result + self.handshake = Some(handshake_result.clone()); + } + Ok(()) + } + + fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { + let encrypt_cipher = match &mut self.state { + State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), + _ => { + unreachable!("May not call on_secret_stream_message when not in SecretStream state") + } + }; + let handshake_result = &self + .handshake + .as_ref() + .expect("Handshake result must be set before secret stream"); + let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; + self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); + self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); + self.read_state.set_frame_type(FrameType::Message); + + // Lastly notify that handshake is ready and set state to established + let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; + self.queue_event(Event::Handshake(remote_public_key)); + self.state = State::Established; + Ok(()) + } + + fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { + // let channel_message = ChannelMessage::decode(buf)?; + let (remote_id, message) = channel_message.into_split(); + match message { + Message::Open(msg) => self.on_open(remote_id, msg)?, + Message::Close(msg) => self.on_close(remote_id, msg)?, + _ => self + .channels + .forward_inbound_message(remote_id as usize, message)?, + } + Ok(()) + } + + fn on_command(&mut self, command: Command) -> Result<()> { + match command { + Command::Open(key) => self.command_open(key), + Command::Close(discovery_key) => self.command_close(discovery_key), + Command::SignalLocal((name, data)) => self.command_signal_local(name, data), + } + } + + /// Open a Channel with the given key. Adding it to our channel map + fn command_open(&mut self, key: Key) -> Result<()> { + // Create a new channel. + let channel_handle = self.channels.attach_local(key); + // Safe because attach_local always puts Some(local_id) + let local_id = channel_handle.local_id().unwrap(); + let discovery_key = *channel_handle.discovery_key(); + + // If the channel was already opened from the remote end, verify, and if + // verification is ok, push a channel open event. + if channel_handle.is_connected() { + self.accept_channel(local_id)?; + } + + // Tell the remote end about the new channel. + let capability = self.capability(&key); + let channel = local_id as u64; + let message = Message::Open(Open { + channel, + protocol: PROTOCOL_NAME.to_string(), + discovery_key: discovery_key.to_vec(), + capability, + }); + let channel_message = ChannelMessage::new(channel, message); + self.write_state + .queue_frame(Frame::MessageBatch(vec![channel_message])); + Ok(()) + } + + fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { + if self.channels.has_channel(&discovery_key) { + self.channels.remove(&discovery_key); + self.queue_event(Event::Close(discovery_key)); + } + Ok(()) + } + + fn command_signal_local(&mut self, name: String, data: Vec) -> Result<()> { + self.queue_event(Event::LocalSignal((name, data))); + Ok(()) + } + + fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { + let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; + let channel_handle = + self.channels + .attach_remote(discovery_key, ch as usize, msg.capability); + + if channel_handle.is_connected() { + let local_id = channel_handle.local_id().unwrap(); + self.accept_channel(local_id)?; + } else { + self.queue_event(Event::DiscoveryKey(discovery_key)); + } + + Ok(()) + } + + fn queue_event(&mut self, event: Event) { + self.queued_events.push_back(event); + } + + /// enequeu a buf to be sent + fn queue_frame_direct(&mut self, body: Vec) -> Result { + let mut frame = Frame::RawBatch(vec![body]); + self.write_state + .try_encode_and_enqueue_frame_for_tx(&mut frame) + } + + fn accept_channel(&mut self, local_id: usize) -> Result<()> { + let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; + self.verify_remote_capability(remote_capability.cloned(), key)?; + let channel = self.channels.accept(local_id, self.outbound_tx.clone())?; + self.queue_event(Event::Channel(channel)); + Ok(()) + } + + fn close_local(&mut self, local_id: u64) { + if let Some(channel) = self.channels.get_local(local_id as usize) { + let discovery_key = *channel.discovery_key(); + self.channels.remove(&discovery_key); + self.queue_event(Event::Close(discovery_key)); + } + } + + fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> { + if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) { + let discovery_key = *channel_handle.discovery_key(); + // There is a possibility both sides will close at the same time, so + // the channel could be closed already, let's tolerate that. + self.channels + .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?; + self.channels.remove(&discovery_key); + self.queue_event(Event::Close(discovery_key)); + } + Ok(()) + } + + fn capability(&self, key: &[u8]) -> Option> { + match self.handshake.as_ref() { + Some(handshake) => handshake.capability(key), + None => None, + } + } + + fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { + match self.handshake.as_ref() { + Some(handshake) => handshake.verify_remote_capability(capability, key), + None => Err(Error::new( + ErrorKind::PermissionDenied, + "Missing handshake state for capability verification", + )), + } + } +} + +impl Stream for Protocol +where + IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Protocol::poll_next(self, cx).map(Some) + } +} + +/// Send [Command](Command)s to the [Protocol](Protocol). +#[derive(Clone, Debug)] +pub struct CommandTx(Sender); + +impl CommandTx { + /// Send a protocol command + pub async fn send(&mut self, command: Command) -> Result<()> { + self.0.send(command).await.map_err(map_channel_err) + } + /// Open a protocol channel. + /// + /// The channel will be emitted on the main protocol. + pub async fn open(&mut self, key: Key) -> Result<()> { + self.send(Command::Open(key)).await + } + + /// Close a protocol channel. + pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { + self.send(Command::Close(discovery_key)).await + } + + /// Send a local signal event to the protocol. + pub async fn signal_local(&mut self, name: &str, data: Vec) -> Result<()> { + self.send(Command::SignalLocal((name.to_string(), data))) + .await + } +} + +fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> { + key.try_into() + .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long")) +} From e7e7dd7f619afd761f836de64d3ffdd428165074 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:53:17 -0400 Subject: [PATCH 043/135] fix spelling --- src/lib.rs | 2 +- src/noise.rs | 2 +- src/protocol/modern.rs | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e4c0744..88b3d32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,7 +138,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::{encyrpted_framed_message_channel, Encrypted}; +pub use noise::{encrypted_framed_message_channel, Encrypted}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/noise.rs b/src/noise.rs index c2269ea..3dbf660 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -15,7 +15,7 @@ use crate::{ }; /// Create a framed and encrypted Stream/Sink that reads/writes to an AsyncRead/AsyncWrite. -pub fn encyrpted_framed_message_channel( +pub fn encrypted_framed_message_channel( is_initiator: bool, io: IO, ) -> Encrypted> { diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index 930f9bd..6e599da 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -17,9 +17,11 @@ use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; use crate::message::{ChannelMessage, Frame, FrameType, Message}; use crate::reader::ReadState; -use crate::schema::*; use crate::util::{map_channel_err, pretty_hash}; use crate::writer::WriteState; +use crate::{ + encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, +}; macro_rules! return_error { ($msg:expr) => { From bb390b463d5077dccb2b7add5caa5ffb20c94c8d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 14:03:48 -0400 Subject: [PATCH 044/135] WIP Drop in encrypted channel compiles, fixed some errors, 'unreachable_code' warnings bc of todo!()'s --- benches/throughput.rs | 7 +++---- src/protocol/modern.rs | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index 76d6874..7f9890d 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -4,7 +4,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::future::Either; use futures::io::{AsyncRead, AsyncWrite}; use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::{schema::*, Duplex}; +use hypercore_protocol::schema::*; use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; use log::*; use std::time::Instant; @@ -88,7 +88,7 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { kill_tx } -async fn onconnection(reader: R, writer: W, is_initiator: bool) -> Duplex +async fn onconnection(reader: R, writer: W, is_initiator: bool) where R: AsyncRead + Send + Unpin + 'static, W: AsyncWrite + Send + Unpin + 'static, @@ -108,12 +108,11 @@ where task::spawn(onchannel(channel, is_initiator)); } Event::Close(_dkey) => { - return protocol.release(); + return; } _ => {} } } - protocol.release() } async fn onchannel(mut channel: Channel, is_initiator: bool) { diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index 6e599da..e0e3c3a 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -140,7 +140,7 @@ impl fmt::Debug for State { pub struct Protocol { write_state: WriteState, read_state: ReadState, - io: IO, + io: Encrypted>, state: State, options: Options, handshake: Option, @@ -184,8 +184,9 @@ where Sender>, Receiver>, ) = async_channel::bounded(1); + Protocol { - io, + io: encrypted_framed_message_channel(options.is_initiator, io), read_state: ReadState::new(), write_state: WriteState::new(), options, @@ -250,7 +251,7 @@ where } /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> IO { + pub fn release(self) -> Encrypted> { self.io } @@ -359,10 +360,10 @@ where } /// Poll for inbound messages and processs them. - fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { + fn poll_inbound_read(&mut self, _cx: &mut Context<'_>) -> Result<()> { loop { - let msg = self.read_state.poll_reader(cx, &mut self.io); - match msg { + //let msg = self.read_state.poll_reader(cx, &mut self.io); + match todo!() { Poll::Ready(Ok(message)) => { self.on_inbound_frame(message)?; } @@ -375,9 +376,9 @@ where /// Poll for outbound messages and write them. fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { - return Err(e); - } + //if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { + // return Err(e); + //} // if no parking or setup in progress if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { return Ok(()); From df1e63405aaeb6f3ce58e1276c8c7678bfbf6f76 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 15:24:16 -0400 Subject: [PATCH 045/135] make Encoder trait immutable self --- src/message.rs | 82 +++++++++++++++++++++++--------------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/src/message.rs b/src/message.rs index 27b74c1..832a5f4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -20,20 +20,20 @@ pub(crate) enum FrameType { /// (channel messages, messages, and individual message types through prost). pub(crate) trait Encoder: Sized + fmt::Debug { /// Calculates the length that the encoded message needs. - fn encoded_len(&mut self) -> Result; + fn encoded_len(&self) -> Result; /// Encodes the message to a buffer. /// /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&mut self, buf: &mut [u8]) -> Result; + fn encode(&self, buf: &mut [u8]) -> Result; } impl Encoder for &[u8] { - fn encoded_len(&mut self) -> Result { + fn encoded_len(&self) -> Result { Ok(self.len()) } - fn encode(&mut self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { let len = self.encoded_len()?; if len > buf.len() { return Err(EncodingError::new( @@ -232,7 +232,7 @@ impl Frame { } } - fn preencode(&mut self, state: &mut State) -> Result { + fn preencode(&self, state: &mut State) -> Result { match self { Self::RawBatch(raw_batch) => { for raw in raw_batch { @@ -257,7 +257,7 @@ impl Frame { state.add_end(2)?; let mut current_channel: u64 = messages[0].channel; state.preencode(¤t_channel)?; - for message in messages.iter_mut() { + for message in messages.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel @@ -277,7 +277,7 @@ impl Frame { } impl Encoder for Frame { - fn encoded_len(&mut self) -> Result { + fn encoded_len(&self) -> Result { let body_len = self.preencode(&mut State::new())?; match self { Self::RawBatch(_) => Ok(body_len), @@ -285,7 +285,7 @@ impl Encoder for Frame { } } - fn encode(&mut self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { let mut state = State::new(); let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; let body_len = self.preencode(&mut state)?; @@ -303,7 +303,7 @@ impl Encoder for Frame { } } #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref mut messages) => { + Self::MessageBatch(ref messages) => { write_uint24_le(body_len, buf); let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); if messages.len() == 1 { @@ -326,7 +326,7 @@ impl Encoder for Frame { state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; let mut current_channel: u64 = messages[0].channel; state.encode(¤t_channel, buf)?; - for message in messages.iter_mut() { + for message in messages.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel @@ -582,54 +582,48 @@ impl ChannelMessage { /// Performance optimization for letting calling encoded_len() already do /// the preencode phase of compact_encoding. - fn prepare_state(&mut self) -> Result<(), EncodingError> { - if self.state.is_none() { - let state = if let Message::Open(_) = self.message { - // Open message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else { - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); - let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state - }; - self.state = Some(state); - } - Ok(()) + fn prepare_state(&self) -> Result { + Ok(if let Message::Open(_) = self.message { + // Open message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else if let Message::Close(_) = self.message { + // Close message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else { + // The header is the channel id uint followed by message type uint + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 + let mut state = HypercoreState::new(); + let typ = self.message.typ(); + (*state).preencode(&typ)?; + self.message.preencode(&mut state)?; + state + }) } } impl Encoder for ChannelMessage { - fn encoded_len(&mut self) -> Result { - self.prepare_state()?; - Ok(self.state.as_ref().unwrap().end()) + fn encoded_len(&self) -> Result { + Ok(self.prepare_state()?.end()) } fn encode(&mut self, buf: &mut [u8]) -> Result { - self.prepare_state()?; - let state = self.state.as_mut().unwrap(); + let mut state = self.prepare_state()?; if let Message::Open(_) = self.message { // Open message is different in that the type byte is missing - self.message.encode(state, buf)?; + self.message.encode(&mut state, buf)?; } else if let Message::Close(_) = self.message { // Close message is different in that the type byte is missing - self.message.encode(state, buf)?; + self.message.encode(&mut state, buf)?; } else { let typ = self.message.typ(); state.0.encode(&typ, buf)?; - self.message.encode(state, buf)?; + self.message.encode(&mut state, buf)?; } Ok(state.start()) } From a860221b79353dedbae00864bcb900f7bf368136 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 15:53:57 -0400 Subject: [PATCH 046/135] rm unused State from ChannelMessage --- src/message.rs | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/message.rs b/src/message.rs index 832a5f4..78ccff5 100644 --- a/src/message.rs +++ b/src/message.rs @@ -479,7 +479,6 @@ impl fmt::Display for Message { pub(crate) struct ChannelMessage { pub(crate) channel: u64, pub(crate) message: Message, - state: Option, } impl PartialEq for ChannelMessage { @@ -497,11 +496,7 @@ impl fmt::Debug for ChannelMessage { impl ChannelMessage { /// Create a new message. pub(crate) fn new(channel: u64, message: Message) -> Self { - Self { - channel, - message, - state: None, - } + Self { channel, message } } /// Consume self and return (channel, Message). @@ -527,7 +522,6 @@ impl ChannelMessage { Self { channel: open_msg.channel, message: Message::Open(open_msg), - state: None, }, state.start(), )) @@ -550,7 +544,6 @@ impl ChannelMessage { Self { channel: close_msg.channel, message: Message::Close(close_msg), - state: None, }, state.start(), )) @@ -570,14 +563,7 @@ impl ChannelMessage { let mut state = State::from_buffer(buf); let typ: u64 = state.decode(buf)?; let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok(( - Self { - channel, - message, - state: None, - }, - state.start() + length, - )) + Ok((Self { channel, message }, state.start() + length)) } /// Performance optimization for letting calling encoded_len() already do From d6e1e72e4331f65d84373a5127b4025d42030862 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 15:54:12 -0400 Subject: [PATCH 047/135] split out Vec encoding --- src/message.rs | 161 +++++++++++++++++++++++++++++-------------------- 1 file changed, 95 insertions(+), 66 deletions(-) diff --git a/src/message.rs b/src/message.rs index 78ccff5..a807bf9 100644 --- a/src/message.rs +++ b/src/message.rs @@ -239,37 +239,8 @@ impl Frame { state.add_end(raw.as_slice().encoded_len()?)?; } } - #[allow(clippy::comparison_chain)] Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } + state.add_end(messages.encoded_len()?)?; } } Ok(state.end()) @@ -302,46 +273,104 @@ impl Encoder for Frame { raw.as_slice().encode(buf)?; } } - #[allow(clippy::comparison_chain)] Self::MessageBatch(ref messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes + messages.encode(buf)?; + } + }; + Ok(len) + } +} + +fn prencode_channel_messages( + messages: &[ChannelMessage], + state: &mut State, +) -> Result { + match messages.len().cmp(&1) { + std::cmp::Ordering::Less => {} + std::cmp::Ordering::Equal => { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + state.preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } + std::cmp::Ordering::Greater => { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + }; + Ok(state.end()) +} + +impl Encoder for Vec { + fn encoded_len(&self) -> Result { + let mut state = State::new(); + prencode_channel_messages(self, &mut state) + } + + fn encode(&self, buf: &mut [u8]) -> Result { + const HEADER_LEN: usize = 3; + let mut state = State::new(); + let body_len = prencode_channel_messages(self, &mut state)?; + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + match self.len().cmp(&1) { + std::cmp::Ordering::Less => {} + std::cmp::Ordering::Equal => { + if let Message::Open(_) = &self[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &self[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&self[0].channel, buf)?; + state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + } + } + std::cmp::Ordering::Greater => { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = self[0].channel; + state.encode(¤t_channel, buf)?; + for message in self.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; } } - }; - Ok(len) + } + Ok(HEADER_LEN + body_len) } } @@ -598,7 +627,7 @@ impl Encoder for ChannelMessage { Ok(self.prepare_state()?.end()) } - fn encode(&mut self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { let mut state = self.prepare_state()?; if let Message::Open(_) = self.message { // Open message is different in that the type byte is missing From 7dbb12574ff6566b24ebb079c700fe6b9bd937f1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 18:15:43 -0400 Subject: [PATCH 048/135] make integration tests use tokio --- tests/_util.rs | 58 +++++++++++++++++++------------------------------- tests/basic.rs | 46 ++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 61 deletions(-) diff --git a/tests/_util.rs b/tests/_util.rs index 9d0f9bf..3064c08 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -1,10 +1,27 @@ use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task::{self, JoinHandle}; use futures_lite::io::{AsyncRead, AsyncWrite}; +use futures_lite::StreamExt; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; +use std::future::Future; use std::io; +use tokio::task::JoinHandle; + +pub(crate) fn log() { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); + START_LOGS.get_or_init(|| { + tracing_subscriber::fmt() + .with_target(true) + .with_line_number(true) + // print when instrumented funtion enters + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) + .with_file(true) + .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable + .without_time() + .init(); + }); +} pub type MemoryProtocol = Protocol>; pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { @@ -18,21 +35,11 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) Ok((a, b)) } -pub type TcpProtocol = Protocol; -pub async fn create_pair_tcp() -> io::Result<(TcpProtocol, TcpProtocol)> { - let (stream_a, stream_b) = tcp::pair().await?; - let a = ProtocolBuilder::new(true).connect(stream_a); - let b = ProtocolBuilder::new(false).connect(stream_b); - Ok((a, b)) -} - -pub fn next_event( - mut proto: Protocol, -) -> impl Future, io::Result)> +pub fn next_event(mut proto: Protocol) -> JoinHandle<(Protocol, io::Result)> where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - task::spawn(async move { + tokio::task::spawn(async move { let e1 = proto.next().await; let e1 = e1.unwrap(); (proto, e1) @@ -62,7 +69,7 @@ pub fn drive_until_channel( where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - task::spawn(async move { + tokio::task::spawn(async move { while let Some(event) = proto.next().await { let event = event?; if let Event::Channel(channel) = event { @@ -76,27 +83,6 @@ where }) } -pub mod tcp { - use async_std::net::{TcpListener, TcpStream}; - use async_std::prelude::*; - use async_std::task; - use std::io::{Error, ErrorKind, Result}; - pub async fn pair() -> Result<(TcpStream, TcpStream)> { - let address = "localhost:9999"; - let listener = TcpListener::bind(&address).await?; - let mut incoming = listener.incoming(); - - let connect_task = task::spawn(async move { TcpStream::connect(&address).await }); - - let server_stream = incoming.next().await; - let server_stream = - server_stream.ok_or_else(|| Error::new(ErrorKind::Other, "Stream closed"))?; - let server_stream = server_stream?; - let client_stream = connect_task.await?; - Ok((server_stream, client_stream)) - } -} - const RETRY_TIMEOUT: u64 = 100_u64; const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; pub async fn wait_for_localhost_port(port: u32) { diff --git a/tests/basic.rs b/tests/basic.rs index 8a99c7e..062cf35 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,26 +1,22 @@ -#![allow(dead_code, unused_imports)] - -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use hypercore_protocol::{discovery_key, Channel, Event, Message, Protocol, ProtocolBuilder}; +use _util::{ + create_pair_memory, drive_until_channel, event_channel, event_discovery_key, next_event, +}; +use futures_lite::StreamExt; +use hypercore_protocol::{discovery_key, Event, Message}; use hypercore_protocol::{schema::*, DiscoveryKey}; use std::io; -use test_log::test; +use tokio::task; mod _util; -use _util::*; -#[test(async_std::test)] +#[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - // env_logger::init(); let (proto_a, proto_b) = create_pair_memory().await?; let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, event_a) = next_a.await; - let (proto_b, event_b) = next_b.await; + let (mut proto_a, event_a) = next_a.await?; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); @@ -35,18 +31,18 @@ async fn basic_protocol() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_b, event_b) = next_b.await; + let (mut proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::DiscoveryKey(_)))); assert_eq!(event_discovery_key(event_b.unwrap()), discovery_key(&key)); proto_b.open(key).await?; let next_b = next_event(proto_b); - let (proto_b, event_b) = next_b.await; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::Channel(_)))); let mut channel_b = event_channel(event_b.unwrap()); - let (proto_a, event_a) = next_a.await; + let (proto_a, event_a) = next_a.await?; assert!(matches!(event_a, Ok(Event::Channel(_)))); let mut channel_a = event_channel(event_a.unwrap()); @@ -68,8 +64,8 @@ async fn basic_protocol() -> anyhow::Result<()> { channel_a.close().await?; - let (_, event_a) = next_a.await; - let (_, event_b) = next_b.await; + let (_, event_a) = next_a.await?; + let (_, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Close(_)))); assert!(matches!(event_b, Ok(Event::Close(_)))); @@ -78,7 +74,7 @@ async fn basic_protocol() -> anyhow::Result<()> { Ok(()) } -#[test(async_std::test)] +#[tokio::test] async fn open_close_channels() -> anyhow::Result<()> { let (mut proto_a, mut proto_b) = create_pair_memory().await?; @@ -91,8 +87,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (mut proto_a, mut channel_a1) = next_a.await?; - let (mut proto_b, mut channel_b1) = next_b.await?; + let (mut proto_a, mut channel_a1) = next_a.await??; + let (mut proto_b, mut channel_b1) = next_b.await??; proto_a.open(key2).await?; proto_b.open(key2).await?; @@ -100,8 +96,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a2) = next_a.await?; - let (proto_b, mut channel_b2) = next_b.await?; + let (proto_a, mut channel_a2) = next_a.await??; + let (proto_b, mut channel_b2) = next_b.await??; eprintln!( "got channels: {:?}", @@ -119,8 +115,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, ev_a) = next_a.await; - let (mut proto_b, ev_b) = next_b.await; + let (mut proto_a, ev_a) = next_a.await?; + let (mut proto_b, ev_b) = next_b.await?; let ev_a = ev_a?; let ev_b = ev_b?; eprintln!("next a: {ev_a:?}"); From bbad25b265dc08598bdf4aab3c9acbc014ee370e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 19 Mar 2025 11:19:10 -0400 Subject: [PATCH 049/135] RESETME --- Cargo.toml | 7 +- src/framing.rs | 3 - src/lib.rs | 1 + src/message.rs | 148 ++++++++++++++------------- src/mqueue.rs | 208 +++++++++++++++++++++++++++++++++++++ src/noise.rs | 3 +- src/protocol/modern.rs | 226 +++++++---------------------------------- src/protocol/old.rs | 1 + src/schema.rs | 9 +- src/test_utils.rs | 5 +- tests/_util.rs | 7 +- tests/basic.rs | 11 +- 12 files changed, 349 insertions(+), 280 deletions(-) create mode 100644 src/mqueue.rs diff --git a/Cargo.toml b/Cargo.toml index 82790eb..7c15c8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,8 +42,9 @@ crypto_secretstream = "0.2" futures = "0.3.31" [dependencies.hypercore] -version = "0.14.0" -default-features = false +path = "../core" +#version = "0.14.0" +#default-features = false [dev-dependencies] @@ -64,7 +65,7 @@ tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } tokio-util = { version = "0.7.14", features = ["compat"] } [features] -default = ["tokio", "sparse"] +default = ["tokio", "sparse", "protocol"] protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" diff --git a/src/framing.rs b/src/framing.rs index e7c9c5c..ce4d7ba 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -223,7 +223,6 @@ pub(crate) mod test { let mut lp = Uint24LELengthPrefixedFraming::new(left); let input = b"yelp"; let msg = wrap_uint24_le(input); - dbg!(&msg); right.write_all(&msg).await?; let Some(Ok(rx)) = lp.next().await else { panic!() @@ -242,11 +241,9 @@ pub(crate) mod test { right.write_all(&msg).await?; } for d in data { - dbg!(); let Some(Ok(res)) = lp.next().await else { panic!(); }; - dbg!(&res); assert_eq!(&res, d); } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 88b3d32..9fca95e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,7 @@ mod crypto; mod duplex; mod framing; mod message; +mod mqueue; mod noise; mod protocol; mod reader; diff --git a/src/message.rs b/src/message.rs index a807bf9..869baf0 100644 --- a/src/message.rs +++ b/src/message.rs @@ -6,6 +6,7 @@ use hypercore::encoding::{ use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; +use tracing::instrument; /// The type of a data frame. #[derive(Debug, Clone, PartialEq)] @@ -76,6 +77,79 @@ impl From> for Frame { } } +pub(crate) fn decode_channel_messages( + buf: &[u8], +) -> Result<(Vec, usize), io::Error> { + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + dbg!(); + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((messages, state.start())) + } else if buf[1] == 0x01 { + dbg!(); + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) + } else if buf[1] == 0x03 { + dbg!(); + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + dbg!(); + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok((vec![channel_message], state.start() + length)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } +} + impl Frame { /// Decodes a frame from a buffer containing multiple concurrent messages. pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { @@ -163,73 +237,8 @@ impl Frame { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } + let (channel_messages, bytes_read) = decode_channel_messages(buf)?; + Ok((Frame::MessageBatch(channel_messages), bytes_read)) } fn preencode(&self, state: &mut State) -> Result { @@ -290,7 +299,7 @@ fn prencode_channel_messages( std::cmp::Ordering::Equal => { if let Message::Open(_) = &messages[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; + state.add_end(2 + dbg!(&messages[0].encoded_len()?))?; } else if let Message::Close(_) = &messages[0].message { // This is a special case with 0x00, 0x03 intro bytes state.add_end(2 + &messages[0].encoded_len()?)?; @@ -327,6 +336,7 @@ impl Encoder for Vec { prencode_channel_messages(self, &mut state) } + #[instrument] fn encode(&self, buf: &mut [u8]) -> Result { const HEADER_LEN: usize = 3; let mut state = State::new(); @@ -655,7 +665,7 @@ mod tests { ($( $msg:expr ),*) => { $( let channel = rand::random::() as u64; - let mut channel_message = ChannelMessage::new(channel, $msg); + let channel_message = ChannelMessage::new(channel, $msg); let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); let mut buf = vec![0u8; encoded_len]; let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); diff --git a/src/mqueue.rs b/src/mqueue.rs new file mode 100644 index 0000000..b968937 --- /dev/null +++ b/src/mqueue.rs @@ -0,0 +1,208 @@ +//! Interface for reading and writing message to a Stream/Sink + +use std::{ + collections::VecDeque, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{AsyncRead, AsyncWrite, Sink, Stream}; +use tracing::{debug, error, info, instrument, trace}; + +use crate::{ + encrypted_framed_message_channel, + message::{decode_channel_messages, ChannelMessage, Encoder as _}, +}; + +pub(crate) struct MessageIo { + io: IO, + write_queue: VecDeque, +} + +use crate::{framing::Uint24LELengthPrefixedFraming, noise::Encrypted}; + +pub(crate) fn encrypted_and_framed( + is_initiator: bool, + io: BytesTxRx, +) -> MessageIo>> { + let io = encrypted_framed_message_channel(is_initiator, io); + MessageIo { + io, + write_queue: Default::default(), + } +} +impl>> + Sink> + Send + Unpin + 'static> MessageIo { + pub(crate) fn new(io: IO) -> Self { + Self { + io, + write_queue: Default::default(), + } + } + + pub(crate) fn enqueue(&mut self, msg: ChannelMessage) { + self.write_queue.push_back(msg) + } + + #[instrument(skip_all)] + pub(crate) fn poll_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut pending = true; + // TODO handle error? + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.io), cx) { + pending = false; + if self.write_queue.is_empty() { + break; + } + let mut messages = vec![]; + while let Some(msg) = self.write_queue.pop_front() { + messages.push(msg); + } + + let mut buf = vec![0; messages.encoded_len()?]; + dbg!(&buf); + match messages.encode(&mut buf) { + Ok(_) => {} + Err(e) => { + error!(error = ?e, "error encoding messages"); + return Poll::Ready(Err(e.into())); + } + } + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { + error!("error in start_send"); + todo!() + } + + match Sink::poll_flush(Pin::new(&mut self.io), cx) { + Poll::Ready(Ok(())) => { + debug!("flushed"); + } + Poll::Ready(Err(_e)) => { + error!("Error flushing"); + return todo!(); + } + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + + if pending { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + pub(crate) fn poll_inbound( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + match Pin::new(&mut self.io).poll_next(cx) { + Poll::Ready(Some(Ok(encoded))) => { + match decode_channel_messages(&encoded) { + Ok((messsages, n_read)) => { + assert_eq!(n_read, encoded.len()); // I think this is always true + Poll::Ready(Ok(messsages)) + } + Err(_) => todo!(), + } + } + Poll::Ready(Some(Err(_e))) => todo!(), + Poll::Ready(None) => todo!(), + Poll::Pending => Poll::Pending, + } + } +} + +impl>> + Sink> + Send + Unpin + 'static> Stream + for MessageIo +{ + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let out_res = self.poll_outbound(cx); + match out_res { + Poll::Ready(res) => match res { + Ok(okres) => trace!(res = ?okres, "MessageIo poll_outbound"), + Err(e) => error!(error = ?e, "MessageIo error in poll_outbound"), + }, + Poll::Pending => trace!("MessageIo poll_outbound Pending"), + } + + let in_res = self.poll_inbound(cx); + trace!(poll_inbound = ?in_res, "MessageIo"); + + in_res.map(Some) + } +} + +#[cfg(test)] +mod test { + use std::io::Result; + + use futures::future::{join, select}; + use futures_lite::StreamExt; + + use crate::{ + framing::test::duplex, + message::{decode_channel_messages, ChannelMessage, Encoder as _}, + mqueue::encrypted_and_framed, + schema::{NoData, Open}, + test_utils::log, + }; + fn new_msg(channel: u64) -> ChannelMessage { + ChannelMessage { + channel, + message: crate::Message::NoData(NoData { request: channel }), + } + } + + #[tokio::test] + async fn mqueue() -> Result<()> { + log(); + let m = vec![new_msg(0)]; + let mut buf = vec![0; m.encoded_len()?]; + dbg!(&buf.len()); + dbg!(); + m.encode(&mut buf)?; + dbg!(&buf); + + let res = dbg!(decode_channel_messages(&buf))?; + assert_eq!(vec![new_msg(42402)], res.0); + dbg!(&buf); + + Ok(()) + + /* + let (left, right) = duplex(1024 * 64); + let mut left = encrypted_and_framed(true, left); + let mut right = encrypted_and_framed(false, right); + left.enqueue(new_msg(42)); + right.enqueue(new_msg(38)); + + match select(left.next(), right.next()).await { + futures::future::Either::Left(ll) => { + println!( + "left + + ooooooooooooooooooooo + + " + ); + } + futures::future::Either::Right(rr) => { + println!( + "rightllllllllllllllll + + ------------------------- + + " + ); + } + } + Ok(()) + */ + } +} diff --git a/src/noise.rs b/src/noise.rs index 3dbf660..d2a3a1d 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -22,6 +22,7 @@ pub fn encrypted_framed_message_channel { - dbg!(); let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { trace!("Read in handshake msg\n{msg:?}"); @@ -562,7 +562,6 @@ mod tset { #[tokio::test] async fn with_framing() -> Result<()> { - crate::test_utils::log(); let hello = b"hello".to_vec(); let (left, right) = duplex(1024 * 64); diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index e0e3c3a..cb71f50 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -5,20 +5,18 @@ use futures_timer::Delay; use std::collections::VecDeque; use std::convert::TryInto; use std::fmt; -use std::future::Future; use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::trace; +use tracing::instrument; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, FrameType, Message}; -use crate::reader::ReadState; +use crate::crypto::{EncryptCipher, Handshake, HandshakeResult}; +use crate::message::{ChannelMessage, Frame, Message}; +use crate::mqueue::MessageIo; use crate::util::{map_channel_err, pretty_hash}; -use crate::writer::WriteState; use crate::{ encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, }; @@ -138,9 +136,7 @@ impl fmt::Debug for State { /// A Protocol stream. pub struct Protocol { - write_state: WriteState, - read_state: ReadState, - io: Encrypted>, + io: MessageIo>>, state: State, options: Options, handshake: Option, @@ -156,8 +152,6 @@ pub struct Protocol { impl std::fmt::Debug for Protocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Protocol") - .field("write_state", &self.write_state) - .field("read_state", &self.read_state) //.field("io", &self.io) .field("state", &self.state) .field("options", &self.options) @@ -186,9 +180,7 @@ where ) = async_channel::bounded(1); Protocol { - io: encrypted_framed_message_channel(options.is_initiator, io), - read_state: ReadState::new(), - write_state: WriteState::new(), + io: MessageIo::new(encrypted_framed_message_channel(options.is_initiator, io)), options, state: State::NotInitialized, channels: ChannelMap::new(), @@ -251,17 +243,14 @@ where } /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> Encrypted> { + pub fn release(self) -> MessageIo>> { self.io } + #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - if let State::NotInitialized = this.state { - return_error!(this.init()); - } - // Drain queued events first. if let Some(event) = this.queued_events.pop_front() { return Poll::Ready(Ok(event)); @@ -270,10 +259,8 @@ where // Read and process incoming messages. return_error!(this.poll_inbound_read(cx)); - if let State::Established = this.state { - // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); - } + // Check for commands, but only once the connection is established. + return_error!(this.poll_commands(cx)); // Poll the keepalive timer. this.poll_keepalive(cx); @@ -289,34 +276,6 @@ where } } - fn init(&mut self) -> Result<()> { - trace!( - "protocol Init, state {:?}, options {:?}", - self.state, - self.options - ); - match self.state { - State::NotInitialized => {} - _ => return Ok(()), - }; - - self.state = if self.options.noise { - let mut handshake = Handshake::new(self.options.is_initiator)?; - // If the handshake start returns a buffer, send it now. - if let Some(buf) = handshake.start()? { - // TODO what if this fails? or returns false - self.queue_frame_direct(buf.to_vec()).unwrap(); - } - self.read_state.set_frame_type(FrameType::Raw); - State::Handshake(Some(handshake)) - } else { - self.read_state.set_frame_type(FrameType::Message); - State::Established - }; - - Ok(()) - } - /// Poll commands. fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { @@ -325,8 +284,9 @@ where Ok(()) } - /// Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&mut self, cx: &mut Context<'_>) { + /// TODO Poll the keepalive timer and queue a ping message if needed. + fn poll_keepalive(&mut self, _cx: &mut Context<'_>) { + /* if Pin::new(&mut self.keepalive).poll(cx).is_ready() { if let State::Established = self.state { // 24 bit header for the empty message, hence the 3 @@ -335,8 +295,10 @@ where } self.keepalive.reset(KEEPALIVE_DURATION); } + */ } + // just handles Close and LocalSignal?? fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { // If message is close, close the local channel. if let ChannelMessage { @@ -360,12 +322,11 @@ where } /// Poll for inbound messages and processs them. - fn poll_inbound_read(&mut self, _cx: &mut Context<'_>) -> Result<()> { + fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - //let msg = self.read_state.poll_reader(cx, &mut self.io); - match todo!() { - Poll::Ready(Ok(message)) => { - self.on_inbound_frame(message)?; + match self.io.poll_inbound(cx) { + Poll::Ready(Ok(messages)) => { + self.on_inbound_channel_messages(messages)?; } Poll::Ready(Err(e)) => return Err(e), Poll::Pending => return Ok(()), @@ -374,23 +335,20 @@ where } /// Poll for outbound messages and write them. + /// Reads messages from Self::outbound and sends them over io fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - //if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { - // return Err(e); - //} // if no parking or setup in progress - if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { - return Ok(()); + if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) { + return Err(e); } - + // send messages outbound_rx match Pin::new(&mut self.outbound_rx).poll_next(cx) { Poll::Ready(Some(mut messages)) => { if !messages.is_empty() { messages.retain(|message| self.on_outbound_message(message)); - if !messages.is_empty() { - let frame = Frame::MessageBatch(messages); - self.write_state.park_frame(frame); + for msg in messages { + self.io.enqueue(msg); } } } @@ -400,125 +358,13 @@ where } } - fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { - match frame { - Frame::RawBatch(raw_batch) => { - let mut processed_state: Option = None; - for buf in raw_batch { - let state_name: String = format!("{:?}", self.state); - match self.state { - State::Handshake(_) => self.on_handshake_message(buf)?, - State::SecretStream(_) => self.on_secret_stream_message(buf)?, - State::Established => { - if let Some(processed_state) = processed_state.as_ref() { - // last state before established - let previous_state = if self.options.encrypted { - // was SecretStream if we're encrypted - State::SecretStream(None) - } else { - // or wa hasdshake if we're not encrypted - State::Handshake(None) - }; - - // if htis raw_batch included regular messages (not handshake) - // after handshake stuff - if processed_state == &format!("{previous_state:?}") { - // This is the unlucky case where the batch had two or more messages where - // the first one was correctly identified as Raw but everything - // after that should have been (decrypted and) a MessageBatch. Correct the mistake - // here post-hoc. - let buf = self.read_state.decrypt_buf(&buf)?; - let frame = Frame::decode(&buf, &FrameType::Message)?; - self.on_inbound_frame(frame)?; - continue; - } - } - unreachable!( - "May not receive raw frames in Established state" - ) - } - _ => unreachable!( - "May not receive raw frames outside of handshake or secretstream state, was {:?}", - self.state - ), - }; - if processed_state.is_none() { - processed_state = Some(state_name) - } - } - Ok(()) - } - Frame::MessageBatch(channel_messages) => match self.state { - State::Established => { - for channel_message in channel_messages { - self.on_inbound_message(channel_message)? - } - Ok(()) - } - _ => unreachable!("May not receive message batch frames when not established"), - }, - } - } - - fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { - let mut handshake = match &mut self.state { - State::Handshake(handshake) => handshake.take().unwrap(), - _ => unreachable!("May not call on_handshake_message when not in Handshake state"), - }; - - if let Some(response_buf) = handshake.read(&buf)? { - self.queue_frame_direct(response_buf.to_vec()).unwrap(); - } - - if !handshake.complete() { - self.state = State::Handshake(Some(handshake)); - } else { - let handshake_result = handshake.into_result()?; - - if self.options.encrypted { - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; - self.state = State::SecretStream(Some(cipher)); - - // Send the secret stream init message header to the other side - self.queue_frame_direct(init_msg).unwrap(); - } else { - // Skip secret stream and go straight to Established, then notify about - // handshake - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - } - // Store handshake result - self.handshake = Some(handshake_result.clone()); + fn on_inbound_channel_messages(&mut self, channel_messages: Vec) -> Result<()> { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? } Ok(()) } - fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { - let encrypt_cipher = match &mut self.state { - State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), - _ => { - unreachable!("May not call on_secret_stream_message when not in SecretStream state") - } - }; - let handshake_result = &self - .handshake - .as_ref() - .expect("Handshake result must be set before secret stream"); - let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; - self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); - self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); - self.read_state.set_frame_type(FrameType::Message); - - // Lastly notify that handshake is ready and set state to established - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - Ok(()) - } - fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -564,8 +410,7 @@ where capability, }); let channel_message = ChannelMessage::new(channel, message); - self.write_state - .queue_frame(Frame::MessageBatch(vec![channel_message])); + self.io.enqueue(channel_message); Ok(()) } @@ -602,13 +447,6 @@ where self.queued_events.push_back(event); } - /// enequeu a buf to be sent - fn queue_frame_direct(&mut self, body: Vec) -> Result { - let mut frame = Frame::RawBatch(vec![body]); - self.write_state - .try_encode_and_enqueue_frame_for_tx(&mut frame) - } - fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -662,7 +500,11 @@ where { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Protocol::poll_next(self, cx).map(Some) + match Protocol::poll_next(self, cx) { + Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } } } diff --git a/src/protocol/old.rs b/src/protocol/old.rs index 930f9bd..01af713 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -387,6 +387,7 @@ where messages.retain(|message| self.on_outbound_message(message)); if !messages.is_empty() { let frame = Frame::MessageBatch(messages); + // TODO try replacing this with queue_frame self.write_state.park_frame(frame); } } diff --git a/src/schema.rs b/src/schema.rs index ef58e77..bf35416 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,9 +18,9 @@ pub struct Open { impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { - self.preencode(&value.channel)?; - self.preencode(&value.protocol)?; - self.preencode(&value.discovery_key)?; + dbg!(self.preencode(&value.channel)?); + dbg!(self.preencode(&value.protocol)?); + dbg!(self.preencode(&value.discovery_key)?); if value.capability.is_some() { self.add_end(1)?; // flags for future use self.preencode_fixed_32()?; @@ -29,6 +29,7 @@ impl CompactEncoding for State { } fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { + dbg!(); self.encode(&value.channel, buffer)?; self.encode(&value.protocol, buffer)?; self.encode(&value.discovery_key, buffer)?; @@ -369,7 +370,7 @@ pub struct NoData { impl CompactEncoding for State { fn preencode(&mut self, value: &NoData) -> Result { - self.preencode(&value.request) + dbg!(self.preencode(dbg!(&value.request))) } fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { diff --git a/src/test_utils.rs b/src/test_utils.rs index ff1a3c2..b12f440 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,7 +1,6 @@ use std::{ io::{self, ErrorKind}, pin::Pin, - sync::OnceLock, task::{Context, Poll}, }; @@ -74,9 +73,9 @@ impl TwoWay { } } -pub(crate) fn log() { +pub fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - static START_LOGS: OnceLock<()> = OnceLock::new(); + static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { tracing_subscriber::fmt() .with_target(true) diff --git a/tests/_util.rs b/tests/_util.rs index 3064c08..d1fa197 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -3,10 +3,10 @@ use futures_lite::io::{AsyncRead, AsyncWrite}; use futures_lite::StreamExt; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; -use std::future::Future; use std::io; use tokio::task::JoinHandle; +#[allow(unused)] pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); @@ -83,9 +83,10 @@ where }) } -const RETRY_TIMEOUT: u64 = 100_u64; -const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; +#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { + const RETRY_TIMEOUT: u64 = 100_u64; + const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; loop { let timeout = async_std::future::timeout( Duration::from_millis(NO_RESPONSE_TIMEOUT), diff --git a/tests/basic.rs b/tests/basic.rs index 062cf35..92e4c8b 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -13,10 +13,19 @@ mod _util; async fn basic_protocol() -> anyhow::Result<()> { let (proto_a, proto_b) = create_pair_memory().await?; + dbg!(); let next_a = next_event(proto_a); + dbg!(); let next_b = next_event(proto_b); - let (mut proto_a, event_a) = next_a.await?; + dbg!(); let (proto_b, event_b) = next_b.await?; + dbg!(); + let (mut proto_a, event_a) = next_a.await?; + //let (a, b) = join(next_a, next_b).await; + dbg!(); + //let (mut proto_a, event_a) = a?; + dbg!(); + //let (proto_b, event_b) = b?; assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); From e9d6236a2a2b47bd71b376453a190b9a00f61569 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 19 Mar 2025 12:22:57 -0400 Subject: [PATCH 050/135] Add frame encoding test --- src/message.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/message.rs b/src/message.rs index 869baf0..25a2545 100644 --- a/src/message.rs +++ b/src/message.rs @@ -676,6 +676,28 @@ mod tests { } } + #[test] + fn frame_encode_decode() -> std::io::Result<()> { + let msg = Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }); + + let channel = rand::random::() as u64; + let channel_message = ChannelMessage::new(channel, msg); + + let frame = Frame::from(channel_message); + let mut buf = vec![0; frame.encoded_len()?]; + frame.encode(&mut buf)?; + let res_frame = Frame::decode(&buf, &FrameType::Message)?; + assert_eq!(res_frame, frame); + Ok(()) + } + #[test] fn message_encode_decode() { message_enc_dec! { From 15d2895ba5ed704e4302a4bd9dacc97c7779de32 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 20 Mar 2025 15:45:34 -0400 Subject: [PATCH 051/135] fix ChannelMessage Encoding. Restore Frame encoding --- src/message.rs | 322 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 292 insertions(+), 30 deletions(-) diff --git a/src/message.rs b/src/message.rs index 25a2545..90028aa 100644 --- a/src/message.rs +++ b/src/message.rs @@ -8,6 +8,8 @@ use std::fmt; use std::io; use tracing::instrument; +const UINT24_HEADER_LEN: usize = 3; + /// The type of a data frame. #[derive(Debug, Clone, PartialEq)] pub(crate) enum FrameType { @@ -71,13 +73,60 @@ impl From for Frame { } } +impl From> for Frame { + fn from(m: Vec) -> Self { + Self::MessageBatch(m) + } +} + impl From> for Frame { fn from(m: Vec) -> Self { Self::RawBatch(vec![m]) } } -pub(crate) fn decode_channel_messages( +pub(crate) fn decode_many_channel_messages( + buf: &[u8], +) -> Result<(Vec, usize), io::Error> { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (msgs, length) = decode_one_channel_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + body_len, + length + ); + } + for message in msgs { + combined_messages.push(message); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok((combined_messages, index)) +} +// bad name bc it returns many. More like, decode unframed channel messages +pub(crate) fn decode_one_channel_message( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { if buf.len() >= 3 && buf[0] == 0x00 { @@ -237,8 +286,76 @@ impl Frame { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - let (channel_messages, bytes_read) = decode_channel_messages(buf)?; - Ok((Frame::MessageBatch(channel_messages), bytes_read)) + println!("decode_message {buf:02X?}"); + // buffer length >= 3 or more and starts with 0 is message batch + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // len >= and + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } } fn preencode(&self, state: &mut State) -> Result { @@ -248,8 +365,37 @@ impl Frame { state.add_end(raw.as_slice().encoded_len()?)?; } } + #[allow(clippy::comparison_chain)] Self::MessageBatch(messages) => { - state.add_end(messages.encoded_len()?)?; + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } } } Ok(state.end()) @@ -282,8 +428,43 @@ impl Encoder for Frame { raw.as_slice().encode(buf)?; } } + #[allow(clippy::comparison_chain)] Self::MessageBatch(ref messages) => { - messages.encode(buf)?; + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } } }; Ok(len) @@ -333,12 +514,11 @@ fn prencode_channel_messages( impl Encoder for Vec { fn encoded_len(&self) -> Result { let mut state = State::new(); - prencode_channel_messages(self, &mut state) + Ok(prencode_channel_messages(self, &mut state)? + UINT24_HEADER_LEN) } #[instrument] fn encode(&self, buf: &mut [u8]) -> Result { - const HEADER_LEN: usize = 3; let mut state = State::new(); let body_len = prencode_channel_messages(self, &mut state)?; write_uint24_le(body_len, buf); @@ -380,7 +560,7 @@ impl Encoder for Vec { } } } - Ok(HEADER_LEN + body_len) + Ok(UINT24_HEADER_LEN + body_len) } } @@ -656,6 +836,7 @@ impl Encoder for ChannelMessage { #[cfg(test)] mod tests { + use super::*; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, @@ -676,28 +857,6 @@ mod tests { } } - #[test] - fn frame_encode_decode() -> std::io::Result<()> { - let msg = Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }); - - let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, msg); - - let frame = Frame::from(channel_message); - let mut buf = vec![0; frame.encoded_len()?]; - frame.encode(&mut buf)?; - let res_frame = Frame::decode(&buf, &FrameType::Message)?; - assert_eq!(res_frame, frame); - Ok(()) - } - #[test] fn message_encode_decode() { message_enc_dec! { @@ -781,4 +940,107 @@ mod tests { }) }; } + + fn message_test_data() -> Vec { + vec![ + Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }), + Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0, + }), + seek: Some(RequestSeek { bytes: 10 }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10, + }), + }), + Message::Cancel(Cancel { request: 1 }), + Message::Data(Data { + request: 1, + fork: 5, + block: Some(DataBlock { + index: 5, + nodes: vec![Node::new(1, vec![0x01; 32], 100)], + value: vec![0xFF; 10], + }), + hash: Some(DataHash { + index: 20, + nodes: vec![Node::new(2, vec![0x02; 32], 200)], + }), + seek: Some(DataSeek { + bytes: 10, + nodes: vec![Node::new(3, vec![0x03; 32], 300)], + }), + upgrade: Some(DataUpgrade { + start: 0, + length: 10, + nodes: vec![Node::new(4, vec![0x04; 32], 400)], + additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], + signature: vec![0xAB; 32], + }), + }), + Message::NoData(NoData { request: 2 }), + Message::Want(Want { + start: 0, + length: 100, + }), + Message::Unwant(Unwant { + start: 10, + length: 2, + }), + Message::Bitfield(Bitfield { + start: 20, + bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], + }), + Message::Range(Range { + drop: true, + start: 12345, + length: 100000, + }), + Message::Extension(Extension { + name: "custom_extension/v1/open".to_string(), + message: vec![0x44, 20], + }), + ] + } + + #[test] + fn compare_with_frame_encoding_decoding() -> std::io::Result<()> { + let channel = 42; + for msg in message_test_data() { + let channel_message = ChannelMessage::new(channel, msg); + let frame = Frame::from(channel_message.clone()); + let cmvec = vec![channel_message.clone()]; + + let mut fbuf = vec![0; frame.encoded_len()?]; + let mut cbuf = vec![0; cmvec.encoded_len()?]; + + assert_eq!(cbuf, fbuf); + + frame.encode(&mut fbuf)?; + cmvec.encode(&mut cbuf)?; + + assert_eq!(cbuf, fbuf); + + let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; + assert_eq!(fres, frame); + let cres_m = decode_many_channel_messages(&cbuf)?.0; + assert_eq!(cres_m, cmvec); + } + Ok(()) + } } From 002ba3681f6a9a957e63851b646b008bf65de272 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 12:29:49 -0400 Subject: [PATCH 052/135] rm/gate old unused stuff. add instrument --- src/framing.rs | 67 +++- src/lib.rs | 4 + src/message.rs | 28 -- src/mqueue.rs | 102 +++--- src/oldmessage.rs | 814 +++++++++++++++++++++++++++++++++++++++++ src/protocol/modern.rs | 38 +- src/protocol/old.rs | 7 +- src/test_utils.rs | 2 +- src/util.rs | 26 ++ tests/_util.rs | 19 + tests/basic.rs | 4 +- 11 files changed, 991 insertions(+), 120 deletions(-) create mode 100644 src/oldmessage.rs diff --git a/src/framing.rs b/src/framing.rs index ce4d7ba..8b8ae8f 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -192,11 +192,12 @@ where } #[cfg(test)] pub(crate) mod test { - use crate::test_utils::log; + use crate::{test_utils::log, Duplex}; use super::*; use futures::{SinkExt, StreamExt}; use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + use tokio::spawn; use tokio_util::compat::TokioAsyncReadCompatExt; pub(crate) fn duplex( @@ -307,6 +308,70 @@ pub(crate) mod test { assert_eq!(r3, data); assert_eq!(r4, data); + Ok(()) + } + #[tokio::test] + async fn left_and_right_sluice() -> Result<()> { + let (ar, bw) = sluice::pipe::pipe(); + let (br, aw) = sluice::pipe::pipe(); + let left = Duplex::new(ar, aw); + let right = Duplex::new(br, bw); + + let mut leftlp = Uint24LELengthPrefixedFraming::new(left); + let mut rightlp = Uint24LELengthPrefixedFraming::new(right); + + // NB sluice has a max "chunk" thing of 4 + // so we limit the data we're sending to 3 things + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle"]; + // NB this sluice pipe + // + for d in data { + rightlp.feed(d.to_vec()).await.unwrap(); + } + let rflush = spawn(async move { + rightlp.flush().await.unwrap(); + rightlp + }); + + let mut result1 = vec![]; + for _ in data { + result1.push(leftlp.next().await.unwrap().unwrap()); + } + let mut rightlp = rflush.await?; + + assert_eq!(result1, data); + + for d in data { + leftlp.feed(d.to_vec()).await.unwrap(); + } + let lflush = spawn(async move { + leftlp.flush().await.unwrap(); + leftlp + }); + + let mut result2 = vec![]; + for _ in data { + result2.push(rightlp.next().await.unwrap().unwrap()); + } + let mut leftlp = lflush.await?; + assert_eq!(result2, data); + + let mut r3 = vec![]; + let mut r4 = vec![]; + + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + leftlp.send(d.to_vec()).await.unwrap(); + } + + for _ in data { + r3.push(rightlp.next().await.unwrap().unwrap()); + r4.push(leftlp.next().await.unwrap().unwrap()); + } + + assert_eq!(r3, data); + assert_eq!(r4, data); + Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index 9fca95e..999aa26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,11 +126,15 @@ mod framing; mod message; mod mqueue; mod noise; +#[cfg(not(feature = "protocol"))] +mod oldmessage; mod protocol; +#[cfg(not(feature = "protocol"))] mod reader; #[cfg(test)] mod test_utils; mod util; +#[cfg(not(feature = "protocol"))] mod writer; /// The wire messages used by the protocol. diff --git a/src/message.rs b/src/message.rs index 90028aa..6d5200c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -13,7 +13,6 @@ const UINT24_HEADER_LEN: usize = 3; /// The type of a data frame. #[derive(Debug, Clone, PartialEq)] pub(crate) enum FrameType { - Raw, Message, } @@ -203,32 +202,6 @@ impl Frame { /// Decodes a frame from a buffer containing multiple concurrent messages. pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { match frame_type { - FrameType::Raw => { - let mut index = 0; - let mut raw_batch: Vec> = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - raw_batch.push( - buf[index + header_len..index + header_len + body_len as usize] - .to_vec(), - ); - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in raw batch", - )); - } - } - Ok(Frame::RawBatch(raw_batch)) - } FrameType::Message => { let mut index = 0; let mut combined_messages: Vec = vec![]; @@ -277,7 +250,6 @@ impl Frame { /// Decode a frame from a buffer. pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { match frame_type { - FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), FrameType::Message => { let (frame, _) = Self::decode_message(buf)?; Ok(frame) diff --git a/src/mqueue.rs b/src/mqueue.rs index b968937..39b4c9b 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -12,7 +12,7 @@ use tracing::{debug, error, info, instrument, trace}; use crate::{ encrypted_framed_message_channel, - message::{decode_channel_messages, ChannelMessage, Encoder as _}, + message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, }; pub(crate) struct MessageIo { @@ -22,16 +22,6 @@ pub(crate) struct MessageIo { use crate::{framing::Uint24LELengthPrefixedFraming, noise::Encrypted}; -pub(crate) fn encrypted_and_framed( - is_initiator: bool, - io: BytesTxRx, -) -> MessageIo>> { - let io = encrypted_framed_message_channel(is_initiator, io); - MessageIo { - io, - write_queue: Default::default(), - } -} impl>> + Sink> + Send + Unpin + 'static> MessageIo { pub(crate) fn new(io: IO) -> Self { Self { @@ -101,7 +91,7 @@ impl>> + Sink> + Send + Unpin + 'static ) -> Poll>> { match Pin::new(&mut self.io).poll_next(cx) { Poll::Ready(Some(Ok(encoded))) => { - match decode_channel_messages(&encoded) { + match decode_many_channel_messages(&encoded) { Ok((messsages, n_read)) => { assert_eq!(n_read, encoded.len()); // I think this is always true Poll::Ready(Ok(messsages)) @@ -142,16 +132,27 @@ impl>> + Sink> + Send + Unpin + 'static mod test { use std::io::Result; - use futures::future::{join, select}; + use futures::{future::select, AsyncRead, AsyncWrite}; use futures_lite::StreamExt; use crate::{ - framing::test::duplex, - message::{decode_channel_messages, ChannelMessage, Encoder as _}, - mqueue::encrypted_and_framed, - schema::{NoData, Open}, - test_utils::log, + encrypted_framed_message_channel, framing::test::duplex, message::ChannelMessage, + schema::NoData, test_utils::log, Encrypted, Uint24LELengthPrefixedFraming, }; + + use super::MessageIo; + pub(crate) fn encrypted_and_framed< + BytesTxRx: AsyncRead + AsyncWrite + Send + Unpin + 'static, + >( + is_initiator: bool, + io: BytesTxRx, + ) -> MessageIo>> { + let io = encrypted_framed_message_channel(is_initiator, io); + MessageIo { + io, + write_queue: Default::default(), + } + } fn new_msg(channel: u64) -> ChannelMessage { ChannelMessage { channel, @@ -162,47 +163,32 @@ mod test { #[tokio::test] async fn mqueue() -> Result<()> { log(); - let m = vec![new_msg(0)]; - let mut buf = vec![0; m.encoded_len()?]; - dbg!(&buf.len()); - dbg!(); - m.encode(&mut buf)?; - dbg!(&buf); - - let res = dbg!(decode_channel_messages(&buf))?; - assert_eq!(vec![new_msg(42402)], res.0); - dbg!(&buf); - - Ok(()) - - /* - let (left, right) = duplex(1024 * 64); - let mut left = encrypted_and_framed(true, left); - let mut right = encrypted_and_framed(false, right); - left.enqueue(new_msg(42)); - right.enqueue(new_msg(38)); - match select(left.next(), right.next()).await { - futures::future::Either::Left(ll) => { - println!( - "left - - ooooooooooooooooooooo - - " - ); - } - futures::future::Either::Right(rr) => { - println!( - "rightllllllllllllllll - - ------------------------- - - " - ); - } + let rtolm = new_msg(38); + let ltorm = new_msg(42); + + let (left, right) = duplex(1024 * 64); + let mut left = encrypted_and_framed(true, left); + let mut right = encrypted_and_framed(false, right); + left.enqueue(ltorm.clone()); + right.enqueue(rtolm.clone()); + + match select(left.next(), right.next()).await { + futures::future::Either::Left((m, _)) => { + if let Some(Ok(res)) = m { + assert_eq!(res, vec![rtolm]); + } else { + panic!(); } - Ok(()) - */ + } + futures::future::Either::Right((m, _)) => { + if let Some(Ok(res)) = m { + assert_eq!(res, vec![ltorm]); + } else { + panic!(); + } + } + } + Ok(()) } } diff --git a/src/oldmessage.rs b/src/oldmessage.rs new file mode 100644 index 0000000..8cb2c61 --- /dev/null +++ b/src/oldmessage.rs @@ -0,0 +1,814 @@ +use crate::schema::*; +use crate::util::{stat_uint24_le, write_uint24_le}; +use hypercore::encoding::{ + CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, +}; +use pretty_hash::fmt as pretty_fmt; +use std::fmt; +use std::io; + +/// The type of a data frame. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum FrameType { + Raw, + Message, +} + +/// Encode data into a buffer. +/// +/// This trait is implemented on data frames and their components +/// (channel messages, messages, and individual message types through prost). +pub(crate) trait Encoder: Sized + fmt::Debug { + /// Calculates the length that the encoded message needs. + fn encoded_len(&mut self) -> Result; + + /// Encodes the message to a buffer. + /// + /// An error will be returned if the buffer does not have sufficient capacity. + fn encode(&mut self, buf: &mut [u8]) -> Result; +} + +impl Encoder for &[u8] { + fn encoded_len(&mut self) -> Result { + Ok(self.len()) + } + + fn encode(&mut self, buf: &mut [u8]) -> Result { + let len = self.encoded_len()?; + if len > buf.len() { + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); + } + buf[..len].copy_from_slice(&self[..]); + Ok(len) + } +} + +/// A frame of data, either a buffer or a message. +#[derive(Clone, PartialEq)] +pub(crate) enum Frame { + /// A raw batch binary buffer. Used in the handshaking phase. + RawBatch(Vec>), + /// Message batch, containing one or more channel messsages. Used for everything after the handshake. + MessageBatch(Vec), +} + +impl fmt::Debug for Frame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), + Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), + } + } +} + +impl From for Frame { + fn from(m: ChannelMessage) -> Self { + Self::MessageBatch(vec![m]) + } +} + +impl From> for Frame { + fn from(m: Vec) -> Self { + Self::RawBatch(vec![m]) + } +} + +impl Frame { + /// Decodes a frame from a buffer containing multiple concurrent messages. + pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { + match frame_type { + FrameType::Raw => { + let mut index = 0; + let mut raw_batch: Vec> = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + raw_batch.push( + buf[index + header_len..index + header_len + body_len as usize] + .to_vec(), + ); + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in raw batch", + )); + } + } + Ok(Frame::RawBatch(raw_batch)) + } + FrameType::Message => { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (frame, length) = Self::decode_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + body_len, + length + ); + } + if let Frame::MessageBatch(messages) = frame { + for message in messages { + combined_messages.push(message); + } + } else { + unreachable!("Can not get Raw messages"); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok(Frame::MessageBatch(combined_messages)) + } + } + } + + /// Decode a frame from a buffer. + pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { + match frame_type { + FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), + FrameType::Message => { + let (frame, _) = Self::decode_message(buf)?; + Ok(frame) + } + } + } + + fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { + println!("decode_message {buf:02X?}"); + // buffer length >= 3 or more and starts with 0 is message batch + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // len >= and + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } + } + + fn preencode(&mut self, state: &mut State) -> Result { + match self { + Self::RawBatch(raw_batch) => { + for raw in raw_batch { + state.add_end(raw.as_slice().encoded_len()?)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(messages) => { + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter_mut() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + } + } + Ok(state.end()) + } +} + +impl Encoder for Frame { + fn encoded_len(&mut self) -> Result { + let body_len = self.preencode(&mut State::new())?; + match self { + Self::RawBatch(_) => Ok(body_len), + Self::MessageBatch(_) => Ok(3 + body_len), + } + } + + fn encode(&mut self, buf: &mut [u8]) -> Result { + let mut state = State::new(); + let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; + let body_len = self.preencode(&mut state)?; + let len = body_len + header_len; + if buf.len() < len { + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); + } + match self { + Self::RawBatch(ref raw_batch) => { + for raw in raw_batch { + raw.as_slice().encode(buf)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(ref mut messages) => { + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter_mut() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } + } + }; + Ok(len) + } +} + +/// A protocol message. +#[derive(Debug, Clone, PartialEq)] +#[allow(missing_docs)] +pub enum Message { + Open(Open), + Close(Close), + Synchronize(Synchronize), + Request(Request), + Cancel(Cancel), + Data(Data), + NoData(NoData), + Want(Want), + Unwant(Unwant), + Bitfield(Bitfield), + Range(Range), + Extension(Extension), + /// A local signalling message never sent over the wire + LocalSignal((String, Vec)), +} + +impl Message { + /// Wire type of this message. + pub(crate) fn typ(&self) -> u64 { + match self { + Self::Synchronize(_) => 0, + Self::Request(_) => 1, + Self::Cancel(_) => 2, + Self::Data(_) => 3, + Self::NoData(_) => 4, + Self::Want(_) => 5, + Self::Unwant(_) => 6, + Self::Bitfield(_) => 7, + Self::Range(_) => 8, + Self::Extension(_) => 9, + value => unimplemented!("{} does not have a type", value), + } + } + + /// Decode a message from a buffer based on type. + pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { + let mut state = HypercoreState::from_buffer(buf); + let message = match typ { + 0 => Ok(Self::Synchronize((*state).decode(buf)?)), + 1 => Ok(Self::Request(state.decode(buf)?)), + 2 => Ok(Self::Cancel((*state).decode(buf)?)), + 3 => Ok(Self::Data(state.decode(buf)?)), + 4 => Ok(Self::NoData((*state).decode(buf)?)), + 5 => Ok(Self::Want((*state).decode(buf)?)), + 6 => Ok(Self::Unwant((*state).decode(buf)?)), + 7 => Ok(Self::Bitfield((*state).decode(buf)?)), + 8 => Ok(Self::Range((*state).decode(buf)?)), + 9 => Ok(Self::Extension((*state).decode(buf)?)), + _ => Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )), + }?; + Ok((message, state.start())) + } + + /// Pre-encodes a message to state, returns length + pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { + match self { + Self::Open(ref message) => state.0.preencode(message)?, + Self::Close(ref message) => state.0.preencode(message)?, + Self::Synchronize(ref message) => state.0.preencode(message)?, + Self::Request(ref message) => state.preencode(message)?, + Self::Cancel(ref message) => state.0.preencode(message)?, + Self::Data(ref message) => state.preencode(message)?, + Self::NoData(ref message) => state.0.preencode(message)?, + Self::Want(ref message) => state.0.preencode(message)?, + Self::Unwant(ref message) => state.0.preencode(message)?, + Self::Bitfield(ref message) => state.0.preencode(message)?, + Self::Range(ref message) => state.0.preencode(message)?, + Self::Extension(ref message) => state.0.preencode(message)?, + Self::LocalSignal(_) => 0, + }; + Ok(state.end()) + } + + /// Encodes a message to a given buffer, using preencoded state, results size + pub(crate) fn encode( + &self, + state: &mut HypercoreState, + buf: &mut [u8], + ) -> Result { + match self { + Self::Open(ref message) => state.0.encode(message, buf)?, + Self::Close(ref message) => state.0.encode(message, buf)?, + Self::Synchronize(ref message) => state.0.encode(message, buf)?, + Self::Request(ref message) => state.encode(message, buf)?, + Self::Cancel(ref message) => state.0.encode(message, buf)?, + Self::Data(ref message) => state.encode(message, buf)?, + Self::NoData(ref message) => state.0.encode(message, buf)?, + Self::Want(ref message) => state.0.encode(message, buf)?, + Self::Unwant(ref message) => state.0.encode(message, buf)?, + Self::Bitfield(ref message) => state.0.encode(message, buf)?, + Self::Range(ref message) => state.0.encode(message, buf)?, + Self::Extension(ref message) => state.0.encode(message, buf)?, + Self::LocalSignal(_) => 0, + }; + Ok(state.start()) + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Open(msg) => write!( + f, + "Open(discovery_key: {}, capability <{}>)", + pretty_fmt(&msg.discovery_key).unwrap(), + msg.capability.as_ref().map_or(0, |c| c.len()) + ), + Self::Data(msg) => write!( + f, + "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})", + msg.request, + msg.fork, + msg.block.is_some(), + msg.hash.is_some(), + msg.seek.is_some(), + msg.upgrade.is_some(), + ), + _ => write!(f, "{:?}", &self), + } + } +} + +/// A message on a channel. +#[derive(Clone)] +pub(crate) struct ChannelMessage { + pub(crate) channel: u64, + pub(crate) message: Message, + state: Option, +} + +impl PartialEq for ChannelMessage { + fn eq(&self, other: &Self) -> bool { + self.channel == other.channel && self.message == other.message + } +} + +impl fmt::Debug for ChannelMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ChannelMessage({}, {})", self.channel, self.message) + } +} + +impl ChannelMessage { + /// Create a new message. + pub(crate) fn new(channel: u64, message: Message) -> Self { + Self { + channel, + message, + state: None, + } + } + + /// Consume self and return (channel, Message). + pub(crate) fn into_split(self) -> (u64, Message) { + (self.channel, self.message) + } + + /// Decodes an open message for a channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { + if buf.len() <= 5 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received too short Open message", + )); + } + + let mut state = State::new_with_start_and_end(0, buf.len()); + let open_msg: Open = state.decode(buf)?; + Ok(( + Self { + channel: open_msg.channel, + message: Message::Open(open_msg), + state: None, + }, + state.start(), + )) + } + + /// Decodes a close message for a channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + if buf.is_empty() { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received too short Close message", + )); + } + let mut state = State::new_with_start_and_end(0, buf.len()); + let close_msg: Close = state.decode(buf)?; + Ok(( + Self { + channel: close_msg.channel, + message: Message::Close(close_msg), + state: None, + }, + state.start(), + )) + } + + /// Decode a normal channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + if buf.len() <= 1 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received empty message", + )); + } + let mut state = State::from_buffer(buf); + let typ: u64 = state.decode(buf)?; + let (message, length) = Message::decode(&buf[state.start()..], typ)?; + Ok(( + Self { + channel, + message, + state: None, + }, + state.start() + length, + )) + } + + /// Performance optimization for letting calling encoded_len() already do + /// the preencode phase of compact_encoding. + fn prepare_state(&mut self) -> Result<(), EncodingError> { + if self.state.is_none() { + let state = if let Message::Open(_) = self.message { + // Open message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else if let Message::Close(_) = self.message { + // Close message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else { + // The header is the channel id uint followed by message type uint + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 + let mut state = HypercoreState::new(); + let typ = self.message.typ(); + (*state).preencode(&typ)?; + self.message.preencode(&mut state)?; + state + }; + self.state = Some(state); + } + Ok(()) + } +} + +impl Encoder for ChannelMessage { + fn encoded_len(&mut self) -> Result { + self.prepare_state()?; + Ok(self.state.as_ref().unwrap().end()) + } + + fn encode(&mut self, buf: &mut [u8]) -> Result { + self.prepare_state()?; + let state = self.state.as_mut().unwrap(); + if let Message::Open(_) = self.message { + // Open message is different in that the type byte is missing + self.message.encode(state, buf)?; + } else if let Message::Close(_) = self.message { + // Close message is different in that the type byte is missing + self.message.encode(state, buf)?; + } else { + let typ = self.message.typ(); + state.0.encode(&typ, buf)?; + self.message.encode(state, buf)?; + } + Ok(state.start()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hypercore::{ + DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, + }; + + macro_rules! message_enc_dec { + ($( $msg:expr ),*) => { + $( + let channel = rand::random::() as u64; + let mut channel_message = ChannelMessage::new(channel, $msg); + let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); + let mut buf = vec![0u8; encoded_len]; + let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); + let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); + assert_eq!(channel, decoded.0); + assert_eq!($msg, decoded.1); + )* + } + } + #[test] + fn frame_encode_decode() -> std::io::Result<()> { + let msg = Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }); + + let channel = rand::random::() as u64; + let channel_message = ChannelMessage::new(channel, msg); + + let mut frame = Frame::from(channel_message); + let mut buf = vec![0; frame.encoded_len()?]; + frame.encode(&mut buf)?; + let res_frame = Frame::decode_multiple(&buf, &FrameType::Message)?; + assert_eq!(res_frame, frame); + Ok(()) + } + #[test] + fn frame_encode_decode_bar() -> std::io::Result<()> { + let msg = Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }); + + //let channel = rand::random::() as u64; + let channel = 42; + let channel_message = ChannelMessage::new(channel, msg); + + let mut frame = Frame::from(channel_message.clone()); + + let mut fbuf = vec![0; frame.encoded_len()?]; + + frame.encode(&mut fbuf)?; + + let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; + assert_eq!(fres, frame); + ///assert_eq!(cres, cmvec); + //println!("REG frame buf\t{frame_buf:02X?}"); + //let res_frame = Frame::decode(&frame_buf, &FrameType::Message)?; + //dbg!(res_frame); + //let res_frame = Frame::decode_multiple(&frame_buf, &FrameType::Message)?; + //dbg!(res_frame); + + //let mut vec_frame_buf = vec![0; vec_frame.encoded_len()?]; + //vec_frame.encode(&mut vec_frame_buf)?; + + //assert_eq!(vec_frame_buf, frame_buf); + //println!("VEC frame buf\t{vec_frame_buf:02X?}"); + + //let res_frame = Frame::decode(&vec_frame_buf, &FrameType::Message)?; + //dbg!(res_frame); + //let res_frame = Frame::decode_multiple(&vec_frame_buf, &FrameType::Message)?; + //dbg!(&res_frame); + + //let (msg, _len) = decode_channel_messages(&vec_frame_buf)?; + //assert_eq!(msg, vec![channel_message]); + + //assert_eq!(res_frame, frame); + Ok(()) + } + + #[test] + fn message_encode_decode() { + message_enc_dec! { + Message::Synchronize(Synchronize{ + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }), + Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0 + }), + seek: Some(RequestSeek { + bytes: 10 + }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10 + }) + }), + Message::Cancel(Cancel { + request: 1, + }), + Message::Data(Data{ + request: 1, + fork: 5, + block: Some(DataBlock { + index: 5, + nodes: vec![Node::new(1, vec![0x01; 32], 100)], + value: vec![0xFF; 10] + }), + hash: Some(DataHash { + index: 20, + nodes: vec![Node::new(2, vec![0x02; 32], 200)], + }), + seek: Some(DataSeek { + bytes: 10, + nodes: vec![Node::new(3, vec![0x03; 32], 300)], + }), + upgrade: Some(DataUpgrade { + start: 0, + length: 10, + nodes: vec![Node::new(4, vec![0x04; 32], 400)], + additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], + signature: vec![0xAB; 32] + }) + }), + Message::NoData(NoData { + request: 2, + }), + Message::Want(Want { + start: 0, + length: 100, + }), + Message::Unwant(Unwant { + start: 10, + length: 2, + }), + Message::Bitfield(Bitfield { + start: 20, + bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], + }), + Message::Range(Range { + drop: true, + start: 12345, + length: 100000 + }), + Message::Extension(Extension { + name: "custom_extension/v1/open".to_string(), + message: vec![0x44, 20] + }) + }; + } +} diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index cb71f50..acbdc05 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -13,8 +13,8 @@ use tracing::instrument; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, Message}; +use crate::crypto::HandshakeResult; +use crate::message::{ChannelMessage, Message}; use crate::mqueue::MessageIo; use crate::util::{map_channel_err, pretty_hash}; use crate::{ @@ -30,7 +30,6 @@ macro_rules! return_error { } const CHANNEL_CAP: usize = 1000; -const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); /// Options for a Protocol instance. #[derive(Debug)] @@ -112,32 +111,9 @@ impl fmt::Debug for Event { } } -/// Protocol state -#[allow(clippy::large_enum_variant)] -pub(crate) enum State { - NotInitialized, - // The Handshake struct sits behind an option only so that we can .take() - // it out, it's never actually empty when in State::Handshake. - Handshake(Option), - SecretStream(Option), - Established, -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::NotInitialized => write!(f, "NotInitialized"), - State::Handshake(_) => write!(f, "Handshaking"), - State::SecretStream(_) => write!(f, "SecretStream"), - State::Established => write!(f, "Established"), - } - } -} - /// A Protocol stream. pub struct Protocol { io: MessageIo>>, - state: State, options: Options, handshake: Option, channels: ChannelMap, @@ -153,7 +129,6 @@ impl std::fmt::Debug for Protocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Protocol") //.field("io", &self.io) - .field("state", &self.state) .field("options", &self.options) .field("handshake", &self.handshake) .field("channels", &self.channels) @@ -182,7 +157,6 @@ where Protocol { io: MessageIo::new(encrypted_framed_message_channel(options.is_initiator, io)), options, - state: State::NotInitialized, channels: ChannelMap::new(), handshake: None, command_rx, @@ -272,6 +246,7 @@ where if let Some(event) = this.queued_events.pop_front() { Poll::Ready(Ok(event)) } else { + cx.waker().wake_by_ref(); Poll::Pending } } @@ -287,6 +262,7 @@ where /// TODO Poll the keepalive timer and queue a ping message if needed. fn poll_keepalive(&mut self, _cx: &mut Context<'_>) { /* + const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); if Pin::new(&mut self.keepalive).poll(cx).is_ready() { if let State::Established = self.state { // 24 bit header for the empty message, hence the 3 @@ -295,7 +271,7 @@ where } self.keepalive.reset(KEEPALIVE_DURATION); } - */ + */ } // just handles Close and LocalSignal?? @@ -322,6 +298,7 @@ where } /// Poll for inbound messages and processs them. + #[instrument(skip_all)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { match self.io.poll_inbound(cx) { @@ -336,6 +313,7 @@ where /// Poll for outbound messages and write them. /// Reads messages from Self::outbound and sends them over io + #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { // if no parking or setup in progress @@ -365,6 +343,7 @@ where Ok(()) } + #[instrument(skip_all)] fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -387,6 +366,7 @@ where } /// Open a Channel with the given key. Adding it to our channel map + #[instrument(skip_all)] fn command_open(&mut self, key: Key) -> Result<()> { // Create a new channel. let channel_handle = self.channels.attach_local(key); diff --git a/src/protocol/old.rs b/src/protocol/old.rs index 01af713..2c7d4c5 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -10,7 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::trace; +use tracing::{instrument, trace}; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; @@ -252,6 +252,7 @@ where self.io } + #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -357,6 +358,7 @@ where } /// Poll for inbound messages and processs them. + #[instrument(skip_all)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { let msg = self.read_state.poll_reader(cx, &mut self.io); @@ -371,6 +373,7 @@ where } /// Poll for outbound messages and write them. + #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { @@ -398,6 +401,7 @@ where } } + #[instrument(skip_all)] fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { match frame { Frame::RawBatch(raw_batch) => { @@ -539,6 +543,7 @@ where } /// Open a Channel with the given key. Adding it to our channel map + #[instrument(skip_all)] fn command_open(&mut self, key: Key) -> Result<()> { // Create a new channel. let channel_handle = self.channels.attach_local(key); diff --git a/src/test_utils.rs b/src/test_utils.rs index b12f440..e67d756 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -73,7 +73,7 @@ impl TwoWay { } } -pub fn log() { +pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { diff --git a/src/util.rs b/src/util.rs index 1350728..21e4c75 100644 --- a/src/util.rs +++ b/src/util.rs @@ -29,7 +29,33 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { } pub(crate) const UINT_24_LENGTH: usize = 3; +#[cfg(feature = "uint24")] +mod uint24 { + use super::UINT_24_LENGTH; + pub struct Uint24LE([u8; UINT_24_LENGTH]); + impl Uint24LE { + pub const MAX_USIZE: usize = 16777215; + pub const SIZE: usize = UINT_24_LENGTH; + } + + impl AsRef<[u8; 3]> for Uint24LE { + fn as_ref(&self) -> &[u8; 3] { + &self.0 + } + } + // TODO we are using std::io::Error everywhere so I won't add a new one but this isn't ideal + impl TryFrom for Uint24LE { + type Error = Error; + + fn try_from(n: usize) -> Result { + if n > Self::MAX_USIZE { + todo!() + } + Ok(Self([(n & 255) as u8, (n >> 8) as u8, (n >> 16) as u8])) + } + } +} #[inline] pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; diff --git a/tests/_util.rs b/tests/_util.rs index d1fa197..aec496d 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -4,6 +4,7 @@ use futures_lite::StreamExt; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; use std::io; +use tokio::io::DuplexStream; use tokio::task::JoinHandle; #[allow(unused)] @@ -23,7 +24,16 @@ pub(crate) fn log() { }); } +type TokioDuplex = tokio_util::compat::Compat; + +pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { + use tokio_util::compat::TokioAsyncReadCompatExt as _; + let (left, right) = tokio::io::duplex(channel_size); + (left.compat(), right.compat()) +} + pub type MemoryProtocol = Protocol>; + pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { let (ar, bw) = sluice::pipe::pipe(); let (br, aw) = sluice::pipe::pipe(); @@ -35,6 +45,15 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) Ok((a, b)) } +pub async fn create_pair_memory2() -> io::Result<(Protocol, Protocol)> { + let (left, right) = duplex(1024 * 1024); + let a = ProtocolBuilder::new(true); + let b = ProtocolBuilder::new(false); + let a = a.connect(left); + let b = b.connect(right); + Ok((a, b)) +} + pub fn next_event(mut proto: Protocol) -> JoinHandle<(Protocol, io::Result)> where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, diff --git a/tests/basic.rs b/tests/basic.rs index 92e4c8b..a102bc0 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,5 +1,6 @@ use _util::{ - create_pair_memory, drive_until_channel, event_channel, event_discovery_key, next_event, + create_pair_memory, create_pair_memory2, drive_until_channel, event_channel, + event_discovery_key, next_event, }; use futures_lite::StreamExt; use hypercore_protocol::{discovery_key, Event, Message}; @@ -170,7 +171,6 @@ async fn open_close_channels() -> anyhow::Result<()> { assert_eq!(msg_b, Some(want(0, 10))); eprintln!("all good!"); - Ok(()) } From fae480439eabab3d849946b8d24eeb57fac668de Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 12:31:37 -0400 Subject: [PATCH 053/135] doc comments --- src/reader.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/reader.rs b/src/reader.rs index 5664d56..cc80c5c 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -116,6 +116,7 @@ impl ReadState { // What happens if decrypt fails here? // next call to this func would have same start, corret? // so it'd fail repeatedly? + // Why not just decrypt to the end? for (index, header_len, body_len) in segments { let de = cipher.decrypt( &mut self.buf[self.start + index..end], @@ -144,10 +145,13 @@ impl ReadState { fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { let (last_index, last_header_len, last_body_len) = last_segment; let total_incoming_length = last_index + last_header_len + last_body_len; + if self.buf.len() < total_incoming_length { // The incoming segments will not fit into the buffer, need to resize it self.buf.resize(total_incoming_length, 0u8); } + + // to-read length let temp = self.buf[self.start..].to_vec(); let len = temp.len(); self.buf[..len].copy_from_slice(&temp[..]); @@ -187,17 +191,21 @@ impl ReadState { } } + // one message within an encrypted frame + // encrypted frame [ u24 header + encoded_frame [ ]] Step::Body { header_len, body_len, } => { let message_len = header_len + body_len; let range = self.start + header_len..self.start + message_len; + // this includes a a frame header let frame = Frame::decode(&self.buf[range], &self.frame_type); self.start += message_len; self.step = Step::Header; return Some(frame); } + // multiple message within an encrypted frame Step::Batch => { let frame = Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); @@ -211,7 +219,9 @@ impl ReadState { } #[allow(clippy::type_complexity)] -// get segments from buff +/// Given a buff get all the segments (starting_index_in_buffer, header_len, buffer_len) +/// returns returns `(true, segments)` if we read all segments, but (false, ..) if there +/// are remaining segments fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { let mut index: usize = 0; let len = buf.len(); From b019e9f143e8c084dc3cb33743b3687d99964cc2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 12:43:51 -0400 Subject: [PATCH 054/135] feature gate oldmessage --- Cargo.toml | 2 ++ src/lib.rs | 3 +-- src/message/mod.rs | 11 +++++++++++ src/{message.rs => message/modern.rs} | 0 src/{oldmessage.rs => message/old.rs} | 2 +- 5 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 src/message/mod.rs rename src/{message.rs => message/modern.rs} (100%) rename src/{oldmessage.rs => message/old.rs} (99%) diff --git a/Cargo.toml b/Cargo.toml index 7c15c8b..7aeb9e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,8 @@ tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse", "protocol"] +#default = ["tokio", "sparse"] +uint24 = [] protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" diff --git a/src/lib.rs b/src/lib.rs index 999aa26..1990b97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,10 +124,9 @@ mod crypto; mod duplex; mod framing; mod message; +#[cfg(feature = "protocol")] mod mqueue; mod noise; -#[cfg(not(feature = "protocol"))] -mod oldmessage; mod protocol; #[cfg(not(feature = "protocol"))] mod reader; diff --git a/src/message/mod.rs b/src/message/mod.rs new file mode 100644 index 0000000..1526f3a --- /dev/null +++ b/src/message/mod.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "protocol")] +mod modern; + +#[cfg(feature = "protocol")] +pub use modern::*; + +#[cfg(not(feature = "protocol"))] +mod old; + +#[cfg(not(feature = "protocol"))] +pub use old::*; diff --git a/src/message.rs b/src/message/modern.rs similarity index 100% rename from src/message.rs rename to src/message/modern.rs diff --git a/src/oldmessage.rs b/src/message/old.rs similarity index 99% rename from src/oldmessage.rs rename to src/message/old.rs index 8cb2c61..373eea2 100644 --- a/src/oldmessage.rs +++ b/src/message/old.rs @@ -703,7 +703,7 @@ mod tests { let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; assert_eq!(fres, frame); - ///assert_eq!(cres, cmvec); + //assert_eq!(cres, cmvec); //println!("REG frame buf\t{frame_buf:02X?}"); //let res_frame = Frame::decode(&frame_buf, &FrameType::Message)?; //dbg!(res_frame); From 24fe4a1f2fbf2e9373f413c29760a6e3f8471490 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 13:19:07 -0400 Subject: [PATCH 055/135] fix all warnings --- src/constants.rs | 12 +- src/crypto/cipher.rs | 168 +++++++++++++------------- src/crypto/handshake.rs | 5 +- src/crypto/mod.rs | 4 + src/message/modern.rs | 257 +++++++++++++++++++--------------------- src/mqueue.rs | 13 +- src/protocol/modern.rs | 5 - src/protocol/old.rs | 5 - src/writer.rs | 4 - tests/basic.rs | 14 +-- 10 files changed, 232 insertions(+), 255 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 77285ee..73d0748 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,15 +1,17 @@ /// Seed for the discovery key hash pub(crate) const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; -/// Default timeout (in seconds) -pub(crate) const DEFAULT_TIMEOUT: u32 = 20; - /// Default keepalive interval (in seconds) pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; +/// v10: Protocol name +pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; + // 16,78MB is the max encrypted wire message size (will be much smaller usually). // This limitation stems from the 24bit header. +#[cfg(not(feature = "protocol"))] pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; -/// v10: Protocol name -pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; +/// Default timeout (in seconds) +#[cfg(not(feature = "protocol"))] +pub(crate) const DEFAULT_TIMEOUT: u32 = 20; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index f2dc9b9..8ef6c9e 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -1,5 +1,4 @@ use super::HandshakeResult; -use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, @@ -11,28 +10,17 @@ use std::io; const STREAM_ID_LENGTH: usize = 32; const KEY_LENGTH: usize = 32; -const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; pub(crate) struct DecryptCipher { pull_stream: PullStream, } -pub(crate) struct EncryptCipher { - push_stream: PushStream, -} - impl std::fmt::Debug for DecryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "DecryptCipher(crypto_secretstream)") } } -impl std::fmt::Debug for EncryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") - } -} - impl DecryptCipher { pub(crate) fn from_handshake_rx_and_init_msg( handshake_result: &HandshakeResult, @@ -75,25 +63,6 @@ impl DecryptCipher { let pull_stream = PullStream::init(Header::from(header), &key); Ok(Self { pull_stream }) } - - pub(crate) fn decrypt( - &mut self, - buf: &mut [u8], - header_len: usize, - body_len: usize, - ) -> io::Result { - let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; - let decrypted_len = to_decrypt.len(); - write_uint24_le(decrypted_len, buf); - let decrypted_end = header_len + to_decrypt.len(); - buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); - // Set extra bytes in the buffer to 0 - // Why? - let encrypted_end = header_len + body_len; - buf[decrypted_end..encrypted_end].fill(0x00); - Ok(decrypted_end) - } - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { let mut to_decrypt = buf.to_vec(); let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { @@ -103,63 +72,102 @@ impl DecryptCipher { } } -impl EncryptCipher { - pub(crate) fn from_handshake_tx( - handshake_result: &HandshakeResult, - ) -> std::io::Result<(Self, Vec)> { - let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] - .try_into() - .expect("split_tx with incorrect length"); - let key = Key::from(key); +#[cfg(not(feature = "protocol"))] +mod encrypt_cipher { + use super::*; + use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; + const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; - let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; - write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); - write_stream_id( - &handshake_result.handshake_hash, - handshake_result.is_initiator, - &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], - ); + pub(crate) struct EncryptCipher { + push_stream: PushStream, + } - let (header, push_stream) = PushStream::init(OsRng, &key); - let header = header.as_ref(); - header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); - let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) + impl std::fmt::Debug for EncryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EncryptCipher(crypto_secretstream)") + } } - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 + impl EncryptCipher { + pub(crate) fn from_handshake_tx( + handshake_result: &HandshakeResult, + ) -> std::io::Result<(Self, Vec)> { + let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + + let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; + write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); + write_stream_id( + &handshake_result.handshake_hash, + handshake_result.is_initiator, + &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], + ); + + let (header, push_stream) = PushStream::init(OsRng, &key); + let header = header.as_ref(); + header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); + let msg = header_message.to_vec(); + Ok((Self { push_stream }, msg)) + } + + /// Get the length needed for encryption, that includes padding. + pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { + // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe + // extra room. + // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ + plaintext_len + 2 * 15 + } + + /// Encrypts message in the given buffer to the same buffer, returns number of bytes + /// of total message. + /// NB: we expect the first 3 bytes of the buffer to a size header. + /// The encrypted buffer will also be written prepended with a size header, with it's new size. + pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { + let stat = stat_uint24_le(buf); + if let Some((header_len, body_len)) = stat { + let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); + self.push_stream + .push(&mut to_encrypt, &[], Tag::Message) + .map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + })?; + let encrypted_len = to_encrypt.len(); + write_uint24_le(encrypted_len, buf); + buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); + Ok(header_len + encrypted_len) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Could not encrypt invalid data, len: {}", buf.len()), + )) + } + } } - /// Encrypts message in the given buffer to the same buffer, returns number of bytes - /// of total message. - /// NB: we expect the first 3 bytes of the buffer to a size header. - /// The encrypted buffer will also be written prepended with a size header, with it's new size. - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { - let stat = stat_uint24_le(buf); - if let Some((header_len, body_len)) = stat { - let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); - self.push_stream - .push(&mut to_encrypt, &[], Tag::Message) - .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) - })?; - let encrypted_len = to_encrypt.len(); - write_uint24_le(encrypted_len, buf); - buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(header_len + encrypted_len) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Could not encrypt invalid data, len: {}", buf.len()), - )) + impl DecryptCipher { + pub(crate) fn decrypt( + &mut self, + buf: &mut [u8], + header_len: usize, + body_len: usize, + ) -> io::Result { + let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; + let decrypted_len = to_decrypt.len(); + write_uint24_le(decrypted_len, buf); + let decrypted_end = header_len + to_decrypt.len(); + buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); + // Set extra bytes in the buffer to 0 + // Why? + let encrypted_end = header_len + body_len; + buf[decrypted_end..encrypted_end].fill(0x00); + Ok(decrypted_end) } } } +#[cfg(not(feature = "protocol"))] +pub use encrypt_cipher::*; // NB: These values come from Javascript-side // @@ -197,7 +205,7 @@ pub(crate) struct RawEncryptCipher { impl std::fmt::Debug for RawEncryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") + write!(f, "RawEncryptCipher(crypto_secretstream)") } } diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 74a1ada..fc5ad02 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -1,5 +1,4 @@ use super::curve::CurveResolver; -use crate::util::wrap_uint24_le; use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, @@ -111,8 +110,9 @@ impl Handshake { Ok(None) } } + #[cfg(not(feature = "protocol"))] pub(crate) fn start(&mut self) -> Result>> { - Ok(self.start_raw()?.map(|x| wrap_uint24_le(&x))) + Ok(self.start_raw()?.map(|x| crate::util::wrap_uint24_le(&x))) } pub(crate) fn complete(&self) -> bool { @@ -177,6 +177,7 @@ impl Handshake { Ok(tx_buf) } // reads in `msg` without framing bytes, but emits msg WITH framing bytes + #[cfg(not(feature = "protocol"))] pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { Ok(self.read_raw(msg)?.map(|x| wrap_uint24_le(&x))) } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 27f12b4..3de592a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,9 @@ mod cipher; mod curve; mod handshake; +#[cfg(not(feature = "protocol"))] pub(crate) use cipher::{DecryptCipher, EncryptCipher, RawEncryptCipher}; + +#[cfg(feature = "protocol")] +pub(crate) use cipher::{DecryptCipher, RawEncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/message/modern.rs b/src/message/modern.rs index 6d5200c..8b16988 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -10,12 +10,6 @@ use tracing::instrument; const UINT24_HEADER_LEN: usize = 3; -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub(crate) enum FrameType { - Message, -} - /// Encode data into a buffer. /// /// This trait is implemented on data frames and their components @@ -200,136 +194,6 @@ pub(crate) fn decode_one_channel_message( impl Frame { /// Decodes a frame from a buffer containing multiple concurrent messages. - pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Message => { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok(Frame::MessageBatch(combined_messages)) - } - } - } - - /// Decode a frame from a buffer. - pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Message => { - let (frame, _) = Self::decode_message(buf)?; - Ok(frame) - } - } - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - println!("decode_message {buf:02X?}"); - // buffer length >= 3 or more and starts with 0 is message batch - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // len >= and - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } - } - fn preencode(&self, state: &mut State) -> Result { match self { Self::RawBatch(raw_batch) => { @@ -989,6 +853,125 @@ mod tests { }), ] } + impl Frame { + pub(crate) fn decode_multiple(buf: &[u8]) -> Result { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (frame, length) = Self::decode_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + body_len, + length + ); + } + if let Frame::MessageBatch(messages) = frame { + for message in messages { + combined_messages.push(message); + } + } else { + unreachable!("Can not get Raw messages"); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok(Frame::MessageBatch(combined_messages)) + } + + fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { + println!("decode_message {buf:02X?}"); + // buffer length >= 3 or more and starts with 0 is message batch + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = + ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // len >= and + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = + ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } + } + } #[test] fn compare_with_frame_encoding_decoding() -> std::io::Result<()> { @@ -1008,7 +991,7 @@ mod tests { assert_eq!(cbuf, fbuf); - let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; + let fres = Frame::decode_multiple(&fbuf)?; assert_eq!(fres, frame); let cres_m = decode_many_channel_messages(&cbuf)?.0; assert_eq!(cres_m, cmvec); diff --git a/src/mqueue.rs b/src/mqueue.rs index 39b4c9b..c802d34 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -7,21 +7,16 @@ use std::{ task::{Context, Poll}, }; -use futures::{AsyncRead, AsyncWrite, Sink, Stream}; -use tracing::{debug, error, info, instrument, trace}; +use futures::{Sink, Stream}; +use tracing::{debug, error, instrument, trace}; -use crate::{ - encrypted_framed_message_channel, - message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, -}; +use crate::message::{decode_many_channel_messages, ChannelMessage, Encoder as _}; pub(crate) struct MessageIo { io: IO, write_queue: VecDeque, } -use crate::{framing::Uint24LELengthPrefixedFraming, noise::Encrypted}; - impl>> + Sink> + Send + Unpin + 'static> MessageIo { pub(crate) fn new(io: IO) -> Self { Self { @@ -68,7 +63,7 @@ impl>> + Sink> + Send + Unpin + 'static } Poll::Ready(Err(_e)) => { error!("Error flushing"); - return todo!(); + todo!() } Poll::Pending => { cx.waker().wake_by_ref(); diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index acbdc05..f9bfa80 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -216,11 +216,6 @@ where self.channels.iter().map(|c| c.discovery_key()) } - /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> MessageIo>> { - self.io - } - #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); diff --git a/src/protocol/old.rs b/src/protocol/old.rs index 2c7d4c5..b5f44ec 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -247,11 +247,6 @@ where self.channels.iter().map(|c| c.discovery_key()) } - /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> IO { - self.io - } - #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); diff --git a/src/writer.rs b/src/writer.rs index d91adfb..df89949 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,10 +9,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; const BUF_SIZE: usize = 1024 * 64; -// This is the largest size that will fit in u24. -// a message is larger than this we should error. -// also check message is smaller than this when we are encrypting. -const _MAX_MSG_SIZE: usize = 2usize.pow(24) - 1; #[derive(Debug)] pub(crate) enum Step { diff --git a/tests/basic.rs b/tests/basic.rs index a102bc0..5730dbc 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -12,22 +12,20 @@ mod _util; #[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - let (proto_a, proto_b) = create_pair_memory().await?; + _util::log(); + let (proto_a, proto_b) = create_pair_memory2().await?; - dbg!(); let next_a = next_event(proto_a); - dbg!(); let next_b = next_event(proto_b); - dbg!(); - let (proto_b, event_b) = next_b.await?; - dbg!(); let (mut proto_a, event_a) = next_a.await?; + let (proto_b, event_b) = next_b.await?; + //let (a, b) = join(next_a, next_b).await; - dbg!(); //let (mut proto_a, event_a) = a?; - dbg!(); //let (proto_b, event_b) = b?; + dbg!(&event_a); + dbg!(&event_b); assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); From b23bead8bdb5aaa974724b5513bce68afcfd2805 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 1 Apr 2025 16:24:05 -0400 Subject: [PATCH 056/135] Add tracing::instrument --- src/channels.rs | 13 +++++++------ src/crypto/handshake.rs | 4 +++- src/message/modern.rs | 2 +- src/noise.rs | 3 ++- src/protocol/modern.rs | 8 +++++++- src/protocol/old.rs | 12 ++++++++++-- src/writer.rs | 2 ++ 7 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/channels.rs b/src/channels.rs index c2e22f8..8e82116 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -13,7 +13,7 @@ use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::Poll; -use tracing::debug; +use tracing::instrument; /// A protocol channel. /// @@ -93,7 +93,6 @@ impl Channel { "Channel is closed", )); } - debug!("TX:\n{message:?}\n"); let message = ChannelMessage::new(self.local_id as u64, message); self.outbound_tx .send(vec![message]) @@ -122,10 +121,7 @@ impl Channel { let messages = messages .iter() - .map(|message| { - debug!("TX:\n{message:?}\n"); - ChannelMessage::new(self.local_id as u64, message.clone()) - }) + .map(|message| ChannelMessage::new(self.local_id as u64, message.clone())) .collect(); self.outbound_tx .send(messages) @@ -249,6 +245,7 @@ impl ChannelHandle { self.remote_state.as_ref().map(|s| s.remote_id) } + #[instrument(skip_all, fields(local_id = local_id))] pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) { let local_state = LocalState { local_id, key }; self.local_state = Some(local_state); @@ -271,11 +268,13 @@ impl ChannelHandle { return Err(error("Channel is not opened from both local and remote")); } // Safe because of the is_connected() check above. + dbg!(&self.local_state, &self.remote_state); let local_state = self.local_state.as_ref().unwrap(); let remote_state = self.remote_state.as_ref().unwrap(); Ok((&local_state.key, remote_state.remote_capability.as_ref())) } + #[instrument(skip_all)] pub(crate) fn open(&mut self, outbound_tx: Sender>) -> Channel { let local_state = self .local_state @@ -433,6 +432,7 @@ impl ChannelMap { self.channels.remove(&hdkey); } + #[instrument(skip(self))] pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec>)> { let channel_handle = self .get_local(local_id) @@ -477,6 +477,7 @@ impl ChannelMap { Ok(()) } + #[instrument(skip_all)] fn alloc_local(&mut self) -> usize { let empty_id = self .local_id diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index fc5ad02..8659f09 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -33,6 +33,7 @@ pub(crate) struct HandshakeResult { } impl HandshakeResult { + #[instrument(skip_all)] pub(crate) fn capability(&self, key: &[u8]) -> Option> { Some(replicate_capability( self.is_initiator, @@ -49,6 +50,7 @@ impl HandshakeResult { )) } + #[instrument(skip_all)] pub(crate) fn verify_remote_capability( &self, capability: Option>, @@ -179,7 +181,7 @@ impl Handshake { // reads in `msg` without framing bytes, but emits msg WITH framing bytes #[cfg(not(feature = "protocol"))] pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { - Ok(self.read_raw(msg)?.map(|x| wrap_uint24_le(&x))) + Ok(self.read_raw(msg)?.map(|x| crate::util::wrap_uint24_le(&x))) } pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { diff --git a/src/message/modern.rs b/src/message/modern.rs index 8b16988..9d7346d 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -353,7 +353,7 @@ impl Encoder for Vec { Ok(prencode_channel_messages(self, &mut state)? + UINT24_HEADER_LEN) } - #[instrument] + #[instrument(skip_all)] fn encode(&self, buf: &mut [u8]) -> Result { let mut state = State::new(); let body_len = prencode_channel_messages(self, &mut state)?; diff --git a/src/noise.rs b/src/noise.rs index d2a3a1d..c1fa5d6 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -148,7 +148,7 @@ impl>> + Sink> + Send + Unpin + 'static { type Item = Result>; - #[instrument(skip_all, fields(initiator = %self.is_initiator))] + #[instrument(skip(cx), fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -357,6 +357,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } +#[instrument(skip_all)] fn reset_encrypted( step: &mut Step, maybe_init_message: Option>, diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index f9bfa80..6098ff7 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -216,7 +216,7 @@ where self.channels.iter().map(|c| c.discovery_key()) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -352,6 +352,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_command(&mut self, command: Command) -> Result<()> { match command { Command::Open(key) => self.command_open(key), @@ -402,6 +403,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; let channel_handle = @@ -418,10 +420,12 @@ where Ok(()) } + #[instrument(skip(self))] fn queue_event(&mut self, event: Event) { self.queued_events.push_back(event); } + #[instrument(skip(self))] fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -451,6 +455,7 @@ where Ok(()) } + #[instrument(skip_all)] fn capability(&self, key: &[u8]) -> Option> { match self.handshake.as_ref() { Some(handshake) => handshake.capability(key), @@ -458,6 +463,7 @@ where } } + #[instrument(skip_all)] fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { match self.handshake.as_ref() { Some(handshake) => handshake.verify_remote_capability(capability, key), diff --git a/src/protocol/old.rs b/src/protocol/old.rs index b5f44ec..20c9064 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -230,6 +230,7 @@ where } /// Give a command to the protocol. + #[instrument(skip(self))] pub async fn command(&mut self, command: Command) -> Result<()> { self.command_tx.send(command).await } @@ -238,6 +239,7 @@ where /// /// Once the other side proofed that it also knows the `key`, the channel is emitted as /// `Event::Channel` on the protocol event stream. + #[instrument(skip(self))] pub async fn open(&mut self, key: Key) -> Result<()> { self.command_tx.open(key).await } @@ -247,7 +249,7 @@ where self.channels.iter().map(|c| c.discovery_key()) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -311,6 +313,7 @@ where } /// Poll commands. + #[instrument(skip_all)] fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { self.on_command(command)?; @@ -515,7 +518,7 @@ where self.state = State::Established; Ok(()) } - + #[instrument(skip_all)] fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -529,6 +532,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_command(&mut self, command: Command) -> Result<()> { match command { Command::Open(key) => self.command_open(key), @@ -580,6 +584,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; let channel_handle = @@ -596,6 +601,7 @@ where Ok(()) } + #[instrument(skip(self))] fn queue_event(&mut self, event: Event) { self.queued_events.push_back(event); } @@ -607,6 +613,7 @@ where .try_encode_and_enqueue_frame_for_tx(&mut frame) } + #[instrument(skip(self))] fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -636,6 +643,7 @@ where Ok(()) } + #[instrument(skip_all)] fn capability(&self, key: &[u8]) -> Option> { match self.handshake.as_ref() { Some(handshake) => handshake.capability(key), diff --git a/src/writer.rs b/src/writer.rs index df89949..9a1465b 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,5 +1,6 @@ use crate::crypto::EncryptCipher; use crate::message::{Encoder, Frame}; +use tracing::instrument; use futures_lite::{ready, AsyncWrite}; use std::collections::VecDeque; @@ -61,6 +62,7 @@ impl WriteState { self.queue.push_back(frame.into()) } + #[instrument(skip(self))] pub(crate) fn try_encode_and_enqueue_frame_for_tx( &mut self, frame: &mut T, From 16029ee471e2c31315a170d568ce76f1b12fdfcd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 1 Apr 2025 16:34:30 -0400 Subject: [PATCH 057/135] rm dbg --- src/schema.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index bf35416..08e5221 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,9 +18,10 @@ pub struct Open { impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { - dbg!(self.preencode(&value.channel)?); - dbg!(self.preencode(&value.protocol)?); - dbg!(self.preencode(&value.discovery_key)?); + let start = self.end(); + self.preencode(&value.channel)?; + self.preencode(&value.protocol)?; + self.preencode(&value.discovery_key)?; if value.capability.is_some() { self.add_end(1)?; // flags for future use self.preencode_fixed_32()?; @@ -29,7 +30,6 @@ impl CompactEncoding for State { } fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { - dbg!(); self.encode(&value.channel, buffer)?; self.encode(&value.protocol, buffer)?; self.encode(&value.discovery_key, buffer)?; @@ -370,7 +370,7 @@ pub struct NoData { impl CompactEncoding for State { fn preencode(&mut self, value: &NoData) -> Result { - dbg!(self.preencode(dbg!(&value.request))) + self.preencode(&value.request) } fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { From 5bf98ed4e524c95be68b1cdf87071eda52d958bb Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:44:24 -0400 Subject: [PATCH 058/135] rm unused --- src/channels.rs | 1 - src/crypto/handshake.rs | 2 +- src/framing.rs | 2 +- src/message/modern.rs | 6 +----- src/mqueue.rs | 1 - src/schema.rs | 1 - 6 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/channels.rs b/src/channels.rs index 8e82116..1b94ece 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -268,7 +268,6 @@ impl ChannelHandle { return Err(error("Channel is not opened from both local and remote")); } // Safe because of the is_connected() check above. - dbg!(&self.local_state, &self.remote_state); let local_state = self.local_state.as_ref().unwrap(); let remote_state = self.remote_state.as_ref().unwrap(); Ok((&local_state.key, remote_state.remote_capability.as_ref())) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 8659f09..0094edb 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -237,7 +237,7 @@ fn map_err(e: SnowError) -> Error { } /// Create a hash used to indicate replication capability. -/// See https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11 +/// See JavaScript [here](https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11). fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) -> Vec { let seed = if is_initiator { REPLICATE_INITIATOR diff --git a/src/framing.rs b/src/framing.rs index 8b8ae8f..a51cea9 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -40,7 +40,7 @@ impl Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { - /// Build [`LengthPrefixed`] around an [`AsyncWrite`]/[`AsyncRead`] thing. + /// Build [`Uint24LELengthPrefixedFraming`] around an [`AsyncWrite`]/[`AsyncRead`] thing. pub fn new(io: IO) -> Self { Self { io, diff --git a/src/message/modern.rs b/src/message/modern.rs index 9d7346d..11c1491 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -125,7 +125,6 @@ pub(crate) fn decode_one_channel_message( if buf.len() >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { // Batch of messages - dbg!(); let mut messages: Vec = vec![]; let mut state = State::new_with_start_and_end(2, buf.len()); @@ -162,12 +161,10 @@ pub(crate) fn decode_one_channel_message( } Ok((messages, state.start())) } else if buf[1] == 0x01 { - dbg!(); // Open message let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; Ok((vec![channel_message], length + 2)) } else if buf[1] == 0x03 { - dbg!(); // Close message let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; Ok((vec![channel_message], length + 2)) @@ -178,7 +175,6 @@ pub(crate) fn decode_one_channel_message( )) } } else if buf.len() >= 2 { - dbg!(); // Single message let mut state = State::from_buffer(buf); let channel: u64 = state.decode(buf)?; @@ -316,7 +312,7 @@ fn prencode_channel_messages( std::cmp::Ordering::Equal => { if let Message::Open(_) = &messages[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + dbg!(&messages[0].encoded_len()?))?; + state.add_end(2 + &messages[0].encoded_len()?)?; } else if let Message::Close(_) = &messages[0].message { // This is a special case with 0x00, 0x03 intro bytes state.add_end(2 + &messages[0].encoded_len()?)?; diff --git a/src/mqueue.rs b/src/mqueue.rs index c802d34..ea99f42 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -44,7 +44,6 @@ impl>> + Sink> + Send + Unpin + 'static } let mut buf = vec![0; messages.encoded_len()?]; - dbg!(&buf); match messages.encode(&mut buf) { Ok(_) => {} Err(e) => { diff --git a/src/schema.rs b/src/schema.rs index 08e5221..ef58e77 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,7 +18,6 @@ pub struct Open { impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { - let start = self.end(); self.preencode(&value.channel)?; self.preencode(&value.protocol)?; self.preencode(&value.discovery_key)?; From d3bfd07b89d28c56940a11f4d2fc165f39fe8e85 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:50:24 -0400 Subject: [PATCH 059/135] rm unused --- src/crypto/cipher.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 8ef6c9e..53c291f 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -167,7 +167,7 @@ mod encrypt_cipher { } } #[cfg(not(feature = "protocol"))] -pub use encrypt_cipher::*; +pub(crate) use encrypt_cipher::*; // NB: These values come from Javascript-side // From 094caa83c6b5ae26dbe2116e6d36d43b0a6bf07d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:51:09 -0400 Subject: [PATCH 060/135] pub HandshakeResult --- src/crypto/handshake.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 0094edb..72c9da3 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -23,7 +23,7 @@ const REPLICATE_RESPONDER: [u8; 32] = [ ]; #[derive(Debug, Clone, Default)] -pub(crate) struct HandshakeResult { +pub struct HandshakeResult { pub(crate) is_initiator: bool, pub(crate) local_pubkey: Vec, pub(crate) remote_pubkey: Vec, From 7a8b2dadb99529e64fbe5fcf9ea8d7bf4252e5c5 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:52:19 -0400 Subject: [PATCH 061/135] custom debug for Framig --- src/framing.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/framing.rs b/src/framing.rs index a51cea9..7daef38 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -33,7 +33,14 @@ pub struct Uint24LELengthPrefixedFraming { } impl Debug for Uint24LELengthPrefixedFraming { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Format()") + f.debug_struct("Framer") + //.field("io", &self.io) + .field("to_stream.len()", &self.to_stream.len()) + .field("from_sink", &self.from_sink.len()) + .field("last_out_idx", &self.last_out_idx) + .field("last_data_idx", &self.last_data_idx) + .field("step", &self.step) + .finish() } } impl Uint24LELengthPrefixedFraming From a7fac6deab1fab5378e195dc48bb9d72ad2888f7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:52:47 -0400 Subject: [PATCH 062/135] feature gates --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 1990b97..c13ccae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,8 +142,9 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::{encrypted_framed_message_channel, Encrypted}; +pub use noise::{encrypted_framed_message_channel, Encrypted, Event as NoiseEvent}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() +#[cfg(feature = "protocol")] pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, }; From a189b047f6d7bb6a0b43e9becab7153cf1fe7990 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:54:35 -0400 Subject: [PATCH 063/135] rm println --- src/message/old.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/message/old.rs b/src/message/old.rs index 373eea2..d4afd64 100644 --- a/src/message/old.rs +++ b/src/message/old.rs @@ -163,7 +163,6 @@ impl Frame { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - println!("decode_message {buf:02X?}"); // buffer length >= 3 or more and starts with 0 is message batch if buf.len() >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { From 9db631bf5a25171908e3869a1f03aad1621ea582 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:56:24 -0400 Subject: [PATCH 064/135] rm print --- src/message/modern.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index 11c1491..2d8e732 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -894,7 +894,6 @@ mod tests { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - println!("decode_message {buf:02X?}"); // buffer length >= 3 or more and starts with 0 is message batch if buf.len() >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { From a7325077854fb1319000c981f844bccc9908c580 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:18:18 -0400 Subject: [PATCH 065/135] refactor mqueue to pass through non byte messages --- src/mqueue.rs | 126 ++++++++++++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 54 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index ea99f42..7e1fffd 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -2,22 +2,60 @@ use std::{ collections::VecDeque, + fmt::Debug, io::Result, pin::Pin, task::{Context, Poll}, }; use futures::{Sink, Stream}; -use tracing::{debug, error, instrument, trace}; +use tracing::{error, instrument}; -use crate::message::{decode_many_channel_messages, ChannelMessage, Encoder as _}; +use crate::{ + message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, + noise::EncryptionInfo, + NoiseEvent, +}; + +#[derive(Debug)] +pub(crate) enum MqueueEvent { + Meta(EncryptionInfo), + Message(Result>), +} + +impl From for MqueueEvent { + fn from(e: NoiseEvent) -> Self { + match e { + NoiseEvent::Meta(einf) => Self::Meta(einf), + NoiseEvent::Decrypted(dec_res) => { + match dec_res { + Ok(encoded) => match decode_many_channel_messages(&encoded) { + //assert_eq!(_n_read, encoded.len()); } + Ok((messsages, _n_read)) => Self::Message(Ok(messsages)), + Err(e) => Self::Message(Err(e)), + }, + Err(e) => Self::Message(Err(e)), + } + } + } + } +} pub(crate) struct MessageIo { io: IO, write_queue: VecDeque, } -impl>> + Sink> + Send + Unpin + 'static> MessageIo { +impl Debug for MessageIo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MessageIo") + //.field("io", &self.io) + .field("write_queue", &self.write_queue) + .finish() + } +} + +impl + Sink> + Send + Unpin + 'static> MessageIo { pub(crate) fn new(io: IO) -> Self { Self { io, @@ -48,26 +86,24 @@ impl>> + Sink> + Send + Unpin + 'static Ok(_) => {} Err(e) => { error!(error = ?e, "error encoding messages"); + // TODO this would probably be a programming error. + // if so, this sholud just be an unwrap/expect return Poll::Ready(Err(e.into())); } } if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { - error!("error in start_send"); todo!() } match Sink::poll_flush(Pin::new(&mut self.io), cx) { - Poll::Ready(Ok(())) => { - debug!("flushed"); - } Poll::Ready(Err(_e)) => { - error!("Error flushing"); todo!() } Poll::Pending => { cx.waker().wake_by_ref(); return Poll::Pending; } + _ => {} } } @@ -79,46 +115,21 @@ impl>> + Sink> + Send + Unpin + 'static } } - pub(crate) fn poll_inbound( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - match Pin::new(&mut self.io).poll_next(cx) { - Poll::Ready(Some(Ok(encoded))) => { - match decode_many_channel_messages(&encoded) { - Ok((messsages, n_read)) => { - assert_eq!(n_read, encoded.len()); // I think this is always true - Poll::Ready(Ok(messsages)) - } - Err(_) => todo!(), - } - } - Poll::Ready(Some(Err(_e))) => todo!(), - Poll::Ready(None) => todo!(), - Poll::Pending => Poll::Pending, - } + pub(crate) fn poll_inbound(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io) + .poll_next(cx) + .map(|opt| opt.map(MqueueEvent::from)) } } -impl>> + Sink> + Send + Unpin + 'static> Stream +impl + Sink> + Send + Unpin + 'static> Stream for MessageIo { - type Item = Result>; + type Item = MqueueEvent; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let out_res = self.poll_outbound(cx); - match out_res { - Poll::Ready(res) => match res { - Ok(okres) => trace!(res = ?okres, "MessageIo poll_outbound"), - Err(e) => error!(error = ?e, "MessageIo error in poll_outbound"), - }, - Poll::Pending => trace!("MessageIo poll_outbound Pending"), - } - - let in_res = self.poll_inbound(cx); - trace!(poll_inbound = ?in_res, "MessageIo"); - - in_res.map(Some) + let _ = self.poll_outbound(cx); + self.poll_inbound(cx) } } @@ -134,7 +145,7 @@ mod test { schema::NoData, test_utils::log, Encrypted, Uint24LELengthPrefixedFraming, }; - use super::MessageIo; + use super::{MessageIo, MqueueEvent}; pub(crate) fn encrypted_and_framed< BytesTxRx: AsyncRead + AsyncWrite + Send + Unpin + 'static, >( @@ -154,6 +165,13 @@ mod test { } } + fn take_messages(e: Option) -> Option> { + match e { + Some(MqueueEvent::Message(Result::Ok(out))) => Some(out), + _ => None, + } + } + #[tokio::test] async fn mqueue() -> Result<()> { log(); @@ -167,19 +185,19 @@ mod test { left.enqueue(ltorm.clone()); right.enqueue(rtolm.clone()); - match select(left.next(), right.next()).await { - futures::future::Either::Left((m, _)) => { - if let Some(Ok(res)) = m { - assert_eq!(res, vec![rtolm]); - } else { - panic!(); + loop { + match select(left.next(), right.next()).await { + futures::future::Either::Left((m, _)) => { + if let Some(m) = take_messages(m) { + assert_eq!(m, vec![rtolm]); + break; + } } - } - futures::future::Either::Right((m, _)) => { - if let Some(Ok(res)) = m { - assert_eq!(res, vec![ltorm]); - } else { - panic!(); + futures::future::Either::Right((m, _)) => { + if let Some(m) = take_messages(m) { + assert_eq!(m, vec![rtolm]); + break; + } } } } From c5888f569dff7d7b08385abd39d7d921c940abb1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:29:14 -0400 Subject: [PATCH 066/135] use tracing-tree for viewing logs in tests --- Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7aeb9e3..4ceb1f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,8 @@ sluice = "0.5.4" futures = "0.3.13" log = "0.4" test-log = { version = "0.2.11", default-features = false, features = ["trace"] } -tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } +tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } [features] From d21d4a9972b80feb3ead499396d6d501a0d4ef22 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:31:05 -0400 Subject: [PATCH 067/135] use tracing-tree --- src/test_utils.rs | 26 +++++++++++++++++--------- tests/_util.rs | 26 +++++++++++++++++--------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index e67d756..3f687ea 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -74,17 +74,25 @@ impl TwoWay { } pub(crate) fn log() { - use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { - tracing_subscriber::fmt() - .with_target(true) - .with_line_number(true) - // print when instrumented funtion enters - .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) - .with_file(true) - .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable - .without_time() + use tracing_subscriber::{ + layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, + }; + let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable + + // Create the hierarchical layer from tracing_tree + let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level + .with_targets(true) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_span_modes(true) + .with_thread_ids(false) + .with_thread_names(false); + + tracing_subscriber::registry() + .with(env_filter) + .with(tree_layer) .init(); }); } diff --git a/tests/_util.rs b/tests/_util.rs index aec496d..e2cc679 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -9,17 +9,25 @@ use tokio::task::JoinHandle; #[allow(unused)] pub(crate) fn log() { - use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { - tracing_subscriber::fmt() - .with_target(true) - .with_line_number(true) - // print when instrumented funtion enters - .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) - .with_file(true) - .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable - .without_time() + use tracing_subscriber::{ + layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, + }; + let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable + + // Create the hierarchical layer from tracing_tree + let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level + .with_targets(true) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_span_modes(true) + .with_thread_ids(false) + .with_thread_names(false); + + tracing_subscriber::registry() + .with(env_filter) + .with(tree_layer) .init(); }); } From 3062f0781dbc54870f7096026ae1d62a67fea04d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:33:59 -0400 Subject: [PATCH 068/135] rm unused async --- tests/_util.rs | 4 ++-- tests/basic.rs | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/_util.rs b/tests/_util.rs index e2cc679..d15be38 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -42,7 +42,7 @@ pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { pub type MemoryProtocol = Protocol>; -pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { +pub fn create_pair_memory() -> (MemoryProtocol, MemoryProtocol) { let (ar, bw) = sluice::pipe::pipe(); let (br, aw) = sluice::pipe::pipe(); @@ -50,7 +50,7 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) let b = ProtocolBuilder::new(false); let a = a.connect_rw(ar, aw); let b = b.connect_rw(br, bw); - Ok((a, b)) + (a, b) } pub async fn create_pair_memory2() -> io::Result<(Protocol, Protocol)> { diff --git a/tests/basic.rs b/tests/basic.rs index 5730dbc..280e5be 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -24,8 +24,6 @@ async fn basic_protocol() -> anyhow::Result<()> { //let (mut proto_a, event_a) = a?; //let (proto_b, event_b) = b?; - dbg!(&event_a); - dbg!(&event_b); assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); @@ -84,7 +82,7 @@ async fn basic_protocol() -> anyhow::Result<()> { #[tokio::test] async fn open_close_channels() -> anyhow::Result<()> { - let (mut proto_a, mut proto_b) = create_pair_memory().await?; + let (mut proto_a, mut proto_b) = create_pair_memory(); let key1 = [0u8; 32]; let key2 = [1u8; 32]; From 294184057d90ccb626b054c8082797f8d4cb5041 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:52:11 -0400 Subject: [PATCH 069/135] expose handshake result, refactor, fix deadlock --- src/noise.rs | 401 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 268 insertions(+), 133 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index c1fa5d6..40ce6ac 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -7,7 +7,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{debug, error, info, instrument, trace, warn}; +use tracing::{debug, error, instrument, trace, warn}; use crate::{ crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}, @@ -30,6 +30,53 @@ pub(crate) enum Step { SecretStream((RawEncryptCipher, HandshakeResult)), Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } + +impl Step { + fn established(&self) -> bool { + matches!(self, Step::Established(_)) + } +} + +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + } + ) + } +} + +#[derive(Debug)] +/// Encryption related info +pub enum EncryptionInfo { + Handshake(HandshakeResult), +} +#[derive(Debug)] +/// Decrypted messages and encryption related events +pub enum Event { + /// Events related to the encryption stream + Meta(EncryptionInfo), + /// A decrypted message + Decrypted(Result>), +} + +impl From>> for Event { + fn from(value: Result>) -> Self { + Self::Decrypted(value) + } +} +impl From for Event { + fn from(value: HandshakeResult) -> Self { + Self::Meta(EncryptionInfo::Handshake(value)) + } +} + /// Wrap a stream with encryption pub struct Encrypted { io: IO, @@ -38,7 +85,7 @@ pub struct Encrypted { encrypted_tx: VecDeque>, encrypted_rx: VecDeque>>, plain_tx: VecDeque>, - plain_rx: VecDeque>>, + plain_rx: VecDeque, flush: bool, } @@ -62,7 +109,7 @@ where } /// Wether an encrypted connection has been established. pub fn encryption_established(&self) -> bool { - matches!(self.step, Step::Established(_)) + self.step.established() } } @@ -85,7 +132,7 @@ impl< #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { - info!(initiator = %self.is_initiator, "enqueue plain_tx\n{item:?}"); + trace!(initiator = %self.is_initiator, "enqueue plain_tx"); self.plain_tx.push_back(item); Ok(()) } @@ -95,6 +142,10 @@ impl< self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + // The flow here can be understood as reading from the encrypted side moving those messages + // through to the plaintext side, then reading new plaintext messages and moving them to + // the encrypted side. + // We do this repeatedly until there's nothing else to do let Encrypted { io, step, @@ -107,30 +158,37 @@ impl< .. } = self.get_mut(); - poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); - - if let Step::Established((encryptor, decryptor, ..)) = step { - poll_do_encrypt_and_decrypt( - encryptor, - decryptor, + loop { + poll_message_throughput( + io, + cx, + step, encrypted_tx, encrypted_rx, - plain_tx, plain_rx, + plain_tx, *is_initiator, flush, ); + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush); - if *flush { - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(Ok(())) + // check if we've done all possible work + if did_as_much_as_possible( + io, + cx, + step, + encrypted_tx, + encrypted_rx, + plain_tx, + *is_initiator, + ) { + if !step.established() || !encrypted_tx.is_empty() || *flush { + trace!(not_established = !step.established(), tx_msgs_waiting = !encrypted_tx.is_empty(), flush = ?flush, "not done flushing"); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + return Poll::Ready(Ok(())); } - } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); - cx.waker().wake_by_ref(); - Poll::Pending } } @@ -143,10 +201,32 @@ impl< } } +/// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. +fn did_as_much_as_possible< + IO: Stream>> + Sink> + Send + Unpin + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + step: &mut Step, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + plain_tx: &mut VecDeque>, + is_initiator: bool, +) -> bool { + // No incoming encrypted messages available. + poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator).is_pending() + // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. + && (encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(io), cx).is_pending()) + // No encrypted messages waiting to be decrypted. + && encrypted_rx.is_empty() + // No plaint text messages waiting to be enccrypted or we're still setting up + && (plain_tx.is_empty() || !step.established()) +} + impl>> + Sink> + Send + Unpin + 'static> Stream for Encrypted { - type Item = Result>; + type Item = Event; #[instrument(skip(cx), fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -162,39 +242,71 @@ impl>> + Sink> + Send + Unpin + 'static .. } = self.get_mut(); - poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); - - if let Step::Established((encryptor, decryptor, ..)) = step { - poll_do_encrypt_and_decrypt( - encryptor, - decryptor, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - *is_initiator, - flush, - ); - // emit any messages that are ready + if poll_message_throughput( + io, + cx, + step, + encrypted_tx, + encrypted_rx, + plain_rx, + plain_tx, + *is_initiator, + flush, + ) { if let Some(msg) = plain_rx.pop_front() { - trace!(initiator = %is_initiator, "plain rx emit"); Poll::Ready(Some(msg)) } else { Poll::Pending } } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); cx.waker().wake_by_ref(); Poll::Pending } } } +/// Handle all message throughput. Sends, encrypts and decrypts messages +/// Returns `true` `step` is already [`Step::Established`]. +fn poll_message_throughput< + IO: Stream>> + Sink> + Send + Unpin + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + step: &mut Step, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, + plain_tx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) -> bool { + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush); + let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator); + if let Step::Established((encryptor, decryptor, ..)) = step { + // decrypt incomming msgs + poll_decrypt(decryptor, encrypted_rx, plain_rx, is_initiator); + // encrypt any pending plaintext outgoinng messages + poll_encrypt(encryptor, encrypted_tx, plain_tx, is_initiator, flush); + true + } else { + poll_setup( + step, + encrypted_tx, + encrypted_rx, + plain_rx, + is_initiator, + flush, + ); + false + } +} + #[instrument(skip_all, fields(initiator = %is_initiator))] fn poll_setup( step: &mut Step, encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, is_initiator: bool, flush: &mut bool, ) { @@ -205,27 +317,25 @@ fn poll_setup( // Still setting up if let Ok(Some(msg)) = maybe_init(step, is_initiator) { // queue the init message to send first - info!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue initial msg"); encrypted_tx.push_front(msg); } // TODO handle error - loop { - match encrypted_rx.pop_front() { - None => { - break; - } - Some(Err(e)) => { + while let Some(enc_res) = encrypted_rx.pop_front() { + match enc_res { + Err(e) => { error!("Recieved an error during setup encryption setup: {e:?}"); break; } - Some(Ok(incoming_msg)) => { - info!(initiator = %is_initiator, "recieved setup msg"); + Ok(incoming_msg) => { + trace!(initiator = %is_initiator, "encrypted_rx dequeue recieved setup msg"); if let Ok(msgs) = match handle_setup_message( step, &incoming_msg, is_initiator, encrypted_tx, encrypted_rx, + plain_rx, flush, ) { Ok(x) => Ok(x), @@ -235,31 +345,34 @@ fn poll_setup( } } { for msg in msgs.into_iter().rev() { - info!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue more setup msg"); encrypted_tx.push_front(msg); } } } } + + if step.established() { + return; + } } } #[instrument(skip_all, fields(initiator = %is_initiator))] /// Fills `encrypted_rx` and drains `encrypted_tx`. -fn poll_encrypted_side_io< +fn poll_outgoing_encrypted_messages< IO: Stream>> + Sink> + Send + Unpin + 'static, >( io: &mut IO, cx: &mut Context<'_>, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, is_initiator: bool, flush: &mut bool, ) { // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - info!(initiator = %is_initiator, msg_len = encrypted_out.len(), "enc tx send msg\n{encrypted_out:?}"); + trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), "TX message"); if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { error!("Error polling encyrpted side io") } @@ -273,74 +386,86 @@ fn poll_encrypted_side_io< match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - info!(initiator = %is_initiator, "flushed good"); + trace!(initiator = %is_initiator, "all flushed"); } Poll::Ready(Err(_e)) => { error!(initiator = %is_initiator, "Error sending encrypted msg") } Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - debug!("flush not completed"); + // flush not complete try again later *flush = true; } } } +} + +fn poll_incomming_encrypted_messages< + IO: Stream>> + Sink> + Send + Unpin + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + encrypted_rx: &mut VecDeque>>, + is_initiator: bool, +) -> Poll<()> { // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => break, - Poll::Ready(Some(encrypted_msg)) => { - trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); - encrypted_rx.push_back(encrypted_msg); - } - } + let mut got_some = false; + while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(io), cx) { + trace!(initiator = %is_initiator, "RX message"); + encrypted_rx.push_back(encrypted_msg); + got_some = true; + } + if got_some { + Poll::Ready(()) + } else { + Poll::Pending } } -/// Process messages waiting to be encrypted or decrypted -// TODO sholud this return a Result? #[instrument(skip_all)] -fn poll_do_encrypt_and_decrypt( - encryptor: &mut RawEncryptCipher, +fn poll_decrypt( decryptor: &mut DecryptCipher, - encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, - plain_tx: &mut VecDeque>, - plain_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, is_initiator: bool, - flush: &mut bool, ) { // decrypt any incromming encrypted messages // TODO handle error - while let Some(Ok(incoming_msg)) = encrypted_rx.pop_front() { - info!(initiator = %is_initiator, "enc rx decrypting\n{incoming_msg:?}"); - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - info!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(Ok(plain_msg)); + while let Some(incoming_msg_res) = encrypted_rx.pop_front() { + match incoming_msg_res { + Ok(incoming_msg) => { + trace!(initiator = %is_initiator, "encrypted_rx dequeue decrypt"); + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!(initiator = %is_initiator, "plain rx queue"); + plain_rx.push_back(Event::from(Ok(plain_msg))); + } + Err(e) => { + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") + } + } } Err(e) => { error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") } } } +} +#[instrument(skip_all)] +fn poll_encrypt( + encryptor: &mut RawEncryptCipher, + encrypted_tx: &mut VecDeque>, + plain_tx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { // encrypt any pending plaintext outgoinng messages while let Some(plain_out) = plain_tx.pop_front() { let enc_out = match encryptor.encrypt(&plain_out) { Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue"); encrypted_tx.push_back(enc_out); *flush = true; } @@ -365,6 +490,7 @@ fn reset_encrypted( encrypted_rx: &mut VecDeque>>, flush: &mut bool, ) { + error!("Encrypted RESET"); *step = Step::NotInitialized; encrypted_tx.clear(); encrypted_rx.clear(); @@ -382,6 +508,7 @@ fn handle_setup_message( is_initiator: bool, encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, flush: &mut bool, ) -> Result>> { // this would only happen after reset with a bad message. @@ -401,7 +528,7 @@ fn handle_setup_message( Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - trace!("Read in handshake msg\n{msg:?}"); + trace!("RX handshake msg"); if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { @@ -418,15 +545,15 @@ fn handle_setup_message( return Err(e); } } { - info!( + trace!( initiator = %is_initiator, - "read message and emitting response {response:?}", + "read message and emitting response", ); out.push(response); } if handshake.complete() { - debug!(initiator = %is_initiator, "HS complete. Making result"); + debug!(initiator = %is_initiator, "Handshake completed"); let handshake_result = match handshake.into_result() { Ok(x) => x, Err(e) => { @@ -456,6 +583,7 @@ fn handle_setup_message( if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + plain_rx.push_back(Event::from(hs_result.clone())); *step = Step::Established((enc_cipher, dec_cipher, hs_result)); debug!(initiator = %is_initiator, "Step changed to {step}"); } @@ -465,38 +593,23 @@ fn handle_setup_message( } } -impl std::fmt::Display for Step { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Step::NotInitialized => "NotInitialized", - Step::Handshake(_) => "Handshake", - Step::SecretStream(_) => "SecretStream", - Step::Established(_) => "Established", - } - ) - } -} - impl std::fmt::Debug for Encrypted { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Encrypted") //.field("io", &self.io) //.field("step", &self.step) - .field("is_initiator", &self.is_initiator) - .field("encrypted_tx", &self.encrypted_tx) - .field("encrypted_rx", &self.encrypted_rx) - .field("plain_tx", &self.plain_tx) - .field("plain_rx", &self.plain_rx) - .field("flush", &self.flush) + .field("initiator", &self.is_initiator) + .field("encrypted_tx.len()", &self.encrypted_tx.len()) + .field("encrypted_rx", &self.encrypted_rx.len()) + .field("plain_tx", &self.plain_tx.len()) + .field("plain_rx", &self.plain_rx.len()) + //.field("flush", &self.flush) .finish() } } #[cfg(test)] -mod tset { +mod test { use crate::{ framing::test::duplex, test_utils::create_result_connected, Uint24LELengthPrefixedFraming, @@ -505,6 +618,12 @@ mod tset { use super::*; use futures::{future::join, SinkExt, StreamExt}; + fn inner(e: Option) -> Vec { + if let Some(Event::Decrypted(Ok(x))) = e { + return x; + } + panic!() + } #[tokio::test] async fn encrypted() -> Result<()> { let hello = b"hello".to_vec(); @@ -513,26 +632,31 @@ mod tset { let mut left = Encrypted::new(true, lc); let mut right = Encrypted::new(false, rc); - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; - assert_eq!(receieved.unwrap()?, hello); + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(received), hello); assert!(left.encryption_established()); + assert!(right.encryption_established()); // NB: we cannot totally finish 'left.send' until the other side becomes active // because the handshake with the other side ('right') must complete // before the 'hello' message is sent. So we poll both the send and receive concurrently. - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + // right recieves left's message - assert_eq!(receieved.unwrap()?, hello); + assert_eq!(inner(received), hello); // now that the encrypted channel is established, we don't need to spawn. right.send(world.clone()).await.unwrap(); // left recieves right's message - assert_eq!(left.next().await.unwrap()?, world); + left.next().await; + assert_eq!(inner(left.next().await), world); Ok(()) } + #[tokio::test] async fn encrypted_many() -> Result<()> { let hello = b"hello".to_vec(); @@ -547,15 +671,17 @@ mod tset { let mut left = Encrypted::new(true, lc); let mut right = Encrypted::new(false, rc); - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; - assert_eq!(receieved.unwrap()?, hello); + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(received), hello); for d in &data { right.send(d.to_vec()).await?; } let mut result = vec![]; + let _ = left.next().await; for _ in &data { - result.push(left.next().await.unwrap()?); + result.push(inner(left.next().await)); } assert_eq!(result, data); Ok(()) @@ -572,8 +698,8 @@ mod tset { let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; - assert_eq!(receieved.unwrap()?, hello); + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(right.next().await), hello); let data = vec![ b"yolo".to_vec(), @@ -587,9 +713,10 @@ mod tset { for d in &data { right.send(d.to_vec()).await?; } + let _ = left.next().await; let mut result = vec![]; for _ in &data { - result.push(left.next().await.unwrap()?); + result.push(inner(left.next().await)); } assert_eq!(result, data); @@ -599,20 +726,22 @@ mod tset { } let mut result = vec![]; for _ in &data { - result.push(right.next().await.unwrap()?); + result.push(inner(right.next().await)); } assert_eq!(result, data); // send both ways + let mut res = vec![]; for d in &data { left.send(d.to_vec()).await?; right.send(d.to_vec()).await?; + res.push(d.to_vec()); } let mut left_result = vec![]; let mut right_result = vec![]; for _ in &data { - right_result.push(right.next().await.unwrap()?); - left_result.push(left.next().await.unwrap()?); + right_result.push(inner(right.next().await)); + left_result.push(inner(left.next().await)); } assert_eq!(right_result, data); assert_eq!(left_result, data); @@ -647,13 +776,16 @@ mod tset { // send a bad message to init side. It should reset, and emit new init msg init_side_messages.send(b"bad msg".to_vec()).await?; - let new_init_msg = init_side_messages.next().await.unwrap()?; - other_side_messages.send(new_init_msg).await?; - let new_response = other_side_messages.next().await.unwrap()?; - init_side_messages.send(new_response).await?; - let final_setup_message = init_side_messages.next().await.unwrap()?; - other_side_messages.send(final_setup_message).await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; // exchange one more message then we're set up init_side_messages @@ -676,8 +808,11 @@ mod tset { assert!(left.encryption_established()); assert!(right.encryption_established()); - assert_eq!(right.next().await.unwrap()?, b"hello"); - assert_eq!(left.next().await.unwrap()?, b"other hello"); + let _ = right.next().await; + let _ = left.next().await; + + assert_eq!(inner(right.next().await), b"hello"); + assert_eq!(inner(left.next().await), b"other hello"); Ok(()) } From 4a35d1ffa54d20c041d3c1cd19ed9a1295592d01 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 13:13:57 -0400 Subject: [PATCH 070/135] wait for setup to handle commands --- src/protocol/modern.rs | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index 6098ff7..42bd5d6 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -9,13 +9,14 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::instrument; +use tracing::{debug, error, instrument, warn}; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; use crate::crypto::HandshakeResult; use crate::message::{ChannelMessage, Message}; -use crate::mqueue::MessageIo; +use crate::mqueue::{MessageIo, MqueueEvent}; +use crate::noise::EncryptionInfo; use crate::util::{map_channel_err, pretty_hash}; use crate::{ encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, @@ -229,7 +230,9 @@ where return_error!(this.poll_inbound_read(cx)); // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); + if this.options.noise && this.handshake.is_some() { + return_error!(this.poll_commands(cx)); + } // Poll the keepalive timer. this.poll_keepalive(cx); @@ -241,7 +244,6 @@ where if let Some(event) = this.queued_events.pop_front() { Poll::Ready(Ok(event)) } else { - cx.waker().wake_by_ref(); Poll::Pending } } @@ -249,7 +251,10 @@ where /// Poll commands. fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { - self.on_command(command)?; + if let Err(e) = self.on_command(command) { + error!(error = ?e, "Error handling command"); + return Err(e); + } } Ok(()) } @@ -297,10 +302,21 @@ where fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { match self.io.poll_inbound(cx) { - Poll::Ready(Ok(messages)) => { - self.on_inbound_channel_messages(messages)?; - } - Poll::Ready(Err(e)) => return Err(e), + Poll::Ready(opt) => match opt { + Some(e) => match e { + MqueueEvent::Meta(einf) => match einf { + EncryptionInfo::Handshake(hs_res) => { + let remote_pubkey = parse_key(&hs_res.remote_pubkey)?; + self.handshake = Some(hs_res); + debug!(handshake = ?self.handshake, "set Protocol::handshake"); + self.queue_event(Event::Handshake(remote_pubkey)) + } + }, + MqueueEvent::Message(msgs) => self.on_inbound_channel_messages(msgs?)?, + }, + + None => return Ok(()), + }, Poll::Pending => return Ok(()), } } @@ -313,6 +329,7 @@ where loop { // if no parking or setup in progress if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) { + error!(err = ?e, "error from poll_outbound"); return Err(e); } // send messages outbound_rx @@ -489,7 +506,7 @@ where } } -/// Send [Command](Command)s to the [Protocol](Protocol). +/// Send [`Command`]s to the [`Protocol`]. #[derive(Clone, Debug)] pub struct CommandTx(Sender); From 35a6bac4742756d31e17832246016052af6fcc51 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 9 Apr 2025 12:34:18 -0500 Subject: [PATCH 071/135] move unused framed stuff into tests and do rename --- src/message/modern.rs | 310 +++++++++++++++++++++--------------------- src/mqueue.rs | 4 +- 2 files changed, 157 insertions(+), 157 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index 2d8e732..e70bed2 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -42,43 +42,7 @@ impl Encoder for &[u8] { } } -/// A frame of data, either a buffer or a message. -#[derive(Clone, PartialEq)] -pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), -} - -impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } -} - -impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::MessageBatch(m) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } -} - -pub(crate) fn decode_many_channel_messages( +pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { let mut index = 0; @@ -93,7 +57,7 @@ pub(crate) fn decode_many_channel_messages( let stat = stat_uint24_le(&buf[index..]); if let Some((header_len, body_len)) = stat { - let (msgs, length) = decode_one_channel_message( + let (msgs, length) = decode_unframed_channel_messages( &buf[index + header_len..index + header_len + body_len as usize], )?; if length != body_len as usize { @@ -119,7 +83,7 @@ pub(crate) fn decode_many_channel_messages( Ok((combined_messages, index)) } // bad name bc it returns many. More like, decode unframed channel messages -pub(crate) fn decode_one_channel_message( +pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { if buf.len() >= 3 && buf[0] == 0x00 { @@ -188,121 +152,6 @@ pub(crate) fn decode_one_channel_message( } } -impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - fn preencode(&self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) - } -} - -impl Encoder for Frame { - fn encoded_len(&self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } -} - fn prencode_channel_messages( messages: &[ChannelMessage], state: &mut State, @@ -688,6 +537,156 @@ mod tests { )* } } + /// A frame of data, either a buffer or a message. + #[derive(Clone, PartialEq)] + pub(crate) enum Frame { + /// A raw batch binary buffer. Used in the handshaking phase. + RawBatch(Vec>), + /// Message batch, containing one or more channel messsages. Used for everything after the handshake. + MessageBatch(Vec), + } + + impl fmt::Debug for Frame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), + Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), + } + } + } + + impl From for Frame { + fn from(m: ChannelMessage) -> Self { + Self::MessageBatch(vec![m]) + } + } + + impl From> for Frame { + fn from(m: Vec) -> Self { + Self::MessageBatch(m) + } + } + + impl From> for Frame { + fn from(m: Vec) -> Self { + Self::RawBatch(vec![m]) + } + } + + impl Frame { + /// Decodes a frame from a buffer containing multiple concurrent messages. + fn preencode(&self, state: &mut State) -> Result { + match self { + Self::RawBatch(raw_batch) => { + for raw in raw_batch { + state.add_end(raw.as_slice().encoded_len()?)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(messages) => { + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + } + } + Ok(state.end()) + } + } + + impl Encoder for Frame { + fn encoded_len(&self) -> Result { + let body_len = self.preencode(&mut State::new())?; + match self { + Self::RawBatch(_) => Ok(body_len), + Self::MessageBatch(_) => Ok(3 + body_len), + } + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let mut state = State::new(); + let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; + let body_len = self.preencode(&mut state)?; + let len = body_len + header_len; + if buf.len() < len { + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); + } + match self { + Self::RawBatch(ref raw_batch) => { + for raw in raw_batch { + raw.as_slice().encode(buf)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(ref messages) => { + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } + } + }; + Ok(len) + } + } #[test] fn message_encode_decode() { @@ -849,6 +848,7 @@ mod tests { }), ] } + impl Frame { pub(crate) fn decode_multiple(buf: &[u8]) -> Result { let mut index = 0; @@ -988,7 +988,7 @@ mod tests { let fres = Frame::decode_multiple(&fbuf)?; assert_eq!(fres, frame); - let cres_m = decode_many_channel_messages(&cbuf)?.0; + let cres_m = decode_framed_channel_messages(&cbuf)?.0; assert_eq!(cres_m, cmvec); } Ok(()) diff --git a/src/mqueue.rs b/src/mqueue.rs index 7e1fffd..cd86caf 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -12,7 +12,7 @@ use futures::{Sink, Stream}; use tracing::{error, instrument}; use crate::{ - message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, + message::{decode_framed_channel_messages, ChannelMessage, Encoder as _}, noise::EncryptionInfo, NoiseEvent, }; @@ -29,7 +29,7 @@ impl From for MqueueEvent { NoiseEvent::Meta(einf) => Self::Meta(einf), NoiseEvent::Decrypted(dec_res) => { match dec_res { - Ok(encoded) => match decode_many_channel_messages(&encoded) { + Ok(encoded) => match decode_framed_channel_messages(&encoded) { //assert_eq!(_n_read, encoded.len()); } Ok((messsages, _n_read)) => Self::Message(Ok(messsages)), Err(e) => Self::Message(Err(e)), From 483e9ffdb592189bc1a1908d6d173b58c99ce832 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 9 Apr 2025 14:46:24 -0500 Subject: [PATCH 072/135] impl CompactEncodable for Schema --- src/schema.rs | 408 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 406 insertions(+), 2 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index ef58e77..328e1f6 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,6 +1,10 @@ -use hypercore::encoding::{CompactEncoding, EncodingError, HypercoreState, State}; +use hypercore::encoding::{ + take_array, take_array_mut, write_array, write_slice, CompactEncodable, CompactEncoding, + EncodingError, HypercoreState, State, +}; use hypercore::{ - DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, + chain_encoded_bytes, decode, sum_encoded_size, DataBlock, DataHash, DataSeek, DataUpgrade, + Proof, RequestBlock, RequestSeek, RequestUpgrade, }; /// Open message @@ -16,6 +20,55 @@ pub struct Open { pub capability: Option>, } +impl CompactEncodable for Open { + fn encoded_size(&self) -> Result { + let out = sum_encoded_size!(self, channel, protocol, discovery_key); + if self.capability.is_some() { + return Ok( + out + + 1 // flags for future use + + 32, // TODO capabalilities buff should always be 32 bytes, but it's a vec + ); + } + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = chain_encoded_bytes!(self, buffer, channel, protocol, discovery_key); + if let Some(cap) = &self.capability { + let (_, rest) = take_array_mut::<1>(rest)?; + return write_slice(cap, rest); + } + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (channel, rest) = u64::decode(buffer)?; + let (protocol, rest) = String::decode(rest)?; + let (discovery_key, rest) = >::decode(rest)?; + // TODO this is a CLEAR bug it assumes nothing is encoded after this message + let (capability, rest) = if !rest.is_empty() { + let (_, rest) = take_array::<1>(rest)?; + let (capability, rest) = take_array::<32>(rest)?; + (Some(capability.to_vec()), rest) + } else { + (None, rest) + }; + Ok(( + Open { + channel, + protocol, + discovery_key, + capability, + }, + rest, + )) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { self.preencode(&value.channel)?; @@ -43,6 +96,7 @@ impl CompactEncoding for State { let channel: u64 = self.decode(buffer)?; let protocol: String = self.decode(buffer)?; let discovery_key: Vec = self.decode(buffer)?; + // TODO This is a BUG!!! when anything is encoded **after** Open message let capability: Option> = if self.start() < self.end() { self.add_start(1)?; // flags for future use let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); @@ -66,6 +120,22 @@ pub struct Close { pub channel: u64, } +impl CompactEncodable for Close { + fn encoded_size(&self) -> Result { + Ok(self.channel.encoded_size()?) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.channel.encoded_bytes(buffer) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Close, buffer, {channel: u64}) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Close) -> Result { self.preencode(&value.channel) @@ -98,6 +168,50 @@ pub struct Synchronize { pub can_upgrade: bool, } +impl CompactEncodable for Synchronize { + fn encoded_size(&self) -> Result { + Ok(1 + sum_encoded_size!(self, fork, length, remote_length)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; + flags |= if self.uploading { 2 } else { 0 }; + flags |= if self.downloading { 4 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + Ok(chain_encoded_bytes!( + self, + rest, + fork, + length, + remote_length + )) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (fork, rest) = u64::decode(rest)?; + let (length, rest) = u64::decode(rest)?; + let (remote_length, rest) = u64::decode(rest)?; + let can_upgrade = flags & 1 != 0; + let uploading = flags & 2 != 0; + let downloading = flags & 4 != 0; + Ok(( + Synchronize { + fork, + length, + remote_length, + can_upgrade, + uploading, + downloading, + }, + rest, + )) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Synchronize) -> Result { self.add_end(1)?; // flags @@ -152,6 +266,85 @@ pub struct Request { pub upgrade: Option, } +macro_rules! maybe_decode { + ($cond:expr, $type:ty, $buf:ident) => { + if $cond { + let (result, rest) = <$type>::decode($buf)?; + (Some(result), rest) + } else { + (None, $buf) + } + }; +} + +impl CompactEncodable for Request { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self, id, fork); + if let Some(block) = &self.block { + out += block.encoded_size()?; + } + if let Some(hash) = &self.hash { + out += hash.encoded_size()?; + } + if let Some(seek) = &self.seek { + out += seek.encoded_size()?; + } + if let Some(upgrade) = &self.upgrade { + out += upgrade.encoded_size()?; + } + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + let mut rest = write_array(&[flags], buffer)?; + chain_encoded_bytes!(self, rest, id, fork); + + if let Some(block) = &self.block { + rest = block.encoded_bytes(rest)?; + } + if let Some(hash) = &self.hash { + rest = hash.encoded_bytes(rest)?; + } + if let Some(seek) = &self.seek { + rest = seek.encoded_bytes(rest)?; + } + if let Some(upgrade) = &self.upgrade { + rest = upgrade.encoded_bytes(rest)?; + } + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (id, rest) = u64::decode(rest)?; + let (fork, rest) = u64::decode(rest)?; + + let (block, rest) = maybe_decode!(flags & 1 != 0, RequestBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, RequestSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, RequestUpgrade, rest); + Ok(( + Request { + id, + fork, + block, + hash, + seek, + upgrade, + }, + rest, + )) + } +} + impl CompactEncoding for HypercoreState { fn preencode(&mut self, value: &Request) -> Result { self.add_end(1)?; // flags @@ -237,6 +430,23 @@ pub struct Cancel { pub request: u64, } +impl CompactEncodable for Cancel { + fn encoded_size(&self) -> Result { + self.request.encoded_size() + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.request.encoded_bytes(buffer) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (request, rest) = u64::decode(buffer)?; + Ok((Cancel { request }, rest)) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Cancel) -> Result { self.preencode(&value.request) @@ -269,6 +479,74 @@ pub struct Data { pub upgrade: Option, } +macro_rules! opt_encoded_size { + ($opt:expr, $sum:ident) => { + if let Some(thing) = $opt { + $sum += thing.encoded_size()?; + } + }; +} + +macro_rules! opt_encoded_bytes { + ($opt:expr, $buf:ident) => { + if let Some(thing) = $opt { + thing.encoded_bytes($buf)? + } else { + $buf + } + }; +} +impl CompactEncodable for Data { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self, request, fork); + opt_encoded_size!(&self.block, out); + opt_encoded_size!(&self.hash, out); + opt_encoded_size!(&self.seek, out); + opt_encoded_size!(&self.upgrade, out); + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + chain_encoded_bytes!(self, rest, request, fork); + + let rest = opt_encoded_bytes!(&self.block, rest); + let rest = opt_encoded_bytes!(&self.hash, rest); + let rest = opt_encoded_bytes!(&self.seek, rest); + let rest = opt_encoded_bytes!(&self.upgrade, rest); + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (request, rest) = u64::decode(rest)?; + let (fork, rest) = u64::decode(rest)?; + let (block, rest) = maybe_decode!(flags & 1 != 0, DataBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, DataHash, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, DataSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, DataUpgrade, rest); + Ok(( + Data { + request, + fork, + block, + hash, + seek, + upgrade, + }, + rest, + )) + } +} + impl CompactEncoding for HypercoreState { fn preencode(&mut self, value: &Data) -> Result { self.add_end(1)?; // flags @@ -367,6 +645,22 @@ pub struct NoData { pub request: u64, } +impl CompactEncodable for NoData { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, request)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, request)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(NoData, buffer, { request: u64 }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &NoData) -> Result { self.preencode(&value.request) @@ -390,6 +684,23 @@ pub struct Want { /// Length pub length: u64, } + +impl CompactEncodable for Want { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, start, length)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, start, length)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { start: u64, length: u64 }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Want) -> Result { self.preencode(&value.start)?; @@ -416,6 +727,24 @@ pub struct Unwant { /// Length pub length: u64, } + +impl CompactEncodable for Unwant { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, start, length)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, start, length)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { start: u64, length: u64 }) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Unwant) -> Result { self.preencode(&value.start)?; @@ -442,6 +771,22 @@ pub struct Bitfield { /// Bitfield in 32 bit chunks beginning from `start` pub bitfield: Vec, } +impl CompactEncodable for Bitfield { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, start, bitfield)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, start, bitfield)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { start: u64, bitfield: Vec }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Bitfield) -> Result { self.preencode(&value.start)?; @@ -473,6 +818,49 @@ pub struct Range { pub length: u64, } +impl CompactEncodable for Range { + fn encoded_size(&self) -> Result { + let mut out = 1 + sum_encoded_size!(self, start); + if self.length != 1 { + out += self.length.encoded_size()?; + } + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.drop { 1 } else { 0 }; + flags |= if self.length == 1 { 2 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + let rest = self.start.encoded_bytes(rest)?; + if self.length != 1 { + return self.length.encoded_bytes(rest); + } + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (start, rest) = u64::decode(rest)?; + let drop = flags & 1 != 0; + let (length, rest) = if flags & 2 != 0 { + (1, rest) + } else { + u64::decode(rest)? + }; + Ok(( + Range { + drop, + length, + start, + }, + rest, + )) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Range) -> Result { self.add_end(1)?; // flags @@ -519,6 +907,22 @@ pub struct Extension { /// Message content, use empty vector for no data. pub message: Vec, } +impl CompactEncodable for Extension { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, name, message)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, name, message)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { name: String, message: Vec }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Extension) -> Result { self.preencode(&value.name)?; From 350f03a6fe5b8891403d2aba22ae6f2172ae152a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 14 Apr 2025 17:19:09 -0400 Subject: [PATCH 073/135] encoded_bytes renamed to encode --- src/schema.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 328e1f6..676348e 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -33,7 +33,7 @@ impl CompactEncodable for Open { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let rest = chain_encoded_bytes!(self, buffer, channel, protocol, discovery_key); if let Some(cap) = &self.capability { let (_, rest) = take_array_mut::<1>(rest)?; @@ -125,8 +125,8 @@ impl CompactEncodable for Close { Ok(self.channel.encoded_size()?) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - self.channel.encoded_bytes(buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.channel.encode(buffer) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -173,7 +173,7 @@ impl CompactEncodable for Synchronize { Ok(1 + sum_encoded_size!(self, fork, length, remote_length)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; flags |= if self.uploading { 2 } else { 0 }; flags |= if self.downloading { 4 } else { 0 }; @@ -296,7 +296,7 @@ impl CompactEncodable for Request { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; flags |= if self.hash.is_some() { 2 } else { 0 }; flags |= if self.seek.is_some() { 4 } else { 0 }; @@ -305,16 +305,16 @@ impl CompactEncodable for Request { chain_encoded_bytes!(self, rest, id, fork); if let Some(block) = &self.block { - rest = block.encoded_bytes(rest)?; + rest = block.encode(rest)?; } if let Some(hash) = &self.hash { - rest = hash.encoded_bytes(rest)?; + rest = hash.encode(rest)?; } if let Some(seek) = &self.seek { - rest = seek.encoded_bytes(rest)?; + rest = seek.encode(rest)?; } if let Some(upgrade) = &self.upgrade { - rest = upgrade.encoded_bytes(rest)?; + rest = upgrade.encode(rest)?; } Ok(rest) } @@ -435,8 +435,8 @@ impl CompactEncodable for Cancel { self.request.encoded_size() } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - self.request.encoded_bytes(buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.request.encode(buffer) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -490,7 +490,7 @@ macro_rules! opt_encoded_size { macro_rules! opt_encoded_bytes { ($opt:expr, $buf:ident) => { if let Some(thing) = $opt { - thing.encoded_bytes($buf)? + thing.encode($buf)? } else { $buf } @@ -507,7 +507,7 @@ impl CompactEncodable for Data { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; flags |= if self.hash.is_some() { 2 } else { 0 }; flags |= if self.seek.is_some() { 4 } else { 0 }; @@ -650,7 +650,7 @@ impl CompactEncodable for NoData { Ok(sum_encoded_size!(self, request)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, request)) } @@ -690,7 +690,7 @@ impl CompactEncodable for Want { Ok(sum_encoded_size!(self, start, length)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, start, length)) } @@ -733,7 +733,7 @@ impl CompactEncodable for Unwant { Ok(sum_encoded_size!(self, start, length)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, start, length)) } @@ -776,7 +776,7 @@ impl CompactEncodable for Bitfield { Ok(sum_encoded_size!(self, start, bitfield)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, start, bitfield)) } @@ -827,13 +827,13 @@ impl CompactEncodable for Range { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.drop { 1 } else { 0 }; flags |= if self.length == 1 { 2 } else { 0 }; let rest = write_array(&[flags], buffer)?; - let rest = self.start.encoded_bytes(rest)?; + let rest = self.start.encode(rest)?; if self.length != 1 { - return self.length.encoded_bytes(rest); + return self.length.encode(rest); } Ok(rest) } @@ -912,7 +912,7 @@ impl CompactEncodable for Extension { Ok(sum_encoded_size!(self, name, message)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, name, message)) } From bf0d309f68d99c3d291bb798ab3fc158ec3d9c12 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 22 Apr 2025 14:16:53 -0400 Subject: [PATCH 074/135] use new CompactEncoding in schema.rs --- src/schema.rs | 465 +++++--------------------------------------------- 1 file changed, 40 insertions(+), 425 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 676348e..8a9a0a2 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,10 +1,10 @@ use hypercore::encoding::{ - take_array, take_array_mut, write_array, write_slice, CompactEncodable, CompactEncoding, - EncodingError, HypercoreState, State, + map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, + CompactEncoding, EncodingError, }; use hypercore::{ - chain_encoded_bytes, decode, sum_encoded_size, DataBlock, DataHash, DataSeek, DataUpgrade, - Proof, RequestBlock, RequestSeek, RequestUpgrade, + decode, DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, + RequestUpgrade, }; /// Open message @@ -20,9 +20,9 @@ pub struct Open { pub capability: Option>, } -impl CompactEncodable for Open { +impl CompactEncoding for Open { fn encoded_size(&self) -> Result { - let out = sum_encoded_size!(self, channel, protocol, discovery_key); + let out = sum_encoded_size!(self.channel, self.protocol, self.discovery_key); if self.capability.is_some() { return Ok( out @@ -34,7 +34,7 @@ impl CompactEncodable for Open { } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - let rest = chain_encoded_bytes!(self, buffer, channel, protocol, discovery_key); + let rest = map_encode!(buffer, self.channel, self.protocol, self.discovery_key); if let Some(cap) = &self.capability { let (_, rest) = take_array_mut::<1>(rest)?; return write_slice(cap, rest); @@ -69,50 +69,6 @@ impl CompactEncodable for Open { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Open) -> Result { - self.preencode(&value.channel)?; - self.preencode(&value.protocol)?; - self.preencode(&value.discovery_key)?; - if value.capability.is_some() { - self.add_end(1)?; // flags for future use - self.preencode_fixed_32()?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer)?; - self.encode(&value.protocol, buffer)?; - self.encode(&value.discovery_key, buffer)?; - if let Some(capability) = &value.capability { - self.add_start(1)?; // flags for future use - self.encode_fixed_32(capability, buffer)?; - } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - let protocol: String = self.decode(buffer)?; - let discovery_key: Vec = self.decode(buffer)?; - // TODO This is a BUG!!! when anything is encoded **after** Open message - let capability: Option> = if self.start() < self.end() { - self.add_start(1)?; // flags for future use - let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); - Some(capability) - } else { - None - }; - Ok(Open { - channel, - protocol, - discovery_key, - capability, - }) - } -} - /// Close message #[derive(Debug, Clone, PartialEq)] pub struct Close { @@ -120,7 +76,7 @@ pub struct Close { pub channel: u64, } -impl CompactEncodable for Close { +impl CompactEncoding for Close { fn encoded_size(&self) -> Result { Ok(self.channel.encoded_size()?) } @@ -136,20 +92,6 @@ impl CompactEncodable for Close { decode!(Close, buffer, {channel: u64}) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Close) -> Result { - self.preencode(&value.channel) - } - - fn encode(&mut self, value: &Close, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - Ok(Close { channel }) - } -} /// Synchronize message. Type 0. #[derive(Debug, Clone, PartialEq)] @@ -168,22 +110,22 @@ pub struct Synchronize { pub can_upgrade: bool, } -impl CompactEncodable for Synchronize { +impl CompactEncoding for Synchronize { fn encoded_size(&self) -> Result { - Ok(1 + sum_encoded_size!(self, fork, length, remote_length)) + Ok(1 + sum_encoded_size!(self.fork, self.length, self.remote_length)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; flags |= if self.uploading { 2 } else { 0 }; flags |= if self.downloading { 4 } else { 0 }; + dbg!(flags); let rest = write_array(&[flags], buffer)?; - Ok(chain_encoded_bytes!( - self, + Ok(map_encode!( rest, - fork, - length, - remote_length + self.fork, + self.length, + self.remote_length )) } @@ -192,6 +134,7 @@ impl CompactEncodable for Synchronize { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; + dbg!(flags); let (fork, rest) = u64::decode(rest)?; let (length, rest) = u64::decode(rest)?; let (remote_length, rest) = u64::decode(rest)?; @@ -212,43 +155,6 @@ impl CompactEncodable for Synchronize { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Synchronize) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.fork)?; - self.preencode(&value.length)?; - self.preencode(&value.remote_length) - } - - fn encode(&mut self, value: &Synchronize, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.can_upgrade { 1 } else { 0 }; - flags |= if value.uploading { 2 } else { 0 }; - flags |= if value.downloading { 4 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.fork, buffer)?; - self.encode(&value.length, buffer)?; - self.encode(&value.remote_length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let fork: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - let remote_length: u64 = self.decode(buffer)?; - let can_upgrade = flags & 1 != 0; - let uploading = flags & 2 != 0; - let downloading = flags & 4 != 0; - Ok(Synchronize { - fork, - length, - remote_length, - can_upgrade, - uploading, - downloading, - }) - } -} - /// Request message. Type 1. #[derive(Debug, Clone, PartialEq)] pub struct Request { @@ -277,10 +183,10 @@ macro_rules! maybe_decode { }; } -impl CompactEncodable for Request { +impl CompactEncoding for Request { fn encoded_size(&self) -> Result { let mut out = 1; // flags - out += sum_encoded_size!(self, id, fork); + out += sum_encoded_size!(self.id, self.fork); if let Some(block) = &self.block { out += block.encoded_size()?; } @@ -302,7 +208,7 @@ impl CompactEncodable for Request { flags |= if self.seek.is_some() { 4 } else { 0 }; flags |= if self.upgrade.is_some() { 8 } else { 0 }; let mut rest = write_array(&[flags], buffer)?; - chain_encoded_bytes!(self, rest, id, fork); + rest = map_encode!(rest, self.id, self.fork); if let Some(block) = &self.block { rest = block.encode(rest)?; @@ -345,84 +251,6 @@ impl CompactEncodable for Request { } } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Request) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.id)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; - } - if let Some(hash) = &value.hash { - self.preencode(hash)?; - } - if let Some(seek) = &value.seek { - self.preencode(seek)?; - } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Request, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.id, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; - } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; - } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; - } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; - } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let id: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - Ok(Request { - id, - fork, - block, - hash, - seek, - upgrade, - }) - } -} - /// Cancel message for a [Request]. Type 2 #[derive(Debug, Clone, PartialEq)] pub struct Cancel { @@ -430,7 +258,7 @@ pub struct Cancel { pub request: u64, } -impl CompactEncodable for Cancel { +impl CompactEncoding for Cancel { fn encoded_size(&self) -> Result { self.request.encoded_size() } @@ -447,20 +275,6 @@ impl CompactEncodable for Cancel { Ok((Cancel { request }, rest)) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Cancel) -> Result { - self.preencode(&value.request) - } - - fn encode(&mut self, value: &Cancel, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(Cancel { request }) - } -} /// Data message responding to received [Request]. Type 3. #[derive(Debug, Clone, PartialEq)] @@ -496,10 +310,10 @@ macro_rules! opt_encoded_bytes { } }; } -impl CompactEncodable for Data { +impl CompactEncoding for Data { fn encoded_size(&self) -> Result { let mut out = 1; // flags - out += sum_encoded_size!(self, request, fork); + out += sum_encoded_size!(self.request, self.fork); opt_encoded_size!(&self.block, out); opt_encoded_size!(&self.hash, out); opt_encoded_size!(&self.seek, out); @@ -513,7 +327,7 @@ impl CompactEncodable for Data { flags |= if self.seek.is_some() { 4 } else { 0 }; flags |= if self.upgrade.is_some() { 8 } else { 0 }; let rest = write_array(&[flags], buffer)?; - chain_encoded_bytes!(self, rest, request, fork); + let rest = map_encode!(rest, self.request, self.fork); let rest = opt_encoded_bytes!(&self.block, rest); let rest = opt_encoded_bytes!(&self.hash, rest); @@ -547,84 +361,6 @@ impl CompactEncodable for Data { } } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Data) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.request)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; - } - if let Some(hash) = &value.hash { - self.preencode(hash)?; - } - if let Some(seek) = &value.seek { - self.preencode(seek)?; - } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Data, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.request, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; - } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; - } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; - } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; - } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let request: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - Ok(Data { - request, - fork, - block, - hash, - seek, - upgrade, - }) - } -} - impl Data { /// Transform Data message into a Proof emptying fields pub fn into_proof(&mut self) -> Proof { @@ -645,13 +381,13 @@ pub struct NoData { pub request: u64, } -impl CompactEncodable for NoData { +impl CompactEncoding for NoData { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, request)) + Ok(sum_encoded_size!(self.request)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, request)) + Ok(map_encode!(buffer, self.request)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -661,20 +397,6 @@ impl CompactEncodable for NoData { decode!(NoData, buffer, { request: u64 }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &NoData) -> Result { - self.preencode(&value.request) - } - - fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(NoData { request }) - } -} /// Want message. Type 5. #[derive(Debug, Clone, PartialEq)] @@ -685,13 +407,13 @@ pub struct Want { pub length: u64, } -impl CompactEncodable for Want { +impl CompactEncoding for Want { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, start, length)) + Ok(sum_encoded_size!(self.start, self.length)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, start, length)) + Ok(map_encode!(buffer, self.start, self.length)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -701,23 +423,6 @@ impl CompactEncodable for Want { decode!(Self, buffer, { start: u64, length: u64 }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Want) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) - } - - fn encode(&mut self, value: &Want, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Want { start, length }) - } -} /// Un-want message. Type 6. #[derive(Debug, Clone, PartialEq)] @@ -728,13 +433,13 @@ pub struct Unwant { pub length: u64, } -impl CompactEncodable for Unwant { +impl CompactEncoding for Unwant { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, start, length)) + Ok(sum_encoded_size!(self.start, self.length)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, start, length)) + Ok(map_encode!(buffer, self.start, self.length)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -745,24 +450,6 @@ impl CompactEncodable for Unwant { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Unwant) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) - } - - fn encode(&mut self, value: &Unwant, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Unwant { start, length }) - } -} - /// Bitfield message. Type 7. #[derive(Debug, Clone, PartialEq)] pub struct Bitfield { @@ -771,13 +458,13 @@ pub struct Bitfield { /// Bitfield in 32 bit chunks beginning from `start` pub bitfield: Vec, } -impl CompactEncodable for Bitfield { +impl CompactEncoding for Bitfield { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, start, bitfield)) + Ok(sum_encoded_size!(self.start, self.bitfield)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, start, bitfield)) + Ok(map_encode!(buffer, self.start, self.bitfield)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -787,23 +474,6 @@ impl CompactEncodable for Bitfield { decode!(Self, buffer, { start: u64, bitfield: Vec }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Bitfield) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.bitfield) - } - - fn encode(&mut self, value: &Bitfield, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.bitfield, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let bitfield: Vec = self.decode(buffer)?; - Ok(Bitfield { start, bitfield }) - } -} /// Range message. Type 8. /// Notifies Peer's that the Sender has a range of contiguous blocks. @@ -818,9 +488,9 @@ pub struct Range { pub length: u64, } -impl CompactEncodable for Range { +impl CompactEncoding for Range { fn encoded_size(&self) -> Result { - let mut out = 1 + sum_encoded_size!(self, start); + let mut out = 1 + sum_encoded_size!(self.start); if self.length != 1 { out += self.length.encoded_size()?; } @@ -861,44 +531,6 @@ impl CompactEncodable for Range { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Range) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.start)?; - if value.length != 1 { - self.preencode(&value.length)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Range, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.drop { 1 } else { 0 }; - flags |= if value.length == 1 { 2 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.start, buffer)?; - if value.length != 1 { - self.encode(&value.length, buffer)?; - } - Ok(self.end()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let start: u64 = self.decode(buffer)?; - let drop = flags & 1 != 0; - let length: u64 = if flags & 2 != 0 { - 1 - } else { - self.decode(buffer)? - }; - Ok(Range { - drop, - length, - start, - }) - } -} - /// Extension message. Type 9. Use this for custom messages in your application. #[derive(Debug, Clone, PartialEq)] pub struct Extension { @@ -907,13 +539,13 @@ pub struct Extension { /// Message content, use empty vector for no data. pub message: Vec, } -impl CompactEncodable for Extension { +impl CompactEncoding for Extension { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, name, message)) + Ok(sum_encoded_size!(self.name, self.message)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, name, message)) + Ok(map_encode!(buffer, self.name, self.message)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -923,20 +555,3 @@ impl CompactEncodable for Extension { decode!(Self, buffer, { name: String, message: Vec }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Extension) -> Result { - self.preencode(&value.name)?; - self.preencode_raw_buffer(&value.message) - } - - fn encode(&mut self, value: &Extension, buffer: &mut [u8]) -> Result { - self.encode(&value.name, buffer)?; - self.encode_raw_buffer(&value.message, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let name: String = self.decode(buffer)?; - let message: Vec = self.decode_raw_buffer(buffer)?; - Ok(Extension { name, message }) - } -} From 1700d900514abc1097e4a38d9dcae4a9d50a231c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:44:51 -0400 Subject: [PATCH 075/135] make test easier to debug --- src/mqueue.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index cd86caf..de15e37 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -161,7 +161,9 @@ mod test { fn new_msg(channel: u64) -> ChannelMessage { ChannelMessage { channel, - message: crate::Message::NoData(NoData { request: channel }), + message: crate::Message::NoData(NoData { + request: channel + 1, + }), } } From 1d511598860605ccbd862bcb01b223f7a09adda1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:45:05 -0400 Subject: [PATCH 076/135] lint --- src/schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schema.rs b/src/schema.rs index 8a9a0a2..41ae6aa 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -78,7 +78,7 @@ pub struct Close { impl CompactEncoding for Close { fn encoded_size(&self) -> Result { - Ok(self.channel.encoded_size()?) + self.channel.encoded_size() } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { From a3395c9050f48004d6b0e152bd311a12d05bb837 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:50:58 -0400 Subject: [PATCH 077/135] rename old encoder trait to fix name collisio --- src/message/modern.rs | 521 +++++++++--------------------------------- src/mqueue.rs | 2 +- 2 files changed, 106 insertions(+), 417 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index e70bed2..b9b6792 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -21,7 +21,7 @@ pub(crate) trait Encoder: Sized + fmt::Debug { /// Encodes the message to a buffer. /// /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&self, buf: &mut [u8]) -> Result; + fn encoder_encode(&self, buf: &mut [u8]) -> Result; } impl Encoder for &[u8] { @@ -29,7 +29,7 @@ impl Encoder for &[u8] { Ok(self.len()) } - fn encode(&self, buf: &mut [u8]) -> Result { + fn encoder_encode(&self, buf: &mut [u8]) -> Result { let len = self.encoded_len()?; if len > buf.len() { return Err(EncodingError::new( @@ -199,7 +199,7 @@ impl Encoder for Vec { } #[instrument(skip_all)] - fn encode(&self, buf: &mut [u8]) -> Result { + fn encoder_encode(&self, buf: &mut [u8]) -> Result { let mut state = State::new(); let body_len = prencode_channel_messages(self, &mut state)?; write_uint24_le(body_len, buf); @@ -393,6 +393,17 @@ impl fmt::Debug for ChannelMessage { } } +impl fmt::Display for ChannelMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ChannelMessage {{ channel {}, message {} }}", + self.channel, + self.message.name() + ) + } +} + impl ChannelMessage { /// Create a new message. pub(crate) fn new(channel: u64, message: Message) -> Self { @@ -409,21 +420,21 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.len() <= 5 { + let og_len = buf.len(); + if og_len <= 5 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Open message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let open_msg: Open = state.decode(buf)?; + let (open_msg, buf) = Open::decode(buf)?; Ok(( Self { channel: open_msg.channel, message: Message::Open(open_msg), }, - state.start(), + og_len - buf.len(), )) } @@ -432,86 +443,104 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + let og_len = buf.len(); if buf.is_empty() { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Close message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let close_msg: Close = state.decode(buf)?; + let (close, buf) = Close::decode(buf)?; Ok(( Self { - channel: close_msg.channel, - message: Message::Close(close_msg), + channel: close.channel, + message: Message::Close(close), }, - state.start(), + og_len - buf.len(), )) } + #[instrument(err, skip_all)] + pub(crate) fn decode_from_channel_and_message( + buf: &[u8], + ) -> Result<(Self, &[u8]), EncodingError> { + //::decode(buf) + let (channel, buf) = u64::decode(buf)?; + let (message, buf) = ::decode(buf)?; + Ok((Self { channel, message }, buf)) + } /// Decode a normal channel message from a buffer. /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received empty message", )); } - let mut state = State::from_buffer(buf); - let typ: u64 = state.decode(buf)?; - let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok((Self { channel, message }, state.start() + length)) + let (message, buf) = ::decode(buf)?; + Ok((Self { channel, message }, buf)) } /// Performance optimization for letting calling encoded_len() already do /// the preencode phase of compact_encoding. - fn prepare_state(&self) -> Result { + fn prepare_state(&self) -> Result { Ok(if let Message::Open(_) = self.message { // Open message doesn't have a type // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state + self.message.encoded_size()? } else if let Message::Close(_) = self.message { // Close message doesn't have a type // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state + self.message.encoded_size()? } else { // The header is the channel id uint followed by message type uint // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state + typ.encoded_size()? + self.message.encoded_size()? }) } } +/// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode & +/// encode differently +impl CompactEncoding for ChannelMessage { + fn encoded_size(&self) -> Result { + Ok(self.channel.encoded_size()? + self.message.encoded_size()?) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = self.channel.encode(buffer)?; + ::encode(&self.message, rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + ChannelMessage::decode_from_channel_and_message(buffer) + } +} impl Encoder for ChannelMessage { fn encoded_len(&self) -> Result { - Ok(self.prepare_state()?.end()) + self.prepare_state() } - fn encode(&self, buf: &mut [u8]) -> Result { - let mut state = self.prepare_state()?; - if let Message::Open(_) = self.message { + #[instrument(skip_all)] + fn encoder_encode(&self, buf: &mut [u8]) -> Result { + let before = buf.len(); + let rest = if let Message::Open(_) = self.message { // Open message is different in that the type byte is missing - self.message.encode(&mut state, buf)?; + ::encode(&self.message, buf)? } else if let Message::Close(_) = self.message { // Close message is different in that the type byte is missing - self.message.encode(&mut state, buf)?; + ::encode(&self.message, buf)? } else { - let typ = self.message.typ(); - state.0.encode(&typ, buf)?; - self.message.encode(&mut state, buf)?; - } - Ok(state.start()) + ::encode(&self.message, buf)? + }; + Ok(before - rest.len()) } } @@ -528,168 +557,49 @@ mod tests { $( let channel = rand::random::() as u64; let channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); + let encoded_len = channel_message.encoded_len()?; let mut buf = vec![0u8; encoded_len]; - let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); - assert_eq!(channel, decoded.0); - assert_eq!($msg, decoded.1); + let rest = ::encode(&channel_message, &mut buf)?; + assert!(rest.is_empty()); + let (decoded, rest) = ::decode(&buf)?; + assert!(rest.is_empty()); + assert_eq!(decoded, channel_message); )* } } - /// A frame of data, either a buffer or a message. - #[derive(Clone, PartialEq)] - pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), - } - - impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } - } - - impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } - } - - impl From> for Frame { - fn from(m: Vec) -> Self { - Self::MessageBatch(m) - } - } - - impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } - } - - impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - fn preencode(&self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) - } - } - - impl Encoder for Frame { - fn encoded_len(&self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } + #[test] + fn boo() -> Result<(), EncodingError> { + let m = Message::Cancel(Cancel { request: 1 }); + let m = Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0, + }), + seek: Some(RequestSeek { bytes: 10 }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10, + }), + }); + let channel = rand::random::() as u64; + let channel_message = ChannelMessage::new(channel, m); + let encoded_len = channel_message.encoded_len()?; + let mut buf = vec![0u8; encoded_len]; + let rest = ::encode(&channel_message, &mut buf)?; + assert!(rest.is_empty()); + let (decoded, rest) = ::decode(&buf)?; + assert!(rest.is_empty()); + assert_eq!(decoded, channel_message); + Ok(()) } - #[test] - fn message_encode_decode() { + fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { Message::Synchronize(Synchronize{ fork: 0, @@ -770,227 +680,6 @@ mod tests { message: vec![0x44, 20] }) }; - } - - fn message_test_data() -> Vec { - vec![ - Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }), - Message::Request(Request { - id: 1, - fork: 1, - block: Some(RequestBlock { - index: 5, - nodes: 10, - }), - hash: Some(RequestBlock { - index: 20, - nodes: 0, - }), - seek: Some(RequestSeek { bytes: 10 }), - upgrade: Some(RequestUpgrade { - start: 0, - length: 10, - }), - }), - Message::Cancel(Cancel { request: 1 }), - Message::Data(Data { - request: 1, - fork: 5, - block: Some(DataBlock { - index: 5, - nodes: vec![Node::new(1, vec![0x01; 32], 100)], - value: vec![0xFF; 10], - }), - hash: Some(DataHash { - index: 20, - nodes: vec![Node::new(2, vec![0x02; 32], 200)], - }), - seek: Some(DataSeek { - bytes: 10, - nodes: vec![Node::new(3, vec![0x03; 32], 300)], - }), - upgrade: Some(DataUpgrade { - start: 0, - length: 10, - nodes: vec![Node::new(4, vec![0x04; 32], 400)], - additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], - signature: vec![0xAB; 32], - }), - }), - Message::NoData(NoData { request: 2 }), - Message::Want(Want { - start: 0, - length: 100, - }), - Message::Unwant(Unwant { - start: 10, - length: 2, - }), - Message::Bitfield(Bitfield { - start: 20, - bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], - }), - Message::Range(Range { - drop: true, - start: 12345, - length: 100000, - }), - Message::Extension(Extension { - name: "custom_extension/v1/open".to_string(), - message: vec![0x44, 20], - }), - ] - } - - impl Frame { - pub(crate) fn decode_multiple(buf: &[u8]) -> Result { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok(Frame::MessageBatch(combined_messages)) - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - // buffer length >= 3 or more and starts with 0 is message batch - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = - ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // len >= and - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = - ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } - } - } - - #[test] - fn compare_with_frame_encoding_decoding() -> std::io::Result<()> { - let channel = 42; - for msg in message_test_data() { - let channel_message = ChannelMessage::new(channel, msg); - let frame = Frame::from(channel_message.clone()); - let cmvec = vec![channel_message.clone()]; - - let mut fbuf = vec![0; frame.encoded_len()?]; - let mut cbuf = vec![0; cmvec.encoded_len()?]; - - assert_eq!(cbuf, fbuf); - - frame.encode(&mut fbuf)?; - cmvec.encode(&mut cbuf)?; - - assert_eq!(cbuf, fbuf); - - let fres = Frame::decode_multiple(&fbuf)?; - assert_eq!(fres, frame); - let cres_m = decode_framed_channel_messages(&cbuf)?.0; - assert_eq!(cres_m, cmvec); - } Ok(()) } } diff --git a/src/mqueue.rs b/src/mqueue.rs index de15e37..9be4ab7 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -82,7 +82,7 @@ impl + Sink> + Send + Unpin + 'static> Mes } let mut buf = vec![0; messages.encoded_len()?]; - match messages.encode(&mut buf) { + match messages.encoder_encode(&mut buf) { Ok(_) => {} Err(e) => { error!(error = ?e, "error encoding messages"); From a9bab5696148048295829f8a530265fe0f9d6c6c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:51:56 -0400 Subject: [PATCH 078/135] cleaning up messages --- src/message/modern.rs | 333 ++++++++++++++++++++++-------------------- 1 file changed, 175 insertions(+), 158 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index b9b6792..05e37be 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -1,12 +1,12 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; use hypercore::encoding::{ - CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, + decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::instrument; +use tracing::{instrument, trace}; const UINT24_HEADER_LEN: usize = 3; @@ -86,44 +86,41 @@ pub(crate) fn decode_framed_channel_messages( pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { - if buf.len() >= 3 && buf[0] == 0x00 { + let og_len = buf.len(); + if og_len >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { + let (_, mut buf) = take_array::<2>(buf)?; // Batch of messages let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { + let mut current_channel; + (current_channel, buf) = u64::decode(buf)?; + while !buf.is_empty() { // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { + let channel_message_length; + (channel_message_length, buf) = decode_usize(buf)?; + if channel_message_length > buf.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() + "received invalid message length: [{channel_message_length}] but we have [{}] remaining bytes. Initial buffer size [{og_len}]", + buf.len() ), )); } // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; + let channel_message; + (channel_message, buf) = ChannelMessage::decode(buf, current_channel)?; messages.push(channel_message); - state.add_start(channel_message_length)?; // After that, if there is an extra 0x00, that means the channel // changed. This works because of LE encoding, and channels starting // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; + if !buf.is_empty() && buf[0] == 0x00 { + (current_channel, buf) = u64::decode(buf)?; } } - Ok((messages, state.start())) + Ok((messages, og_len - buf.len())) } else if buf[1] == 0x01 { // Open message let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; @@ -139,11 +136,11 @@ pub(crate) fn decode_unframed_channel_messages( )) } } else if buf.len() >= 2 { + trace!("Decoding single ChannelMessage"); // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok((vec![channel_message], state.start() + length)) + let og_len = buf.len(); + let (channel_message, buf) = ChannelMessage::decode_from_channel_and_message(buf)?; + Ok((vec![channel_message], og_len - buf.len())) } else { Err(io::Error::new( io::ErrorKind::InvalidData, @@ -152,92 +149,84 @@ pub(crate) fn decode_unframed_channel_messages( } } -fn prencode_channel_messages( - messages: &[ChannelMessage], - state: &mut State, -) -> Result { - match messages.len().cmp(&1) { - std::cmp::Ordering::Less => {} +fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result { + Ok(match messages.len().cmp(&1) { + std::cmp::Ordering::Less => 0, std::cmp::Ordering::Equal => { if let Message::Open(_) = &messages[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; + 2 + &messages[0].encoded_len()? } else if let Message::Close(_) = &messages[0].message { // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; + 2 + &messages[0].encoded_len()? } else { - state.preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; + messages[0].channel.encoded_size()? + messages[0].encoded_size()? } } std::cmp::Ordering::Greater => { // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; + let mut out = 2; let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; + out += current_channel.encoded_size()?; for message in messages.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel - state.add_end(1)?; - state.preencode(&message.channel)?; + out += 1 + message.channel.encoded_size()?; current_channel = message.channel; } let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; + out += message.encoded_size()? + message_length; } + out } - }; - Ok(state.end()) + }) } impl Encoder for Vec { fn encoded_len(&self) -> Result { - let mut state = State::new(); - Ok(prencode_channel_messages(self, &mut state)? + UINT24_HEADER_LEN) + Ok(prencode_channel_messages(self)? + UINT24_HEADER_LEN) } #[instrument(skip_all)] fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let body_len = prencode_channel_messages(self, &mut state)?; - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + let body_len = prencode_channel_messages(self)?; + let mut u24_bytes = [0, 0, 0]; + write_uint24_le(body_len, u24_bytes.as_mut_slice()); + let mut buf = write_array(&u24_bytes, buf)?; + // skip the u24 we just wrote match self.len().cmp(&1) { std::cmp::Ordering::Less => {} std::cmp::Ordering::Equal => { + trace!("Encoding single ChannelMessage {}", self[0]); if let Message::Open(_) = &self[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + buf = write_array(&[0, 1], buf)?; + self[0].encode(buf)?; } else if let Message::Close(_) = &self[0].message { // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + buf = write_array(&[0, 3], buf)?; + self[0].encode(buf)?; } else { - state.encode(&self[0].channel, buf)?; - state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + self[0].encode(buf)?; } } std::cmp::Ordering::Greater => { // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + buf = write_array(&[0, 0], buf)?; let mut current_channel: u64 = self[0].channel; - state.encode(¤t_channel, buf)?; + buf = current_channel.encode(buf)?; for message in self.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; + buf = write_array(&[0], buf)?; + buf = message.channel.encode(buf)?; current_channel = message.channel; } let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; + buf = (message_length as u32).encode(buf)?; + buf = message.encode(buf)?; } } } @@ -265,6 +254,113 @@ pub enum Message { LocalSignal((String, Vec)), } +macro_rules! message_from { + ($($val:ident),+) => { + $( + impl From<$val> for Message { + fn from(value: $val) -> Self { + Message::$val(value) + } + } + )* + } +} +message_from!( + Open, + Close, + Synchronize, + Request, + Cancel, + Data, + NoData, + Want, + Unwant, + Bitfield, + Range, + Extension +); + +macro_rules! decode_message { + ($type:ty, $buf:expr) => {{ + let (x, rest) = <$type>::decode($buf)?; + (Message::from(x), rest) + }}; +} + +impl CompactEncoding for Message { + fn encoded_size(&self) -> Result { + let typ_size = if let Self::Open(_) | Self::Close(_) = &self { + 0 + } else { + self.typ().encoded_size()? + }; + let msg_size = match self { + Self::LocalSignal(_) => Ok(0), + Self::Open(x) => x.encoded_size(), + Self::Close(x) => x.encoded_size(), + Self::Synchronize(x) => x.encoded_size(), + Self::Request(x) => x.encoded_size(), + Self::Cancel(x) => x.encoded_size(), + Self::Data(x) => x.encoded_size(), + Self::NoData(x) => x.encoded_size(), + Self::Want(x) => x.encoded_size(), + Self::Unwant(x) => x.encoded_size(), + Self::Bitfield(x) => x.encoded_size(), + Self::Range(x) => x.encoded_size(), + Self::Extension(x) => x.encoded_size(), + }?; + Ok(typ_size + msg_size) + } + + #[instrument(skip_all, fields(name = self.name()))] + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = if let Self::Open(_) | Self::Close(_) = &self { + buffer + } else { + self.typ().encode(buffer)? + }; + match self { + Self::Open(x) => x.encode(rest), + Self::Close(x) => x.encode(rest), + Self::Synchronize(x) => x.encode(rest), + Self::Request(x) => x.encode(rest), + Self::Cancel(x) => x.encode(rest), + Self::Data(x) => x.encode(rest), + Self::NoData(x) => x.encode(rest), + Self::Want(x) => x.encode(rest), + Self::Unwant(x) => x.encode(rest), + Self::Bitfield(x) => x.encode(rest), + Self::Range(x) => x.encode(rest), + Self::Extension(x) => x.encode(rest), + Self::LocalSignal(_) => unimplemented!("do not encode LocalSignal"), + } + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (typ, rest) = u64::decode(buffer)?; + Ok(match typ { + 0 => decode_message!(Synchronize, rest), + 1 => decode_message!(Request, rest), + 2 => decode_message!(Cancel, rest), + 3 => decode_message!(Data, rest), + 4 => decode_message!(NoData, rest), + 5 => decode_message!(Want, rest), + 6 => decode_message!(Unwant, rest), + 7 => decode_message!(Bitfield, rest), + 8 => decode_message!(Range, rest), + 9 => decode_message!(Extension, rest), + _ => { + return Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )) + } + }) + } +} impl Message { /// Wire type of this message. pub(crate) fn typ(&self) -> u64 { @@ -282,71 +378,23 @@ impl Message { value => unimplemented!("{} does not have a type", value), } } - - /// Decode a message from a buffer based on type. - pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { - let mut state = HypercoreState::from_buffer(buf); - let message = match typ { - 0 => Ok(Self::Synchronize((*state).decode(buf)?)), - 1 => Ok(Self::Request(state.decode(buf)?)), - 2 => Ok(Self::Cancel((*state).decode(buf)?)), - 3 => Ok(Self::Data(state.decode(buf)?)), - 4 => Ok(Self::NoData((*state).decode(buf)?)), - 5 => Ok(Self::Want((*state).decode(buf)?)), - 6 => Ok(Self::Unwant((*state).decode(buf)?)), - 7 => Ok(Self::Bitfield((*state).decode(buf)?)), - 8 => Ok(Self::Range((*state).decode(buf)?)), - 9 => Ok(Self::Extension((*state).decode(buf)?)), - _ => Err(EncodingError::new( - EncodingErrorKind::InvalidData, - &format!("Invalid message type to decode: {typ}"), - )), - }?; - Ok((message, state.start())) - } - - /// Pre-encodes a message to state, returns length - pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { + /// Get the name of the message + pub fn name(&self) -> &'static str { match self { - Self::Open(ref message) => state.0.preencode(message)?, - Self::Close(ref message) => state.0.preencode(message)?, - Self::Synchronize(ref message) => state.0.preencode(message)?, - Self::Request(ref message) => state.preencode(message)?, - Self::Cancel(ref message) => state.0.preencode(message)?, - Self::Data(ref message) => state.preencode(message)?, - Self::NoData(ref message) => state.0.preencode(message)?, - Self::Want(ref message) => state.0.preencode(message)?, - Self::Unwant(ref message) => state.0.preencode(message)?, - Self::Bitfield(ref message) => state.0.preencode(message)?, - Self::Range(ref message) => state.0.preencode(message)?, - Self::Extension(ref message) => state.0.preencode(message)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.end()) - } - - /// Encodes a message to a given buffer, using preencoded state, results size - pub(crate) fn encode( - &self, - state: &mut HypercoreState, - buf: &mut [u8], - ) -> Result { - match self { - Self::Open(ref message) => state.0.encode(message, buf)?, - Self::Close(ref message) => state.0.encode(message, buf)?, - Self::Synchronize(ref message) => state.0.encode(message, buf)?, - Self::Request(ref message) => state.encode(message, buf)?, - Self::Cancel(ref message) => state.0.encode(message, buf)?, - Self::Data(ref message) => state.encode(message, buf)?, - Self::NoData(ref message) => state.0.encode(message, buf)?, - Self::Want(ref message) => state.0.encode(message, buf)?, - Self::Unwant(ref message) => state.0.encode(message, buf)?, - Self::Bitfield(ref message) => state.0.encode(message, buf)?, - Self::Range(ref message) => state.0.encode(message, buf)?, - Self::Extension(ref message) => state.0.encode(message, buf)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.start()) + Message::Open(_) => "Open", + Message::Close(_) => "Close", + Message::Synchronize(_) => "Synchronize", + Message::Request(_) => "Request", + Message::Cancel(_) => "Cancel", + Message::Data(_) => "Data", + Message::NoData(_) => "NoData", + Message::Want(_) => "Want", + Message::Unwant(_) => "Unwant", + Message::Bitfield(_) => "Bitfield", + Message::Range(_) => "Range", + Message::Extension(_) => "Extension", + Message::LocalSignal(_) => "LocalSignal", + } } } @@ -568,37 +616,6 @@ mod tests { } } #[test] - fn boo() -> Result<(), EncodingError> { - let m = Message::Cancel(Cancel { request: 1 }); - let m = Message::Request(Request { - id: 1, - fork: 1, - block: Some(RequestBlock { - index: 5, - nodes: 10, - }), - hash: Some(RequestBlock { - index: 20, - nodes: 0, - }), - seek: Some(RequestSeek { bytes: 10 }), - upgrade: Some(RequestUpgrade { - start: 0, - length: 10, - }), - }); - let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, m); - let encoded_len = channel_message.encoded_len()?; - let mut buf = vec![0u8; encoded_len]; - let rest = ::encode(&channel_message, &mut buf)?; - assert!(rest.is_empty()); - let (decoded, rest) = ::decode(&buf)?; - assert!(rest.is_empty()); - assert_eq!(decoded, channel_message); - Ok(()) - } - #[test] fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { Message::Synchronize(Synchronize{ From 80c29ee629a34177fa0577726ed5baddb86b9e1c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 24 Apr 2025 14:19:09 -0400 Subject: [PATCH 079/135] removing Encoder trait --- src/message/modern.rs | 74 ++++++++++++++----------------------------- 1 file changed, 24 insertions(+), 50 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index 05e37be..d3b8551 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -2,11 +2,12 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; use hypercore::encoding::{ decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, + VecEncodable, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{instrument, trace}; +use tracing::{instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; @@ -24,24 +25,6 @@ pub(crate) trait Encoder: Sized + fmt::Debug { fn encoder_encode(&self, buf: &mut [u8]) -> Result; } -impl Encoder for &[u8] { - fn encoded_len(&self) -> Result { - Ok(self.len()) - } - - fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let len = self.encoded_len()?; - if len > buf.len() { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - buf[..len].copy_from_slice(&self[..]); - Ok(len) - } -} - pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -61,12 +44,11 @@ pub(crate) fn decode_framed_channel_messages( &buf[index + header_len..index + header_len + body_len as usize], )?; if length != body_len as usize { - tracing::warn!( + warn!( "Did not know what to do with all the bytes, got {} but decoded {}. \ This may be because the peer implements a newer protocol version \ that has extra fields.", - body_len, - length + body_len, length ); } for message in msgs { @@ -150,32 +132,24 @@ pub(crate) fn decode_unframed_channel_messages( } fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result { - Ok(match messages.len().cmp(&1) { - std::cmp::Ordering::Less => 0, - std::cmp::Ordering::Equal => { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - 2 + &messages[0].encoded_len()? - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - 2 + &messages[0].encoded_len()? - } else { - messages[0].channel.encoded_size()? + messages[0].encoded_size()? - } - } - std::cmp::Ordering::Greater => { - // Two intro bytes 0x00 0x00, then channel id, then lengths + Ok(match messages { + [] => 0, + [msg] => match msg.message { + Message::Open(_) | Message::Close(_) => 2 + msg.encoded_size()?, + _ => msg.encoded_size()?, + }, + msgs => { let mut out = 2; let mut current_channel: u64 = messages[0].channel; out += current_channel.encoded_size()?; - for message in messages.iter() { + for message in msgs.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel out += 1 + message.channel.encoded_size()?; current_channel = message.channel; } - let message_length = message.encoded_len()?; + let message_length = message.message.encoded_size()?; out += message.encoded_size()? + message_length; } out @@ -224,7 +198,7 @@ impl Encoder for Vec { buf = message.channel.encode(buf)?; current_channel = message.channel; } - let message_length = message.encoded_len()?; + let message_length = message.message.encoded_size()?; buf = (message_length as u32).encode(buf)?; buf = message.encode(buf)?; } @@ -535,19 +509,17 @@ impl ChannelMessage { /// Performance optimization for letting calling encoded_len() already do /// the preencode phase of compact_encoding. fn prepare_state(&self) -> Result { - Ok(if let Message::Open(_) = self.message { - // Open message doesn't have a type + Ok(match self.message { + // Open & Close message doesn't have a type byte // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - self.message.encoded_size()? - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - self.message.encoded_size()? - } else { + Message::Open(_) | Message::Close(_) => self.message.encoded_size()?, // The header is the channel id uint followed by message type uint // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let typ = self.message.typ(); - typ.encoded_size()? + self.message.encoded_size()? + _ => { + let typ = self.message.typ(); + typ.encoded_size()? + self.message.encoded_size()? + } }) } } @@ -597,7 +569,8 @@ mod tests { use super::*; use hypercore::{ - DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, + encoding::to_encoded_bytes, DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, + RequestSeek, RequestUpgrade, }; macro_rules! message_enc_dec { @@ -615,6 +588,7 @@ mod tests { )* } } + #[test] fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { From a57a885f39c0568b5c49eb17ceef5d18246603ed Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 24 Apr 2025 16:05:12 -0400 Subject: [PATCH 080/135] wip impl VecEncodable for CompactEncoding --- src/message/modern.rs | 136 +++++++++++++++++++++++++++++------------- src/util.rs | 1 + 2 files changed, 94 insertions(+), 43 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index d3b8551..bdc6164 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -1,8 +1,8 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; use hypercore::encoding::{ - decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, - VecEncodable, + decode_usize, take_array, take_array_mut, write_array, CompactEncoding, EncodingError, + EncodingErrorKind, VecEncodable, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; @@ -10,6 +10,10 @@ use std::io; use tracing::{instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; +const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; +const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; +const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; +const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; /// Encode data into a buffer. /// @@ -64,12 +68,12 @@ pub(crate) fn decode_framed_channel_messages( } Ok((combined_messages, index)) } -// bad name bc it returns many. More like, decode unframed channel messages pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { let og_len = buf.len(); if og_len >= 3 && buf[0] == 0x00 { + // batch of NOT open/close messages if buf[1] == 0x00 { let (_, mut buf) = take_array::<2>(buf)?; // Batch of messages @@ -157,6 +161,19 @@ fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result Result<&mut [u8], EncodingError> { + let (header, rest) = take_array_mut::(buf)?; + write_uint24_le(n, header); + Ok(rest) +} + +/// decode a u24 from `buffer` as a `usize` +fn decode_u24(buffer: &[u8]) -> Result<(usize, &[u8]), EncodingError> { + let (u24_bytes, rest) = take_array::(buffer)?; + let (_, out) = stat_uint24_le(&u24_bytes).expect("input garunteed to be long enough"); + Ok((out as usize, rest)) +} + impl Encoder for Vec { fn encoded_len(&self) -> Result { Ok(prencode_channel_messages(self)? + UINT24_HEADER_LEN) @@ -165,9 +182,7 @@ impl Encoder for Vec { #[instrument(skip_all)] fn encoder_encode(&self, buf: &mut [u8]) -> Result { let body_len = prencode_channel_messages(self)?; - let mut u24_bytes = [0, 0, 0]; - write_uint24_le(body_len, u24_bytes.as_mut_slice()); - let mut buf = write_array(&u24_bytes, buf)?; + let mut buf = checked_write_uint24_le(body_len, buf)?; // skip the u24 we just wrote match self.len().cmp(&1) { std::cmp::Ordering::Less => {} @@ -505,23 +520,6 @@ impl ChannelMessage { let (message, buf) = ::decode(buf)?; Ok((Self { channel, message }, buf)) } - - /// Performance optimization for letting calling encoded_len() already do - /// the preencode phase of compact_encoding. - fn prepare_state(&self) -> Result { - Ok(match self.message { - // Open & Close message doesn't have a type byte - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - Message::Open(_) | Message::Close(_) => self.message.encoded_size()?, - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - _ => { - let typ = self.message.typ(); - typ.encoded_size()? + self.message.encoded_size()? - } - }) - } } /// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode & @@ -543,24 +541,77 @@ impl CompactEncoding for ChannelMessage { ChannelMessage::decode_from_channel_and_message(buffer) } } -impl Encoder for ChannelMessage { - fn encoded_len(&self) -> Result { - self.prepare_state() + +impl VecEncodable for ChannelMessage { + fn vec_encoded_size(vec: &[Self]) -> Result + where + Self: Sized, + { + Ok(prencode_channel_messages(vec)? + UINT24_HEADER_LEN) } - #[instrument(skip_all)] - fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let before = buf.len(); - let rest = if let Message::Open(_) = self.message { - // Open message is different in that the type byte is missing - ::encode(&self.message, buf)? - } else if let Message::Close(_) = self.message { - // Close message is different in that the type byte is missing - ::encode(&self.message, buf)? - } else { - ::encode(&self.message, buf)? - }; - Ok(before - rest.len()) + fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> + where + Self: Sized, + { + let body_len = prencode_channel_messages(&vec)?; + let mut buffer = checked_write_uint24_le(body_len, buffer)?; + match vec { + [] => Ok(buffer), + [msg] => { + buffer = match msg.message { + Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, buffer)?, + Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, buffer)?, + _ => msg.channel.encode(buffer)?, + }; + msg.message.encode(buffer) + } + msgs => { + buffer = write_array(&MULTI_MESSAGE_PREFIX, buffer)?; + let mut current_channel: u64 = msgs[0].channel; + buffer = current_channel.encode(buffer)?; + for msg in msgs { + if msg.channel != current_channel { + buffer = write_array(&CHANNEL_CHANGE_SEPERATOR, buffer)?; + buffer = msg.channel.encode(buffer)?; + current_channel = msg.channel; + } + let msg_len = msg.message.encoded_size()?; + buffer = (msg_len as u32).encode(buffer)?; + buffer = msg.message.encode(buffer)?; + } + Ok(buffer) + } + } + } + + fn vec_decode(buffer: &[u8]) -> Result<(Vec, &[u8]), EncodingError> + where + Self: Sized, + { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buffer.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buffer[index] == 0 { + index += 1; + continue; + } + let (frame_len, next_frame_start) = decode_u24(&buffer[index..])?; + let (msgs, length) = decode_unframed_channel_messages(&next_frame_start[..frame_len]) + .map_err(|e| EncodingError::external(&format!("{e}")))?; + if length != frame_len { + warn!( + "Did not know what to do with all the bytes, got {frame_len} but decoded {length}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + ); + } + combined_messages.extend(msgs); + index += UINT24_HEADER_LEN + frame_len; + } + todo!() } } @@ -569,8 +620,7 @@ mod tests { use super::*; use hypercore::{ - encoding::to_encoded_bytes, DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, - RequestSeek, RequestUpgrade, + DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, }; macro_rules! message_enc_dec { @@ -578,8 +628,8 @@ mod tests { $( let channel = rand::random::() as u64; let channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len()?; - let mut buf = vec![0u8; encoded_len]; + let encoded_size = channel_message.encoded_size()?; + let mut buf = vec![0u8; encoded_size]; let rest = ::encode(&channel_message, &mut buf)?; assert!(rest.is_empty()); let (decoded, rest) = ::decode(&buf)?; diff --git a/src/util.rs b/src/util.rs index 21e4c75..579a0fd 100644 --- a/src/util.rs +++ b/src/util.rs @@ -73,6 +73,7 @@ pub(crate) fn write_uint24_le(n: usize, buf: &mut [u8]) { } #[inline] +/// Read uint24 from the given `buffer` as a `u64` pub(crate) fn stat_uint24_le(buffer: &[u8]) -> Option<(usize, u64)> { if buffer.len() >= 3 { let len = From 564c0060c6fbcdd52b59108166562a3535bab211 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:28:17 -0400 Subject: [PATCH 081/135] Only use ChanMsg::channel when not Open & Close --- src/message/modern.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index bdc6164..ae0dd22 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -526,11 +526,21 @@ impl ChannelMessage { /// encode differently impl CompactEncoding for ChannelMessage { fn encoded_size(&self) -> Result { - Ok(self.channel.encoded_size()? + self.message.encoded_size()?) + let channel_size = if let Message::Open(_) | Message::Close(_) = &self.message { + 0 + } else { + self.channel.encoded_size()? + }; + + Ok(channel_size + self.message.encoded_size()?) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - let rest = self.channel.encode(buffer)?; + let rest = if let Message::Open(_) | Message::Close(_) = &self.message { + buffer + } else { + self.channel.encode(buffer)? + }; ::encode(&self.message, rest) } From 861ee29bd60254128681f7860d162e88880a604f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:28:55 -0400 Subject: [PATCH 082/135] Add #[instrument] --- src/message/modern.rs | 1 + src/schema.rs | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/message/modern.rs b/src/message/modern.rs index ae0dd22..23524b2 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -456,6 +456,7 @@ impl ChannelMessage { /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it + #[instrument(skip_all, err)] pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { let og_len = buf.len(); if og_len <= 5 { diff --git a/src/schema.rs b/src/schema.rs index 41ae6aa..c58a40b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -6,6 +6,7 @@ use hypercore::{ decode, DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, }; +use tracing::instrument; /// Open message #[derive(Debug, Clone, PartialEq)] @@ -21,6 +22,7 @@ pub struct Open { } impl CompactEncoding for Open { + #[instrument(skip_all, ret, err)] fn encoded_size(&self) -> Result { let out = sum_encoded_size!(self.channel, self.protocol, self.discovery_key); if self.capability.is_some() { @@ -33,6 +35,7 @@ impl CompactEncoding for Open { Ok(out) } + #[instrument(skip_all)] fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let rest = map_encode!(buffer, self.channel, self.protocol, self.discovery_key); if let Some(cap) = &self.capability { @@ -42,6 +45,7 @@ impl CompactEncoding for Open { Ok(rest) } + #[instrument(skip_all, err)] fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> where Self: Sized, From a62ce55b58cfef5000e88f180930d66bb2a63651 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:33:38 -0400 Subject: [PATCH 083/135] rm old stuff --- src/message/old.rs | 813 -------------------------------------------- src/protocol/old.rs | 706 -------------------------------------- src/reader.rs | 246 -------------- src/writer.rs | 198 ----------- 4 files changed, 1963 deletions(-) delete mode 100644 src/message/old.rs delete mode 100644 src/protocol/old.rs delete mode 100644 src/reader.rs delete mode 100644 src/writer.rs diff --git a/src/message/old.rs b/src/message/old.rs deleted file mode 100644 index d4afd64..0000000 --- a/src/message/old.rs +++ /dev/null @@ -1,813 +0,0 @@ -use crate::schema::*; -use crate::util::{stat_uint24_le, write_uint24_le}; -use hypercore::encoding::{ - CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, -}; -use pretty_hash::fmt as pretty_fmt; -use std::fmt; -use std::io; - -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub(crate) enum FrameType { - Raw, - Message, -} - -/// Encode data into a buffer. -/// -/// This trait is implemented on data frames and their components -/// (channel messages, messages, and individual message types through prost). -pub(crate) trait Encoder: Sized + fmt::Debug { - /// Calculates the length that the encoded message needs. - fn encoded_len(&mut self) -> Result; - - /// Encodes the message to a buffer. - /// - /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&mut self, buf: &mut [u8]) -> Result; -} - -impl Encoder for &[u8] { - fn encoded_len(&mut self) -> Result { - Ok(self.len()) - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let len = self.encoded_len()?; - if len > buf.len() { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - buf[..len].copy_from_slice(&self[..]); - Ok(len) - } -} - -/// A frame of data, either a buffer or a message. -#[derive(Clone, PartialEq)] -pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), -} - -impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } -} - -impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } -} - -impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => { - let mut index = 0; - let mut raw_batch: Vec> = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - raw_batch.push( - buf[index + header_len..index + header_len + body_len as usize] - .to_vec(), - ); - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in raw batch", - )); - } - } - Ok(Frame::RawBatch(raw_batch)) - } - FrameType::Message => { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok(Frame::MessageBatch(combined_messages)) - } - } - } - - /// Decode a frame from a buffer. - pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), - FrameType::Message => { - let (frame, _) = Self::decode_message(buf)?; - Ok(frame) - } - } - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - // buffer length >= 3 or more and starts with 0 is message batch - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // len >= and - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } - } - - fn preencode(&mut self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) - } -} - -impl Encoder for Frame { - fn encoded_len(&mut self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref mut messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } -} - -/// A protocol message. -#[derive(Debug, Clone, PartialEq)] -#[allow(missing_docs)] -pub enum Message { - Open(Open), - Close(Close), - Synchronize(Synchronize), - Request(Request), - Cancel(Cancel), - Data(Data), - NoData(NoData), - Want(Want), - Unwant(Unwant), - Bitfield(Bitfield), - Range(Range), - Extension(Extension), - /// A local signalling message never sent over the wire - LocalSignal((String, Vec)), -} - -impl Message { - /// Wire type of this message. - pub(crate) fn typ(&self) -> u64 { - match self { - Self::Synchronize(_) => 0, - Self::Request(_) => 1, - Self::Cancel(_) => 2, - Self::Data(_) => 3, - Self::NoData(_) => 4, - Self::Want(_) => 5, - Self::Unwant(_) => 6, - Self::Bitfield(_) => 7, - Self::Range(_) => 8, - Self::Extension(_) => 9, - value => unimplemented!("{} does not have a type", value), - } - } - - /// Decode a message from a buffer based on type. - pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { - let mut state = HypercoreState::from_buffer(buf); - let message = match typ { - 0 => Ok(Self::Synchronize((*state).decode(buf)?)), - 1 => Ok(Self::Request(state.decode(buf)?)), - 2 => Ok(Self::Cancel((*state).decode(buf)?)), - 3 => Ok(Self::Data(state.decode(buf)?)), - 4 => Ok(Self::NoData((*state).decode(buf)?)), - 5 => Ok(Self::Want((*state).decode(buf)?)), - 6 => Ok(Self::Unwant((*state).decode(buf)?)), - 7 => Ok(Self::Bitfield((*state).decode(buf)?)), - 8 => Ok(Self::Range((*state).decode(buf)?)), - 9 => Ok(Self::Extension((*state).decode(buf)?)), - _ => Err(EncodingError::new( - EncodingErrorKind::InvalidData, - &format!("Invalid message type to decode: {typ}"), - )), - }?; - Ok((message, state.start())) - } - - /// Pre-encodes a message to state, returns length - pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { - match self { - Self::Open(ref message) => state.0.preencode(message)?, - Self::Close(ref message) => state.0.preencode(message)?, - Self::Synchronize(ref message) => state.0.preencode(message)?, - Self::Request(ref message) => state.preencode(message)?, - Self::Cancel(ref message) => state.0.preencode(message)?, - Self::Data(ref message) => state.preencode(message)?, - Self::NoData(ref message) => state.0.preencode(message)?, - Self::Want(ref message) => state.0.preencode(message)?, - Self::Unwant(ref message) => state.0.preencode(message)?, - Self::Bitfield(ref message) => state.0.preencode(message)?, - Self::Range(ref message) => state.0.preencode(message)?, - Self::Extension(ref message) => state.0.preencode(message)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.end()) - } - - /// Encodes a message to a given buffer, using preencoded state, results size - pub(crate) fn encode( - &self, - state: &mut HypercoreState, - buf: &mut [u8], - ) -> Result { - match self { - Self::Open(ref message) => state.0.encode(message, buf)?, - Self::Close(ref message) => state.0.encode(message, buf)?, - Self::Synchronize(ref message) => state.0.encode(message, buf)?, - Self::Request(ref message) => state.encode(message, buf)?, - Self::Cancel(ref message) => state.0.encode(message, buf)?, - Self::Data(ref message) => state.encode(message, buf)?, - Self::NoData(ref message) => state.0.encode(message, buf)?, - Self::Want(ref message) => state.0.encode(message, buf)?, - Self::Unwant(ref message) => state.0.encode(message, buf)?, - Self::Bitfield(ref message) => state.0.encode(message, buf)?, - Self::Range(ref message) => state.0.encode(message, buf)?, - Self::Extension(ref message) => state.0.encode(message, buf)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.start()) - } -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Open(msg) => write!( - f, - "Open(discovery_key: {}, capability <{}>)", - pretty_fmt(&msg.discovery_key).unwrap(), - msg.capability.as_ref().map_or(0, |c| c.len()) - ), - Self::Data(msg) => write!( - f, - "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})", - msg.request, - msg.fork, - msg.block.is_some(), - msg.hash.is_some(), - msg.seek.is_some(), - msg.upgrade.is_some(), - ), - _ => write!(f, "{:?}", &self), - } - } -} - -/// A message on a channel. -#[derive(Clone)] -pub(crate) struct ChannelMessage { - pub(crate) channel: u64, - pub(crate) message: Message, - state: Option, -} - -impl PartialEq for ChannelMessage { - fn eq(&self, other: &Self) -> bool { - self.channel == other.channel && self.message == other.message - } -} - -impl fmt::Debug for ChannelMessage { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "ChannelMessage({}, {})", self.channel, self.message) - } -} - -impl ChannelMessage { - /// Create a new message. - pub(crate) fn new(channel: u64, message: Message) -> Self { - Self { - channel, - message, - state: None, - } - } - - /// Consume self and return (channel, Message). - pub(crate) fn into_split(self) -> (u64, Message) { - (self.channel, self.message) - } - - /// Decodes an open message for a channel message from a buffer. - /// - /// Note: `buf` has to have a valid length, and without the 3 LE - /// bytes in it - pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.len() <= 5 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "received too short Open message", - )); - } - - let mut state = State::new_with_start_and_end(0, buf.len()); - let open_msg: Open = state.decode(buf)?; - Ok(( - Self { - channel: open_msg.channel, - message: Message::Open(open_msg), - state: None, - }, - state.start(), - )) - } - - /// Decodes a close message for a channel message from a buffer. - /// - /// Note: `buf` has to have a valid length, and without the 3 LE - /// bytes in it - pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.is_empty() { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "received too short Close message", - )); - } - let mut state = State::new_with_start_and_end(0, buf.len()); - let close_msg: Close = state.decode(buf)?; - Ok(( - Self { - channel: close_msg.channel, - message: Message::Close(close_msg), - state: None, - }, - state.start(), - )) - } - - /// Decode a normal channel message from a buffer. - /// - /// Note: `buf` has to have a valid length, and without the 3 LE - /// bytes in it - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { - if buf.len() <= 1 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "received empty message", - )); - } - let mut state = State::from_buffer(buf); - let typ: u64 = state.decode(buf)?; - let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok(( - Self { - channel, - message, - state: None, - }, - state.start() + length, - )) - } - - /// Performance optimization for letting calling encoded_len() already do - /// the preencode phase of compact_encoding. - fn prepare_state(&mut self) -> Result<(), EncodingError> { - if self.state.is_none() { - let state = if let Message::Open(_) = self.message { - // Open message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else { - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); - let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state - }; - self.state = Some(state); - } - Ok(()) - } -} - -impl Encoder for ChannelMessage { - fn encoded_len(&mut self) -> Result { - self.prepare_state()?; - Ok(self.state.as_ref().unwrap().end()) - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - self.prepare_state()?; - let state = self.state.as_mut().unwrap(); - if let Message::Open(_) = self.message { - // Open message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else if let Message::Close(_) = self.message { - // Close message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else { - let typ = self.message.typ(); - state.0.encode(&typ, buf)?; - self.message.encode(state, buf)?; - } - Ok(state.start()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hypercore::{ - DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, - }; - - macro_rules! message_enc_dec { - ($( $msg:expr ),*) => { - $( - let channel = rand::random::() as u64; - let mut channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); - let mut buf = vec![0u8; encoded_len]; - let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); - assert_eq!(channel, decoded.0); - assert_eq!($msg, decoded.1); - )* - } - } - #[test] - fn frame_encode_decode() -> std::io::Result<()> { - let msg = Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }); - - let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, msg); - - let mut frame = Frame::from(channel_message); - let mut buf = vec![0; frame.encoded_len()?]; - frame.encode(&mut buf)?; - let res_frame = Frame::decode_multiple(&buf, &FrameType::Message)?; - assert_eq!(res_frame, frame); - Ok(()) - } - #[test] - fn frame_encode_decode_bar() -> std::io::Result<()> { - let msg = Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }); - - //let channel = rand::random::() as u64; - let channel = 42; - let channel_message = ChannelMessage::new(channel, msg); - - let mut frame = Frame::from(channel_message.clone()); - - let mut fbuf = vec![0; frame.encoded_len()?]; - - frame.encode(&mut fbuf)?; - - let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; - assert_eq!(fres, frame); - //assert_eq!(cres, cmvec); - //println!("REG frame buf\t{frame_buf:02X?}"); - //let res_frame = Frame::decode(&frame_buf, &FrameType::Message)?; - //dbg!(res_frame); - //let res_frame = Frame::decode_multiple(&frame_buf, &FrameType::Message)?; - //dbg!(res_frame); - - //let mut vec_frame_buf = vec![0; vec_frame.encoded_len()?]; - //vec_frame.encode(&mut vec_frame_buf)?; - - //assert_eq!(vec_frame_buf, frame_buf); - //println!("VEC frame buf\t{vec_frame_buf:02X?}"); - - //let res_frame = Frame::decode(&vec_frame_buf, &FrameType::Message)?; - //dbg!(res_frame); - //let res_frame = Frame::decode_multiple(&vec_frame_buf, &FrameType::Message)?; - //dbg!(&res_frame); - - //let (msg, _len) = decode_channel_messages(&vec_frame_buf)?; - //assert_eq!(msg, vec![channel_message]); - - //assert_eq!(res_frame, frame); - Ok(()) - } - - #[test] - fn message_encode_decode() { - message_enc_dec! { - Message::Synchronize(Synchronize{ - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }), - Message::Request(Request { - id: 1, - fork: 1, - block: Some(RequestBlock { - index: 5, - nodes: 10, - }), - hash: Some(RequestBlock { - index: 20, - nodes: 0 - }), - seek: Some(RequestSeek { - bytes: 10 - }), - upgrade: Some(RequestUpgrade { - start: 0, - length: 10 - }) - }), - Message::Cancel(Cancel { - request: 1, - }), - Message::Data(Data{ - request: 1, - fork: 5, - block: Some(DataBlock { - index: 5, - nodes: vec![Node::new(1, vec![0x01; 32], 100)], - value: vec![0xFF; 10] - }), - hash: Some(DataHash { - index: 20, - nodes: vec![Node::new(2, vec![0x02; 32], 200)], - }), - seek: Some(DataSeek { - bytes: 10, - nodes: vec![Node::new(3, vec![0x03; 32], 300)], - }), - upgrade: Some(DataUpgrade { - start: 0, - length: 10, - nodes: vec![Node::new(4, vec![0x04; 32], 400)], - additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], - signature: vec![0xAB; 32] - }) - }), - Message::NoData(NoData { - request: 2, - }), - Message::Want(Want { - start: 0, - length: 100, - }), - Message::Unwant(Unwant { - start: 10, - length: 2, - }), - Message::Bitfield(Bitfield { - start: 20, - bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], - }), - Message::Range(Range { - drop: true, - start: 12345, - length: 100000 - }), - Message::Extension(Extension { - name: "custom_extension/v1/open".to_string(), - message: vec![0x44, 20] - }) - }; - } -} diff --git a/src/protocol/old.rs b/src/protocol/old.rs deleted file mode 100644 index 20c9064..0000000 --- a/src/protocol/old.rs +++ /dev/null @@ -1,706 +0,0 @@ -use async_channel::{Receiver, Sender}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::Stream; -use futures_timer::Delay; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::fmt; -use std::future::Future; -use std::io::{self, Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; -use tracing::{instrument, trace}; - -use crate::channels::{Channel, ChannelMap}; -use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, FrameType, Message}; -use crate::reader::ReadState; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::writer::WriteState; - -macro_rules! return_error { - ($msg:expr) => { - if let Err(e) = $msg { - return Poll::Ready(Err(e)); - } - }; -} - -const CHANNEL_CAP: usize = 1000; -const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); - -/// Options for a Protocol instance. -#[derive(Debug)] -pub(crate) struct Options { - /// Whether this peer initiated the IO connection for this protoccol - pub(crate) is_initiator: bool, - /// Enable or disable the handshake. - /// Disabling the handshake will also disable capabilitity verification. - /// Don't disable this if you're not 100% sure you want this. - pub(crate) noise: bool, - /// Enable or disable transport encryption. - pub(crate) encrypted: bool, -} - -impl Options { - /// Create with default options. - pub(crate) fn new(is_initiator: bool) -> Self { - Self { - is_initiator, - noise: true, - encrypted: true, - } - } -} - -/// Remote public key (32 bytes). -pub(crate) type RemotePublicKey = [u8; 32]; -/// Discovery key (32 bytes). -pub type DiscoveryKey = [u8; 32]; -/// Key (32 bytes). -pub type Key = [u8; 32]; - -/// A protocol event. -#[non_exhaustive] -#[derive(PartialEq)] -pub enum Event { - /// Emitted after the handshake with the remote peer is complete. - /// This is the first event (if the handshake is not disabled). - Handshake(RemotePublicKey), - /// Emitted when the remote peer opens a channel that we did not yet open. - DiscoveryKey(DiscoveryKey), - /// Emitted when a channel is established. - Channel(Channel), - /// Emitted when a channel is closed. - Close(DiscoveryKey), - /// Convenience event to make it possible to signal the protocol from a channel. - /// See channel.signal_local() and protocol.commands().signal_local(). - LocalSignal((String, Vec)), -} - -/// A protocol command. -#[derive(Debug)] -pub enum Command { - /// Open a channel - Open(Key), - /// Close a channel by discovery key - Close(DiscoveryKey), - /// Signal locally to protocol - SignalLocal((String, Vec)), -} - -impl fmt::Debug for Event { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Event::Handshake(remote_key) => { - write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key)) - } - Event::DiscoveryKey(discovery_key) => { - write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key)) - } - Event::Channel(channel) => { - write!(f, "Channel({})", &pretty_hash(channel.discovery_key())) - } - Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)), - Event::LocalSignal((name, data)) => { - write!(f, "LocalSignal(name={},len={})", name, data.len()) - } - } - } -} - -/// Protocol state -#[allow(clippy::large_enum_variant)] -pub(crate) enum State { - NotInitialized, - // The Handshake struct sits behind an option only so that we can .take() - // it out, it's never actually empty when in State::Handshake. - Handshake(Option), - SecretStream(Option), - Established, -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::NotInitialized => write!(f, "NotInitialized"), - State::Handshake(_) => write!(f, "Handshaking"), - State::SecretStream(_) => write!(f, "SecretStream"), - State::Established => write!(f, "Established"), - } - } -} - -/// A Protocol stream. -pub struct Protocol { - write_state: WriteState, - read_state: ReadState, - io: IO, - state: State, - options: Options, - handshake: Option, - channels: ChannelMap, - command_rx: Receiver, - command_tx: CommandTx, - outbound_rx: Receiver>, - outbound_tx: Sender>, - keepalive: Delay, - queued_events: VecDeque, -} - -impl std::fmt::Debug for Protocol { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Protocol") - .field("write_state", &self.write_state) - .field("read_state", &self.read_state) - //.field("io", &self.io) - .field("state", &self.state) - .field("options", &self.options) - .field("handshake", &self.handshake) - .field("channels", &self.channels) - .field("command_rx", &self.command_rx) - .field("command_tx", &self.command_tx) - .field("outbound_rx", &self.outbound_rx) - .field("outbound_tx", &self.outbound_tx) - .field("keepalive", &self.keepalive) - .field("queued_events", &self.queued_events) - .finish() - } -} - -impl Protocol -where - IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, -{ - /// Create a new protocol instance. - pub(crate) fn new(io: IO, options: Options) -> Self { - let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP); - let (outbound_tx, outbound_rx): ( - Sender>, - Receiver>, - ) = async_channel::bounded(1); - Protocol { - io, - read_state: ReadState::new(), - write_state: WriteState::new(), - options, - state: State::NotInitialized, - channels: ChannelMap::new(), - handshake: None, - command_rx, - command_tx: CommandTx(command_tx), - outbound_tx, - outbound_rx, - keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)), - queued_events: VecDeque::new(), - } - } - - /// Whether this protocol stream initiated the underlying IO connection. - pub fn is_initiator(&self) -> bool { - self.options.is_initiator - } - - /// Get your own Noise public key. - /// - /// Empty before the handshake completed. - pub fn public_key(&self) -> Option<&[u8]> { - match &self.handshake { - None => None, - Some(handshake) => Some(handshake.local_pubkey.as_slice()), - } - } - - /// Get the remote's Noise public key. - /// - /// Empty before the handshake completed. - pub fn remote_public_key(&self) -> Option<&[u8]> { - match &self.handshake { - None => None, - Some(handshake) => Some(handshake.remote_pubkey.as_slice()), - } - } - - /// Get a sender to send commands. - pub fn commands(&self) -> CommandTx { - self.command_tx.clone() - } - - /// Give a command to the protocol. - #[instrument(skip(self))] - pub async fn command(&mut self, command: Command) -> Result<()> { - self.command_tx.send(command).await - } - - /// Open a new protocol channel. - /// - /// Once the other side proofed that it also knows the `key`, the channel is emitted as - /// `Event::Channel` on the protocol event stream. - #[instrument(skip(self))] - pub async fn open(&mut self, key: Key) -> Result<()> { - self.command_tx.open(key).await - } - - /// Iterator of all currently opened channels. - pub fn channels(&self) -> impl Iterator { - self.channels.iter().map(|c| c.discovery_key()) - } - - #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - if let State::NotInitialized = this.state { - return_error!(this.init()); - } - - // Drain queued events first. - if let Some(event) = this.queued_events.pop_front() { - return Poll::Ready(Ok(event)); - } - - // Read and process incoming messages. - return_error!(this.poll_inbound_read(cx)); - - if let State::Established = this.state { - // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); - } - - // Poll the keepalive timer. - this.poll_keepalive(cx); - - // Write everything we can write. - return_error!(this.poll_outbound_write(cx)); - - // Check if any events are enqueued. - if let Some(event) = this.queued_events.pop_front() { - Poll::Ready(Ok(event)) - } else { - Poll::Pending - } - } - - fn init(&mut self) -> Result<()> { - trace!( - "protocol Init, state {:?}, options {:?}", - self.state, - self.options - ); - match self.state { - State::NotInitialized => {} - _ => return Ok(()), - }; - - self.state = if self.options.noise { - let mut handshake = Handshake::new(self.options.is_initiator)?; - // If the handshake start returns a buffer, send it now. - if let Some(buf) = handshake.start()? { - // TODO what if this fails? or returns false - self.queue_frame_direct(buf.to_vec()).unwrap(); - } - self.read_state.set_frame_type(FrameType::Raw); - State::Handshake(Some(handshake)) - } else { - self.read_state.set_frame_type(FrameType::Message); - State::Established - }; - - Ok(()) - } - - /// Poll commands. - #[instrument(skip_all)] - fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { - while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { - self.on_command(command)?; - } - Ok(()) - } - - /// Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&mut self, cx: &mut Context<'_>) { - if Pin::new(&mut self.keepalive).poll(cx).is_ready() { - if let State::Established = self.state { - // 24 bit header for the empty message, hence the 3 - self.write_state - .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]])); - } - self.keepalive.reset(KEEPALIVE_DURATION); - } - } - - fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { - // If message is close, close the local channel. - if let ChannelMessage { - channel, - message: Message::Close(_), - .. - } = message - { - self.close_local(*channel); - // If message is a LocalSignal, emit an event and return false to indicate - // this message should be filtered out. - } else if let ChannelMessage { - message: Message::LocalSignal((name, data)), - .. - } = message - { - self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec()))); - return false; - } - true - } - - /// Poll for inbound messages and processs them. - #[instrument(skip_all)] - fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { - loop { - let msg = self.read_state.poll_reader(cx, &mut self.io); - match msg { - Poll::Ready(Ok(message)) => { - self.on_inbound_frame(message)?; - } - Poll::Ready(Err(e)) => return Err(e), - Poll::Pending => return Ok(()), - } - } - } - - /// Poll for outbound messages and write them. - #[instrument(skip_all)] - fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { - loop { - if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { - return Err(e); - } - // if no parking or setup in progress - if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { - return Ok(()); - } - - match Pin::new(&mut self.outbound_rx).poll_next(cx) { - Poll::Ready(Some(mut messages)) => { - if !messages.is_empty() { - messages.retain(|message| self.on_outbound_message(message)); - if !messages.is_empty() { - let frame = Frame::MessageBatch(messages); - // TODO try replacing this with queue_frame - self.write_state.park_frame(frame); - } - } - } - Poll::Ready(None) => unreachable!("Channel closed before end"), - Poll::Pending => return Ok(()), - } - } - } - - #[instrument(skip_all)] - fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { - match frame { - Frame::RawBatch(raw_batch) => { - let mut processed_state: Option = None; - for buf in raw_batch { - let state_name: String = format!("{:?}", self.state); - match self.state { - State::Handshake(_) => self.on_handshake_message(buf)?, - State::SecretStream(_) => self.on_secret_stream_message(buf)?, - State::Established => { - if let Some(processed_state) = processed_state.as_ref() { - // last state before established - let previous_state = if self.options.encrypted { - // was SecretStream if we're encrypted - State::SecretStream(None) - } else { - // or wa hasdshake if we're not encrypted - State::Handshake(None) - }; - - // if htis raw_batch included regular messages (not handshake) - // after handshake stuff - if processed_state == &format!("{previous_state:?}") { - // This is the unlucky case where the batch had two or more messages where - // the first one was correctly identified as Raw but everything - // after that should have been (decrypted and) a MessageBatch. Correct the mistake - // here post-hoc. - let buf = self.read_state.decrypt_buf(&buf)?; - let frame = Frame::decode(&buf, &FrameType::Message)?; - self.on_inbound_frame(frame)?; - continue; - } - } - unreachable!( - "May not receive raw frames in Established state" - ) - } - _ => unreachable!( - "May not receive raw frames outside of handshake or secretstream state, was {:?}", - self.state - ), - }; - if processed_state.is_none() { - processed_state = Some(state_name) - } - } - Ok(()) - } - Frame::MessageBatch(channel_messages) => match self.state { - State::Established => { - for channel_message in channel_messages { - self.on_inbound_message(channel_message)? - } - Ok(()) - } - _ => unreachable!("May not receive message batch frames when not established"), - }, - } - } - - fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { - let mut handshake = match &mut self.state { - State::Handshake(handshake) => handshake.take().unwrap(), - _ => unreachable!("May not call on_handshake_message when not in Handshake state"), - }; - - if let Some(response_buf) = handshake.read(&buf)? { - self.queue_frame_direct(response_buf.to_vec()).unwrap(); - } - - if !handshake.complete() { - self.state = State::Handshake(Some(handshake)); - } else { - let handshake_result = handshake.into_result()?; - - if self.options.encrypted { - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; - self.state = State::SecretStream(Some(cipher)); - - // Send the secret stream init message header to the other side - self.queue_frame_direct(init_msg).unwrap(); - } else { - // Skip secret stream and go straight to Established, then notify about - // handshake - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - } - // Store handshake result - self.handshake = Some(handshake_result.clone()); - } - Ok(()) - } - - fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { - let encrypt_cipher = match &mut self.state { - State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), - _ => { - unreachable!("May not call on_secret_stream_message when not in SecretStream state") - } - }; - let handshake_result = &self - .handshake - .as_ref() - .expect("Handshake result must be set before secret stream"); - let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; - self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); - self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); - self.read_state.set_frame_type(FrameType::Message); - - // Lastly notify that handshake is ready and set state to established - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - Ok(()) - } - #[instrument(skip_all)] - fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { - // let channel_message = ChannelMessage::decode(buf)?; - let (remote_id, message) = channel_message.into_split(); - match message { - Message::Open(msg) => self.on_open(remote_id, msg)?, - Message::Close(msg) => self.on_close(remote_id, msg)?, - _ => self - .channels - .forward_inbound_message(remote_id as usize, message)?, - } - Ok(()) - } - - #[instrument(skip(self))] - fn on_command(&mut self, command: Command) -> Result<()> { - match command { - Command::Open(key) => self.command_open(key), - Command::Close(discovery_key) => self.command_close(discovery_key), - Command::SignalLocal((name, data)) => self.command_signal_local(name, data), - } - } - - /// Open a Channel with the given key. Adding it to our channel map - #[instrument(skip_all)] - fn command_open(&mut self, key: Key) -> Result<()> { - // Create a new channel. - let channel_handle = self.channels.attach_local(key); - // Safe because attach_local always puts Some(local_id) - let local_id = channel_handle.local_id().unwrap(); - let discovery_key = *channel_handle.discovery_key(); - - // If the channel was already opened from the remote end, verify, and if - // verification is ok, push a channel open event. - if channel_handle.is_connected() { - self.accept_channel(local_id)?; - } - - // Tell the remote end about the new channel. - let capability = self.capability(&key); - let channel = local_id as u64; - let message = Message::Open(Open { - channel, - protocol: PROTOCOL_NAME.to_string(), - discovery_key: discovery_key.to_vec(), - capability, - }); - let channel_message = ChannelMessage::new(channel, message); - self.write_state - .queue_frame(Frame::MessageBatch(vec![channel_message])); - Ok(()) - } - - fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { - if self.channels.has_channel(&discovery_key) { - self.channels.remove(&discovery_key); - self.queue_event(Event::Close(discovery_key)); - } - Ok(()) - } - - fn command_signal_local(&mut self, name: String, data: Vec) -> Result<()> { - self.queue_event(Event::LocalSignal((name, data))); - Ok(()) - } - - #[instrument(skip(self))] - fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { - let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; - let channel_handle = - self.channels - .attach_remote(discovery_key, ch as usize, msg.capability); - - if channel_handle.is_connected() { - let local_id = channel_handle.local_id().unwrap(); - self.accept_channel(local_id)?; - } else { - self.queue_event(Event::DiscoveryKey(discovery_key)); - } - - Ok(()) - } - - #[instrument(skip(self))] - fn queue_event(&mut self, event: Event) { - self.queued_events.push_back(event); - } - - /// enequeu a buf to be sent - fn queue_frame_direct(&mut self, body: Vec) -> Result { - let mut frame = Frame::RawBatch(vec![body]); - self.write_state - .try_encode_and_enqueue_frame_for_tx(&mut frame) - } - - #[instrument(skip(self))] - fn accept_channel(&mut self, local_id: usize) -> Result<()> { - let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; - self.verify_remote_capability(remote_capability.cloned(), key)?; - let channel = self.channels.accept(local_id, self.outbound_tx.clone())?; - self.queue_event(Event::Channel(channel)); - Ok(()) - } - - fn close_local(&mut self, local_id: u64) { - if let Some(channel) = self.channels.get_local(local_id as usize) { - let discovery_key = *channel.discovery_key(); - self.channels.remove(&discovery_key); - self.queue_event(Event::Close(discovery_key)); - } - } - - fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> { - if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) { - let discovery_key = *channel_handle.discovery_key(); - // There is a possibility both sides will close at the same time, so - // the channel could be closed already, let's tolerate that. - self.channels - .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?; - self.channels.remove(&discovery_key); - self.queue_event(Event::Close(discovery_key)); - } - Ok(()) - } - - #[instrument(skip_all)] - fn capability(&self, key: &[u8]) -> Option> { - match self.handshake.as_ref() { - Some(handshake) => handshake.capability(key), - None => None, - } - } - - fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { - match self.handshake.as_ref() { - Some(handshake) => handshake.verify_remote_capability(capability, key), - None => Err(Error::new( - ErrorKind::PermissionDenied, - "Missing handshake state for capability verification", - )), - } - } -} - -impl Stream for Protocol -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Protocol::poll_next(self, cx).map(Some) - } -} - -/// Send [Command](Command)s to the [Protocol](Protocol). -#[derive(Clone, Debug)] -pub struct CommandTx(Sender); - -impl CommandTx { - /// Send a protocol command - pub async fn send(&mut self, command: Command) -> Result<()> { - self.0.send(command).await.map_err(map_channel_err) - } - /// Open a protocol channel. - /// - /// The channel will be emitted on the main protocol. - pub async fn open(&mut self, key: Key) -> Result<()> { - self.send(Command::Open(key)).await - } - - /// Close a protocol channel. - pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { - self.send(Command::Close(discovery_key)).await - } - - /// Send a local signal event to the protocol. - pub async fn signal_local(&mut self, name: &str, data: Vec) -> Result<()> { - self.send(Command::SignalLocal((name.to_string(), data))) - .await - } -} - -fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> { - key.try_into() - .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long")) -} diff --git a/src/reader.rs b/src/reader.rs deleted file mode 100644 index cc80c5c..0000000 --- a/src/reader.rs +++ /dev/null @@ -1,246 +0,0 @@ -use crate::crypto::DecryptCipher; -use futures_lite::io::AsyncRead; -use futures_timer::Delay; -use std::future::Future; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE}; -use crate::message::{Frame, FrameType}; -use crate::util::stat_uint24_le; -use std::time::Duration; - -const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64); -const READ_BUF_INITIAL_SIZE: usize = 1024 * 128; - -#[derive(Debug)] -pub(crate) struct ReadState { - /// The read buffer. - buf: Vec, - /// The start of the not-yet-processed byte range in the read buffer. - start: usize, - /// The end of the not-yet-processed byte range in the read buffer. - end: usize, - /// The logical state of the reading (either header or body). - step: Step, - /// The timeout after which the connection is closed. - timeout: Delay, - /// Optional decryption cipher. - cipher: Option, - /// The frame type to be passed to the decoder. - frame_type: FrameType, -} - -impl ReadState { - pub(crate) fn new() -> ReadState { - ReadState { - buf: vec![0u8; READ_BUF_INITIAL_SIZE], - start: 0, - end: 0, - step: Step::Header, - timeout: Delay::new(TIMEOUT), - cipher: None, - frame_type: FrameType::Raw, - } - } -} - -#[derive(Debug)] -enum Step { - Header, - Body { - header_len: usize, - body_len: usize, - }, - /// Multiple messages one after another - Batch, -} - -impl ReadState { - pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) { - self.cipher = Some(decrypt_cipher); - } - - /// Decrypts a given buf with stored cipher, if present. Used to correct - /// the rare mistake that more than two messages came in where the first - /// one created the cipher, and the next one should have been decrypted - /// but wasn't. - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result> { - if let Some(cipher) = self.cipher.as_mut() { - Ok(cipher.decrypt_buf(buf)?.0) - } else { - Ok(buf.to_vec()) - } - } - - pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) { - self.frame_type = frame_type; - } - - pub(crate) fn poll_reader( - &mut self, - cx: &mut Context<'_>, - mut reader: &mut R, - ) -> Poll> - where - R: AsyncRead + Unpin, - { - let mut incomplete = true; - loop { - if !incomplete { - if let Some(result) = self.process() { - return Poll::Ready(result); - } - } else { - incomplete = false; - } - let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) { - Poll::Ready(Ok(n)) if n > 0 => n, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - // If the reader is pending, poll the timeout. - Poll::Pending | Poll::Ready(Ok(_)) => { - // Return Pending if the timeout is pending, or an error if the - // timeout expired (i.e. returned Poll::Ready). - return Pin::new(&mut self.timeout) - .poll(cx) - .map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out"))); - } - }; - - let end = self.end + n; - let (success, segments) = create_segments(&self.buf[self.start..end])?; - if success { - if let Some(ref mut cipher) = self.cipher { - let mut dec_end = self.start; - // What happens if decrypt fails here? - // next call to this func would have same start, corret? - // so it'd fail repeatedly? - // Why not just decrypt to the end? - for (index, header_len, body_len) in segments { - let de = cipher.decrypt( - &mut self.buf[self.start + index..end], - header_len, - body_len, - )?; - dec_end = self.start + index + de; - } - self.end = dec_end; - } else { - self.end = end; - } - } else { - // Could not segment due to buffer being full, need to cycle the buffer - // and possibly resize it too if the message is too big. - self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]); - - // Set incomplete flag to skip processing and instead poll more data - incomplete = true; - } - self.timeout.reset(TIMEOUT); - } - } - - /// Moves start of unprocessed data to the start of the buffer. And resize if necessary. - fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { - let (last_index, last_header_len, last_body_len) = last_segment; - let total_incoming_length = last_index + last_header_len + last_body_len; - - if self.buf.len() < total_incoming_length { - // The incoming segments will not fit into the buffer, need to resize it - self.buf.resize(total_incoming_length, 0u8); - } - - // to-read length - let temp = self.buf[self.start..].to_vec(); - let len = temp.len(); - self.buf[..len].copy_from_slice(&temp[..]); - self.end = len; - self.start = 0; - } - - fn process(&mut self) -> Option> { - loop { - match self.step { - Step::Header => { - let stat = stat_uint24_le(&self.buf[self.start..self.end]); - if let Some((header_len, body_len)) = stat { - if body_len == 0 { - // This is a keepalive message, just remain in Step::Header - self.start += header_len; - return None; - } else if (self.start + header_len + body_len as usize) < self.end { - // There are more than one message here, create a batch from all of - // then - self.step = Step::Batch; - } else { - let body_len = body_len as usize; - if body_len > MAX_MESSAGE_SIZE as usize { - return Some(Err(Error::new( - ErrorKind::InvalidData, - "Message length above max allowed size", - ))); - } - self.step = Step::Body { - header_len, - body_len, - }; - } - } else { - return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header"))); - } - } - - // one message within an encrypted frame - // encrypted frame [ u24 header + encoded_frame [ ]] - Step::Body { - header_len, - body_len, - } => { - let message_len = header_len + body_len; - let range = self.start + header_len..self.start + message_len; - // this includes a a frame header - let frame = Frame::decode(&self.buf[range], &self.frame_type); - self.start += message_len; - self.step = Step::Header; - return Some(frame); - } - // multiple message within an encrypted frame - Step::Batch => { - let frame = - Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); - self.start = self.end; - self.step = Step::Header; - return Some(frame); - } - } - } - } -} - -#[allow(clippy::type_complexity)] -/// Given a buff get all the segments (starting_index_in_buffer, header_len, buffer_len) -/// returns returns `(true, segments)` if we read all segments, but (false, ..) if there -/// are remaining segments -fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { - let mut index: usize = 0; - let len = buf.len(); - let mut segments: Vec<(usize, usize, usize)> = vec![]; - while index < len { - if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) { - let body_len = body_len as usize; - segments.push((index, header_len, body_len)); - if len < index + header_len + body_len { - // The segments will not fit, return false to indicate that more needs to be read - return Ok((false, segments)); - } - index += header_len + body_len; - } else { - return Err(Error::new( - ErrorKind::InvalidData, - "Could not read header while decrypting", - )); - } - } - Ok((true, segments)) -} diff --git a/src/writer.rs b/src/writer.rs deleted file mode 100644 index 9a1465b..0000000 --- a/src/writer.rs +++ /dev/null @@ -1,198 +0,0 @@ -use crate::crypto::EncryptCipher; -use crate::message::{Encoder, Frame}; -use tracing::instrument; - -use futures_lite::{ready, AsyncWrite}; -use std::collections::VecDeque; -use std::fmt; -use std::io::Result; -use std::pin::Pin; -use std::task::{Context, Poll}; - -const BUF_SIZE: usize = 1024 * 64; - -#[derive(Debug)] -pub(crate) enum Step { - Flushing, - Writing, - Processing, -} - -pub(crate) struct WriteState { - queue: VecDeque, - current_frame: Option, - cipher: Option, - buf: Vec, - written_up_to_idx: usize, - should_write_up_to_idx: usize, - step: Step, -} - -impl fmt::Debug for WriteState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WriteState") - .field("queue (len)", &self.queue.len()) - .field("current_frame", &self.current_frame) - .field("cipher", &self.cipher.is_some()) - .field("buf (len)", &self.buf.len()) - .field("start", &self.written_up_to_idx) - .field("end", &self.should_write_up_to_idx) - .field("step", &self.step) - .finish() - } -} - -impl WriteState { - pub(crate) fn new() -> Self { - Self { - queue: VecDeque::new(), - buf: vec![0u8; BUF_SIZE], - current_frame: None, - written_up_to_idx: 0, - should_write_up_to_idx: 0, - cipher: None, - step: Step::Processing, - } - } - - pub(crate) fn queue_frame(&mut self, frame: F) - where - F: Into, - { - self.queue.push_back(frame.into()) - } - - #[instrument(skip(self))] - pub(crate) fn try_encode_and_enqueue_frame_for_tx( - &mut self, - frame: &mut T, - ) -> Result { - let promised_len = frame.encoded_len()?; - let padded_promised_len = self.safe_encrypted_len(promised_len); - // this handles when a message would be longer than the entire buffer - if self.buf.len() < padded_promised_len { - self.buf.resize(padded_promised_len, 0u8); - } - - // check we have enough room - if padded_promised_len > self.remaining() { - return Ok(false); - } - - // write frame starting at end. fram is from end to end + actual_end - let actual_len = frame.encode(&mut self.buf[self.should_write_up_to_idx..])?; - if actual_len != promised_len { - panic!( - "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" - ); - } - // Instead of the above, write the buffer to a new vec `foo` of length `promised_length` - // encode frame.to this buff - // slice `foo[(header_len /* 3*/)..actual_len]` this is the fram data - // encrypt this in place - // replace header at start of foo - // write its len to self.buf and then write it to self.buf - // slice from - - self.encrypt_frame_contents_onto_buf(padded_promised_len)?; - Ok(true) - } - - pub(crate) fn can_park_frame(&self) -> bool { - self.current_frame.is_none() - } - - pub(crate) fn park_frame(&mut self, frame: F) - where - F: Into, - { - if self.current_frame.is_none() { - self.current_frame = Some(frame.into()) - } - } - - /// The frame should be written to `self.buf` before calling this. And - /// `self.should_write_up_to_idx` should mark the start of the message. - /// `max_message_size` is the maximum size the message could be when it is encrypted - /// We encrypt the message in-place on `self.buf`. - fn encrypt_frame_contents_onto_buf(&mut self, max_message_size: usize) -> Result<()> { - let end_of_message_index = self.should_write_up_to_idx + max_message_size; - - let encrypted_end = if let Some(ref mut cipher) = self.cipher { - self.should_write_up_to_idx - + cipher - .encrypt(&mut self.buf[self.should_write_up_to_idx..end_of_message_index])? - } else { - end_of_message_index - }; - - self.should_write_up_to_idx = encrypted_end; - Ok(()) - } - - pub(crate) fn upgrade_with_encrypt_cipher(&mut self, encrypt_cipher: EncryptCipher) { - self.cipher = Some(encrypt_cipher); - } - - fn remaining(&self) -> usize { - self.buf.len() - self.should_write_up_to_idx - } - - fn pending(&self) -> usize { - self.should_write_up_to_idx - self.written_up_to_idx - } - - pub(crate) fn poll_send( - &mut self, - cx: &mut Context<'_>, - mut writer: &mut W, - ) -> Poll> - where - W: AsyncWrite + Unpin, - { - loop { - self.step = match self.step { - Step::Processing => { - if self.current_frame.is_none() && !self.queue.is_empty() { - self.current_frame = self.queue.pop_front(); - } - - if let Some(mut frame) = self.current_frame.take() { - if !self.try_encode_and_enqueue_frame_for_tx(&mut frame)? { - self.current_frame = Some(frame); - } - } - - if self.pending() == 0 { - return Poll::Ready(Ok(())); - } - Step::Writing - } - Step::Writing => { - let n = ready!(Pin::new(&mut writer).poll_write( - cx, - &self.buf[self.written_up_to_idx..self.should_write_up_to_idx] - ))?; - self.written_up_to_idx += n; - if self.written_up_to_idx == self.should_write_up_to_idx { - self.written_up_to_idx = 0; - self.should_write_up_to_idx = 0; - } - Step::Flushing - } - Step::Flushing => { - ready!(Pin::new(&mut writer).poll_flush(cx))?; - Step::Processing - } - } - } - } - - fn safe_encrypted_len(&self, encoded_len: usize) -> usize { - if let Some(cipher) = &self.cipher { - cipher.safe_encrypted_len(encoded_len) - } else { - encoded_len - } - } -} From 6f1995de65a560003d04dbcebf96d5f3ed51e660 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:36:26 -0400 Subject: [PATCH 084/135] cargo clippy --fix --- benches/pipe.rs | 29 +++++++++++------------ benches/throughput.rs | 53 ++++++++++++++++++------------------------- src/message/modern.rs | 2 +- src/test_utils.rs | 2 +- 4 files changed, 37 insertions(+), 49 deletions(-) diff --git a/benches/pipe.rs b/benches/pipe.rs index 630146c..b726545 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -18,7 +18,7 @@ fn bench_throughput(c: &mut Criterion) { env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); let mut group = c.benchmark_group("pipe"); group.sample_size(10); - group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS as u64)); + group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS)); group.bench_function("pipe_echo", |b| { b.iter(|| { task::block_on(async move { @@ -72,7 +72,7 @@ where debug!("[{}] EVENT {:?}", is_initiator, event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await?; + protocol.open(key).await?; } Event::DiscoveryKey(_dkey) => {} Event::Channel(channel) => { @@ -92,7 +92,7 @@ where } Some(Err(err)) => { error!("ERROR {:?}", err); - return Err(err.into()); + return Err(err); } None => return Ok(0), } @@ -127,20 +127,17 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { let start = std::time::Instant::now(); while let Some(message) = channel.next().await { - match message { - Message::Data(mut data) => { - len += value_len(&data); - debug!("[a] recv {}", index(&data)); - if index(&data) >= COUNT { - debug!("close at {}", index(&data)); - channel.close().await?; - break; - } else { - increment_index(&mut data); - channel.send(Message::Data(data)).await?; - } + if let Message::Data(mut data) = message { + len += value_len(&data); + debug!("[a] recv {}", index(&data)); + if index(&data) >= COUNT { + debug!("close at {}", index(&data)); + channel.close().await?; + break; + } else { + increment_index(&mut data); + channel.send(Message::Data(data)).await?; } - _ => {} } } // let bytes = (COUNT * SIZE) as f64; diff --git a/benches/throughput.rs b/benches/throughput.rs index 7f9890d..6b9d6af 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -71,15 +71,12 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { // let kill_rx = &mut kill_rx; loop { match futures::future::select(incoming.next(), &mut kill_rx).await { - Either::Left((next, _)) => match next { - Some(Ok(stream)) => { - let peer_addr = stream.peer_addr().unwrap(); - debug!("new connection from {}", peer_addr); - task::spawn(async move { - onconnection(stream.clone(), stream, false).await; - }); - } - _ => {} + Either::Left((next, _)) => if let Some(Ok(stream)) = next { + let peer_addr = stream.peer_addr().unwrap(); + debug!("new connection from {}", peer_addr); + task::spawn(async move { + onconnection(stream.clone(), stream, false).await; + }); }, Either::Right((_, _)) => return, } @@ -101,7 +98,7 @@ where // eprintln!("RECV EVENT [{}] {:?}", protocol.is_initiator(), event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await.unwrap(); + protocol.open(key).await.unwrap(); } Event::DiscoveryKey(_) => {} Event::Channel(channel) => { @@ -126,10 +123,7 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { async fn channel_server(channel: &mut Channel) { while let Some(message) = channel.next().await { - match message { - Message::Data(_) => channel.send(message).await.unwrap(), - _ => {} - } + if let Message::Data(_) = message { channel.send(message).await.unwrap() } } } @@ -139,24 +133,21 @@ async fn channel_client(channel: &mut Channel) { let message = msg_data(0, data.clone()); channel.send(message).await.unwrap(); while let Some(message) = channel.next().await { - match message { - Message::Data(ref msg) => { - if index(msg) < COUNT { - let message = msg_data(index(msg) + 1, data.clone()); - channel.send(message).await.unwrap(); - } else { - let time = start.elapsed(); - let bytes = COUNT * SIZE; - trace!( - "client completed. {} blocks, {} bytes, {:?}", - index(msg), - bytes, - time - ); - break; - } + if let Message::Data(ref msg) = message { + if index(msg) < COUNT { + let message = msg_data(index(msg) + 1, data.clone()); + channel.send(message).await.unwrap(); + } else { + let time = start.elapsed(); + let bytes = COUNT * SIZE; + trace!( + "client completed. {} blocks, {} bytes, {:?}", + index(msg), + bytes, + time + ); + break; } - _ => {} } } } diff --git a/src/message/modern.rs b/src/message/modern.rs index 23524b2..1b68e24 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -565,7 +565,7 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - let body_len = prencode_channel_messages(&vec)?; + let body_len = prencode_channel_messages(vec)?; let mut buffer = checked_write_uint24_le(body_len, buffer)?; match vec { [] => Ok(buffer), diff --git a/src/test_utils.rs b/src/test_utils.rs index 3f687ea..5309529 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -185,7 +185,7 @@ impl Moo { fn result_channel() -> (Sender>, impl Stream>>) { let (tx, rx) = unbounded::>(); - (tx, rx.map(|x| Ok(x))) + (tx, rx.map(Ok)) } pub(crate) fn create_result_connected() -> ( From ca329b4a9c59a1bdd9469f9d47f5007c27d325c4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:45:14 -0400 Subject: [PATCH 085/135] Remove protocol feature --- Cargo.toml | 3 +- src/constants.rs | 9 ---- src/crypto/cipher.rs | 97 ----------------------------------------- src/crypto/handshake.rs | 9 ---- src/crypto/mod.rs | 4 -- src/lib.rs | 6 --- src/message/mod.rs | 9 ---- src/protocol/mod.rs | 10 ----- 8 files changed, 1 insertion(+), 146 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4ceb1f3..0862678 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,10 +66,9 @@ tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } [features] -default = ["tokio", "sparse", "protocol"] +default = ["tokio", "sparse"] #default = ["tokio", "sparse"] uint24 = [] -protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] diff --git a/src/constants.rs b/src/constants.rs index 73d0748..1efbbed 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -6,12 +6,3 @@ pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; /// v10: Protocol name pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; - -// 16,78MB is the max encrypted wire message size (will be much smaller usually). -// This limitation stems from the 24bit header. -#[cfg(not(feature = "protocol"))] -pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; - -/// Default timeout (in seconds) -#[cfg(not(feature = "protocol"))] -pub(crate) const DEFAULT_TIMEOUT: u32 = 20; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 53c291f..aa096c3 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -72,103 +72,6 @@ impl DecryptCipher { } } -#[cfg(not(feature = "protocol"))] -mod encrypt_cipher { - use super::*; - use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; - const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; - - pub(crate) struct EncryptCipher { - push_stream: PushStream, - } - - impl std::fmt::Debug for EncryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") - } - } - - impl EncryptCipher { - pub(crate) fn from_handshake_tx( - handshake_result: &HandshakeResult, - ) -> std::io::Result<(Self, Vec)> { - let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] - .try_into() - .expect("split_tx with incorrect length"); - let key = Key::from(key); - - let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; - write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); - write_stream_id( - &handshake_result.handshake_hash, - handshake_result.is_initiator, - &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], - ); - - let (header, push_stream) = PushStream::init(OsRng, &key); - let header = header.as_ref(); - header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); - let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) - } - - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 - } - - /// Encrypts message in the given buffer to the same buffer, returns number of bytes - /// of total message. - /// NB: we expect the first 3 bytes of the buffer to a size header. - /// The encrypted buffer will also be written prepended with a size header, with it's new size. - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { - let stat = stat_uint24_le(buf); - if let Some((header_len, body_len)) = stat { - let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); - self.push_stream - .push(&mut to_encrypt, &[], Tag::Message) - .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) - })?; - let encrypted_len = to_encrypt.len(); - write_uint24_le(encrypted_len, buf); - buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(header_len + encrypted_len) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Could not encrypt invalid data, len: {}", buf.len()), - )) - } - } - } - - impl DecryptCipher { - pub(crate) fn decrypt( - &mut self, - buf: &mut [u8], - header_len: usize, - body_len: usize, - ) -> io::Result { - let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; - let decrypted_len = to_decrypt.len(); - write_uint24_le(decrypted_len, buf); - let decrypted_end = header_len + to_decrypt.len(); - buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); - // Set extra bytes in the buffer to 0 - // Why? - let encrypted_end = header_len + body_len; - buf[decrypted_end..encrypted_end].fill(0x00); - Ok(decrypted_end) - } - } -} -#[cfg(not(feature = "protocol"))] -pub(crate) use encrypt_cipher::*; - // NB: These values come from Javascript-side // // const [NS_INITIATOR, NS_RESPONDER] = crypto.namespace('hyperswarm/secret-stream', 2) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 72c9da3..492a9a4 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -112,10 +112,6 @@ impl Handshake { Ok(None) } } - #[cfg(not(feature = "protocol"))] - pub(crate) fn start(&mut self) -> Result>> { - Ok(self.start_raw()?.map(|x| crate::util::wrap_uint24_le(&x))) - } pub(crate) fn complete(&self) -> bool { self.complete @@ -178,11 +174,6 @@ impl Handshake { self.complete = true; Ok(tx_buf) } - // reads in `msg` without framing bytes, but emits msg WITH framing bytes - #[cfg(not(feature = "protocol"))] - pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { - Ok(self.read_raw(msg)?.map(|x| crate::util::wrap_uint24_le(&x))) - } pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { if !self.complete() { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 3de592a..9e49c0a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,9 +1,5 @@ mod cipher; mod curve; mod handshake; -#[cfg(not(feature = "protocol"))] -pub(crate) use cipher::{DecryptCipher, EncryptCipher, RawEncryptCipher}; - -#[cfg(feature = "protocol")] pub(crate) use cipher::{DecryptCipher, RawEncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/lib.rs b/src/lib.rs index c13ccae..3602517 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,17 +124,12 @@ mod crypto; mod duplex; mod framing; mod message; -#[cfg(feature = "protocol")] mod mqueue; mod noise; mod protocol; -#[cfg(not(feature = "protocol"))] -mod reader; #[cfg(test)] mod test_utils; mod util; -#[cfg(not(feature = "protocol"))] -mod writer; /// The wire messages used by the protocol. pub mod schema; @@ -144,7 +139,6 @@ pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; pub use noise::{encrypted_framed_message_channel, Encrypted, Event as NoiseEvent}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() -#[cfg(feature = "protocol")] pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, }; diff --git a/src/message/mod.rs b/src/message/mod.rs index 1526f3a..dc42d7a 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,11 +1,2 @@ -#[cfg(feature = "protocol")] mod modern; - -#[cfg(feature = "protocol")] pub use modern::*; - -#[cfg(not(feature = "protocol"))] -mod old; - -#[cfg(not(feature = "protocol"))] -pub use old::*; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 7382df8..d24738c 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,13 +1,3 @@ -#[cfg(feature = "protocol")] mod modern; -#[cfg(feature = "protocol")] pub(crate) use modern::Options; -#[cfg(feature = "protocol")] pub use modern::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; - -#[cfg(not(feature = "protocol"))] -mod old; -#[cfg(not(feature = "protocol"))] -pub(crate) use old::Options; -#[cfg(not(feature = "protocol"))] -pub use old::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; From 3548f43871b418506bef6a576e57400e7fda0dac Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:45:36 -0400 Subject: [PATCH 086/135] cargo fmt --- benches/throughput.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index 6b9d6af..cc2c278 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -71,13 +71,15 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { // let kill_rx = &mut kill_rx; loop { match futures::future::select(incoming.next(), &mut kill_rx).await { - Either::Left((next, _)) => if let Some(Ok(stream)) = next { - let peer_addr = stream.peer_addr().unwrap(); - debug!("new connection from {}", peer_addr); - task::spawn(async move { - onconnection(stream.clone(), stream, false).await; - }); - }, + Either::Left((next, _)) => { + if let Some(Ok(stream)) = next { + let peer_addr = stream.peer_addr().unwrap(); + debug!("new connection from {}", peer_addr); + task::spawn(async move { + onconnection(stream.clone(), stream, false).await; + }); + } + } Either::Right((_, _)) => return, } } @@ -123,7 +125,9 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { async fn channel_server(channel: &mut Channel) { while let Some(message) = channel.next().await { - if let Message::Data(_) = message { channel.send(message).await.unwrap() } + if let Message::Data(_) = message { + channel.send(message).await.unwrap() + } } } From 72d1d9e0f3e4deff1bb88cf629d86005b9bd0c56 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:50:03 -0400 Subject: [PATCH 087/135] clippy fixes --- src/crypto/handshake.rs | 2 +- src/noise.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 492a9a4..10c111f 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -175,7 +175,7 @@ impl Handshake { Ok(tx_buf) } - pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { + pub(crate) fn get_result(&self) -> Result<&HandshakeResult> { if !self.complete() { Err(Error::new(ErrorKind::Other, "Handshake is not complete")) } else { diff --git a/src/noise.rs b/src/noise.rs index 40ce6ac..5b5b867 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -267,6 +267,7 @@ impl>> + Sink> + Send + Unpin + 'static /// Handle all message throughput. Sends, encrypts and decrypts messages /// Returns `true` `step` is already [`Step::Established`]. +#[allow(clippy::too_many_arguments)] fn poll_message_throughput< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -554,7 +555,7 @@ fn handle_setup_message( if handshake.complete() { debug!(initiator = %is_initiator, "Handshake completed"); - let handshake_result = match handshake.into_result() { + let handshake_result = match handshake.get_result() { Ok(x) => x, Err(e) => { error!("into-result error {e:?}"); From 1e3edd127c8d103d56b1a815856b5ee1f99ec0a0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:53:02 -0400 Subject: [PATCH 088/135] rm nested messaged & protocol modules --- src/{message/modern.rs => message.rs} | 0 src/message/mod.rs | 2 -- src/{protocol/modern.rs => protocol.rs} | 0 src/protocol/mod.rs | 3 --- 4 files changed, 5 deletions(-) rename src/{message/modern.rs => message.rs} (100%) delete mode 100644 src/message/mod.rs rename src/{protocol/modern.rs => protocol.rs} (100%) delete mode 100644 src/protocol/mod.rs diff --git a/src/message/modern.rs b/src/message.rs similarity index 100% rename from src/message/modern.rs rename to src/message.rs diff --git a/src/message/mod.rs b/src/message/mod.rs deleted file mode 100644 index dc42d7a..0000000 --- a/src/message/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod modern; -pub use modern::*; diff --git a/src/protocol/modern.rs b/src/protocol.rs similarity index 100% rename from src/protocol/modern.rs rename to src/protocol.rs diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs deleted file mode 100644 index d24738c..0000000 --- a/src/protocol/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod modern; -pub(crate) use modern::Options; -pub use modern::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; From 323503633604f6d03a7229808b321fcdb840c910 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 12:08:45 -0400 Subject: [PATCH 089/135] remove encoder trait --- src/message.rs | 63 -------------------------------------------------- src/mqueue.rs | 11 +++++---- 2 files changed, 6 insertions(+), 68 deletions(-) diff --git a/src/message.rs b/src/message.rs index 1b68e24..f222a1e 100644 --- a/src/message.rs +++ b/src/message.rs @@ -15,20 +15,6 @@ const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; -/// Encode data into a buffer. -/// -/// This trait is implemented on data frames and their components -/// (channel messages, messages, and individual message types through prost). -pub(crate) trait Encoder: Sized + fmt::Debug { - /// Calculates the length that the encoded message needs. - fn encoded_len(&self) -> Result; - - /// Encodes the message to a buffer. - /// - /// An error will be returned if the buffer does not have sufficient capacity. - fn encoder_encode(&self, buf: &mut [u8]) -> Result; -} - pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -174,55 +160,6 @@ fn decode_u24(buffer: &[u8]) -> Result<(usize, &[u8]), EncodingError> { Ok((out as usize, rest)) } -impl Encoder for Vec { - fn encoded_len(&self) -> Result { - Ok(prencode_channel_messages(self)? + UINT24_HEADER_LEN) - } - - #[instrument(skip_all)] - fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let body_len = prencode_channel_messages(self)?; - let mut buf = checked_write_uint24_le(body_len, buf)?; - // skip the u24 we just wrote - match self.len().cmp(&1) { - std::cmp::Ordering::Less => {} - std::cmp::Ordering::Equal => { - trace!("Encoding single ChannelMessage {}", self[0]); - if let Message::Open(_) = &self[0].message { - // This is a special case with 0x00, 0x01 intro bytes - buf = write_array(&[0, 1], buf)?; - self[0].encode(buf)?; - } else if let Message::Close(_) = &self[0].message { - // This is a special case with 0x00, 0x03 intro bytes - buf = write_array(&[0, 3], buf)?; - self[0].encode(buf)?; - } else { - self[0].encode(buf)?; - } - } - std::cmp::Ordering::Greater => { - // Two intro bytes 0x00 0x00, then channel id, then lengths - buf = write_array(&[0, 0], buf)?; - let mut current_channel: u64 = self[0].channel; - buf = current_channel.encode(buf)?; - for message in self.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - buf = write_array(&[0], buf)?; - buf = message.channel.encode(buf)?; - current_channel = message.channel; - } - let message_length = message.message.encoded_size()?; - buf = (message_length as u32).encode(buf)?; - buf = message.encode(buf)?; - } - } - } - Ok(UINT24_HEADER_LEN + body_len) - } -} - /// A protocol message. #[derive(Debug, Clone, PartialEq)] #[allow(missing_docs)] diff --git a/src/mqueue.rs b/src/mqueue.rs index 9be4ab7..cd2237b 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -9,10 +9,11 @@ use std::{ }; use futures::{Sink, Stream}; +use hypercore::encoding::CompactEncoding as _; use tracing::{error, instrument}; use crate::{ - message::{decode_framed_channel_messages, ChannelMessage, Encoder as _}, + message::{decode_framed_channel_messages, ChannelMessage}, noise::EncryptionInfo, NoiseEvent, }; @@ -81,16 +82,16 @@ impl + Sink> + Send + Unpin + 'static> Mes messages.push(msg); } - let mut buf = vec![0; messages.encoded_len()?]; - match messages.encoder_encode(&mut buf) { - Ok(_) => {} + let buf = match messages.to_encoded_bytes() { + Ok(x) => x, Err(e) => { error!(error = ?e, "error encoding messages"); // TODO this would probably be a programming error. // if so, this sholud just be an unwrap/expect return Poll::Ready(Err(e.into())); } - } + }; + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { todo!() } From b9e5b490ebece2165c3a28266d1f92899aa3da5c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 12:34:11 -0400 Subject: [PATCH 090/135] clippy fixes --- src/framing.rs | 2 +- src/test_utils.rs | 1 + tests/js_interop.rs | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 7daef38..12d3c41 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -16,7 +16,7 @@ use crate::util::{stat_uint24_le, wrap_uint24_le}; const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; -/// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream +/// take a `AsyncWrite` of length prefixed messages and emit them as a Stream pub struct Uint24LELengthPrefixedFraming { io: IO, /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. diff --git a/src/test_utils.rs b/src/test_utils.rs index 5309529..24256cb 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -188,6 +188,7 @@ fn result_channel() -> (Sender>, impl Stream>> (tx, rx.map(Ok)) } +#[allow(clippy::type_complexity)] pub(crate) fn create_result_connected() -> ( Moo>>, impl Sink>>, Moo>>, impl Sink>>, diff --git a/tests/js_interop.rs b/tests/js_interop.rs index d703734..41bad94 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -767,7 +767,10 @@ async fn on_replication_message( for i in 0..new_info.contiguous_length { let value = String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap(); let line = format!("{} {}\n", i, value); - writer.write(line.as_bytes()).await?; + let n_written = writer.write(line.as_bytes()).await?; + if line.len() != n_written { + panic!("Couldn't write all write all bytse"); + } } writer.flush().await?; true From 2f6f694cc029c10dcbab1a621bbd8fa28d5cb384 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 2 May 2025 18:41:34 -0400 Subject: [PATCH 091/135] add compact_encoding dependency --- Cargo.toml | 3 +++ src/message.rs | 2 +- src/mqueue.rs | 4 ++-- src/schema.rs | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0862678..fe7ccda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,9 @@ path = "../core" #version = "0.14.0" #default-features = false +[dependencies.compact-encoding] +path = "../compact-encoding" + [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes", "unstable"] } diff --git a/src/message.rs b/src/message.rs index f222a1e..aa6ca24 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,6 +1,6 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; -use hypercore::encoding::{ +use compact_encoding::{ decode_usize, take_array, take_array_mut, write_array, CompactEncoding, EncodingError, EncodingErrorKind, VecEncodable, }; diff --git a/src/mqueue.rs b/src/mqueue.rs index cd2237b..e5df5b8 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -8,8 +8,8 @@ use std::{ task::{Context, Poll}, }; +use compact_encoding::CompactEncoding as _; use futures::{Sink, Stream}; -use hypercore::encoding::CompactEncoding as _; use tracing::{error, instrument}; use crate::{ @@ -92,7 +92,7 @@ impl + Sink> + Send + Unpin + 'static> Mes } }; - if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf.to_vec()) { todo!() } diff --git a/src/schema.rs b/src/schema.rs index c58a40b..7d6fb58 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,4 +1,4 @@ -use hypercore::encoding::{ +use compact_encoding::{ map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, CompactEncoding, EncodingError, }; From f1188c787cae57f6f879f28068618606180afedd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 5 May 2025 13:15:16 -0400 Subject: [PATCH 092/135] remove decode macro this came from hypercore but has been removed --- src/schema.rs | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 7d6fb58..049a590 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,10 +1,9 @@ use compact_encoding::{ - map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, + map_decode, map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, CompactEncoding, EncodingError, }; use hypercore::{ - decode, DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, - RequestUpgrade, + DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, }; use tracing::instrument; @@ -50,9 +49,8 @@ impl CompactEncoding for Open { where Self: Sized, { - let (channel, rest) = u64::decode(buffer)?; - let (protocol, rest) = String::decode(rest)?; - let (discovery_key, rest) = >::decode(rest)?; + let ((channel, protocol, discovery_key), rest) = + map_decode!(buffer, [u64, String, Vec]); // TODO this is a CLEAR bug it assumes nothing is encoded after this message let (capability, rest) = if !rest.is_empty() { let (_, rest) = take_array::<1>(rest)?; @@ -62,7 +60,7 @@ impl CompactEncoding for Open { (None, rest) }; Ok(( - Open { + Self { channel, protocol, discovery_key, @@ -93,7 +91,8 @@ impl CompactEncoding for Close { where Self: Sized, { - decode!(Close, buffer, {channel: u64}) + let (channel, rest) = u64::decode(buffer)?; + Ok((Self { channel }, rest)) } } @@ -138,10 +137,7 @@ impl CompactEncoding for Synchronize { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; - dbg!(flags); - let (fork, rest) = u64::decode(rest)?; - let (length, rest) = u64::decode(rest)?; - let (remote_length, rest) = u64::decode(rest)?; + let ((fork, length, remote_length), rest) = map_decode!(rest, [u64, u64, u64]); let can_upgrade = flags & 1 != 0; let uploading = flags & 2 != 0; let downloading = flags & 4 != 0; @@ -234,8 +230,7 @@ impl CompactEncoding for Request { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; - let (id, rest) = u64::decode(rest)?; - let (fork, rest) = u64::decode(rest)?; + let ((id, fork), rest) = map_decode!(rest, [u64, u64]); let (block, rest) = maybe_decode!(flags & 1 != 0, RequestBlock, rest); let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); @@ -345,8 +340,7 @@ impl CompactEncoding for Data { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; - let (request, rest) = u64::decode(rest)?; - let (fork, rest) = u64::decode(rest)?; + let ((request, fork), rest) = map_decode!(rest, [u64, u64]); let (block, rest) = maybe_decode!(flags & 1 != 0, DataBlock, rest); let (hash, rest) = maybe_decode!(flags & 2 != 0, DataHash, rest); let (seek, rest) = maybe_decode!(flags & 4 != 0, DataSeek, rest); @@ -398,7 +392,8 @@ impl CompactEncoding for NoData { where Self: Sized, { - decode!(NoData, buffer, { request: u64 }) + let (request, rest) = u64::decode(buffer)?; + Ok((Self { request }, rest)) } } @@ -424,7 +419,8 @@ impl CompactEncoding for Want { where Self: Sized, { - decode!(Self, buffer, { start: u64, length: u64 }) + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -450,7 +446,8 @@ impl CompactEncoding for Unwant { where Self: Sized, { - decode!(Self, buffer, { start: u64, length: u64 }) + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -475,7 +472,8 @@ impl CompactEncoding for Bitfield { where Self: Sized, { - decode!(Self, buffer, { start: u64, bitfield: Vec }) + let ((start, bitfield), rest) = map_decode!(buffer, [u64, Vec]); + Ok((Self { start, bitfield }, rest)) } } @@ -556,6 +554,7 @@ impl CompactEncoding for Extension { where Self: Sized, { - decode!(Self, buffer, { name: String, message: Vec }) + let ((name, message), rest) = map_decode!(buffer, [String, Vec]); + Ok((Self { name, message }, rest)) } } From c15aff9595fcf6a5f37367f37905fa7d0e93dce0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:34:53 -0400 Subject: [PATCH 093/135] update compact-encoding version --- Cargo.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fe7ccda..3c31ea0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,16 +40,13 @@ sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" futures = "0.3.31" +compact-encoding = "2" [dependencies.hypercore] path = "../core" #version = "0.14.0" #default-features = false -[dependencies.compact-encoding] -path = "../compact-encoding" - - [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes", "unstable"] } async-compat = "0.2.1" From 6a44f819b8ff64503e348db8f716226cc84975bd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:35:03 -0400 Subject: [PATCH 094/135] Remove unused features --- Cargo.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3c31ea0..d53f5f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,8 +67,6 @@ tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] -#default = ["tokio", "sparse"] -uint24 = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] From e7188075128efd2aae3d24f68a0c102a33b2208d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:35:55 -0400 Subject: [PATCH 095/135] remove unused uint24 feature --- src/util.rs | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/util.rs b/src/util.rs index 579a0fd..7e70336 100644 --- a/src/util.rs +++ b/src/util.rs @@ -29,33 +29,6 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { } pub(crate) const UINT_24_LENGTH: usize = 3; -#[cfg(feature = "uint24")] -mod uint24 { - use super::UINT_24_LENGTH; - pub struct Uint24LE([u8; UINT_24_LENGTH]); - impl Uint24LE { - pub const MAX_USIZE: usize = 16777215; - pub const SIZE: usize = UINT_24_LENGTH; - } - - impl AsRef<[u8; 3]> for Uint24LE { - fn as_ref(&self) -> &[u8; 3] { - &self.0 - } - } - - // TODO we are using std::io::Error everywhere so I won't add a new one but this isn't ideal - impl TryFrom for Uint24LE { - type Error = Error; - - fn try_from(n: usize) -> Result { - if n > Self::MAX_USIZE { - todo!() - } - Ok(Self([(n & 255) as u8, (n >> 8) as u8, (n >> 16) as u8])) - } - } -} #[inline] pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; From ec47b843ac69fb60e0eba54d21d9f4bf98ce9e28 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:49:34 -0400 Subject: [PATCH 096/135] remove use of test_log just use our own logger as needed --- tests/js_interop.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 41bad94..28f40ba 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -24,14 +24,13 @@ use async_std::{ task::{self, sleep}, test as async_test, }; -use test_log::test; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, net::{TcpListener, TcpStream}, sync::Mutex, - task, test as async_test, + task, time::sleep, }; @@ -59,28 +58,28 @@ const TEST_SET_SERVER_WRITER: &str = "sw"; const TEST_SET_CLIENT_WRITER: &str = "cw"; const TEST_SET_SIMPLE: &str = "simple"; -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncns_simple_server_writer() -> Result<()> { js_interop_ncns_simple(true, 8101).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncns_simple_client_writer() -> Result<()> { js_interop_ncns_simple(false, 8102).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_rcns_simple_server_writer() -> Result<()> { js_interop_rcns_simple(true, 8103).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically async fn js_interop_rcns_simple_client_writer() -> Result<()> { @@ -88,28 +87,29 @@ async fn js_interop_rcns_simple_client_writer() -> Result<()> { Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncrs_simple_server_writer() -> Result<()> { js_interop_ncrs_simple(true, 8105).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncrs_simple_client_writer() -> Result<()> { js_interop_ncrs_simple(false, 8106).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_rcrs_simple_server_writer() -> Result<()> { + _util::log(); js_interop_rcrs_simple(true, 8107).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically async fn js_interop_rcrs_simple_client_writer() -> Result<()> { From 3357e6a9175a6ae3192c283e3e5b82418c74b8f0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:50:38 -0400 Subject: [PATCH 097/135] remove test-log dep --- Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d53f5f6..a651734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,6 @@ duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" log = "0.4" -test-log = { version = "0.2.11", default-features = false, features = ["trace"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } From 778a192066c1ed272497db35e825711f0a24e815 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 02:07:38 -0400 Subject: [PATCH 098/135] Add instrument to some funcs --- src/message.rs | 2 ++ src/noise.rs | 4 ++-- src/protocol.rs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/message.rs b/src/message.rs index aa6ca24..80208ff 100644 --- a/src/message.rs +++ b/src/message.rs @@ -54,6 +54,7 @@ pub(crate) fn decode_framed_channel_messages( } Ok((combined_messages, index)) } +#[instrument(skip_all err)] pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -448,6 +449,7 @@ impl ChannelMessage { /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it + #[instrument(err, skip(buf))] pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( diff --git a/src/noise.rs b/src/noise.rs index 5b5b867..9f35393 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -228,7 +228,7 @@ impl>> + Sink> + Send + Unpin + 'static { type Item = Event; - #[instrument(skip(cx), fields(initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -502,7 +502,7 @@ fn reset_encrypted( } /// handle setup messages: if any are incorrect (cause an error) the state is reset -#[instrument(skip_all, fields(initiator = %is_initiator))] +#[instrument(err, skip_all, fields(initiator = %is_initiator))] fn handle_setup_message( step: &mut Step, msg: &[u8], diff --git a/src/protocol.rs b/src/protocol.rs index 42bd5d6..955ef7c 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -298,7 +298,7 @@ where } /// Poll for inbound messages and processs them. - #[instrument(skip_all)] + #[instrument(skip_all, err)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { match self.io.poll_inbound(cx) { From b5ed63ef9d917aad24f3476a4db52bb36ced77f8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 15:38:52 -0400 Subject: [PATCH 099/135] remove redundant 'simple' from every func name --- tests/js_interop.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 28f40ba..ebd03ac 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -60,50 +60,50 @@ const TEST_SET_SIMPLE: &str = "simple"; #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_server_writer() -> Result<()> { +async fn js_interop_ncns_server_writer() -> Result<()> { js_interop_ncns_simple(true, 8101).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_client_writer() -> Result<()> { +async fn js_interop_ncns_client_writer() -> Result<()> { js_interop_ncns_simple(false, 8102).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcns_simple_server_writer() -> Result<()> { - js_interop_rcns_simple(true, 8103).await?; +async fn js_interop_rcns_server_writer() -> Result<()> { + js_interop_rcns(true, 8103).await?; Ok(()) } #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcns_simple_client_writer() -> Result<()> { - js_interop_rcns_simple(false, 8104).await?; +async fn js_interop_rcns_client_writer() -> Result<()> { + js_interop_rcns(false, 8104).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_server_writer() -> Result<()> { +async fn js_interop_ncrs_server_writer() -> Result<()> { js_interop_ncrs_simple(true, 8105).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_client_writer() -> Result<()> { +async fn js_interop_ncrs_client_writer() -> Result<()> { js_interop_ncrs_simple(false, 8106).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcrs_simple_server_writer() -> Result<()> { +async fn js_interop_rcrs_server_writer() -> Result<()> { _util::log(); js_interop_rcrs_simple(true, 8107).await?; Ok(()) @@ -112,7 +112,7 @@ async fn js_interop_rcrs_simple_server_writer() -> Result<()> { #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcrs_simple_client_writer() -> Result<()> { +async fn js_interop_rcrs_client_writer() -> Result<()> { js_interop_rcrs_simple(false, 8108).await?; Ok(()) } @@ -156,7 +156,7 @@ async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn js_interop_rcns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", From 4c0402fe7fe78433eeec80e3fcb5c76cd0fa3b39 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 7 May 2025 11:21:26 -0400 Subject: [PATCH 100/135] rm async_std test wrappers --- tests/js_interop.rs | 76 --------------------------------------------- 1 file changed, 76 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index ebd03ac..41d8160 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -15,15 +15,6 @@ use std::sync::Once; #[cfg(feature = "tokio")] use async_compat::CompatExt; -#[cfg(feature = "async-std")] -use async_std::{ - fs::{metadata, File}, - io::{prelude::BufReadExt, BufReader, BufWriter, WriteExt}, - net::{TcpListener, TcpStream}, - sync::Mutex, - task::{self, sleep}, - test as async_test, -}; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, @@ -441,40 +432,6 @@ pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { PartialKeypair { public, secret } } -#[cfg(feature = "async-std")] -async fn on_replication_connection( - stream: TcpStream, - is_initiator: bool, - hypercore: Arc, -) -> Result<()> { - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); - while let Some(event) = protocol.next().await { - let event = event?; - match event { - Event::Handshake(_) => { - if is_initiator { - protocol.open(*hypercore.key()).await?; - } - } - Event::DiscoveryKey(dkey) => { - if hypercore.discovery_key == dkey { - protocol.open(*hypercore.key()).await?; - } else { - panic!("Invalid discovery key"); - } - } - Event::Channel(channel) => { - hypercore.on_replication_peer(channel); - } - Event::Close(_dkey) => { - break; - } - _ => {} - } - } - Ok(()) -} - #[cfg(feature = "tokio")] async fn on_replication_connection( stream: TcpStream, @@ -850,39 +807,6 @@ impl RustServer { } } -impl Drop for RustServer { - fn drop(&mut self) { - #[cfg(feature = "async-std")] - if let Some(handle) = self.handle.take() { - task::block_on(handle.cancel()); - } - } -} - -#[cfg(feature = "async-std")] -pub async fn tcp_server( - port: u32, - onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, - context: C, -) -> Result<()> -where - F: Future> + Send, - C: Clone + Send + 'static, -{ - let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; - let mut incoming = listener.incoming(); - while let Some(Ok(stream)) = incoming.next().await { - let context = context.clone(); - let _peer_addr = stream.peer_addr().unwrap(); - task::spawn(async move { - onconnection(stream, false, context) - .await - .expect("Should return ok"); - }); - } - Ok(()) -} - #[cfg(feature = "tokio")] pub async fn tcp_server( port: u32, From 917b8374d9612fbcafedef73c284812ad0a00d44 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 8 May 2025 13:47:46 -0400 Subject: [PATCH 101/135] instrument and rename vec_encoded_size for cm --- src/message.rs | 51 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/message.rs b/src/message.rs index 80208ff..ca09465 100644 --- a/src/message.rs +++ b/src/message.rs @@ -7,7 +7,7 @@ use compact_encoding::{ use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{instrument, trace, warn}; +use tracing::{debug, instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; @@ -41,9 +41,7 @@ pub(crate) fn decode_framed_channel_messages( body_len, length ); } - for message in msgs { - combined_messages.push(message); - } + combined_messages.extend(msgs); index += header_len + body_len as usize; } else { return Err(io::Error::new( @@ -122,7 +120,7 @@ pub(crate) fn decode_unframed_channel_messages( } } -fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result { +fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result { Ok(match messages { [] => 0, [msg] => match msg.message { @@ -454,7 +452,7 @@ impl ChannelMessage { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, - "received empty message", + format!("received empty message [{buf:?}]"), )); } let (message, buf) = ::decode(buf)?; @@ -493,13 +491,15 @@ impl CompactEncoding for ChannelMessage { } impl VecEncodable for ChannelMessage { + #[instrument(skip_all, ret)] fn vec_encoded_size(vec: &[Self]) -> Result where Self: Sized, { - Ok(prencode_channel_messages(vec)? + UINT24_HEADER_LEN) + Ok(vec_channel_messages_encoded_size(vec)? + UINT24_HEADER_LEN) } + #[instrument(skip_all)] fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> where Self: Sized, @@ -568,6 +568,8 @@ impl VecEncodable for ChannelMessage { #[cfg(test)] mod tests { + use crate::test_utils::log; + use super::*; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, @@ -673,4 +675,39 @@ mod tests { }; Ok(()) } + + #[test] + fn extras() -> Result<(), EncodingError> { + let one = Message::Synchronize(Synchronize { + fork: 0, + length: 4, + remote_length: 0, + downloading: true, + uploading: true, + can_upgrade: true, + }); + let two = Message::Range(Range { + drop: false, + start: 0, + length: 4, + }); + let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; + let buff = msgs.to_encoded_bytes()?; + + let res = as CompactEncoding>::decode(&buff); + assert!(res.is_err()); + log(); + + let buff = msgs.to_encoded_bytes()?; + let (res2, _size) = decode_framed_channel_messages(&buff).unwrap(); + assert_eq!(res2, msgs); + + // from js interop tests + // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // [23, 0, 0, 0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + //assert!(res2.is_ok()); + Ok(()) + } } From d7bd06d29bad8aa396addbb4aa3be9419be96017 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 8 May 2025 16:02:33 -0400 Subject: [PATCH 102/135] fix Vec encoding --- src/message.rs | 62 ++++++++++++++++++++++++++++++++++---------------- src/schema.rs | 1 - 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/message.rs b/src/message.rs index ca09465..ff2ea2b 100644 --- a/src/message.rs +++ b/src/message.rs @@ -15,6 +15,7 @@ const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; +#[instrument(skip_all)] pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -75,14 +76,22 @@ pub(crate) fn decode_unframed_channel_messages( return Err(io::Error::new( io::ErrorKind::InvalidData, format!( - "received invalid message length: [{channel_message_length}] but we have [{}] remaining bytes. Initial buffer size [{og_len}]", + "received invalid message length: [{channel_message_length}] +\tbut we have [{}] remaining bytes. +\tInitial buffer size [{og_len}]", buf.len() ), )); } // Then the actual message let channel_message; - (channel_message, buf) = ChannelMessage::decode(buf, current_channel)?; + let bl = buf.len(); + (channel_message, buf) = ChannelMessage::decode_with_channel(buf, current_channel)?; + trace!( + "Decoded ChannelMessage::{:?} using [{} bytes]", + channel_message.message, + bl - buf.len() + ); messages.push(channel_message); // After that, if there is an extra 0x00, that means the channel // changed. This works because of LE encoding, and channels starting @@ -128,25 +137,25 @@ fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result msg.encoded_size()?, }, msgs => { - let mut out = 2; + let mut out = MULTI_MESSAGE_PREFIX.len(); let mut current_channel: u64 = messages[0].channel; out += current_channel.encoded_size()?; for message in msgs.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel - out += 1 + message.channel.encoded_size()?; + out += CHANNEL_CHANGE_SEPERATOR.len() + message.channel.encoded_size()?; current_channel = message.channel; } let message_length = message.message.encoded_size()?; - out += message.encoded_size()? + message_length; + out += message_length + (message_length as u64).encoded_size()?; } out } }) } -fn checked_write_uint24_le(n: usize, buf: &mut [u8]) -> Result<&mut [u8], EncodingError> { +fn encode_usize_as_u24(n: usize, buf: &mut [u8]) -> Result<&mut [u8], EncodingError> { let (header, rest) = take_array_mut::(buf)?; write_uint24_le(n, header); Ok(rest) @@ -239,6 +248,7 @@ impl CompactEncoding for Message { #[instrument(skip_all, fields(name = self.name()))] fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + debug!("Encoding {self:?}"); let rest = if let Self::Open(_) | Self::Close(_) = &self { buffer } else { @@ -394,6 +404,7 @@ impl ChannelMessage { /// bytes in it #[instrument(skip_all, err)] pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { + debug!("Decode ChannelMessage::Open"); let og_len = buf.len(); if og_len <= 5 { return Err(io::Error::new( @@ -417,6 +428,7 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + debug!("Decode ChannelMessage::Close"); let og_len = buf.len(); if buf.is_empty() { return Err(io::Error::new( @@ -441,6 +453,10 @@ impl ChannelMessage { //::decode(buf) let (channel, buf) = u64::decode(buf)?; let (message, buf) = ::decode(buf)?; + debug!( + "Decode ChannelMessage{{ channel: {channel}, message: {} }}", + message.name() + ); Ok((Self { channel, message }, buf)) } /// Decode a normal channel message from a buffer. @@ -448,7 +464,7 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it #[instrument(err, skip(buf))] - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { + pub(crate) fn decode_with_channel(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -504,8 +520,13 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - let body_len = prencode_channel_messages(vec)?; - let mut buffer = checked_write_uint24_le(body_len, buffer)?; + let in_buf_len = buffer.len(); + trace!( + "Vec::encode to buf.len() = [{}]", + buffer.len() + ); + let body_len = vec_channel_messages_encoded_size(vec)?; + let mut buffer = encode_usize_as_u24(body_len, buffer)?; match vec { [] => Ok(buffer), [msg] => { @@ -527,9 +548,10 @@ impl VecEncodable for ChannelMessage { current_channel = msg.channel; } let msg_len = msg.message.encoded_size()?; - buffer = (msg_len as u32).encode(buffer)?; + buffer = (msg_len as u64).encode(buffer)?; buffer = msg.message.encode(buffer)?; } + trace!("wrote [{}] bytes to buffer", in_buf_len - buffer.len()); Ok(buffer) } } @@ -541,16 +563,19 @@ impl VecEncodable for ChannelMessage { { let mut index = 0; let mut combined_messages: Vec = vec![]; + let mut rest = buffer; while index < buffer.len() { // There might be zero bytes in between, and with LE, the next message will // start with a non-zero - if buffer[index] == 0 { + if rest[index] == 0 { index += 1; continue; } - let (frame_len, next_frame_start) = decode_u24(&buffer[index..])?; - let (msgs, length) = decode_unframed_channel_messages(&next_frame_start[..frame_len]) + let frame_len; + (frame_len, rest) = decode_u24(&rest[index..])?; + let (msgs, length) = decode_unframed_channel_messages(&rest[..frame_len]) .map_err(|e| EncodingError::external(&format!("{e}")))?; + rest = &rest[length..]; if length != frame_len { warn!( "Did not know what to do with all the bytes, got {frame_len} but decoded {length}. \ @@ -561,7 +586,7 @@ impl VecEncodable for ChannelMessage { combined_messages.extend(msgs); index += UINT24_HEADER_LEN + frame_len; } - todo!() + Ok((combined_messages, rest)) } } @@ -692,13 +717,12 @@ mod tests { length: 4, }); let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; - let buff = msgs.to_encoded_bytes()?; - - let res = as CompactEncoding>::decode(&buff); - assert!(res.is_err()); log(); - let buff = msgs.to_encoded_bytes()?; + let (result, rest) = as CompactEncoding>::decode(&buff)?; + assert!(rest.is_empty()); + assert_eq!(result, msgs); + let (res2, _size) = decode_framed_channel_messages(&buff).unwrap(); assert_eq!(res2, msgs); diff --git a/src/schema.rs b/src/schema.rs index 049a590..d1bec1e 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -122,7 +122,6 @@ impl CompactEncoding for Synchronize { let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; flags |= if self.uploading { 2 } else { 0 }; flags |= if self.downloading { 4 } else { 0 }; - dbg!(flags); let rest = write_array(&[flags], buffer)?; Ok(map_encode!( rest, From 3da6c62f2659ac259ad41ac8e6382335e3c95fc2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 9 May 2025 14:51:11 -0400 Subject: [PATCH 103/135] remove redundant names --- tests/js_interop.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 41d8160..7764141 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -51,64 +51,64 @@ const TEST_SET_SIMPLE: &str = "simple"; #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_server_writer() -> Result<()> { - js_interop_ncns_simple(true, 8101).await?; +async fn ncns_server_writer() -> Result<()> { + ncns(true, 8101).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_client_writer() -> Result<()> { - js_interop_ncns_simple(false, 8102).await?; +async fn ncns_client_writer() -> Result<()> { + ncns(false, 8102).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcns_server_writer() -> Result<()> { - js_interop_rcns(true, 8103).await?; +async fn rcns_server_writer() -> Result<()> { + rcns(true, 8103).await?; Ok(()) } #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcns_client_writer() -> Result<()> { - js_interop_rcns(false, 8104).await?; +async fn rcns_client_writer() -> Result<()> { + rcns(false, 8104).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_server_writer() -> Result<()> { - js_interop_ncrs_simple(true, 8105).await?; +async fn ncrs_server_writer() -> Result<()> { + ncrs(true, 8105).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_client_writer() -> Result<()> { - js_interop_ncrs_simple(false, 8106).await?; +async fn ncrs_client_writer() -> Result<()> { + ncrs(false, 8106).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcrs_server_writer() -> Result<()> { +async fn rcrs_server_writer() -> Result<()> { _util::log(); - js_interop_rcrs_simple(true, 8107).await?; + rcrs(true, 8107).await?; Ok(()) } #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcrs_client_writer() -> Result<()> { - js_interop_rcrs_simple(false, 8108).await?; +//#[ignore] // FIXME this tests hangs sporadically +async fn rcrs_client_writer() -> Result<()> { + rcrs(false, 8108).await?; Ok(()) } -async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -147,7 +147,7 @@ async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcns(server_writer: bool, port: u32) -> Result<()> { +async fn rcns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -192,7 +192,7 @@ async fn js_interop_rcns(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -238,7 +238,7 @@ async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn rcrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", From 74033f1ece4e3a1906ae7f785912f87f7b80cf8b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 10 May 2025 20:33:06 -0400 Subject: [PATCH 104/135] RMME --- tests/js_interop.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 7764141..e0dc216 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,4 +1,4 @@ -use _util::wait_for_localhost_port; +use _util::{log, wait_for_localhost_port}; use anyhow::Result; use futures::Future; use futures_lite::stream::StreamExt; @@ -66,6 +66,7 @@ async fn ncns_client_writer() -> Result<()> { #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcns_server_writer() -> Result<()> { + log(); rcns(true, 8103).await?; Ok(()) } @@ -159,7 +160,9 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { }, TEST_SET_SIMPLE ); + dbg!(); let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); + dbg!(); let item_count = 4; let item_size = 4; let data_char = '1'; From d87c498d801f88d46d0e37718fe8eec7a3d85d06 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 18 May 2025 01:39:06 -0400 Subject: [PATCH 105/135] More logging rm unused --- src/noise.rs | 20 ++++++++++++-------- src/protocol.rs | 1 + src/test_utils.rs | 10 ---------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 9f35393..8185162 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -170,7 +170,7 @@ impl< *is_initiator, flush, ); - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush); + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush, step); // check if we've done all possible work if did_as_much_as_possible( @@ -202,6 +202,7 @@ impl< } /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. +#[instrument(skip_all, ret)] fn did_as_much_as_possible< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -214,7 +215,7 @@ fn did_as_much_as_possible< is_initiator: bool, ) -> bool { // No incoming encrypted messages available. - poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator).is_pending() + poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step).is_pending() // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. && (encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(io), cx).is_pending()) // No encrypted messages waiting to be decrypted. @@ -228,7 +229,7 @@ impl>> + Sink> + Send + Unpin + 'static { type Item = Event; - #[instrument(skip_all, fields(initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator, ret, err))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -268,6 +269,7 @@ impl>> + Sink> + Send + Unpin + 'static /// Handle all message throughput. Sends, encrypts and decrypts messages /// Returns `true` `step` is already [`Step::Established`]. #[allow(clippy::too_many_arguments)] +#[instrument(skip_all, ret)] fn poll_message_throughput< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -281,8 +283,8 @@ fn poll_message_throughput< is_initiator: bool, flush: &mut bool, ) -> bool { - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush); - let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator); + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush, step); + let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step); if let Step::Established((encryptor, decryptor, ..)) = step { // decrypt incomming msgs poll_decrypt(decryptor, encrypted_rx, plain_rx, is_initiator); @@ -369,11 +371,12 @@ fn poll_outgoing_encrypted_messages< encrypted_tx: &mut VecDeque>, is_initiator: bool, flush: &mut bool, + step: &Step ) { // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), "TX message"); + trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), step = %step, "TX message"); if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { error!("Error polling encyrpted side io") } @@ -407,11 +410,12 @@ fn poll_incomming_encrypted_messages< cx: &mut Context<'_>, encrypted_rx: &mut VecDeque>>, is_initiator: bool, + step: &Step, ) -> Poll<()> { // pull in any incomming encrypted messages let mut got_some = false; while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(io), cx) { - trace!(initiator = %is_initiator, "RX message"); + trace!(initiator = %is_initiator, step = %step, "RX message"); encrypted_rx.push_back(encrypted_msg); got_some = true; } @@ -437,8 +441,8 @@ fn poll_decrypt( trace!(initiator = %is_initiator, "encrypted_rx dequeue decrypt"); match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); plain_rx.push_back(Event::from(Ok(plain_msg))); + trace!(initiator = %is_initiator, n_plain_rx_msgs = plain_rx.len(), "plain_rx enqueue"); } Err(e) => { error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") diff --git a/src/protocol.rs b/src/protocol.rs index 955ef7c..615cd53 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -348,6 +348,7 @@ where } } + #[instrument(skip_all)] fn on_inbound_channel_messages(&mut self, channel_messages: Vec) -> Result<()> { for channel_message in channel_messages { self.on_inbound_message(channel_message)? diff --git a/src/test_utils.rs b/src/test_utils.rs index 24256cb..8a4dd74 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -197,13 +197,3 @@ pub(crate) fn create_result_connected() -> ( let b = Moo::from(result_channel()); a.connect(b) } - -#[tokio::test] -async fn foo() -> Result<(), Box> { - let a = Moo::from(result_channel()); - let b = Moo::from(result_channel()); - let (mut left, mut right) = a.connect(b); - left.send(b"hello".to_vec()).await?; - assert_eq!(right.next().await.unwrap()?, b"hello".to_vec()); - Ok(()) -} From 539a0179fe93787f912a574ac37df4c06a53dfea Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 18 May 2025 01:40:51 -0400 Subject: [PATCH 106/135] Notes --- src/schema.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/schema.rs b/src/schema.rs index d1bec1e..bc8f141 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -299,6 +299,10 @@ macro_rules! opt_encoded_size { }; } +// TODO we could write a macro where it takes a $cond that returns an opt. +// if the option is Some(T) then do T::encode(buf) +// also if some add $flag. +// This would simplify some of these impls macro_rules! opt_encoded_bytes { ($opt:expr, $buf:ident) => { if let Some(thing) = $opt { From 39d0dbec60eba60cf5ffa5bbbf21b239044c0c84 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:18:02 -0400 Subject: [PATCH 107/135] use checked get --- src/message.rs | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/message.rs b/src/message.rs index ff2ea2b..beabb8d 100644 --- a/src/message.rs +++ b/src/message.rs @@ -7,7 +7,7 @@ use compact_encoding::{ use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{debug, instrument, trace, warn}; +use tracing::{debug, error, instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; @@ -31,19 +31,38 @@ pub(crate) fn decode_framed_channel_messages( let stat = stat_uint24_le(&buf[index..]); if let Some((header_len, body_len)) = stat { - let (msgs, length) = decode_unframed_channel_messages( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ + dbg!(&body_len); + if let Some(frame_body) = + buf.get(index + header_len..index + header_len + body_len as usize) + { + let (msgs, length) = decode_unframed_channel_messages(frame_body)?; + if length != body_len as usize { + warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ This may be because the peer implements a newer protocol version \ that has extra fields.", - body_len, length + body_len, length + ); + } + combined_messages.extend(msgs); + index += header_len + body_len as usize; + } else { + error!( + "Could not get bytes for whole frame. +frame_header_length + frame_body_length \t= [{}] +remaining buffer (after current index) \t= [{}] +total_buffer_len \t= [{}] +current_index \t= [{}] +buffer_starts_with \t= [{:?}] +", + header_len + (body_len as usize), + buf.len() - index, + buf.len(), + index, + &buf, ); + return Ok((combined_messages, index)); } - combined_messages.extend(msgs); - index += header_len + body_len as usize; } else { return Err(io::Error::new( io::ErrorKind::InvalidData, From d97bc646e46bf86b21ac61ef3908e4e0b037cc24 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:18:31 -0400 Subject: [PATCH 108/135] RMME --- src/message.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/message.rs b/src/message.rs index beabb8d..69b18b1 100644 --- a/src/message.rs +++ b/src/message.rs @@ -753,4 +753,19 @@ mod tests { //assert!(res2.is_ok()); Ok(()) } + + #[test] + fn foo() -> Result<(), EncodingError> { + log(); + let buf = vec![ + 0, 1, 1, 15, 104, 121, 112, 101, 114, 99, 111, 114, 101, 47, 97, 108, 112, 104, 97, 32, + 23, 228, 138, 218, 81, 18, 123, 111, 160, 195, 104, 154, 55, 116, 18, 132, 44, 229, 77, + 118, 217, 54, 41, 162, 97, 118, 95, 4, 213, 142, 79, 124, 1, 89, 165, 64, 201, 94, 50, + 58, 137, 153, 119, 156, 234, 18, 164, 157, 161, 49, 16, 28, 206, 84, 241, 0, 245, 14, + 143, 129, 9, 151, 247, 29, 10, + ]; + let res = decode_unframed_channel_messages(&buf).unwrap(); + dbg!(&res); + Ok(()) + } } From 592699981e25456c466ec01ddba9653d8520d4b8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:21:29 -0400 Subject: [PATCH 109/135] un-ignore tests --- src/message.rs | 1 + tests/js_interop.rs | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/message.rs b/src/message.rs index 69b18b1..41cd267 100644 --- a/src/message.rs +++ b/src/message.rs @@ -754,6 +754,7 @@ mod tests { Ok(()) } + // TODO RMME #[test] fn foo() -> Result<(), EncodingError> { log(); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index e0dc216..320436a 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -66,14 +66,12 @@ async fn ncns_client_writer() -> Result<()> { #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcns_server_writer() -> Result<()> { - log(); rcns(true, 8103).await?; Ok(()) } #[tokio::test] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcns_client_writer() -> Result<()> { rcns(false, 8104).await?; Ok(()) From 81e0dad195678093966ce8af3a51a1a6c369db99 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:21:43 -0400 Subject: [PATCH 110/135] rm debug --- src/message.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/message.rs b/src/message.rs index 41cd267..3805158 100644 --- a/src/message.rs +++ b/src/message.rs @@ -31,7 +31,6 @@ pub(crate) fn decode_framed_channel_messages( let stat = stat_uint24_le(&buf[index..]); if let Some((header_len, body_len)) = stat { - dbg!(&body_len); if let Some(frame_body) = buf.get(index + header_len..index + header_len + body_len as usize) { From fb14f920265d875c9bb353a42694ea69c494f215 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:40:46 -0400 Subject: [PATCH 111/135] rm framing stuff from messages --- src/message.rs | 156 ++++++++----------------------------------------- 1 file changed, 25 insertions(+), 131 deletions(-) diff --git a/src/message.rs b/src/message.rs index 3805158..637e99e 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,76 +1,18 @@ use crate::schema::*; -use crate::util::{stat_uint24_le, write_uint24_le}; use compact_encoding::{ - decode_usize, take_array, take_array_mut, write_array, CompactEncoding, EncodingError, - EncodingErrorKind, VecEncodable, + decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, + VecEncodable, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{debug, error, instrument, trace, warn}; +use tracing::{debug, instrument, trace, warn}; -const UINT24_HEADER_LEN: usize = 3; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; -#[instrument(skip_all)] -pub(crate) fn decode_framed_channel_messages( - buf: &[u8], -) -> Result<(Vec, usize), io::Error> { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - if let Some(frame_body) = - buf.get(index + header_len..index + header_len + body_len as usize) - { - let (msgs, length) = decode_unframed_channel_messages(frame_body)?; - if length != body_len as usize { - warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, length - ); - } - combined_messages.extend(msgs); - index += header_len + body_len as usize; - } else { - error!( - "Could not get bytes for whole frame. -frame_header_length + frame_body_length \t= [{}] -remaining buffer (after current index) \t= [{}] -total_buffer_len \t= [{}] -current_index \t= [{}] -buffer_starts_with \t= [{:?}] -", - header_len + (body_len as usize), - buf.len() - index, - buf.len(), - index, - &buf, - ); - return Ok((combined_messages, index)); - } - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok((combined_messages, index)) -} #[instrument(skip_all err)] pub(crate) fn decode_unframed_channel_messages( buf: &[u8], @@ -142,7 +84,7 @@ pub(crate) fn decode_unframed_channel_messages( } else { Err(io::Error::new( io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), + format!("received too short message, {buf:?}"), )) } } @@ -173,19 +115,6 @@ fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result Result<&mut [u8], EncodingError> { - let (header, rest) = take_array_mut::(buf)?; - write_uint24_le(n, header); - Ok(rest) -} - -/// decode a u24 from `buffer` as a `usize` -fn decode_u24(buffer: &[u8]) -> Result<(usize, &[u8]), EncodingError> { - let (u24_bytes, rest) = take_array::(buffer)?; - let (_, out) = stat_uint24_le(&u24_bytes).expect("input garunteed to be long enough"); - Ok((out as usize, rest)) -} - /// A protocol message. #[derive(Debug, Clone, PartialEq)] #[allow(missing_docs)] @@ -530,7 +459,7 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - Ok(vec_channel_messages_encoded_size(vec)? + UINT24_HEADER_LEN) + Ok(vec_channel_messages_encoded_size(vec)?) } #[instrument(skip_all)] @@ -543,34 +472,33 @@ impl VecEncodable for ChannelMessage { "Vec::encode to buf.len() = [{}]", buffer.len() ); - let body_len = vec_channel_messages_encoded_size(vec)?; - let mut buffer = encode_usize_as_u24(body_len, buffer)?; + let mut rest = buffer; match vec { - [] => Ok(buffer), + [] => Ok(rest), [msg] => { - buffer = match msg.message { - Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, buffer)?, - Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, buffer)?, - _ => msg.channel.encode(buffer)?, + rest = match msg.message { + Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, rest)?, + Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, rest)?, + _ => msg.channel.encode(rest)?, }; - msg.message.encode(buffer) + msg.message.encode(rest) } msgs => { - buffer = write_array(&MULTI_MESSAGE_PREFIX, buffer)?; + rest = write_array(&MULTI_MESSAGE_PREFIX, rest)?; let mut current_channel: u64 = msgs[0].channel; - buffer = current_channel.encode(buffer)?; + rest = current_channel.encode(rest)?; for msg in msgs { if msg.channel != current_channel { - buffer = write_array(&CHANNEL_CHANGE_SEPERATOR, buffer)?; - buffer = msg.channel.encode(buffer)?; + rest = write_array(&CHANNEL_CHANGE_SEPERATOR, rest)?; + rest = msg.channel.encode(rest)?; current_channel = msg.channel; } let msg_len = msg.message.encoded_size()?; - buffer = (msg_len as u64).encode(buffer)?; - buffer = msg.message.encode(buffer)?; + rest = (msg_len as u64).encode(rest)?; + rest = msg.message.encode(rest)?; } - trace!("wrote [{}] bytes to buffer", in_buf_len - buffer.len()); - Ok(buffer) + trace!("wrote [{}] bytes to buffer", in_buf_len - rest.len()); + Ok(rest) } } } @@ -579,30 +507,13 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - let mut index = 0; let mut combined_messages: Vec = vec![]; let mut rest = buffer; - while index < buffer.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if rest[index] == 0 { - index += 1; - continue; - } - let frame_len; - (frame_len, rest) = decode_u24(&rest[index..])?; - let (msgs, length) = decode_unframed_channel_messages(&rest[..frame_len]) + while !rest.is_empty() { + let (msgs, length) = decode_unframed_channel_messages(rest) .map_err(|e| EncodingError::external(&format!("{e}")))?; rest = &rest[length..]; - if length != frame_len { - warn!( - "Did not know what to do with all the bytes, got {frame_len} but decoded {length}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - ); - } combined_messages.extend(msgs); - index += UINT24_HEADER_LEN + frame_len; } Ok((combined_messages, rest)) } @@ -662,7 +573,9 @@ mod tests { upgrade: Some(RequestUpgrade { start: 0, length: 10 - }) + }), + manifest: false, + priority: 0 }), Message::Cancel(Cancel { request: 1, @@ -741,9 +654,6 @@ mod tests { assert!(rest.is_empty()); assert_eq!(result, msgs); - let (res2, _size) = decode_framed_channel_messages(&buff).unwrap(); - assert_eq!(res2, msgs); - // from js interop tests // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -752,20 +662,4 @@ mod tests { //assert!(res2.is_ok()); Ok(()) } - - // TODO RMME - #[test] - fn foo() -> Result<(), EncodingError> { - log(); - let buf = vec![ - 0, 1, 1, 15, 104, 121, 112, 101, 114, 99, 111, 114, 101, 47, 97, 108, 112, 104, 97, 32, - 23, 228, 138, 218, 81, 18, 123, 111, 160, 195, 104, 154, 55, 116, 18, 132, 44, 229, 77, - 118, 217, 54, 41, 162, 97, 118, 95, 4, 213, 142, 79, 124, 1, 89, 165, 64, 201, 94, 50, - 58, 137, 153, 119, 156, 234, 18, 164, 157, 161, 49, 16, 28, 206, 84, 241, 0, 245, 14, - 143, 129, 9, 151, 247, 29, 10, - ]; - let res = decode_unframed_channel_messages(&buf).unwrap(); - dbg!(&res); - Ok(()) - } } From 3b9430b440856bc0eccc7a6ac8dbf069ce9d1aa3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:54:28 -0400 Subject: [PATCH 112/135] Add manifest & priority to Request --- examples/replication.rs | 4 ++++ src/mqueue.rs | 6 +----- src/schema.rs | 22 ++++++++++++++++++++++ tests/js_interop.rs | 4 ++++ 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/replication.rs b/examples/replication.rs index 459df9f..ac10df6 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -299,6 +299,8 @@ async fn onmessage( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -405,6 +407,8 @@ async fn onmessage( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } channel.send_batch(&messages).await.unwrap(); diff --git a/src/mqueue.rs b/src/mqueue.rs index e5df5b8..0997f36 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -12,11 +12,7 @@ use compact_encoding::CompactEncoding as _; use futures::{Sink, Stream}; use tracing::{error, instrument}; -use crate::{ - message::{decode_framed_channel_messages, ChannelMessage}, - noise::EncryptionInfo, - NoiseEvent, -}; +use crate::{message::ChannelMessage, noise::EncryptionInfo, NoiseEvent}; #[derive(Debug)] pub(crate) enum MqueueEvent { diff --git a/src/schema.rs b/src/schema.rs index bc8f141..b27f9e4 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -169,6 +169,13 @@ pub struct Request { pub seek: Option, /// Request upgrade pub upgrade: Option, + // TODO what is this + /// Request manifest + pub manifest: bool, + // TODO what is this + // this could prob be usize + /// Request priority + pub priority: u64, } macro_rules! maybe_decode { @@ -206,6 +213,8 @@ impl CompactEncoding for Request { flags |= if self.hash.is_some() { 2 } else { 0 }; flags |= if self.seek.is_some() { 4 } else { 0 }; flags |= if self.upgrade.is_some() { 8 } else { 0 }; + flags |= if self.manifest { 16 } else { 0 }; + flags |= if self.priority != 0 { 32 } else { 0 }; let mut rest = write_array(&[flags], buffer)?; rest = map_encode!(rest, self.id, self.fork); @@ -221,6 +230,11 @@ impl CompactEncoding for Request { if let Some(upgrade) = &self.upgrade { rest = upgrade.encode(rest)?; } + + if self.priority != 0 { + rest = self.priority.encode(rest)?; + } + Ok(rest) } @@ -235,6 +249,12 @@ impl CompactEncoding for Request { let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); let (seek, rest) = maybe_decode!(flags & 4 != 0, RequestSeek, rest); let (upgrade, rest) = maybe_decode!(flags & 8 != 0, RequestUpgrade, rest); + let manifest = flags & 16 != 0; + let (priority, rest) = if flags & 32 != 0 { + u64::decode(rest)? + } else { + (0, rest) + }; Ok(( Request { id, @@ -243,6 +263,8 @@ impl CompactEncoding for Request { hash, seek, upgrade, + manifest, + priority, }, rest, )) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 320436a..ee9d55f 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -605,6 +605,8 @@ async fn on_replication_message( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -716,6 +718,8 @@ async fn on_replication_message( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } let exit = if synced { From f535d3b4c1107d1527014278f6ad25e19c24bd5a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:54:48 -0400 Subject: [PATCH 113/135] rm unused framing stuff --- src/mqueue.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index 0997f36..87f14d5 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -24,16 +24,13 @@ impl From for MqueueEvent { fn from(e: NoiseEvent) -> Self { match e { NoiseEvent::Meta(einf) => Self::Meta(einf), - NoiseEvent::Decrypted(dec_res) => { - match dec_res { - Ok(encoded) => match decode_framed_channel_messages(&encoded) { - //assert_eq!(_n_read, encoded.len()); } - Ok((messsages, _n_read)) => Self::Message(Ok(messsages)), - Err(e) => Self::Message(Err(e)), - }, - Err(e) => Self::Message(Err(e)), - } - } + NoiseEvent::Decrypted(dec_res) => Self::Message(match dec_res { + Ok(encoded) => match >::decode(&encoded) { + Ok((messages, _rest)) => Ok(messages), // _rest.len() == 0 + Err(e) => Err(e.into()), + }, + Err(e) => Err(e), + }), } } } From b59077eeac30de34befe2ee31d82ac96557711c4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:58:01 -0400 Subject: [PATCH 114/135] rm test logging --- tests/js_interop.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index ee9d55f..a52900c 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -158,9 +158,7 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { }, TEST_SET_SIMPLE ); - dbg!(); let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); - dbg!(); let item_count = 4; let item_size = 4; let data_char = '1'; From f6b37e55ca2ff1f427831c010fa8099213f9b177 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:59:08 -0400 Subject: [PATCH 115/135] cargo fmt --- src/noise.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/noise.rs b/src/noise.rs index 8185162..f51b0b7 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -371,7 +371,7 @@ fn poll_outgoing_encrypted_messages< encrypted_tx: &mut VecDeque>, is_initiator: bool, flush: &mut bool, - step: &Step + step: &Step, ) { // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { From 76b3a27803f63c6cef5647a78b3b61ffcb58b3c8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:59:34 -0400 Subject: [PATCH 116/135] cargo clippy --fix --- src/message.rs | 2 +- tests/js_interop.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/message.rs b/src/message.rs index 637e99e..1f0fdb5 100644 --- a/src/message.rs +++ b/src/message.rs @@ -459,7 +459,7 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - Ok(vec_channel_messages_encoded_size(vec)?) + vec_channel_messages_encoded_size(vec) } #[instrument(skip_all)] diff --git a/tests/js_interop.rs b/tests/js_interop.rs index a52900c..5ae6acb 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,4 +1,4 @@ -use _util::{log, wait_for_localhost_port}; +use _util::wait_for_localhost_port; use anyhow::Result; use futures::Future; use futures_lite::stream::StreamExt; From 2d1ef1b743d94e436e5dca9f40d58faf2ca07bb8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:05:10 -0400 Subject: [PATCH 117/135] RawEncCipher -> EncCipher --- src/crypto/cipher.rs | 6 +++--- src/crypto/mod.rs | 2 +- src/noise.rs | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index aa096c3..94e325e 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -102,17 +102,17 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { //NB "raw" here means UN-framed. No frame header. const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; -pub(crate) struct RawEncryptCipher { +pub(crate) struct EncryptCipher { push_stream: PushStream, } -impl std::fmt::Debug for RawEncryptCipher { +impl std::fmt::Debug for EncryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "RawEncryptCipher(crypto_secretstream)") } } -impl RawEncryptCipher { +impl EncryptCipher { pub(crate) fn from_handshake_tx( handshake_result: &HandshakeResult, ) -> std::io::Result<(Self, Vec)> { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 9e49c0a..66bb62d 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,5 @@ mod cipher; mod curve; mod handshake; -pub(crate) use cipher::{DecryptCipher, RawEncryptCipher}; +pub(crate) use cipher::{DecryptCipher, EncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/noise.rs b/src/noise.rs index f51b0b7..7bac01a 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -10,7 +10,7 @@ use std::{ use tracing::{debug, error, instrument, trace, warn}; use crate::{ - crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}, + crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}, Uint24LELengthPrefixedFraming, }; @@ -27,8 +27,8 @@ pub fn encrypted_framed_message_channel), - SecretStream((RawEncryptCipher, HandshakeResult)), - Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), + SecretStream((EncryptCipher, HandshakeResult)), + Established((EncryptCipher, DecryptCipher, HandshakeResult)), } impl Step { @@ -458,7 +458,7 @@ fn poll_decrypt( #[instrument(skip_all)] fn poll_encrypt( - encryptor: &mut RawEncryptCipher, + encryptor: &mut EncryptCipher, encrypted_tx: &mut VecDeque>, plain_tx: &mut VecDeque>, is_initiator: bool, @@ -568,7 +568,7 @@ fn handle_setup_message( }; // The cipher will be put to use to the writer only after the peer's answer has come let (cipher, init_msg) = - match RawEncryptCipher::from_handshake_tx(handshake_result) { + match EncryptCipher::from_handshake_tx(handshake_result) { Ok(x) => x, Err(e) => { error!("from_handshake_tx error {e:?}"); From fe2eed1f073adf687b948ef86773444d15a334f6 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:10:02 -0400 Subject: [PATCH 118/135] Remove old notes --- src/crypto/cipher.rs | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 94e325e..ea11920 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -136,24 +136,9 @@ impl EncryptCipher { Ok((Self { push_stream }, msg)) } - // Possible API's: - // encrypted message is (tag + encrypted + mac ) - // to have *zero* alocations we could - // * take a buffer of the expected final length, plantext starts at 1 to 1 + planetext.len() - // * final length is 1 + plaintext.len() + mac.len() - // * we write tag to 0 - // * encrypt plain text part in place - // * write mac to end - // - // it would be akward to take an array like this. We could infer the plaintext via the buffer - // it's range would be (1..(buf.len() - mac.len())) - // encypt-in-place the palintext, - // For now... let's just return the encrypted buffer - // + // TODO make this work in-place /// Encrypts `msg` and returns the encrypted bytes pub(crate) fn encrypt(&mut self, msg: &[u8]) -> io::Result> { - // NB: the result is written in place to the provided, however the buffer must be able to - // grow, since the encrypted message is bigger. So here we convert the slice to a vec. let mut out = msg.to_vec(); self.push_stream .push(&mut out, &[], Tag::Message) From f92b5b2a83868551bd8685f3f1d6921cedd511e2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:30:53 -0400 Subject: [PATCH 119/135] lint --- src/framing.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 12d3c41..02730b1 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -1,3 +1,7 @@ +//! Wrap bytes in length prefixed framing. +use crate::util::{stat_uint24_le, wrap_uint24_le}; +use futures::{Sink, Stream}; +use futures_lite::io::{AsyncRead, AsyncWrite}; use std::{ collections::VecDeque, fmt::Debug, @@ -5,14 +9,8 @@ use std::{ pin::Pin, task::{Context, Poll}, }; - -use futures::{Sink, Stream}; - -use futures_lite::io::{AsyncRead, AsyncWrite}; use tracing::{debug, error, info, instrument, trace, warn}; -use crate::util::{stat_uint24_le, wrap_uint24_le}; - const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; From f4fb37180d865aa9b5fe8c8114eac09779183e28 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:31:03 -0400 Subject: [PATCH 120/135] rename tests --- src/message.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/message.rs b/src/message.rs index 1f0fdb5..09bd294 100644 --- a/src/message.rs +++ b/src/message.rs @@ -633,7 +633,7 @@ mod tests { } #[test] - fn extras() -> Result<(), EncodingError> { + fn enc_dec_vec_chan_message() -> Result<(), EncodingError> { let one = Message::Synchronize(Synchronize { fork: 0, length: 4, @@ -653,13 +653,6 @@ mod tests { let (result, rest) = as CompactEncoding>::decode(&buff)?; assert!(rest.is_empty()); assert_eq!(result, msgs); - - // from js interop tests - // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // [23, 0, 0, 0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - //assert!(res2.is_ok()); Ok(()) } } From b247a36e0872458ca8c7c33d36fae6f0a24904c7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:31:14 -0400 Subject: [PATCH 121/135] rm old notse --- src/protocol.rs | 1 - src/schema.rs | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 615cd53..4c10a1c 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -323,7 +323,6 @@ where } /// Poll for outbound messages and write them. - /// Reads messages from Self::outbound and sends them over io #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { diff --git a/src/schema.rs b/src/schema.rs index b27f9e4..49a0ac5 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -51,7 +51,8 @@ impl CompactEncoding for Open { { let ((channel, protocol, discovery_key), rest) = map_decode!(buffer, [u64, String, Vec]); - // TODO this is a CLEAR bug it assumes nothing is encoded after this message + // NB: Open/Close are only sent alone in their own Frame. So we're done when there is no + // more data let (capability, rest) = if !rest.is_empty() { let (_, rest) = take_array::<1>(rest)?; let (capability, rest) = take_array::<32>(rest)?; From cefc744c554f2f8079e7e92514627337457866e3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:32:40 -0400 Subject: [PATCH 122/135] group imports ran: cargo +nightly fmt with: imports_granularity = "crate" in .rustformat.toml --- benches/pipe.rs | 12 ++++++------ benches/throughput.rs | 17 ++++++++++------- examples/replication.rs | 19 ++++++++----------- src/builder.rs | 3 +-- src/channels.rs | 33 +++++++++++++++++++-------------- src/crypto/cipher.rs | 3 +-- src/crypto/handshake.rs | 6 ++++-- src/duplex.rs | 8 +++++--- src/message.rs | 3 +-- src/protocol.rs | 40 +++++++++++++++++++++++----------------- src/util.rs | 9 +++++---- tests/_util.rs | 9 +++++---- tests/basic.rs | 3 +-- tests/js/mod.rs | 8 +++++--- tests/js_interop.rs | 15 +++++++-------- 15 files changed, 101 insertions(+), 87 deletions(-) diff --git a/benches/pipe.rs b/benches/pipe.rs index b726545..6f2a4b8 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -1,14 +1,14 @@ use async_std::task; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::StreamExt; -use hypercore_protocol::{schema::*, Duplex}; -use hypercore_protocol::{Channel, Event, Message, Protocol, ProtocolBuilder}; +use futures::{ + io::{AsyncRead, AsyncWrite}, + stream::StreamExt, +}; +use hypercore_protocol::{schema::*, Channel, Duplex, Event, Message, Protocol, ProtocolBuilder}; use log::*; use pretty_bytes::converter::convert as pretty_bytes; use sluice::pipe::pipe; -use std::io::Result; -use std::time::Instant; +use std::{io::Result, time::Instant}; const COUNT: u64 = 1000; const SIZE: u64 = 100; diff --git a/benches/throughput.rs b/benches/throughput.rs index cc2c278..1d9c4c0 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,11 +1,14 @@ -use async_std::net::{Shutdown, TcpListener, TcpStream}; -use async_std::task; +use async_std::{ + net::{Shutdown, TcpListener, TcpStream}, + task, +}; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::future::Either; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::schema::*; -use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; +use futures::{ + future::Either, + io::{AsyncRead, AsyncWrite}, + stream::{FuturesUnordered, StreamExt}, +}; +use hypercore_protocol::{schema::*, Channel, Event, Message, ProtocolBuilder}; use log::*; use std::time::Instant; diff --git a/examples/replication.rs b/examples/replication.rs index ac10df6..85d0d11 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -1,22 +1,19 @@ use anyhow::Result; -use async_std::net::{TcpListener, TcpStream}; -use async_std::prelude::*; -use async_std::sync::{Arc, Mutex}; -use async_std::task; +use async_std::{ + net::{TcpListener, TcpStream}, + prelude::*, + sync::{Arc, Mutex}, + task, +}; use futures_lite::stream::StreamExt; use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use std::collections::HashMap; -use std::convert::TryInto; -use std::env; -use std::fmt::Debug; -use std::sync::OnceLock; +use std::{collections::HashMap, convert::TryInto, env, fmt::Debug, sync::OnceLock}; use tracing::{error, info}; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; fn main() { log(); diff --git a/src/builder.rs b/src/builder.rs index d797654..0b9127e 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,5 +1,4 @@ -use crate::Protocol; -use crate::{duplex::Duplex, protocol::Options}; +use crate::{duplex::Duplex, protocol::Options, Protocol}; use futures_lite::io::{AsyncRead, AsyncWrite}; /// Build a Protocol instance with options. diff --git a/src/channels.rs b/src/channels.rs index 1b94ece..f16ac7f 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,18 +1,23 @@ -use crate::message::ChannelMessage; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::Message; -use crate::{discovery_key, DiscoveryKey, Key}; +use crate::{ + discovery_key, + message::ChannelMessage, + schema::*, + util::{map_channel_err, pretty_hash}, + DiscoveryKey, Key, Message, +}; use async_channel::{Receiver, Sender, TrySendError}; -use futures_lite::ready; -use futures_lite::stream::Stream; -use std::collections::HashMap; -use std::fmt; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::Poll; +use futures_lite::{ready, stream::Stream}; +use std::{ + collections::HashMap, + fmt, + io::{Error, ErrorKind, Result}, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::Poll, +}; use tracing::instrument; /// A protocol channel. diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index ea11920..20cb734 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -5,8 +5,7 @@ use blake2::{ }; use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; -use std::convert::TryInto; -use std::io; +use std::{convert::TryInto, io}; const STREAM_ID_LENGTH: usize = 32; const KEY_LENGTH: usize = 32; diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 10c111f..53f3889 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -3,8 +3,10 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use snow::resolvers::{DefaultResolver, FallbackResolver}; -use snow::{Builder, Error as SnowError, HandshakeState}; +use snow::{ + resolvers::{DefaultResolver, FallbackResolver}, + Builder, Error as SnowError, HandshakeState, +}; use std::io::{Error, ErrorKind, Result}; use tracing::instrument; diff --git a/src/duplex.rs b/src/duplex.rs index fe79c1b..7b0f1e5 100644 --- a/src/duplex.rs +++ b/src/duplex.rs @@ -1,7 +1,9 @@ use futures_lite::{AsyncRead, AsyncWrite}; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; #[derive(Clone, Debug, PartialEq)] /// Duplex IO stream from reader and writer halves. diff --git a/src/message.rs b/src/message.rs index 09bd294..7665df4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -4,8 +4,7 @@ use compact_encoding::{ VecEncodable, }; use pretty_hash::fmt as pretty_fmt; -use std::fmt; -use std::io; +use std::{fmt, io}; use tracing::{debug, instrument, trace, warn}; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; diff --git a/src/protocol.rs b/src/protocol.rs index 4c10a1c..e188baf 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,25 +1,31 @@ use async_channel::{Receiver, Sender}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::Stream; +use futures_lite::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; use futures_timer::Delay; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::fmt; -use std::io::{self, Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::{ + collections::VecDeque, + convert::TryInto, + fmt, + io::{self, Error, ErrorKind, Result}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; use tracing::{debug, error, instrument, warn}; -use crate::channels::{Channel, ChannelMap}; -use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::HandshakeResult; -use crate::message::{ChannelMessage, Message}; -use crate::mqueue::{MessageIo, MqueueEvent}; -use crate::noise::EncryptionInfo; -use crate::util::{map_channel_err, pretty_hash}; use crate::{ - encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, + channels::{Channel, ChannelMap}, + constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}, + crypto::HandshakeResult, + encrypted_framed_message_channel, + message::{ChannelMessage, Message}, + mqueue::{MessageIo, MqueueEvent}, + noise::EncryptionInfo, + schema::*, + util::{map_channel_err, pretty_hash}, + Encrypted, Uint24LELengthPrefixedFraming, }; macro_rules! return_error { diff --git a/src/util.rs b/src/util.rs index 7e70336..5f243f2 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,11 +2,12 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use std::convert::TryInto; -use std::io::{Error, ErrorKind}; +use std::{ + convert::TryInto, + io::{Error, ErrorKind}, +}; -use crate::constants::DISCOVERY_NS_BUF; -use crate::DiscoveryKey; +use crate::{constants::DISCOVERY_NS_BUF, DiscoveryKey}; /// Calculate the discovery key of a key. /// diff --git a/tests/_util.rs b/tests/_util.rs index d15be38..fc299ca 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -1,11 +1,12 @@ use async_std::net::TcpStream; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::StreamExt; +use futures_lite::{ + io::{AsyncRead, AsyncWrite}, + StreamExt, +}; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; use std::io; -use tokio::io::DuplexStream; -use tokio::task::JoinHandle; +use tokio::{io::DuplexStream, task::JoinHandle}; #[allow(unused)] pub(crate) fn log() { diff --git a/tests/basic.rs b/tests/basic.rs index 280e5be..f0d2b77 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -3,8 +3,7 @@ use _util::{ event_discovery_key, next_event, }; use futures_lite::StreamExt; -use hypercore_protocol::{discovery_key, Event, Message}; -use hypercore_protocol::{schema::*, DiscoveryKey}; +use hypercore_protocol::{discovery_key, schema::*, DiscoveryKey, Event, Message}; use std::io; use tokio::task; diff --git a/tests/js/mod.rs b/tests/js/mod.rs index 8894b3d..b8cd6ec 100644 --- a/tests/js/mod.rs +++ b/tests/js/mod.rs @@ -1,8 +1,10 @@ use anyhow::Result; use instant::Duration; -use std::fs::{create_dir_all, remove_dir_all, remove_file}; -use std::path::Path; -use std::process::Command; +use std::{ + fs::{create_dir_all, remove_dir_all, remove_file}, + path::Path, + process::Command, +}; #[cfg(feature = "async-std")] use async_std::{ diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 5ae6acb..3c74112 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -2,16 +2,16 @@ use _util::wait_for_localhost_port; use anyhow::Result; use futures::Future; use futures_lite::stream::StreamExt; -use hypercore::SigningKey; use hypercore::{ - Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, + Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, SigningKey, Storage, VerifyingKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH, }; use instant::Duration; -use std::fmt::Debug; -use std::path::Path; -use std::sync::Arc; -use std::sync::Once; +use std::{ + fmt::Debug, + path::Path, + sync::{Arc, Once}, +}; #[cfg(feature = "tokio")] use async_compat::CompatExt; @@ -25,8 +25,7 @@ use tokio::{ time::sleep, }; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; pub mod _util; mod js; From f841a2bb846540ea6ff28424d748e3631344eae9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:36:05 -0400 Subject: [PATCH 123/135] format code in docs --- src/lib.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3602517..7857bd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,17 +50,14 @@ //! //! ```no_run //! # async_std::task::block_on(async { -//! use hypercore_protocol::{ProtocolBuilder, Event, Message}; -//! use hypercore_protocol::schema::*; //! use async_std::prelude::*; +//! use hypercore_protocol::{schema::*, Event, Message, ProtocolBuilder}; //! // Start a tcp server. //! let listener = async_std::net::TcpListener::bind("localhost:8000").await.unwrap(); //! async_std::task::spawn(async move { //! let mut incoming = listener.incoming(); //! while let Some(Ok(stream)) = incoming.next().await { -//! async_std::task::spawn(async move { -//! onconnection(stream, false).await -//! }); +//! async_std::task::spawn(async move { onconnection(stream, false).await }); //! } //! }); //! @@ -69,7 +66,7 @@ //! onconnection(stream, true).await; //! //! /// Start Hypercore protocol on a TcpStream. -//! async fn onconnection (stream: async_std::net::TcpStream, is_initiator: bool) { +//! async fn onconnection(stream: async_std::net::TcpStream, is_initiator: bool) { //! // A peer either is the initiator or a connection or is being connected to. //! let name = if is_initiator { "dialer" } else { "listener" }; //! // A key for the channel we want to open. Usually, this is a pre-shared key that both peers @@ -86,7 +83,7 @@ //! // The handshake event is emitted after the protocol is fully established. //! Event::Handshake(_remote_key) => { //! protocol.open(key.clone()).await; -//! }, +//! } //! // A Channel event is emitted for each established channel. //! Event::Channel(mut channel) => { //! // A Channel can be sent to other tasks. @@ -97,7 +94,7 @@ //! eprintln!("{} received message: {:?}", name, message); //! } //! }); -//! }, +//! } //! _ => {} //! } //! } From 0ee4be6464e22a56150f491552188946ec4b3a0e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:38:43 -0400 Subject: [PATCH 124/135] rm unused --- tests/_util.rs | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/_util.rs b/tests/_util.rs index fc299ca..b6f1d22 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -8,31 +8,6 @@ use instant::Duration; use std::io; use tokio::{io::DuplexStream, task::JoinHandle}; -#[allow(unused)] -pub(crate) fn log() { - static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); - START_LOGS.get_or_init(|| { - use tracing_subscriber::{ - layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, - }; - let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable - - // Create the hierarchical layer from tracing_tree - let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level - .with_targets(true) - .with_bracketed_fields(true) - .with_indent_lines(true) - .with_span_modes(true) - .with_thread_ids(false) - .with_thread_names(false); - - tracing_subscriber::registry() - .with(env_filter) - .with(tree_layer) - .init(); - }); -} - type TokioDuplex = tokio_util::compat::Compat; pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { @@ -111,7 +86,6 @@ where }) } -#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { const RETRY_TIMEOUT: u64 = 100_u64; const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; From 7f38fdab9d1eba4d368230873d7726715b250d83 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:51:55 -0400 Subject: [PATCH 125/135] rm unwraps --- src/framing.rs | 16 ++++++++-------- tests/basic.rs | 1 - tests/js_interop.rs | 1 - 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 02730b1..5bc8297 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -331,7 +331,7 @@ pub(crate) mod test { // NB this sluice pipe // for d in data { - rightlp.feed(d.to_vec()).await.unwrap(); + rightlp.feed(d.to_vec()).await?; } let rflush = spawn(async move { rightlp.flush().await.unwrap(); @@ -340,14 +340,14 @@ pub(crate) mod test { let mut result1 = vec![]; for _ in data { - result1.push(leftlp.next().await.unwrap().unwrap()); + result1.push(leftlp.next().await.unwrap()?); } let mut rightlp = rflush.await?; assert_eq!(result1, data); for d in data { - leftlp.feed(d.to_vec()).await.unwrap(); + leftlp.feed(d.to_vec()).await?; } let lflush = spawn(async move { leftlp.flush().await.unwrap(); @@ -356,7 +356,7 @@ pub(crate) mod test { let mut result2 = vec![]; for _ in data { - result2.push(rightlp.next().await.unwrap().unwrap()); + result2.push(rightlp.next().await.unwrap()?); } let mut leftlp = lflush.await?; assert_eq!(result2, data); @@ -365,13 +365,13 @@ pub(crate) mod test { let mut r4 = vec![]; for d in data { - rightlp.send(d.to_vec()).await.unwrap(); - leftlp.send(d.to_vec()).await.unwrap(); + rightlp.send(d.to_vec()).await?; + leftlp.send(d.to_vec()).await?; } for _ in data { - r3.push(rightlp.next().await.unwrap().unwrap()); - r4.push(leftlp.next().await.unwrap().unwrap()); + r3.push(rightlp.next().await.unwrap()?); + r4.push(leftlp.next().await.unwrap()?); } assert_eq!(r3, data); diff --git a/tests/basic.rs b/tests/basic.rs index f0d2b77..d713937 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -11,7 +11,6 @@ mod _util; #[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - _util::log(); let (proto_a, proto_b) = create_pair_memory2().await?; let next_a = next_event(proto_a); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 3c74112..619841b 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -93,7 +93,6 @@ async fn ncrs_client_writer() -> Result<()> { #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcrs_server_writer() -> Result<()> { - _util::log(); rcrs(true, 8107).await?; Ok(()) } From 2904f01d90566ee88bc437eaa793a11a7ed0883e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 00:51:39 -0400 Subject: [PATCH 126/135] clean up noise module --- src/noise.rs | 605 ++++++++++++++++++++------------------------------- 1 file changed, 240 insertions(+), 365 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 7bac01a..4115cbe 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -39,16 +39,13 @@ impl Step { impl std::fmt::Display for Step { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Step::NotInitialized => "NotInitialized", - Step::Handshake(_) => "Handshake", - Step::SecretStream(_) => "SecretStream", - Step::Established(_) => "Established", - } - ) + let x = match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + }; + write!(f, "{}", x) } } @@ -111,6 +108,230 @@ where pub fn encryption_established(&self) -> bool { self.step.established() } + + /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. + #[instrument(skip_all, ret)] + fn did_as_much_as_possible(&mut self, cx: &mut Context<'_>) -> bool { + // No incoming encrypted messages available. + self.poll_incomming_encrypted_messages(cx).is_pending() + // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. + && (self.encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(&mut self.io), cx).is_pending()) + // No encrypted messages waiting to be decrypted. + && self.encrypted_rx.is_empty() + // No plaint text messages waiting to be enccrypted or we're still setting up + && (self.plain_tx.is_empty() || !self.step.established()) + } + + /// Handle all message throughput. Sends, encrypts and decrypts messages + /// Returns `true` `step` is already [`Step::Established`]. + #[allow(clippy::too_many_arguments)] + #[instrument(skip_all, ret)] + fn poll_message_throughput(&mut self, cx: &mut Context<'_>) -> bool { + self.poll_outgoing_encrypted_messages(cx); + let _ = self.poll_incomming_encrypted_messages(cx); + if let Step::Established((encryptor, decryptor, ..)) = &mut self.step { + // decrypt incomming msgs + poll_decrypt( + decryptor, + &mut self.encrypted_rx, + &mut self.plain_rx, + self.is_initiator, + ); + // encrypt any pending plaintext outgoinng messages + poll_encrypt( + encryptor, + &mut self.encrypted_tx, + &mut self.plain_tx, + self.is_initiator, + &mut self.flush, + ); + true + } else { + self.poll_setup(); + false + } + } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + fn poll_setup(&mut self) { + // if we get an error, it could be because the other side reset, and is sending a new + // initialization message. + // If this is the case, we should retry this message after the error. + // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. + // Still setting up + if let Ok(Some(msg)) = maybe_init(&mut self.step, self.is_initiator) { + // queue the init message to send first + trace!(initiator = %self.is_initiator,"queue initial msg"); + self.encrypted_tx.push_front(msg); + } + // TODO handle error + while let Some(enc_res) = self.encrypted_rx.pop_front() { + match enc_res { + Err(e) => { + error!("Recieved an error during setup encryption setup: {e:?}"); + break; + } + Ok(incoming_msg) => { + trace!(initiator = %self.is_initiator, "encrypted_rx dequeue recieved setup msg"); + if let Ok(msgs) = match self.handle_setup_message(&incoming_msg) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(initiator = %self.is_initiator,"queue more setup msg"); + self.encrypted_tx.push_front(msg); + } + } + } + } + + if self.step.established() { + return; + } + } + } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + /// Fills `encrypted_rx` and drains `encrypted_tx`. + fn poll_outgoing_encrypted_messages(&mut self, cx: &mut Context<'_>) { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.io), cx) { + if let Some(encrypted_out) = self.encrypted_tx.pop_front() { + trace!(initiator = %self.is_initiator, msg_len = encrypted_out.len(), step = %self.step, "TX message"); + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), encrypted_out) { + error!("Error polling encyrpted side io") + } + + self.flush = true; + } else { + break; + } + } + if self.flush { + match Sink::poll_flush(Pin::new(&mut self.io), cx) { + Poll::Ready(Ok(())) => { + self.flush = false; + trace!(initiator = %self.is_initiator, "all flushed"); + } + Poll::Ready(Err(_e)) => { + error!(initiator = %self.is_initiator, "Error sending encrypted msg") + } + Poll::Pending => { + // flush not complete try again later + self.flush = true; + } + } + } + } + + fn poll_incomming_encrypted_messages(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // pull in any incomming encrypted messages + let mut got_some = false; + while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(&mut self.io), cx) { + trace!(initiator = %self.is_initiator, step = %self.step, "RX message"); + self.encrypted_rx.push_back(encrypted_msg); + got_some = true; + } + if got_some { + Poll::Ready(()) + } else { + Poll::Pending + } + } + /// handle setup messages: if any are incorrect (cause an error) the state is reset + #[instrument(err, skip_all, fields(initiator = %self.is_initiator))] + fn handle_setup_message(&mut self, msg: &[u8]) -> Result>> { + // this would only happen after reset with a bad message. + let mut first_message = false; + if let Step::NotInitialized = self.step { + first_message = true; + assert!(!self.is_initiator); + warn!(initiator = %self.is_initiator, "Encrypted state was reset"); + let mut handshake = Handshake::new(self.is_initiator)?; + let _ = handshake.start_raw()?; + self.step = Step::Handshake(Box::new(handshake)); + } + match &self.step { + Step::NotInitialized => { + unreachable!("should not happen") + } + Step::Handshake(_) => { + let mut out = vec![]; + if let Step::Handshake(mut handshake) = + replace(&mut self.step, Step::NotInitialized) + { + trace!("RX handshake msg"); + if let Some(response) = match handshake.read_raw(msg) { + Ok(x) => x, + Err(e) => { + let maybe_init_message = + (!first_message && !self.is_initiator).then_some(msg.to_vec()); + + self.reset_encrypted(maybe_init_message); + return Err(e); + } + } { + trace!( + initiator = %self.is_initiator, + "read message and emitting response", + ); + out.push(response); + } + + if handshake.complete() { + debug!(initiator = %self.is_initiator, "Handshake completed"); + let handshake_result = match handshake.get_result() { + Ok(x) => x, + Err(e) => { + error!("into-result error {e:?}"); + return Err(e); + } + }; + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = + match EncryptCipher::from_handshake_tx(handshake_result) { + Ok(x) => x, + Err(e) => { + error!("from_handshake_tx error {e:?}"); + return Err(e); + } + }; + out.push(init_msg); + self.step = Step::SecretStream((cipher, handshake_result.clone())); + debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); + } else { + self.step = Step::Handshake(handshake); + } + } + Ok(out) + } + Step::SecretStream(_) => { + if let Step::SecretStream((enc_cipher, hs_result)) = + replace(&mut self.step, Step::NotInitialized) + { + let dec_cipher = + DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + self.plain_rx.push_back(Event::from(hs_result.clone())); + self.step = Step::Established((enc_cipher, dec_cipher, hs_result)); + debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); + } + Ok(vec![]) + } + Step::Established((..)) => todo!(), + } + } + #[instrument(skip_all)] + fn reset_encrypted(&mut self, maybe_init_message: Option>) { + error!("Encrypted RESET"); + self.step = Step::NotInitialized; + self.encrypted_tx.clear(); + self.encrypted_rx.clear(); + if let Some(msg) = maybe_init_message { + self.encrypted_rx.push_front(Ok(msg)); + } + self.flush = false; + } } impl< @@ -139,51 +360,21 @@ impl< #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_flush( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { // The flow here can be understood as reading from the encrypted side moving those messages // through to the plaintext side, then reading new plaintext messages and moving them to // the encrypted side. // We do this repeatedly until there's nothing else to do - let Encrypted { - io, - step, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - flush, - .. - } = self.get_mut(); - loop { - poll_message_throughput( - io, - cx, - step, - encrypted_tx, - encrypted_rx, - plain_rx, - plain_tx, - *is_initiator, - flush, - ); - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush, step); + self.poll_message_throughput(cx); + self.poll_outgoing_encrypted_messages(cx); // check if we've done all possible work - if did_as_much_as_possible( - io, - cx, - step, - encrypted_tx, - encrypted_rx, - plain_tx, - *is_initiator, - ) { - if !step.established() || !encrypted_tx.is_empty() || *flush { - trace!(not_established = !step.established(), tx_msgs_waiting = !encrypted_tx.is_empty(), flush = ?flush, "not done flushing"); + if self.did_as_much_as_possible(cx) { + if !self.step.established() || !self.encrypted_tx.is_empty() || self.flush { + trace!(not_established = !self.step.established(), tx_msgs_waiting = !self.encrypted_tx.is_empty(), flush = ?self.flush, "not done flushing"); cx.waker().wake_by_ref(); return Poll::Pending; } @@ -201,60 +392,15 @@ impl< } } -/// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. -#[instrument(skip_all, ret)] -fn did_as_much_as_possible< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - step: &mut Step, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_tx: &mut VecDeque>, - is_initiator: bool, -) -> bool { - // No incoming encrypted messages available. - poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step).is_pending() - // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. - && (encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(io), cx).is_pending()) - // No encrypted messages waiting to be decrypted. - && encrypted_rx.is_empty() - // No plaint text messages waiting to be enccrypted or we're still setting up - && (plain_tx.is_empty() || !step.established()) -} - impl>> + Sink> + Send + Unpin + 'static> Stream for Encrypted { type Item = Event; #[instrument(skip_all, fields(initiator = %self.is_initiator, ret, err))] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Encrypted { - io, - step, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - flush, - .. - } = self.get_mut(); - - if poll_message_throughput( - io, - cx, - step, - encrypted_tx, - encrypted_rx, - plain_rx, - plain_tx, - *is_initiator, - flush, - ) { - if let Some(msg) = plain_rx.pop_front() { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.poll_message_throughput(cx) { + if let Some(msg) = self.plain_rx.pop_front() { Poll::Ready(Some(msg)) } else { Poll::Pending @@ -266,166 +412,6 @@ impl>> + Sink> + Send + Unpin + 'static } } -/// Handle all message throughput. Sends, encrypts and decrypts messages -/// Returns `true` `step` is already [`Step::Established`]. -#[allow(clippy::too_many_arguments)] -#[instrument(skip_all, ret)] -fn poll_message_throughput< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - step: &mut Step, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_rx: &mut VecDeque, - plain_tx: &mut VecDeque>, - is_initiator: bool, - flush: &mut bool, -) -> bool { - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush, step); - let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step); - if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt incomming msgs - poll_decrypt(decryptor, encrypted_rx, plain_rx, is_initiator); - // encrypt any pending plaintext outgoinng messages - poll_encrypt(encryptor, encrypted_tx, plain_tx, is_initiator, flush); - true - } else { - poll_setup( - step, - encrypted_tx, - encrypted_rx, - plain_rx, - is_initiator, - flush, - ); - false - } -} - -#[instrument(skip_all, fields(initiator = %is_initiator))] -fn poll_setup( - step: &mut Step, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_rx: &mut VecDeque, - is_initiator: bool, - flush: &mut bool, -) { - // if we get an error, it could be because the other side reset, and is sending a new - // initialization message. - // If this is the case, we should retry this message after the error. - // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. - // Still setting up - if let Ok(Some(msg)) = maybe_init(step, is_initiator) { - // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg"); - encrypted_tx.push_front(msg); - } - // TODO handle error - while let Some(enc_res) = encrypted_rx.pop_front() { - match enc_res { - Err(e) => { - error!("Recieved an error during setup encryption setup: {e:?}"); - break; - } - Ok(incoming_msg) => { - trace!(initiator = %is_initiator, "encrypted_rx dequeue recieved setup msg"); - if let Ok(msgs) = match handle_setup_message( - step, - &incoming_msg, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_rx, - flush, - ) { - Ok(x) => Ok(x), - Err(e) => { - error!("handle_setup_message error: {e:?}"); - Err(e) - } - } { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg"); - encrypted_tx.push_front(msg); - } - } - } - } - - if step.established() { - return; - } - } -} - -#[instrument(skip_all, fields(initiator = %is_initiator))] -/// Fills `encrypted_rx` and drains `encrypted_tx`. -fn poll_outgoing_encrypted_messages< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - encrypted_tx: &mut VecDeque>, - is_initiator: bool, - flush: &mut bool, - step: &Step, -) { - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), step = %step, "TX message"); - if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { - error!("Error polling encyrpted side io") - } - - *flush = true; - } else { - break; - } - } - if *flush { - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!(initiator = %is_initiator, "all flushed"); - } - Poll::Ready(Err(_e)) => { - error!(initiator = %is_initiator, "Error sending encrypted msg") - } - Poll::Pending => { - // flush not complete try again later - *flush = true; - } - } - } -} - -fn poll_incomming_encrypted_messages< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - encrypted_rx: &mut VecDeque>>, - is_initiator: bool, - step: &Step, -) -> Poll<()> { - // pull in any incomming encrypted messages - let mut got_some = false; - while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(io), cx) { - trace!(initiator = %is_initiator, step = %step, "RX message"); - encrypted_rx.push_back(encrypted_msg); - got_some = true; - } - if got_some { - Poll::Ready(()) - } else { - Poll::Pending - } -} - #[instrument(skip_all)] fn poll_decrypt( decryptor: &mut DecryptCipher, @@ -487,117 +473,6 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } -#[instrument(skip_all)] -fn reset_encrypted( - step: &mut Step, - maybe_init_message: Option>, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - flush: &mut bool, -) { - error!("Encrypted RESET"); - *step = Step::NotInitialized; - encrypted_tx.clear(); - encrypted_rx.clear(); - if let Some(msg) = maybe_init_message { - encrypted_rx.push_front(Ok(msg)); - } - *flush = false; -} - -/// handle setup messages: if any are incorrect (cause an error) the state is reset -#[instrument(err, skip_all, fields(initiator = %is_initiator))] -fn handle_setup_message( - step: &mut Step, - msg: &[u8], - is_initiator: bool, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_rx: &mut VecDeque, - flush: &mut bool, -) -> Result>> { - // this would only happen after reset with a bad message. - let mut first_message = false; - if let Step::NotInitialized = step { - first_message = true; - assert!(!is_initiator); - warn!(initiator = %is_initiator, "Encrypted state was reset"); - let mut handshake = Handshake::new(is_initiator)?; - let _ = handshake.start_raw()?; - *step = Step::Handshake(Box::new(handshake)); - } - match &step { - Step::NotInitialized => { - unreachable!("should not happen") - } - Step::Handshake(_) => { - let mut out = vec![]; - if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - trace!("RX handshake msg"); - if let Some(response) = match handshake.read_raw(msg) { - Ok(x) => x, - Err(e) => { - let maybe_init_message = - (!first_message && !is_initiator).then_some(msg.to_vec()); - - reset_encrypted( - step, - maybe_init_message, - encrypted_tx, - encrypted_rx, - flush, - ); - return Err(e); - } - } { - trace!( - initiator = %is_initiator, - "read message and emitting response", - ); - out.push(response); - } - - if handshake.complete() { - debug!(initiator = %is_initiator, "Handshake completed"); - let handshake_result = match handshake.get_result() { - Ok(x) => x, - Err(e) => { - error!("into-result error {e:?}"); - return Err(e); - } - }; - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = - match EncryptCipher::from_handshake_tx(handshake_result) { - Ok(x) => x, - Err(e) => { - error!("from_handshake_tx error {e:?}"); - return Err(e); - } - }; - out.push(init_msg); - *step = Step::SecretStream((cipher, handshake_result.clone())); - debug!(initiator = %is_initiator, "Step changed to {step}"); - } else { - *step = Step::Handshake(handshake); - } - } - Ok(out) - } - Step::SecretStream(_) => { - if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) - { - let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; - plain_rx.push_back(Event::from(hs_result.clone())); - *step = Step::Established((enc_cipher, dec_cipher, hs_result)); - debug!(initiator = %is_initiator, "Step changed to {step}"); - } - Ok(vec![]) - } - Step::Established((..)) => todo!(), - } -} - impl std::fmt::Debug for Encrypted { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Encrypted") From 0b17c7aa6eb538c72ef753023c2e5ece77d00202 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:13:20 -0400 Subject: [PATCH 127/135] lint --- src/mqueue.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index 87f14d5..bb92824 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -90,9 +90,7 @@ impl + Sink> + Send + Unpin + 'static> Mes } match Sink::poll_flush(Pin::new(&mut self.io), cx) { - Poll::Ready(Err(_e)) => { - todo!() - } + Poll::Ready(Err(_e)) => todo!(), Poll::Pending => { cx.waker().wake_by_ref(); return Poll::Pending; From 47a690044af077946e3bd2eb7292cea4f5159b06 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:35:25 -0400 Subject: [PATCH 128/135] Remove log and env_log depndencies --- Cargo.toml | 2 -- benches/pipe.rs | 6 ++-- benches/throughput.rs | 8 +++-- examples/replication.rs | 65 ++++++++++++++--------------------------- src/test_utils.rs | 6 ++-- tests/js_interop.rs | 2 ++ 6 files changed, 37 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a651734..173c907 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,6 @@ path = "../core" async-std = { version = "1.12.0", features = ["attributes", "unstable"] } async-compat = "0.2.1" tokio = { version = "1.27.0", features = ["macros", "net", "process", "rt", "rt-multi-thread", "sync", "time"] } -env_logger = "0.7.1" anyhow = "1.0.28" instant = "0.1" criterion = { version = "0.4", features = ["async_std"] } @@ -59,7 +58,6 @@ pretty-bytes = "0.2.2" duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" -log = "0.4" tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } diff --git a/benches/pipe.rs b/benches/pipe.rs index 6f2a4b8..9f87d84 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use async_std::task; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::{ @@ -5,17 +7,17 @@ use futures::{ stream::StreamExt, }; use hypercore_protocol::{schema::*, Channel, Duplex, Event, Message, Protocol, ProtocolBuilder}; -use log::*; use pretty_bytes::converter::convert as pretty_bytes; use sluice::pipe::pipe; use std::{io::Result, time::Instant}; +use tracing::{debug, error}; const COUNT: u64 = 1000; const SIZE: u64 = 100; const CONNS: u64 = 10; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let mut group = c.benchmark_group("pipe"); group.sample_size(10); group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS)); diff --git a/benches/throughput.rs b/benches/throughput.rs index 1d9c4c0..b19167e 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use async_std::{ net::{Shutdown, TcpListener, TcpStream}, task, @@ -9,8 +11,8 @@ use futures::{ stream::{FuturesUnordered, StreamExt}, }; use hypercore_protocol::{schema::*, Channel, Event, Message, ProtocolBuilder}; -use log::*; use std::time::Instant; +use tracing::{debug, info, trace}; const PORT: usize = 11011; const SIZE: u64 = 1000; @@ -18,7 +20,7 @@ const COUNT: u64 = 200; const CLIENTS: usize = 1; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let address = format!("localhost:{}", PORT); let mut group = c.benchmark_group("throughput"); @@ -67,7 +69,7 @@ criterion_main!(server_benches); async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { let listener = TcpListener::bind(&address).await.unwrap(); - log::info!("listening on {}", listener.local_addr().unwrap()); + info!("listening on {}", listener.local_addr().unwrap()); let (kill_tx, mut kill_rx) = futures::channel::oneshot::channel(); task::spawn(async move { let mut incoming = listener.incoming(); diff --git a/examples/replication.rs b/examples/replication.rs index 85d0d11..35e2908 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use anyhow::Result; use async_std::{ net::{TcpListener, TcpStream}, @@ -10,13 +12,13 @@ use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use std::{collections::HashMap, convert::TryInto, env, fmt::Debug, sync::OnceLock}; -use tracing::{error, info}; +use std::{collections::HashMap, convert::TryInto, env, fmt::Debug}; +use tracing::{error, info, instrument}; use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; fn main() { - log(); + test_utils::log(); if env::args().count() < 3 { usage(); } @@ -62,12 +64,11 @@ fn main() { hypercore_store.add(hypercore_wrapper); let hypercore_store = Arc::new(hypercore_store); - let result = match mode.as_ref() { + let _ = match mode.as_ref() { "server" => tcp_server(address, onconnection, hypercore_store).await, "client" => tcp_client(address, onconnection, hypercore_store).await, _ => panic!("{:?}", usage()), }; - log_if_error(&result); }); } @@ -81,6 +82,7 @@ fn usage() { // or once when connected (if client). // Unfortunately, everything that touches the hypercore_store or a hypercore has to be generic // at the moment. +#[instrument(skip_all, ret)] async fn onconnection( stream: TcpStream, is_initiator: bool, @@ -123,17 +125,17 @@ struct HypercoreStore { hypercores: HashMap>, } impl HypercoreStore { - pub fn new() -> Self { + fn new() -> Self { let hypercores = HashMap::new(); Self { hypercores } } - pub fn add(&mut self, hypercore: HypercoreWrapper) { + fn add(&mut self, hypercore: HypercoreWrapper) { let hdkey = hex::encode(hypercore.discovery_key); self.hypercores.insert(hdkey, Arc::new(hypercore)); } - pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { + fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { let hdkey = hex::encode(discovery_key); self.hypercores.get(&hdkey) } @@ -148,7 +150,7 @@ struct HypercoreWrapper { } impl HypercoreWrapper { - pub fn from_memory_hypercore(hypercore: Hypercore) -> Self { + fn from_memory_hypercore(hypercore: Hypercore) -> Self { let key = hypercore.key_pair().public.to_bytes(); HypercoreWrapper { key, @@ -157,11 +159,11 @@ impl HypercoreWrapper { } } - pub fn key(&self) -> &[u8; 32] { + fn key(&self) -> &[u8; 32] { &self.key } - pub fn onpeer(&self, mut channel: Channel) { + fn onpeer(&self, mut channel: Channel) { let mut peer_state = PeerState::default(); let mut hypercore = self.hypercore.clone(); task::spawn(async move { @@ -415,32 +417,9 @@ async fn onmessage( Ok(()) } -#[allow(unused)] -pub fn log() { - use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - static START_LOGS: OnceLock<()> = OnceLock::new(); - START_LOGS.get_or_init(|| { - tracing_subscriber::fmt() - .with_target(true) - .with_line_number(true) - // print when instrumented funtion enters - .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) - .with_file(true) - .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable - .without_time() - .init(); - }); -} - -/// Log a result if it's an error. -pub fn log_if_error(result: &Result<()>) { - if let Err(err) = result.as_ref() { - log::error!("error: {}", err); - } -} - /// A simple async TCP server that calls an async function for each incoming connection. -pub async fn tcp_server( +#[instrument(skip_all, ret)] +async fn tcp_server( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, context: C, @@ -450,22 +429,22 @@ where C: Clone + Send + 'static, { let listener = TcpListener::bind(&address).await?; - log::info!("listening on {}", listener.local_addr()?); + tracing::info!("listening on {}", listener.local_addr()?); let mut incoming = listener.incoming(); while let Some(Ok(stream)) = incoming.next().await { let context = context.clone(); let peer_addr = stream.peer_addr().unwrap(); - log::info!("new connection from {}", peer_addr); + tracing::info!("new connection from {}", peer_addr); task::spawn(async move { - let result = onconnection(stream, false, context).await; - log_if_error(&result); - log::info!("connection closed from {}", peer_addr); + let _ = onconnection(stream, false, context).await; + tracing::info!("connection closed from {}", peer_addr); }); } Ok(()) } /// A simple async TCP client that calls an async function when connected. +#[instrument(skip_all, ret)] pub async fn tcp_client( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, @@ -475,8 +454,8 @@ where F: Future> + Send, C: Clone + Send + 'static, { - log::info!("attempting connection to {address}"); + tracing::info!("attempting connection to {address}"); let stream = TcpStream::connect(&address).await?; - log::info!("connected to {address}"); + tracing::info!("connected to {address}"); onconnection(stream, true, context).await } diff --git a/src/test_utils.rs b/src/test_utils.rs index 8a4dd74..b1fc32d 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,13 +1,13 @@ +#![allow(dead_code)] use std::{ io::{self, ErrorKind}, pin::Pin, task::{Context, Poll}, }; -//use async_channel::{unbounded, Receiver, io::Error, Sender}; use futures::{ channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender}, - Sink, SinkExt, Stream, StreamExt, + Sink, Stream, StreamExt, }; #[derive(Debug)] @@ -99,6 +99,7 @@ pub(crate) fn log() { #[tokio::test] async fn way_one() { + use futures::SinkExt; let mut a = Io::default(); let _ = a.send(b"hello".into()).await; let Some(res) = a.next().await else { panic!() }; @@ -107,6 +108,7 @@ async fn way_one() { #[tokio::test] async fn split() { + use futures::SinkExt; let (mut left, mut right) = (TwoWay::default()).split_sides(); left.send(b"hello".to_vec()).await.unwrap(); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 619841b..e115288 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use _util::wait_for_localhost_port; use anyhow::Result; use futures::Future; From 13eeee77eb4c4fd52cd437a4fbd81de764a28e66 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:43:49 -0400 Subject: [PATCH 129/135] clean up test_utils --- src/noise.rs | 4 +++- src/test_utils.rs | 44 +++++++++++++++++++++++--------------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 4115cbe..e18f697 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -318,7 +318,9 @@ where } Ok(vec![]) } - Step::Established((..)) => todo!(), + Step::Established(_) => { + unreachable!("`handle_setup_message` should never be called when Step::Established") + } } } #[instrument(skip_all)] diff --git a/src/test_utils.rs b/src/test_utils.rs index b1fc32d..d8be13a 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -97,27 +97,6 @@ pub(crate) fn log() { }); } -#[tokio::test] -async fn way_one() { - use futures::SinkExt; - let mut a = Io::default(); - let _ = a.send(b"hello".into()).await; - let Some(res) = a.next().await else { panic!() }; - assert_eq!(res, b"hello"); -} - -#[tokio::test] -async fn split() { - use futures::SinkExt; - let (mut left, mut right) = (TwoWay::default()).split_sides(); - - left.send(b"hello".to_vec()).await.unwrap(); - let Some(res) = right.next().await else { - panic!(); - }; - assert_eq!(res, b"hello"); -} - pub(crate) struct Moo { receiver: Rx, sender: Tx, @@ -199,3 +178,26 @@ pub(crate) fn create_result_connected() -> ( let b = Moo::from(result_channel()); a.connect(b) } + +#[cfg(test)] +mod test_test_utils { + use super::*; + use futures::SinkExt; + #[tokio::test] + async fn way_one() { + let mut a = Io::default(); + let _ = a.send(b"hello".into()).await; + let Some(res) = a.next().await else { panic!() }; + assert_eq!(res, b"hello"); + } + + #[tokio::test] + async fn split() { + let (mut left, mut right) = (TwoWay::default()).split_sides(); + left.send(b"hello".to_vec()).await.unwrap(); + let Some(res) = right.next().await else { + panic!(); + }; + assert_eq!(res, b"hello"); + } +} From a40c4ff78139613a01fde5c996701acda434034a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:55:13 -0400 Subject: [PATCH 130/135] rename js_interop_tests to js_tests It was too much to type --- .github/workflows/ci.yml | 6 +++--- Cargo.toml | 4 ++-- README.md | 4 ++-- tests/js_interop.rs | 16 ++++++++-------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b58d59..3842f94 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,9 +33,9 @@ jobs: cargo check --all-targets cargo check --all-targets --no-default-features --features tokio cargo check --all-targets --no-default-features --features async-std - cargo test --features js_interop_tests - cargo test --no-default-features --features js_interop_tests,tokio - cargo test --no-default-features --features js_interop_tests,async-std + cargo test --features js_tests + cargo test --no-default-features --features js_tests,tokio + cargo test --no-default-features --features js_tests,async-std cargo test --benches build-extra: diff --git a/Cargo.toml b/Cargo.toml index 173c907..df292c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,8 +73,8 @@ tokio = ["hypercore/tokio"] async-std = ["hypercore/async-std"] # Used only in interoperability tests under tests/js-interop which use the javascript version of hypercore # to verify that this crate works. To run them, use: -# cargo test --features js_interop_tests -js_interop_tests = [] +# cargo test --features js_tests +js_tests = [] [profile.bench] # debug = true diff --git a/README.md b/README.md index b8ed180..fada9df 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ node examples-nodejs/run.js node ## Development -To test interoperability with Javascript, enable the `js_interop_tests` feature: +To test interoperability with Javascript, enable the `js_tests` feature: ```bash -cargo test --features js_interop_tests +cargo test --features js_tests ``` Run benches with: diff --git a/tests/js_interop.rs b/tests/js_interop.rs index e115288..935382c 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -51,56 +51,56 @@ const TEST_SET_CLIENT_WRITER: &str = "cw"; const TEST_SET_SIMPLE: &str = "simple"; #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncns_server_writer() -> Result<()> { ncns(true, 8101).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncns_client_writer() -> Result<()> { ncns(false, 8102).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn rcns_server_writer() -> Result<()> { rcns(true, 8103).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn rcns_client_writer() -> Result<()> { rcns(false, 8104).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncrs_server_writer() -> Result<()> { ncrs(true, 8105).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncrs_client_writer() -> Result<()> { ncrs(false, 8106).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn rcrs_server_writer() -> Result<()> { rcrs(true, 8107).await?; Ok(()) } #[tokio::test] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +//#[cfg_attr(not(feature = "js_tests"), ignore)] //#[ignore] // FIXME this tests hangs sporadically async fn rcrs_client_writer() -> Result<()> { rcrs(false, 8108).await?; From a3f3e6e2cb144014c9ad7d59d8d05bf930f7f9bf Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 14:12:03 -0400 Subject: [PATCH 131/135] refactor tsets --- src/test_utils.rs | 10 +++++----- tests/js_interop.rs | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index d8be13a..c2bba98 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -180,12 +180,12 @@ pub(crate) fn create_result_connected() -> ( } #[cfg(test)] -mod test_test_utils { - use super::*; - use futures::SinkExt; +mod test { + #![allow(unused_imports)] // test's within tests confused clippy + use futures::{SinkExt, StreamExt}; #[tokio::test] async fn way_one() { - let mut a = Io::default(); + let mut a = super::Io::default(); let _ = a.send(b"hello".into()).await; let Some(res) = a.next().await else { panic!() }; assert_eq!(res, b"hello"); @@ -193,7 +193,7 @@ mod test_test_utils { #[tokio::test] async fn split() { - let (mut left, mut right) = (TwoWay::default()).split_sides(); + let (mut left, mut right) = (super::TwoWay::default()).split_sides(); left.send(b"hello".to_vec()).await.unwrap(); let Some(res) = right.next().await else { panic!(); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 935382c..05174f6 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,7 +1,11 @@ +pub mod _util; #[path = "../src/test_utils.rs"] mod test_utils; + use _util::wait_for_localhost_port; use anyhow::Result; +#[cfg(feature = "tokio")] +use async_compat::CompatExt; use futures::Future; use futures_lite::stream::StreamExt; use hypercore::{ @@ -14,9 +18,6 @@ use std::{ path::Path, sync::{Arc, Once}, }; - -#[cfg(feature = "tokio")] -use async_compat::CompatExt; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, @@ -29,7 +30,6 @@ use tokio::{ use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; -pub mod _util; mod js; use js::{cleanup, install, js_run_client, js_start_server, prepare_test_set}; @@ -40,6 +40,7 @@ fn init() { cleanup(); install(); }); + test_utils::log(); } const TEST_SET_NODE_CLIENT_NODE_SERVER: &str = "ncns"; From 5002b60c8bfe8c3a04445d6db0b435dd30ed3e9c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 16:07:48 -0400 Subject: [PATCH 132/135] RMME reset --- src/mqueue.rs | 13 +++++++++++-- src/noise.rs | 5 ++--- src/test_utils.rs | 5 +++-- tests/js_interop.rs | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index bb92824..9a2d91a 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -10,7 +10,7 @@ use std::{ use compact_encoding::CompactEncoding as _; use futures::{Sink, Stream}; -use tracing::{error, instrument}; +use tracing::{error, info, instrument}; use crate::{message::ChannelMessage, noise::EncryptionInfo, NoiseEvent}; @@ -119,9 +119,18 @@ impl + Sink> + Send + Unpin + 'static> Str { type Item = MqueueEvent; + #[instrument(skip_all, ret)] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let _ = self.poll_outbound(cx); - self.poll_inbound(cx) + match self.poll_inbound(cx) { + Poll::Ready(Some(MqueueEvent::Message(Ok(x)))) => { + for m in x.iter() { + info!("RX ChannelMessage::{m}"); + } + Poll::Ready(Some(MqueueEvent::Message(Ok(x)))) + } + x => x, + } } } diff --git a/src/noise.rs b/src/noise.rs index e18f697..2c6e001 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -110,7 +110,7 @@ where } /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. - #[instrument(skip_all, ret)] + #[instrument(name = "did_as_much_as_possible", skip_all, ret)] fn did_as_much_as_possible(&mut self, cx: &mut Context<'_>) -> bool { // No incoming encrypted messages available. self.poll_incomming_encrypted_messages(cx).is_pending() @@ -124,8 +124,7 @@ where /// Handle all message throughput. Sends, encrypts and decrypts messages /// Returns `true` `step` is already [`Step::Established`]. - #[allow(clippy::too_many_arguments)] - #[instrument(skip_all, ret)] + #[instrument(name = "poll_message_throughput", skip_all, ret)] fn poll_message_throughput(&mut self, cx: &mut Context<'_>) -> bool { self.poll_outgoing_encrypted_messages(cx); let _ = self.poll_incomming_encrypted_messages(cx); diff --git a/src/test_utils.rs b/src/test_utils.rs index c2bba98..2e5e994 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -86,9 +86,10 @@ pub(crate) fn log() { .with_targets(true) .with_bracketed_fields(true) .with_indent_lines(true) - .with_span_modes(true) .with_thread_ids(false) - .with_thread_names(false); + .with_thread_names(true) + //.with_span_modes(true) + ; tracing_subscriber::registry() .with(env_filter) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 05174f6..2b2ee38 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -27,6 +27,7 @@ use tokio::{ task, time::sleep, }; +use tracing::instrument; use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; @@ -186,6 +187,7 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { &result_path, ) .await?; + dbg!(); assert_result(result_path, item_count, item_size, data_char).await?; drop(server); @@ -330,11 +332,15 @@ async fn run_client( data_path: &str, result_path: &str, ) -> Result<()> { + dbg!(); let hypercore = if is_writer { + dbg!(); create_writer_hypercore(data_count, data_size, data_char, data_path).await? } else { + dbg!(); create_reader_hypercore(data_path).await? }; + dbg!(); let hypercore_wrapper = HypercoreWrapper::from_disk_hypercore( hypercore, if is_writer { @@ -343,7 +349,9 @@ async fn run_client( Some(result_path.to_string()) }, ); + dbg!(); tcp_client(port, on_replication_connection, Arc::new(hypercore_wrapper)).await?; + dbg!(); Ok(()) } @@ -433,21 +441,26 @@ pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { } #[cfg(feature = "tokio")] +#[instrument(skip_all)] async fn on_replication_connection( stream: TcpStream, is_initiator: bool, hypercore: Arc, ) -> Result<()> { + use tracing::info; + let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); while let Some(event) = protocol.next().await { let event = event?; match event { Event::Handshake(_) => { + info!("Event::Handshake"); if is_initiator { protocol.open(*hypercore.key()).await?; } } Event::DiscoveryKey(dkey) => { + info!("Event::DiscoveryKey"); if hypercore.discovery_key == dkey { protocol.open(*hypercore.key()).await?; } else { @@ -455,6 +468,7 @@ async fn on_replication_connection( } } Event::Channel(channel) => { + info!("Event::Channel"); hypercore.on_replication_peer(channel); } Event::Close(_dkey) => { From 083dfa9604dafc3044f1ddc0ba24d043add58b12 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 22 May 2025 19:41:16 -0400 Subject: [PATCH 133/135] docs & logging --- src/framing.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 5bc8297..760b7c6 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -9,7 +9,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{debug, error, info, instrument, trace, warn}; +use tracing::{error, info, instrument, trace, warn}; const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; @@ -21,7 +21,7 @@ pub struct Uint24LELengthPrefixedFraming { to_stream: Vec, /// Data from the `Sink` interface to be written out to [`Self::io`]'s [`AsyncWrite`] interface. from_sink: VecDeque>, - /// The index in [`Self::to_stream`] of the last byte that was to the [`Stream`]. + /// The index in [`Self::to_stream`] of the last byte that was sent to the [`Stream`]. last_out_idx: usize, /// The index in [`Self::to_stream`] of the last byte that was read from [`Self::io`]'s /// [`AsyncRead`] @@ -80,7 +80,7 @@ where step, .. } = self.get_mut(); - debug!( + trace!( "Try to AsyncRead up to (buff_size[{}] - last_data_idx[{}]) = [{}]", to_stream.len(), *last_data_idx, @@ -92,7 +92,7 @@ where Poll::Pending => 0, }; // TODO handle if to_stream is full - debug!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); *last_data_idx += n_bytes_read; // grow buffer if it's full if *last_data_idx == to_stream.len() - 1 { @@ -121,7 +121,7 @@ where if let Step::Body { start, end } = step { let end = *end as usize; if end <= *last_data_idx { - debug!(frame_size = end - *start, "Frame ready"); + trace!(frame_size = end - *start, "Frame ready"); let out = to_stream[*start..end].to_vec(); *step = Step::Header; @@ -173,7 +173,7 @@ where from_sink.push_front(msg[n..].to_vec()); warn!("only wrote [{n} / {}] bytes of message", msg.len()); } - debug!("flushed whole message of N=[{n}] bytes"); + trace!("flushed whole message of N=[{n}] bytes"); } Poll::Ready(Err(e)) => { error!("Error flushing data"); @@ -181,7 +181,7 @@ where } } } else { - debug!("No more messages to flush"); + trace!("No more messages to flush"); return Poll::Ready(Ok(())); } } From f1e0eb37dc61ed01c4e11a16e8fa9b81f063a3bf Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 22 May 2025 19:47:27 -0400 Subject: [PATCH 134/135] Allow unused for test utils --- tests/_util.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/_util.rs b/tests/_util.rs index b6f1d22..78c89e4 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -86,6 +86,7 @@ where }) } +#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { const RETRY_TIMEOUT: u64 = 100_u64; const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; From 8ef2e2f8b716c74baf1f846a20be6558bc5120b1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 22 May 2025 20:05:12 -0400 Subject: [PATCH 135/135] Don't open 2 channels with same peer in tests --- tests/js_interop.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 2b2ee38..d81a812 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -450,25 +450,28 @@ async fn on_replication_connection( use tracing::info; let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); + let mut channel_opened = false; while let Some(event) = protocol.next().await { let event = event?; match event { Event::Handshake(_) => { info!("Event::Handshake"); - if is_initiator { + if is_initiator && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } } Event::DiscoveryKey(dkey) => { info!("Event::DiscoveryKey"); - if hypercore.discovery_key == dkey { + if hypercore.discovery_key == dkey && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } else { panic!("Invalid discovery key"); } } Event::Channel(channel) => { - info!("Event::Channel"); + info!("Event::Channel is_initiator = {is_initiator}"); hypercore.on_replication_peer(channel); } Event::Close(_dkey) => {