diff --git a/Cargo.lock b/Cargo.lock index 02bdde00..aeffb6c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -505,6 +505,7 @@ name = "celestia-node" version = "0.1.0" dependencies = [ "async-trait", + "bytes", "celestia-proto", "celestia-rpc", "celestia-types", diff --git a/node/Cargo.toml b/node/Cargo.toml index 4723a451..b6abc866 100644 --- a/node/Cargo.toml +++ b/node/Cargo.toml @@ -31,5 +31,6 @@ getrandom = { version = "0.2.10", features = ["js"] } wasm-bindgen-futures = "0.4.37" [dev-dependencies] +bytes = "1.4.0" celestia-rpc = { workspace = true } dotenvy = "0.15.7" diff --git a/node/src/exchange.rs b/node/src/exchange.rs index 70342366..e151ecab 100644 --- a/node/src/exchange.rs +++ b/node/src/exchange.rs @@ -5,17 +5,20 @@ use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse}; use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use libp2p::request_response::{self, Codec, ProtocolSupport}; use libp2p::StreamProtocol; +use prost::length_delimiter_len; use prost::Message; use crate::utils::stream_protocol_id; /// Max request size in bytes -const REQUEST_SIZE_MAXIMUM: u64 = 1024; +const REQUEST_SIZE_MAXIMUM: usize = 1024; /// Max response size in bytes -const RESPONSE_SIZE_MAXIMUM: u64 = 10 * 1024 * 1024; +const RESPONSE_SIZE_MAXIMUM: usize = 10 * 1024 * 1024; +/// Maximum length of the protobuf length delimiter in bytes +const PROTOBUF_MAX_LENGTH_DELIMITER_LEN: usize = 10; pub type Behaviour = request_response::Behaviour; -pub type Event = request_response::Event; +pub type Event = request_response::Event>; /// Create a new [`Behaviour`] pub fn new_behaviour(network: &str) -> Behaviour { @@ -31,21 +34,113 @@ pub fn new_behaviour(network: &str) -> Behaviour { #[derive(Clone, Copy, Debug, Default)] pub struct HeaderCodec; +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum ReadHeaderError { + #[error("stream closed while trying to get header length")] + StreamClosed, + #[error("varint overflow")] + VarintOverflow, + #[error("request too large: {0}")] + ResponseTooLarge(usize), +} + +impl HeaderCodec { + async fn read_message( + reader: &mut R, + buf: &mut Vec, + max_len: usize, + ) -> io::Result> + where + R: AsyncRead + Unpin + Send, + T: Message + Default, + { + let mut read_len = buf.len(); // buf might have data from previous iterations + + if read_len < 512 { + // resize to increase the chance of reading all the data in one go + buf.resize(512, 0) + } + + let data_len = loop { + if let Ok(len) = prost::decode_length_delimiter(&buf[..read_len]) { + break len; + } + + if read_len >= PROTOBUF_MAX_LENGTH_DELIMITER_LEN { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + ReadHeaderError::VarintOverflow, + )); + } + + match reader.read(&mut buf[read_len..]).await? { + 0 => { + // check if we're between Messages, in which case it's ok to stop + if read_len == 0 { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + ReadHeaderError::StreamClosed, + )); + } + } + n => read_len += n, + }; + }; + + // truncate buffer to the data that was actually read_len + buf.truncate(read_len); + + let length_delimiter_len = length_delimiter_len(data_len); + let single_message_len = length_delimiter_len + data_len; + + if data_len > max_len { + return Err(io::Error::new( + io::ErrorKind::Other, + ReadHeaderError::ResponseTooLarge(data_len), + )); + } + + if read_len < single_message_len { + // we need to read_len more + buf.resize(single_message_len, 0); + reader + .read_exact(&mut buf[read_len..single_message_len]) + .await?; + } + + let val = T::decode(&buf[length_delimiter_len..single_message_len])?; + + // we've read_len past one message when trying to get length delimiter, need to handle + // partially read_len data in the buffer + if single_message_len < read_len { + buf.drain(..single_message_len); + } else { + buf.clear(); + } + + Ok(Some(val)) + } +} + #[async_trait] impl Codec for HeaderCodec { type Protocol = StreamProtocol; type Request = HeaderRequest; - type Response = HeaderResponse; + type Response = Vec; async fn read_request(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result where T: AsyncRead + Unpin + Send, { - let mut vec = Vec::new(); + let mut buf = Vec::new(); - io.take(REQUEST_SIZE_MAXIMUM).read_to_end(&mut vec).await?; - - Ok(HeaderRequest::decode_length_delimited(&vec[..])?) + HeaderCodec::read_message(io, &mut buf, REQUEST_SIZE_MAXIMUM) + .await? + .ok_or_else(|| { + io::Error::new(io::ErrorKind::UnexpectedEof, ReadHeaderError::StreamClosed) + }) } async fn read_response( @@ -56,11 +151,18 @@ impl Codec for HeaderCodec { where T: AsyncRead + Unpin + Send, { - let mut vec = Vec::new(); - - io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?; - - Ok(HeaderResponse::decode_length_delimited(&vec[..])?) + let mut messages = vec![]; + let mut buf = Vec::new(); + loop { + match HeaderCodec::read_message(io, &mut buf, RESPONSE_SIZE_MAXIMUM).await { + Ok(None) => break, + Ok(Some(msg)) => messages.push(msg), + Err(e) => { + return Err(e); + } + }; + } + Ok(messages) } async fn write_request( @@ -83,15 +185,338 @@ impl Codec for HeaderCodec { &mut self, _: &Self::Protocol, io: &mut T, - resp: Self::Response, + resps: Self::Response, ) -> io::Result<()> where T: AsyncWrite + Unpin + Send, { - let data = resp.encode_length_delimited_to_vec(); + for resp in resps { + let data = resp.encode_length_delimited_to_vec(); - io.write_all(data.as_ref()).await?; + io.write_all(&data).await?; + } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::{HeaderCodec, ReadHeaderError, REQUEST_SIZE_MAXIMUM, RESPONSE_SIZE_MAXIMUM}; + use bytes::BytesMut; + use celestia_proto::p2p::pb::{header_request::Data, HeaderRequest, HeaderResponse}; + use futures::io::{AsyncRead, AsyncReadExt, Cursor, Error}; + use futures::task::{Context, Poll}; + use libp2p::{request_response::Codec, swarm::StreamProtocol}; + use prost::{encode_length_delimiter, Message}; + use std::io::ErrorKind; + use std::pin::Pin; + + #[tokio::test] + async fn test_decode_header_request_empty() { + let header_request = HeaderRequest { + amount: 0, + data: None, + }; + + let encoded_header_request = header_request.encode_length_delimited_to_vec(); + + let mut reader = Cursor::new(encoded_header_request); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + + let decoded_header_request = codec + .read_request(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_request, decoded_header_request); + } + + #[tokio::test] + async fn test_decode_multiple_small_header_response() { + const MSG_COUNT: usize = 10; + let header_response = HeaderResponse { + body: vec![1, 2, 3], + status_code: 1, + }; + + let encoded_header_response = header_response.encode_length_delimited_to_vec(); + + let mut multi_msg = vec![]; + for _ in 0..MSG_COUNT { + multi_msg.extend_from_slice(&encoded_header_response); + } + let mut reader = Cursor::new(multi_msg); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + + let decoded_header_response = codec + .read_response(&stream_protocol, &mut reader) + .await + .unwrap(); + + for decoded_header in decoded_header_response.iter() { + assert_eq!(&header_response, decoded_header); + } + assert_eq!(decoded_header_response.len(), MSG_COUNT); + } + + #[tokio::test] + async fn test_decode_header_request_too_large() { + let too_long_message_len = REQUEST_SIZE_MAXIMUM + 1; + let mut length_delimiter_buffer = BytesMut::new(); + prost::encode_length_delimiter(REQUEST_SIZE_MAXIMUM + 1, &mut length_delimiter_buffer) + .unwrap(); + let mut reader = Cursor::new(length_delimiter_buffer); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + + let decoding_error = codec + .read_request(&stream_protocol, &mut reader) + .await + .expect_err("expected error for too large request"); + + assert_eq!(decoding_error.kind(), ErrorKind::Other); + let inner_err = decoding_error + .get_ref() + .unwrap() + .downcast_ref::() + .unwrap(); + assert_eq!( + inner_err, + &ReadHeaderError::ResponseTooLarge(too_long_message_len) + ); + } + + #[tokio::test] + async fn test_decode_header_response_too_large() { + let too_long_message_len = RESPONSE_SIZE_MAXIMUM + 1; + let mut length_delimiter_buffer = BytesMut::new(); + encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap(); + let mut reader = Cursor::new(length_delimiter_buffer); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + + let decoding_error = codec + .read_response(&stream_protocol, &mut reader) + .await + .expect_err("expected error for too large request"); + + assert_eq!(decoding_error.kind(), ErrorKind::Other); + let inner_err = decoding_error + .get_ref() + .unwrap() + .downcast_ref::() + .unwrap(); + assert_eq!( + inner_err, + &ReadHeaderError::ResponseTooLarge(too_long_message_len) + ); + } + + #[tokio::test] + async fn test_invalid_varint() { + // 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed + // for length delimiter + let varint = [ + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b1000_0000, + 0b0000_0001, + ]; + let mut reader = Cursor::new(varint); + + let mut buf = vec![]; + let decoding_error = + HeaderCodec::read_message::<_, HeaderRequest>(&mut reader, &mut buf, 512) + .await + .expect_err("expected varint overflow"); + + assert_eq!(decoding_error.kind(), ErrorKind::InvalidData); + let inner_err = decoding_error + .get_ref() + .unwrap() + .downcast_ref::() + .unwrap(); + assert_eq!(inner_err, &ReadHeaderError::VarintOverflow); + } + + #[tokio::test] + async fn test_decode_header_double_response_data() { + let mut header_response_buffer = BytesMut::with_capacity(512); + let header_response0 = HeaderResponse { + body: b"9999888877776666555544443333222211110000".to_vec(), + status_code: 1, + }; + let header_response1 = HeaderResponse { + body: b"0000111122223333444455556666777788889999".to_vec(), + status_code: 2, + }; + header_response0 + .encode_length_delimited(&mut header_response_buffer) + .unwrap(); + header_response1 + .encode_length_delimited(&mut header_response_buffer) + .unwrap(); + let mut reader = Cursor::new(header_response_buffer); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + + let decoded_header_response = codec + .read_response(&stream_protocol, &mut reader) + .await + .unwrap(); + assert_eq!(header_response0, decoded_header_response[0]); + assert_eq!(header_response1, decoded_header_response[1]); + } + + #[tokio::test] + async fn test_decode_header_request_chunked_data() { + let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + let header_request = HeaderRequest { + amount: 1, + data: Some(Data::Hash(data.to_vec())), + }; + let encoded_header_request = header_request.encode_length_delimited_to_vec(); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + { + let mut reader = + ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone())); + let decoded_header_request = codec + .read_request(&stream_protocol, &mut reader) + .await + .unwrap(); + assert_eq!(header_request, decoded_header_request); + } + { + let mut reader = + ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone())); + let decoded_header_request = codec + .read_request(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_request, decoded_header_request); + } + { + let mut reader = + ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone())); + let decoded_header_request = codec + .read_request(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_request, decoded_header_request); + } + { + let mut reader = + ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone())); + let decoded_header_request = codec + .read_request(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_request, decoded_header_request); + } + } + + #[tokio::test] + async fn test_decode_header_response_chunked_data() { + let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + let header_response = HeaderResponse { + body: data.to_vec(), + status_code: 2, + }; + let encoded_header_response = header_response.encode_length_delimited_to_vec(); + + let stream_protocol = StreamProtocol::new("/foo/bar/v0.1"); + let mut codec = HeaderCodec {}; + { + let mut reader = + ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone())); + let decoded_header_response = codec + .read_response(&stream_protocol, &mut reader) + .await + .unwrap(); + assert_eq!(header_response, decoded_header_response[0]); + } + { + let mut reader = + ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone())); + let decoded_header_response = codec + .read_response(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_response, decoded_header_response[0]); + } + { + let mut reader = + ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone())); + let decoded_header_response = codec + .read_response(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_response, decoded_header_response[0]); + } + { + let mut reader = + ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone())); + let decoded_header_response = codec + .read_response(&stream_protocol, &mut reader) + .await + .unwrap(); + + assert_eq!(header_response, decoded_header_response[0]); + } + } + + #[tokio::test] + async fn test_chunky_async_read() { + let read_data = "FOO123"; + let cur0 = Cursor::new(read_data); + let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0); + + let mut output_buffer: BytesMut = b"BAR987".as_ref().into(); + + let _ = chunky.read(&mut output_buffer[..]).await.unwrap(); + assert_eq!(output_buffer, b"FOO987".as_ref()); + } + + struct ChunkyAsyncRead { + inner: T, + } + + impl ChunkyAsyncRead { + fn new(inner: T) -> Self { + ChunkyAsyncRead { inner } + } + } + + impl AsyncRead for ChunkyAsyncRead { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let len = buf.len().min(CHUNK_SIZE); + Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len]) + } + } +} diff --git a/node/src/p2p.rs b/node/src/p2p.rs index 7e8cce5c..7f2e1815 100644 --- a/node/src/p2p.rs +++ b/node/src/p2p.rs @@ -232,13 +232,15 @@ impl Worker { response, }, } => { - debug!( - "Response for request: {request_id}, from peer: {peer}, status: {:?}", - response.status_code() - ); - let header = ExtendedHeader::decode(&response.body[..]).unwrap(); - // TODO: Forward response back with one shot channel - debug!("Header: {header:?}"); + for r in response { + debug!( + "Response for request: {request_id}, from peer: {peer}, status: {:?}", + r.status_code() + ); + let header = ExtendedHeader::decode(&r.body[..]).unwrap(); + // TODO: Forward response back with one shot channel + debug!("Header: {header:?}"); + } } _ => debug!("Unhandled header_ex event: {ev:?}"), }