Skip to content

Commit

Permalink
multistream-select: Less allocations. (#800)
Browse files Browse the repository at this point in the history
  • Loading branch information
twittner authored Jan 9, 2019
1 parent aedf9c0 commit f195925
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 365 deletions.
1 change: 1 addition & 0 deletions core/src/upgrade/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ type NameWrapIter<I> =
std::iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>;

/// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`.
#[derive(Clone)]
struct NameWrap<N>(N);

impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> {
Expand Down
2 changes: 1 addition & 1 deletion core/src/upgrade/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl<T: AsRef<[u8]>> ProtocolName for T {
/// or both.
pub trait UpgradeInfo {
/// Opaque type representing a negotiable protocol.
type Info: ProtocolName;
type Info: ProtocolName + Clone;
/// Iterator returned by `protocol_info`.
type InfoIter: IntoIterator<Item = Self::Info>;

Expand Down
292 changes: 186 additions & 106 deletions misc/multistream-select/src/dialer_select.rs

Large diffs are not rendered by default.

161 changes: 65 additions & 96 deletions misc/multistream-select/src/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use bytes::Bytes;
use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec;
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
use tokio_codec::FramedWrite;
use std::{io, u16};
use tokio_codec::{Encoder, FramedWrite};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
use unsigned_varint::decode;

/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
/// and write unsigned-varint prefixed frames.
///
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
pub struct LengthDelimited<I, S> {
pub struct LengthDelimited<R, C> {
// The inner socket where data is pulled from.
inner: FramedWrite<S, UviBytes>,
inner: FramedWrite<R, C>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned.
// Must always contain enough space to read data from `inner`.
internal_buffer: SmallVec<[u8; 64]>,
// Number of bytes within `internal_buffer` that contain valid data.
internal_buffer_pos: usize,
// State of the decoder.
state: State,
marker: PhantomData<I>,
state: State
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand All @@ -52,24 +52,21 @@ enum State {
ReadingData { frame_len: u16 },
}

impl<I, S> LengthDelimited<I, S>
impl<R, C> LengthDelimited<R, C>
where
S: AsyncWrite
R: AsyncWrite,
C: Encoder
{
pub fn new(inner: S) -> LengthDelimited<I, S> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));

pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
LengthDelimited {
inner: FramedWrite::new(inner, encoder),
inner: FramedWrite::new(inner, codec),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength,
marker: PhantomData,
state: State::ReadingLength
}
}

Expand All @@ -85,20 +82,19 @@ where
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
/// which `into_inner()` will not panic.
#[inline]
pub fn into_inner(self) -> S {
pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner()
}
}

impl<I, S> Stream for LengthDelimited<I, S>
impl<R, C> Stream for LengthDelimited<R, C>
where
S: AsyncRead,
I: for<'r> From<&'r [u8]>,
R: AsyncRead
{
type Item = I;
type Error = IoError;
type Item = Bytes;
type Error = io::Error;

fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
Expand All @@ -107,23 +103,21 @@ where

match self.state {
State::ReadingLength => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => {
// EOF
if self.internal_buffer_pos == 0 {
return Ok(Async::Ready(None));
} else {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
return Err(io::ErrorKind::UnexpectedEof.into());
}
}
Ok(n) => {
debug_assert_eq!(n, 1);
self.internal_buffer_pos += n;
}
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
Expand All @@ -136,7 +130,10 @@ where
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
// End of length prefix. Most of the time we will switch to reading data,
// but we need to handle a few corner cases first.
let frame_len = decode_length_prefix(&self.internal_buffer);
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;

if frame_len >= 1 {
self.state = State::ReadingData { frame_len };
Expand All @@ -154,33 +151,22 @@ where
}
} else if self.internal_buffer_pos >= 2 {
// Length prefix is too long. See module doc for info about max frame len.
return Err(IoError::new(
IoErrorKind::InvalidData,
"frame length too long",
));
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
} else {
// Prepare for next read.
self.internal_buffer.push(0);
}
}

State::ReadingData { frame_len } => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
}
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(n) => self.internal_buffer_pos += n,
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
return Err(err);
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady)
}
Err(err) => return Err(err)
};

if self.internal_buffer_pos >= frame_len as usize {
// Finished reading the frame of data.
self.state = State::ReadingLength;
Expand All @@ -196,12 +182,13 @@ where
}
}

impl<I, S> Sink for LengthDelimited<I, S>
impl<R, C> Sink for LengthDelimited<R, C>
where
S: AsyncWrite
R: AsyncWrite,
C: Encoder
{
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;

#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
Expand All @@ -219,42 +206,25 @@ where
}
}

fn decode_length_prefix(buf: &[u8]) -> u16 {
debug_assert!(buf.len() <= 2);

let mut sum = 0u16;

for &byte in buf.iter().rev() {
let byte = byte & 0x7f;
sum <<= 7;
debug_assert!(sum.checked_add(u16::from(byte)).is_some());
sum += u16::from(byte);
}

sum
}

#[cfg(test)]
mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited;
use std::io::Cursor;
use std::io::ErrorKind;
use std::io::{Cursor, ErrorKind};
use unsigned_varint::codec::UviBytes;

#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}

#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
Expand All @@ -266,8 +236,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed
.into_future()
.map(|(m, _)| m)
Expand All @@ -281,24 +250,24 @@ mod tests {
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|(err, _)| err)
.wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::InvalidData),
_ => panic!(),

if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
} else {
panic!()
}
}

#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(
recved,
Expand All @@ -315,36 +284,36 @@ mod tests {
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}

#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}

#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
}
Expand Down
Loading

0 comments on commit f195925

Please sign in to comment.