Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve performance of Exchange #104

Merged
merged 10 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

276 changes: 125 additions & 151 deletions node/src/exchange.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::sync::Arc;
use std::task::{Context, Poll};
use tracing::warn;

use async_trait::async_trait;
use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
Expand All @@ -15,7 +16,8 @@ use libp2p::{
},
Multiaddr, PeerId, StreamProtocol,
};
use prost::{length_delimiter_len, Message};
use prost::Message;
use tracing::debug;
use tracing::instrument;

mod client;
Expand All @@ -33,8 +35,6 @@ use crate::utils::{protocol_id, OneshotResultSender};
const REQUEST_SIZE_MAXIMUM: usize = 1024;
/// Max response size in bytes
const RESPONSE_SIZE_MAXIMUM: usize = 10 * 1024 * 1024;
/// Maximum length of the protobuf length delimiter in bytes
const PROTOBUF_MAX_LENGTH_DELIMITER_LEN: usize = 10;

type RequestType = HeaderRequest;
type ResponseType = Vec<HeaderResponse>;
Expand Down Expand Up @@ -278,96 +278,6 @@ where
#[derive(Clone, Copy, Debug, Default)]
pub(crate) struct HeaderCodec;

#[derive(Debug, thiserror::Error, PartialEq)]
pub enum ReadHeaderError {
#[error("stream closed while trying to get header length")]
StreamClosed,
#[error("varint overflow")]
VarintOverflow,
#[error("request too large: {0}")]
ResponseTooLarge(usize),
}

impl HeaderCodec {
async fn read_message<R, T>(
reader: &mut R,
buf: &mut Vec<u8>,
max_len: usize,
) -> io::Result<Option<T>>
where
R: AsyncRead + Unpin + Send,
T: Message + Default,
{
let mut read_len = buf.len(); // buf might have data from previous iterations

if read_len < 512 {
// resize to increase the chance of reading all the data in one go
buf.resize(512, 0)
}

let data_len = loop {
if let Ok(len) = prost::decode_length_delimiter(&buf[..read_len]) {
break len;
}

if read_len >= PROTOBUF_MAX_LENGTH_DELIMITER_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
ReadHeaderError::VarintOverflow,
));
}

match reader.read(&mut buf[read_len..]).await? {
0 => {
// check if we're between Messages, in which case it's ok to stop
if read_len == 0 {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
ReadHeaderError::StreamClosed,
));
}
}
n => read_len += n,
};
};

// truncate buffer to the data that was actually read_len
buf.truncate(read_len);

let length_delimiter_len = length_delimiter_len(data_len);
let single_message_len = length_delimiter_len + data_len;

if data_len > max_len {
return Err(io::Error::new(
io::ErrorKind::Other,
ReadHeaderError::ResponseTooLarge(data_len),
));
}

if read_len < single_message_len {
// we need to read_len more
buf.resize(single_message_len, 0);
reader
.read_exact(&mut buf[read_len..single_message_len])
.await?;
}

let val = T::decode(&buf[length_delimiter_len..single_message_len])?;

// we've read_len past one message when trying to get length delimiter, need to handle
// partially read_len data in the buffer
if single_message_len < read_len {
buf.drain(..single_message_len);
} else {
buf.clear();
}

Ok(Some(val))
}
}

#[async_trait]
impl Codec for HeaderCodec {
type Protocol = StreamProtocol;
Expand All @@ -378,13 +288,14 @@ impl Codec for HeaderCodec {
where
T: AsyncRead + Unpin + Send,
{
let mut buf = Vec::new();
let data = read_up_to(io, REQUEST_SIZE_MAXIMUM).await?;

if data.len() >= REQUEST_SIZE_MAXIMUM {
debug!("Message filled the whole buffer (len: {})", data.len());
}

HeaderCodec::read_message(io, &mut buf, REQUEST_SIZE_MAXIMUM)
.await?
.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, ReadHeaderError::StreamClosed)
})
parse_header_request(&data)
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "invalid request"))
}

