Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the possibility to assume that multistream-select will succeed #1121

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 155 additions & 11 deletions misc/multistream-select/src/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to
//! `multistream-select` for the dialer.

use futures::{future::Either, prelude::*, stream::StreamFuture};
use futures::{prelude::*, stream::StreamFuture};
use crate::protocol::{
Dialer,
DialerFuture,
Expand All @@ -33,10 +33,6 @@ use std::mem;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, ProtocolChoiceError};

/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
/// either sequentially of by considering all protocols in parallel.
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;

/// Helps selecting a protocol amongst the ones supported.
///
/// This function expects a socket and a list of protocols. It uses the `multistream-select`
Expand All @@ -55,11 +51,159 @@ where
I::Item: AsRef<[u8]>
{
let iter = protocols.into_iter();
// We choose between the "serial" and "parallel" strategies based on the number of protocols.
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
Either::A(dialer_select_proto_serial(inner, iter))

// We choose which negotiation strategy to use based on the number of protocols.
if iter.size_hint().0 == 1 && iter.size_hint().1.map(|n| n == 1).unwrap_or(false) {
let protocol = iter.into_iter().next()
.expect("iterator hint returns a minimum length of 1 but its actual length is 0");
DialerSelectFuture::One(dialer_select_proto_one(inner, protocol))
} else if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
DialerSelectFuture::Serial(dialer_select_proto_serial(inner, iter))
} else {
Either::B(dialer_select_proto_parallel(inner, iter))
DialerSelectFuture::Parallel(dialer_select_proto_parallel(inner, iter))
}
}

/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
/// either sequentially of by considering all protocols in parallel.
pub enum DialerSelectFuture<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
/// We try a single protocol.
One(DialerSelectOne<R, I::Item>),
/// We try protocols one by one.
Serial(DialerSelectSeq<R, I>),
/// We try protocols in parallel.
Parallel(DialerSelectPar<R, I>),
}

impl<R, I> Future for DialerSelectFuture<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{
type Item = (I::Item, Negotiated<R>);
type Error = ProtocolChoiceError;

fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
DialerSelectFuture::One(fut) => fut.poll(),
DialerSelectFuture::Serial(fut) => fut.poll(),
DialerSelectFuture::Parallel(fut) => fut.poll(),
}
}
}

/// Tries to negotiate a specific protocol on the substream. Assumes that the negotiation will
/// succeed and immediately returns after starting the negotiation. If it turns out that the
/// negotiation did not succeed, the stream will return an error.
pub fn dialer_select_proto_one<R, P>(inner: R, protocol: P) -> DialerSelectOne<R, P>
where
R: AsyncRead + AsyncWrite,
P: AsRef<[u8]>
{
DialerSelectOne {
inner: DialerSelectOneState::AwaitDialer { dialer_fut: Dialer::dial(inner), proto_name: protocol }
}
}

/// Future returned by `dialer_select_proto_one`.
pub struct DialerSelectOne<R, P>
where
R: AsyncRead + AsyncWrite,
P: AsRef<[u8]>
{
inner: DialerSelectOneState<R, P>
}

enum DialerSelectOneState<R, P>
where
R: AsyncRead + AsyncWrite,
P: AsRef<[u8]>
{
AwaitDialer {
dialer_fut: DialerFuture<R, P>,
proto_name: P
},
SendProtocol {
dialer: Dialer<R, P>,
proto_name: P,
},
FlushProtocol {
dialer: Dialer<R, P>,
proto_name: P,
},
Undefined
}

impl<R, P> Future for DialerSelectOne<R, P>
where
R: AsyncRead + AsyncWrite,
P: AsRef<[u8]> + Clone
{
type Item = (P, Negotiated<R>);
type Error = ProtocolChoiceError;

fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.inner, DialerSelectOneState::Undefined) {
DialerSelectOneState::AwaitDialer { mut dialer_fut, proto_name } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectOneState::AwaitDialer { dialer_fut, proto_name };
return Ok(Async::NotReady)
}
};
self.inner = DialerSelectOneState::SendProtocol {
dialer,
proto_name
}
}
DialerSelectOneState::SendProtocol { mut dialer, proto_name } => {
trace!("sending {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: proto_name.clone()
};
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectOneState::FlushProtocol {
dialer,
proto_name
}
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectOneState::SendProtocol {
dialer,
proto_name
};
return Ok(Async::NotReady)
}
}
}
DialerSelectOneState::FlushProtocol { mut dialer, proto_name } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
let stream = Negotiated::negotiating(dialer, proto_name.as_ref().to_vec());
return Ok(Async::Ready((proto_name, stream)));
}
Async::NotReady => {
self.inner = DialerSelectOneState::FlushProtocol {
dialer,
proto_name,
};
return Ok(Async::NotReady)
}
}
}
DialerSelectOneState::Undefined =>
panic!("DialerSelectOneState::poll called after completion")
}
}
}
}

