Skip to content

Commit

Permalink
Also use it in read_response, add tests for too long msg
Browse files Browse the repository at this point in the history
  • Loading branch information
fl0rek committed Sep 11, 2023
1 parent 5ea68a5 commit 37590d9
Showing 1 changed file with 181 additions and 28 deletions.
209 changes: 181 additions & 28 deletions node/src/exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use prost::Message;
use crate::utils::stream_protocol_id;

/// Max request size in bytes
const REQUEST_SIZE_MAXIMUM: u64 = 1024;
const REQUEST_SIZE_MAXIMUM: usize = 1024;
/// Max response size in bytes
const RESPONSE_SIZE_MAXIMUM: u64 = 10 * 1024 * 1024;
const RESPONSE_SIZE_MAXIMUM: usize = 10 * 1024 * 1024;
/// Maximum length of the protobuf length delimiter in bytes
const PROTOBUF_MAX_LENGTH_DELIMITER_LEN: usize = 10;

Expand All @@ -34,15 +34,11 @@ pub fn new_behaviour(network: &str) -> Behaviour {
#[derive(Clone, Copy, Debug, Default)]
pub struct HeaderCodec;

#[async_trait]
impl Codec for HeaderCodec {
type Protocol = StreamProtocol;
type Request = HeaderRequest;
type Response = HeaderResponse;

async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
impl HeaderCodec {
async fn read_raw_message<T, R>(io: &mut T, max_len: usize) -> io::Result<R>
where
T: AsyncRead + Unpin + Send,
R: Message + Default,
{
let mut buf = bytes::BytesMut::with_capacity(512);
buf.resize(PROTOBUF_MAX_LENGTH_DELIMITER_LEN, 0);
Expand All @@ -64,12 +60,7 @@ impl Codec for HeaderCodec {
}
};

// const value conversion, safe for 32bit and larger usize
if len
> REQUEST_SIZE_MAXIMUM
.try_into()
.expect("usize too small to hold REQUEST_SIZE_MAXIMUM value")
{
if len > max_len {
return Err(io::Error::new(
io::ErrorKind::Other,
ReadHeaderError::RequestTooLarge(len),
Expand All @@ -86,7 +77,21 @@ impl Codec for HeaderCodec {
io.read_exact(&mut buf[read..]).await?;
}

Ok(HeaderRequest::decode(&buf[length_delimiter_len..])?)
Ok(R::decode(&buf[length_delimiter_len..])?)
}
}

#[async_trait]
impl Codec for HeaderCodec {
type Protocol = StreamProtocol;
type Request = HeaderRequest;
type Response = HeaderResponse;

async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send,
{
HeaderCodec::read_raw_message(io, REQUEST_SIZE_MAXIMUM).await
}

async fn read_response<T>(
Expand All @@ -97,11 +102,7 @@ impl Codec for HeaderCodec {
where
T: AsyncRead + Unpin + Send,
{
let mut vec = Vec::new();

io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?;

Ok(HeaderResponse::decode_length_delimited(&vec[..])?)
HeaderCodec::read_raw_message(io, RESPONSE_SIZE_MAXIMUM).await
}

async fn write_request<T>(
Expand Down Expand Up @@ -137,7 +138,7 @@ impl Codec for HeaderCodec {
}
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum ReadHeaderError {
#[error("stream closed while trying to get header length")]
StreamClosed,
Expand All @@ -147,16 +148,18 @@ pub enum ReadHeaderError {

#[cfg(test)]
mod tests {
use super::HeaderCodec;
use celestia_proto::p2p::pb::{header_request::Data, HeaderRequest};
use super::{HeaderCodec, ReadHeaderError, REQUEST_SIZE_MAXIMUM, RESPONSE_SIZE_MAXIMUM};
use bytes::BytesMut;
use celestia_proto::p2p::pb::{header_request::Data, HeaderRequest, HeaderResponse};
use futures::io::{AsyncRead, AsyncReadExt, Cursor, Error};
use futures::task::{Context, Poll};
use libp2p::{request_response::Codec, swarm::StreamProtocol};
use prost::Message;
use prost::{encode_length_delimiter, Message};
use std::io::ErrorKind;
use std::pin::Pin;

#[tokio::test]
async fn test_decode_empty() {
async fn test_decode_header_request_empty() {
let header_request = HeaderRequest {
amount: 0,
data: None,
Expand All @@ -177,7 +180,85 @@ mod tests {
}

#[tokio::test]
async fn test_decode_data() {
async fn test_decode_header_response_empty() {
let header_response = HeaderResponse {
body: vec![],
status_code: 0,
};

let encoded_header_response = header_response.encode_length_delimited_to_vec();
let mut reader = Cursor::new(encoded_header_response);

let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
let mut codec = HeaderCodec {};

let decoded_header_response = codec
.read_response(&stream_protocol, &mut reader)
.await
.unwrap();

assert_eq!(header_response, decoded_header_response);
}

#[tokio::test]
async fn test_decode_header_request_too_large() {
let too_long_message_len = REQUEST_SIZE_MAXIMUM + 1;
let mut length_delimiter_buffer = BytesMut::new();
prost::encode_length_delimiter(REQUEST_SIZE_MAXIMUM + 1, &mut length_delimiter_buffer)
.unwrap();
let mut reader = Cursor::new(length_delimiter_buffer);

let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
let mut codec = HeaderCodec {};

let decoding_error = codec
.read_request(&stream_protocol, &mut reader)
.await
.expect_err("expected error for too large request");

assert_eq!(decoding_error.kind(), ErrorKind::Other);
let inner_err = decoding_error
.get_ref()
.unwrap()
.downcast_ref::<ReadHeaderError>()
.unwrap();
assert_eq!(
inner_err,
&ReadHeaderError::RequestTooLarge(too_long_message_len)
);
}

#[tokio::test]
async fn test_decode_header_response_too_large() {
let too_long_message_len = RESPONSE_SIZE_MAXIMUM + 1;
let mut length_delimiter_buffer = bytes::BytesMut::new();
encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
dbg!(&length_delimiter_buffer);
let mut reader = Cursor::new(length_delimiter_buffer);

let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
let mut codec = HeaderCodec {};

let decoding_error = codec
.read_response(&stream_protocol, &mut reader)
.await
.expect_err("expected error for too large request");
dbg!(&decoding_error);

assert_eq!(decoding_error.kind(), ErrorKind::Other);
let inner_err = decoding_error
.get_ref()
.unwrap()
.downcast_ref::<ReadHeaderError>()
.unwrap();
assert_eq!(
inner_err,
&ReadHeaderError::RequestTooLarge(too_long_message_len)
);
}

#[tokio::test]
async fn test_decode_header_request_data() {
let data = b"9999888877776666555544443333222211110000";
let header_request = HeaderRequest {
amount: 1,
Expand All @@ -197,7 +278,27 @@ mod tests {
}

#[tokio::test]
async fn test_decode_chunked_data() {
async fn test_decode_header_response_data() {
let data = b"9999888877776666555544443333222211110000";
let header_response = HeaderResponse {
body: data.to_vec(),
status_code: 1,
};
let encoded_header_response = header_response.encode_length_delimited_to_vec();
let mut reader = Cursor::new(encoded_header_response.clone());

let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
let mut codec = HeaderCodec {};

let decoded_header_response = codec
.read_response(&stream_protocol, &mut reader)
.await
.unwrap();
assert_eq!(header_response, decoded_header_response);
}

#[tokio::test]
async fn test_decode_header_request_chunked_data() {
let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let header_request = HeaderRequest {
amount: 1,
Expand Down Expand Up @@ -248,6 +349,58 @@ mod tests {
}
}

#[tokio::test]
async fn test_decode_header_response_chunked_data() {
let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
let header_response = HeaderResponse {
body: data.to_vec(),
status_code: 2,
};
let encoded_header_response = header_response.encode_length_delimited_to_vec();

let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
let mut codec = HeaderCodec {};
{
let mut reader =
ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
let decoded_header_response = codec
.read_response(&stream_protocol, &mut reader)
.await
.unwrap();
assert_eq!(header_response, decoded_header_response);
}
{
let mut reader =
ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
let decoded_header_response = codec
.read_response(&stream_protocol, &mut reader)
.await
.unwrap();

assert_eq!(header_response, decoded_header_response);
}
{
let mut reader =
ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
let decoded_header_response = codec
.read_response(&stream_protocol, &mut reader)
.await
.unwrap();

assert_eq!(header_response, decoded_header_response);
}
{
let mut reader =
ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
let decoded_header_response = codec
.read_response(&stream_protocol, &mut reader)
.await
.unwrap();

assert_eq!(header_response, decoded_header_response);
}
}

#[tokio::test]
async fn test_chunky_async_read() {
let read_data = "FOO123";
Expand Down

0 comments on commit 37590d9

Please sign in to comment.