async fn read_response<T>(
Expand All @@ -395,18 +306,25 @@ impl Codec for HeaderCodec {
where
T: AsyncRead + Unpin + Send,
{
let mut messages = vec![];
let mut buf = Vec::new();
loop {
match HeaderCodec::read_message(io, &mut buf, RESPONSE_SIZE_MAXIMUM).await {
Ok(None) => break,
Ok(Some(msg)) => messages.push(msg),
Err(e) => {
return Err(e);
}
};
let data = read_up_to(io, RESPONSE_SIZE_MAXIMUM).await?;

if data.len() >= RESPONSE_SIZE_MAXIMUM {
debug!("Message filled the whole buffer (len: {})", data.len());
}

let mut data = &data[..];
let mut msgs = Vec::new();

while let Some((header, rest)) = parse_header_response(data) {
msgs.push(header);
data = rest;
}

if msgs.is_empty() {
return Err(io::Error::new(io::ErrorKind::Other, "invalid response"));
}
Ok(messages)

Ok(msgs)
}

async fn write_request<T>(
Expand All @@ -418,9 +336,11 @@ impl Codec for HeaderCodec {
where
T: AsyncWrite + Unpin + Send,
{
let data = req.encode_length_delimited_to_vec();
let mut buf = Vec::with_capacity(REQUEST_SIZE_MAXIMUM);

let _ = req.encode_length_delimited(&mut buf);

io.write_all(data.as_ref()).await?;
io.write_all(&buf).await?;

Ok(())
}
Expand All @@ -434,19 +354,96 @@ impl Codec for HeaderCodec {
where
T: AsyncWrite + Unpin + Send,
{
for resp in resps {
let data = resp.encode_length_delimited_to_vec();
let mut buf = Vec::with_capacity(RESPONSE_SIZE_MAXIMUM);

io.write_all(&data).await?;
for resp in resps {
if resp.encode_length_delimited(&mut buf).is_err() {
// Error on encoding means the buffer is full.
// We will send a partial response back.
debug!("Sending partial response");
break;
}
}

io.write_all(&buf).await?;

Ok(())
}
}

async fn read_up_to<T>(io: &mut T, limit: usize) -> io::Result<Vec<u8>>
where
T: AsyncRead + Unpin + Send,
{
let mut buf = vec![0u8; limit];
let mut read_len = 0;

loop {
if read_len == buf.len() {
// No empty space. Buffer is full.
break;
}

let len = io.read(&mut buf[read_len..]).await?;

if len == 0 {
// EOF
break;
}

read_len += len;
}

buf.truncate(read_len);

Ok(buf)
}

fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
if buf.is_empty() {
return None;
}

let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
return None;
};

Some((len, buf))
}

fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
let (len, rest) = parse_delimiter(buf)?;

if rest.len() < len {
debug!("Message is incomplete: {len}");
return None;
}

let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
return None;
};

Some((msg, &rest[len..]))
}

fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
let (len, rest) = parse_delimiter(buf)?;

if rest.len() < len {
debug!("Message is incomplete: {len}");
return None;
}

let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
return None;
};

Some(msg)
}

#[cfg(test)]
mod tests {
use super::{HeaderCodec, ReadHeaderError, REQUEST_SIZE_MAXIMUM, RESPONSE_SIZE_MAXIMUM};
use super::*;
use bytes::BytesMut;
use celestia_proto::p2p::pb::{header_request::Data, HeaderRequest, HeaderResponse};
use futures::io::{AsyncRead, AsyncReadExt, Cursor, Error};
Expand Down Expand Up @@ -512,8 +509,7 @@ mod tests {
async fn test_decode_header_request_too_large() {
let too_long_message_len = REQUEST_SIZE_MAXIMUM + 1;
let mut length_delimiter_buffer = BytesMut::new();
prost::encode_length_delimiter(REQUEST_SIZE_MAXIMUM + 1, &mut length_delimiter_buffer)
.unwrap();
prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
let mut reader = Cursor::new(length_delimiter_buffer);

let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
Expand All @@ -525,15 +521,6 @@ mod tests {
.expect_err("expected error for too large request");

assert_eq!(decoding_error.kind(), ErrorKind::Other);
let inner_err = decoding_error
.get_ref()
.unwrap()
.downcast_ref::<ReadHeaderError>()
.unwrap();
assert_eq!(
inner_err,
&ReadHeaderError::ResponseTooLarge(too_long_message_len)
);
}

#[tokio::test]
Expand All @@ -552,19 +539,10 @@ mod tests {
.expect_err("expected error for too large request");

assert_eq!(decoding_error.kind(), ErrorKind::Other);
let inner_err = decoding_error
.get_ref()
.unwrap()
.downcast_ref::<ReadHeaderError>()
.unwrap();
assert_eq!(
inner_err,
&ReadHeaderError::ResponseTooLarge(too_long_message_len)
);
}

#[tokio::test]
async fn test_invalid_varint() {
#[test]
fn test_invalid_varint() {
// 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed
// for length delimiter
let varint = [
Expand All @@ -580,21 +558,17 @@ mod tests {
0b1000_0000,
0b0000_0001,
];
let mut reader = Cursor::new(varint);

let mut buf = vec![];
let decoding_error =
HeaderCodec::read_message::<_, HeaderRequest>(&mut reader, &mut buf, 512)
.await
.expect_err("expected varint overflow");
assert_eq!(parse_delimiter(&varint), None);
}

assert_eq!(decoding_error.kind(), ErrorKind::InvalidData);
let inner_err = decoding_error
.get_ref()
.unwrap()
.downcast_ref::<ReadHeaderError>()
.unwrap();
assert_eq!(inner_err, &ReadHeaderError::VarintOverflow);
#[test]
fn parse_trailing_zero_varint() {
let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));

let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
assert!(matches!(parse_delimiter(&varint), Some((0, []))));
}

#[tokio::test]
Expand Down
Loading
Loading