Expand Down Expand Up @@ -207,7 +351,7 @@ where
ListenerToDialerMessage::ProtocolAck { ref name }
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, Negotiated(r.into_inner()))))
return Ok(Async::Ready((proto_name, Negotiated::finished(r.into_inner()))))
}
ListenerToDialerMessage::NotAvailable => {
let proto_name = protocols.next()
Expand Down Expand Up @@ -423,7 +567,7 @@ where
Some(ListenerToDialerMessage::ProtocolAck { ref name })
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner()))))
return Ok(Async::Ready((proto_name, Negotiated::finished(dialer.into_inner()))))
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
Expand Down
50 changes: 33 additions & 17 deletions misc/multistream-select/src/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,20 @@ enum State {
ReadingData { frame_len: u16 },
}

impl<R, C> LengthDelimited<R, C>
where
R: AsyncWrite,
C: Encoder
{
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
LengthDelimited {
inner: FramedWrite::new(inner, codec),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength
}
impl<R, C> LengthDelimited<R, C> {
/// Grants access to the underlying socket.
///
/// Be extra careful when you use this method in order to not trigger logic errors.
pub fn get_ref(&self) -> &R {
self.inner.get_ref()
}

/// Grants access to the underlying socket.
///
/// This method is only ever intended to be used for writing. Be extra careful when you use it
/// in order to not trigger logic errors.
pub fn get_mut(&mut self) -> &mut R {
self.inner.get_mut()
}

/// Destroys the `LengthDelimited` and returns the underlying socket.
Expand All @@ -81,14 +79,32 @@ where
/// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
/// which `into_inner()` will not panic.
#[inline]
pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner()
}
}

impl<R, C> LengthDelimited<R, C>
where
R: AsyncWrite,
C: Encoder
{
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
LengthDelimited {
inner: FramedWrite::new(inner, codec),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength
}
}
}

impl<R, C> Stream for LengthDelimited<R, C>
where
R: AsyncRead
Expand Down
55 changes: 3 additions & 52 deletions misc/multistream-select/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,60 +70,11 @@ mod dialer_select;
mod error;
mod length_delimited;
mod listener_select;
mod tests;

mod negotiated;
mod protocol;

use futures::prelude::*;
use std::io;
mod tests;

pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
pub use self::error::ProtocolChoiceError;
pub use self::listener_select::{listener_select_proto, ListenerSelectFuture};

/// A stream after it has been negotiated.
pub struct Negotiated<TInner>(pub(crate) TInner);

impl<TInner> io::Read for Negotiated<TInner>
where
TInner: io::Read
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}

impl<TInner> tokio_io::AsyncRead for Negotiated<TInner>
where
TInner: tokio_io::AsyncRead
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}

fn read_buf<B: bytes::BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.0.read_buf(buf)
}
}

impl<TInner> io::Write for Negotiated<TInner>
where
TInner: io::Write
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}

impl<TInner> tokio_io::AsyncWrite for Negotiated<TInner>
where
TInner: tokio_io::AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown()
}
}
pub use self::negotiated::Negotiated;
2 changes: 1 addition & 1 deletion misc/multistream-select/src/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ where
}
};
if let Some(p) = outcome {
return Ok(Async::Ready((p, Negotiated(listener.into_inner()), protocols)))
return Ok(Async::Ready((p, Negotiated::finished(listener.into_inner()), protocols)))
} else {
let stream = listener.into_future();
self.inner = ListenerSelectState::Incoming { stream, protocols }
Expand Down
Loading