diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index d9a3e7a1fe2..787ec4c4574 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -19,11 +19,12 @@ // DEALINGS IN THE SOFTWARE. use crate::ConnectedPoint; -use crate::upgrade::{UpgradeInfo, InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; +use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError}; +use crate::upgrade::{ProtocolName, NegotiatedComplete}; use futures::{future::Either, prelude::*}; use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; -use std::mem; +use std::{iter, mem}; use tokio_io::{AsyncRead, AsyncWrite}; /// Applies an upgrade to the inbound and outbound direction of a connection or substream. @@ -46,10 +47,10 @@ where C: AsyncRead + AsyncWrite, U: InboundUpgrade, { - let iter = UpgradeInfoIterWrap(up); + let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); let future = multistream_select::listener_select_proto(conn, iter); InboundUpgradeApply { - inner: InboundUpgradeApplyState::Init { future } + inner: InboundUpgradeApplyState::Init { future, upgrade: up } } } @@ -78,10 +79,11 @@ where enum InboundUpgradeApplyState where C: AsyncRead + AsyncWrite, - U: InboundUpgrade + U: InboundUpgrade, { Init { - future: ListenerSelectFuture, NameWrap>, + future: ListenerSelectFuture>, + upgrade: U, }, Upgrade { future: U::Future @@ -100,16 +102,16 @@ where fn poll(&mut self) -> Poll { loop { match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) { - InboundUpgradeApplyState::Init { mut future } => { - let (info, connection, upgrade) = match future.poll()? { + InboundUpgradeApplyState::Init { mut future, upgrade } => { + let (info, io) = match future.poll()? { Async::Ready(x) => x, Async::NotReady => { - self.inner = InboundUpgradeApplyState::Init { future }; + self.inner = InboundUpgradeApplyState::Init { future, upgrade }; return Ok(Async::NotReady) } }; self.inner = InboundUpgradeApplyState::Upgrade { - future: upgrade.0.upgrade_inbound(connection, info.0) + future: upgrade.upgrade_inbound(io, info.0) }; } InboundUpgradeApplyState::Upgrade { mut future } => { @@ -153,6 +155,11 @@ where future: DialerSelectFuture::IntoIter>>, upgrade: U }, + AwaitNegotiated { + io: NegotiatedComplete, + upgrade: U, + protocol: U::Info + }, Upgrade { future: U::Future }, @@ -178,8 +185,24 @@ where return Ok(Async::NotReady) } }; + self.inner = OutboundUpgradeApplyState::AwaitNegotiated { + io: connection.complete(), + protocol: info.0, + upgrade + }; + } + OutboundUpgradeApplyState::AwaitNegotiated { mut io, protocol, upgrade } => { + let io = match io.poll()? { + Async::NotReady => { + self.inner = OutboundUpgradeApplyState::AwaitNegotiated { + io, protocol, upgrade + }; + return Ok(Async::NotReady) + } + Async::Ready(io) => io + }; self.inner = OutboundUpgradeApplyState::Upgrade { - future: upgrade.upgrade_outbound(connection, info.0) + future: upgrade.upgrade_outbound(io, protocol) }; } OutboundUpgradeApplyState::Upgrade { mut future } => { @@ -205,23 +228,7 @@ where } } -/// Wraps around a `UpgradeInfo` and satisfies the requirement of `listener_select_proto`. -struct UpgradeInfoIterWrap(U); - -impl<'a, U> IntoIterator for &'a UpgradeInfoIterWrap -where - U: UpgradeInfo -{ - type Item = NameWrap; - type IntoIter = NameWrapIter<::IntoIter>; - - fn into_iter(self) -> Self::IntoIter { - self.0.protocol_info().into_iter().map(NameWrap) - } -} - -type NameWrapIter = - std::iter::Map::Item) -> NameWrap<::Item>>; +type NameWrapIter = iter::Map::Item) -> NameWrap<::Item>>; /// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`. #[derive(Clone)] diff --git a/core/src/upgrade/error.rs b/core/src/upgrade/error.rs index 6dd3082e0cb..de0ecadbd51 100644 --- a/core/src/upgrade/error.rs +++ b/core/src/upgrade/error.rs @@ -18,14 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use multistream_select::ProtocolChoiceError; +use multistream_select::NegotiationError; use std::fmt; /// Error that can happen when upgrading a connection or substream to use a protocol. #[derive(Debug)] pub enum UpgradeError { /// Error during the negotiation process. - Select(ProtocolChoiceError), + Select(NegotiationError), /// Error during the post-negotiation handshake. Apply(E), } @@ -73,8 +73,8 @@ where } } -impl From for UpgradeError { - fn from(e: ProtocolChoiceError) -> Self { +impl From for UpgradeError { + fn from(e: NegotiationError) -> Self { UpgradeError::Select(e) } } diff --git a/core/src/upgrade/mod.rs b/core/src/upgrade/mod.rs index 6a40d211969..7403655f513 100644 --- a/core/src/upgrade/mod.rs +++ b/core/src/upgrade/mod.rs @@ -68,7 +68,7 @@ mod transfer; use futures::future::Future; -pub use multistream_select::Negotiated; +pub use multistream_select::{Negotiated, NegotiatedComplete, NegotiationError, ProtocolError}; pub use self::{ apply::{apply, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply}, denied::DeniedUpgrade, diff --git a/core/tests/network_dial_error.rs b/core/tests/network_dial_error.rs index 5484168d1d1..e6e3db0ca1c 100644 --- a/core/tests/network_dial_error.rs +++ b/core/tests/network_dial_error.rs @@ -18,6 +18,8 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +mod util; + use futures::{future, prelude::*}; use libp2p_core::identity; use libp2p_core::multiaddr::multiaddr; @@ -167,6 +169,7 @@ fn deny_incoming_connec() { #[test] fn dial_self() { + // Check whether dialing ourselves correctly fails. // // Dialing the same address we're listening should result in three events: @@ -191,7 +194,13 @@ fn dial_self() { .map_outbound(move |muxer| (peer_id, muxer)) .map_inbound(move |muxer| (peer_id2, muxer)); upgrade::apply(out.stream, upgrade, endpoint) + }) + .and_then(|(peer, mplex), _| { + // Gracefully close the connection to allow protocol + // negotiation to complete. + util::CloseMuxer::new(mplex).map(move |mplex| (peer, mplex)) }); + Network::new(transport, local_public_key.into()) }; @@ -243,7 +252,9 @@ fn dial_self() { assert_eq!(*inc.listen_addr(), address); inc.accept(TestHandler::default().into_node_handler_builder()); }, - Async::Ready(ev) => unreachable!("{:?}", ev), + Async::Ready(ev) => { + panic!("Unexpected event: {:?}", ev) + } Async::NotReady => break Ok(Async::NotReady), } } diff --git a/core/tests/network_simult.rs b/core/tests/network_simult.rs index cc9ebdfec80..0b5c23839d6 100644 --- a/core/tests/network_simult.rs +++ b/core/tests/network_simult.rs @@ -18,9 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +mod util; + use futures::{future, prelude::*}; use libp2p_core::identity; -use libp2p_core::nodes::network::{Network, NetworkEvent, IncomingError}; +use libp2p_core::nodes::{Network, NetworkEvent, Peer}; +use libp2p_core::nodes::network::IncomingError; use libp2p_core::{Transport, upgrade, upgrade::OutboundUpgradeExt, upgrade::InboundUpgradeExt}; use libp2p_swarm::{ ProtocolsHandler, @@ -118,6 +121,11 @@ fn raw_swarm_simultaneous_connect() { .map_outbound(move |muxer| (peer_id, muxer)) .map_inbound(move |muxer| (peer_id2, muxer)); upgrade::apply(out.stream, upgrade, endpoint) + }) + .and_then(|(peer, mplex), _| { + // Gracefully close the connection to allow protocol + // negotiation to complete. + util::CloseMuxer::new(mplex).map(move |mplex| (peer, mplex)) }); Network::new(transport, local_public_key.into_peer_id()) }; @@ -134,6 +142,11 @@ fn raw_swarm_simultaneous_connect() { .map_outbound(move |muxer| (peer_id, muxer)) .map_inbound(move |muxer| (peer_id2, muxer)); upgrade::apply(out.stream, upgrade, endpoint) + }) + .and_then(|(peer, mplex), _| { + // Gracefully close the connection to allow protocol + // negotiation to complete. + util::CloseMuxer::new(mplex).map(move |mplex| (peer, mplex)) }); Network::new(transport, local_public_key.into_peer_id()) }; @@ -164,14 +177,14 @@ fn raw_swarm_simultaneous_connect() { let mut reactor = tokio::runtime::current_thread::Runtime::new().unwrap(); - for _ in 0 .. 10 { + loop { let mut swarm1_step = 0; let mut swarm2_step = 0; let mut swarm1_dial_start = Delay::new(Instant::now() + Duration::new(0, rand::random::() % 50_000_000)); let mut swarm2_dial_start = Delay::new(Instant::now() + Duration::new(0, rand::random::() % 50_000_000)); - let future = future::poll_fn(|| -> Poll<(), io::Error> { + let future = future::poll_fn(|| -> Poll { loop { let mut swarm1_not_ready = false; let mut swarm2_not_ready = false; @@ -183,10 +196,11 @@ fn raw_swarm_simultaneous_connect() { match swarm1_dial_start.poll().unwrap() { Async::Ready(_) => { let handler = TestHandler::default().into_node_handler_builder(); - swarm1.peer(swarm2.local_peer_id().clone()).into_not_connected().unwrap() + swarm1.peer(swarm2.local_peer_id().clone()) + .into_not_connected() + .unwrap() .connect(swarm2_listen_addr.clone(), handler); swarm1_step = 1; - swarm1_not_ready = false; }, Async::NotReady => swarm1_not_ready = true, } @@ -196,10 +210,11 @@ fn raw_swarm_simultaneous_connect() { match swarm2_dial_start.poll().unwrap() { Async::Ready(_) => { let handler = TestHandler::default().into_node_handler_builder(); - swarm2.peer(swarm1.local_peer_id().clone()).into_not_connected().unwrap() + swarm2.peer(swarm1.local_peer_id().clone()) + .into_not_connected() + .unwrap() .connect(swarm1_listen_addr.clone(), handler); swarm2_step = 1; - swarm2_not_ready = false; }, Async::NotReady => swarm2_not_ready = true, } @@ -207,12 +222,19 @@ fn raw_swarm_simultaneous_connect() { if rand::random::() < 0.1 { match swarm1.poll() { - Async::Ready(NetworkEvent::IncomingConnectionError { error: IncomingError::DeniedLowerPriority, .. }) => { + Async::Ready(NetworkEvent::IncomingConnectionError { + error: IncomingError::DeniedLowerPriority, .. + }) => { assert_eq!(swarm1_step, 2); swarm1_step = 3; }, Async::Ready(NetworkEvent::Connected { conn_info, .. }) => { assert_eq!(conn_info, *swarm2.local_peer_id()); + if swarm1_step == 0 { + // The connection was established before + // swarm1 started dialing; discard the test run. + return Ok(Async::Ready(false)) + } assert_eq!(swarm1_step, 1); swarm1_step = 2; }, @@ -224,19 +246,26 @@ fn raw_swarm_simultaneous_connect() { Async::Ready(NetworkEvent::IncomingConnection(inc)) => { inc.accept(TestHandler::default().into_node_handler_builder()); }, - Async::Ready(_) => unreachable!(), + Async::Ready(ev) => panic!("swarm1: unexpected event: {:?}", ev), Async::NotReady => swarm1_not_ready = true, } } if rand::random::() < 0.1 { match swarm2.poll() { - Async::Ready(NetworkEvent::IncomingConnectionError { error: IncomingError::DeniedLowerPriority, .. }) => { + Async::Ready(NetworkEvent::IncomingConnectionError { + error: IncomingError::DeniedLowerPriority, .. + }) => { assert_eq!(swarm2_step, 2); swarm2_step = 3; }, Async::Ready(NetworkEvent::Connected { conn_info, .. }) => { assert_eq!(conn_info, *swarm1.local_peer_id()); + if swarm2_step == 0 { + // The connection was established before + // swarm2 started dialing; discard the test run. + return Ok(Async::Ready(false)) + } assert_eq!(swarm2_step, 1); swarm2_step = 2; }, @@ -248,14 +277,14 @@ fn raw_swarm_simultaneous_connect() { Async::Ready(NetworkEvent::IncomingConnection(inc)) => { inc.accept(TestHandler::default().into_node_handler_builder()); }, - Async::Ready(_) => unreachable!(), + Async::Ready(ev) => panic!("swarm2: unexpected event: {:?}", ev), Async::NotReady => swarm2_not_ready = true, } } // TODO: make sure that >= 5 is correct if swarm1_step + swarm2_step >= 5 { - return Ok(Async::Ready(())); + return Ok(Async::Ready(true)); } if swarm1_not_ready && swarm2_not_ready { @@ -264,11 +293,23 @@ fn raw_swarm_simultaneous_connect() { } }); - reactor.block_on(future).unwrap(); - - // We now disconnect them again. - swarm1.peer(swarm2.local_peer_id().clone()).into_connected().unwrap().close(); - swarm2.peer(swarm1.local_peer_id().clone()).into_connected().unwrap().close(); + if reactor.block_on(future).unwrap() { + // The test exercised what we wanted to exercise: a simultaneous connect. + break + } else { + // The test did not trigger a simultaneous connect; ensure the nodes + // are disconnected and re-run the test. + match swarm1.peer(swarm2.local_peer_id().clone()) { + Peer::Connected(p) => p.close(), + Peer::PendingConnect(p) => p.interrupt(), + x => panic!("Unexpected state for swarm1: {:?}", x) + } + match swarm2.peer(swarm1.local_peer_id().clone()) { + Peer::Connected(p) => p.close(), + Peer::PendingConnect(p) => p.interrupt(), + x => panic!("Unexpected state for swarm2: {:?}", x) + } + } } } } diff --git a/core/tests/util.rs b/core/tests/util.rs new file mode 100644 index 00000000000..b43442822cb --- /dev/null +++ b/core/tests/util.rs @@ -0,0 +1,47 @@ + +#![allow(dead_code)] + +use futures::prelude::*; +use libp2p_core::muxing::StreamMuxer; + +pub struct CloseMuxer { + state: CloseMuxerState, +} + +impl CloseMuxer { + pub fn new(m: M) -> CloseMuxer { + CloseMuxer { + state: CloseMuxerState::Close(m) + } + } +} + +pub enum CloseMuxerState { + Close(M), + Done, +} + +impl Future for CloseMuxer +where + M: StreamMuxer, + M::Error: From +{ + type Item = M; + type Error = M::Error; + + fn poll(&mut self) -> Poll { + loop { + match std::mem::replace(&mut self.state, CloseMuxerState::Done) { + CloseMuxerState::Close(muxer) => { + if muxer.close()?.is_not_ready() { + self.state = CloseMuxerState::Close(muxer); + return Ok(Async::NotReady) + } + return Ok(Async::Ready(muxer)) + } + CloseMuxerState::Done => panic!() + } + } + } +} + diff --git a/misc/multistream-select/Cargo.toml b/misc/multistream-select/Cargo.toml index 83de968e89d..b7a9ab0cafe 100644 --- a/misc/multistream-select/Cargo.toml +++ b/misc/multistream-select/Cargo.toml @@ -20,3 +20,5 @@ unsigned-varint = { version = "0.2.2" } [dev-dependencies] tokio = "0.1" tokio-tcp = "0.1" +quickcheck = "0.8" +rand = "0.6" diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 59bda8ad396..dc39f753230 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -18,31 +18,41 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to -//! `multistream-select` for the dialer. +//! Protocol negotiation strategies for the peer acting as the dialer. -use futures::{future::Either, prelude::*, stream::StreamFuture}; -use crate::protocol::{Dialer, DialerFuture, Request, Response}; -use log::trace; -use std::mem; +use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version}; +use futures::{future::Either, prelude::*}; +use log::debug; +use std::{io, iter, mem, convert::TryFrom}; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::{Negotiated, ProtocolChoiceError}; +use crate::{Negotiated, NegotiationError}; -/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer -/// either sequentially or by considering all protocols in parallel. -pub type DialerSelectFuture = Either, DialerSelectPar>; - -/// Helps selecting a protocol amongst the ones supported. +/// Returns a `Future` that negotiates a protocol on the given I/O stream +/// for a peer acting as the _dialer_ (or _initiator_). /// -/// This function expects a socket and a list of protocols. It uses the `multistream-select` -/// protocol to choose with the remote a protocol amongst the ones produced by the iterator. +/// This function is given an I/O stream and a list of protocols and returns a +/// computation that performs the protocol negotiation with the remote. The +/// returned `Future` resolves with the name of the negotiated protocol and +/// a [`Negotiated`] I/O stream. /// -/// The iterator must produce a tuple of a protocol name advertised to the remote, a function that -/// checks whether a protocol name matches the protocol, and a protocol "identifier" of type `P` -/// (you decide what `P` is). The parameters of the match function are the name proposed by the -/// remote, and the protocol name that we passed (so that you don't have to clone the name). On -/// success, the function returns the identifier (of type `P`), plus the socket which now uses that -/// chosen protocol. +/// The chosen message flow for protocol negotiation depends on the numbers +/// of supported protocols given. That is, this function delegates to +/// [`dialer_select_proto_serial`] or [`dialer_select_proto_parallel`] +/// based on the number of protocols given. The number of protocols is +/// determined through the `size_hint` of the given iterator and thus +/// an inaccurate size estimate may result in a suboptimal choice. +/// +/// > **Note**: When multiple `DialerSelectFuture`s are composed, i.e. a +/// > dialer performs multiple, nested protocol negotiations with just a +/// > single supported protocol (0-RTT negotiations), a listener that +/// > does not support one of the intermediate protocols may still process +/// > the request data associated with a supported follow-up protocol. +/// > See \[[1]\]. To avoid this behaviour, a dialer should ensure completion +/// > of the previous negotiation before starting the next negotiation, +/// > which can be accomplished by waiting for the future returned by +/// > [`Negotiated::complete`] to resolve. +/// +/// [1]: https://github.com/multiformats/go-multistream/issues/20 pub fn dialer_select_proto(inner: R, protocols: I) -> DialerSelectFuture where R: AsyncRead + AsyncWrite, @@ -58,371 +68,261 @@ where } } -/// Helps selecting a protocol amongst the ones supported. +/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer +/// either trying protocols in-order, or by requesting all protocols supported +/// by the remote upfront, from which the first protocol found in the dialer's +/// list of protocols is selected. +pub type DialerSelectFuture = Either, DialerSelectPar>; + +/// Returns a `Future` that negotiates a protocol on the given I/O stream. +/// +/// Just like [`dialer_select_proto`] but always using an iterative message flow, +/// trying the given list of supported protocols one-by-one. /// -/// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce -/// match functions, because it's not needed. +/// This strategy is preferable if the dialer only supports a few protocols. pub fn dialer_select_proto_serial(inner: R, protocols: I) -> DialerSelectSeq where R: AsyncRead + AsyncWrite, I: IntoIterator, I::Item: AsRef<[u8]> { - let protocols = protocols.into_iter(); + let protocols = protocols.into_iter().peekable(); DialerSelectSeq { - inner: DialerSelectSeqState::AwaitDialer { - dialer_fut: Dialer::dial(inner), - protocols + protocols, + state: SeqState::SendHeader { + io: MessageIO::new(inner) } } } +/// Returns a `Future` that negotiates a protocol on the given I/O stream. +/// +/// Just like [`dialer_select_proto`] but always using a message flow that first +/// requests all supported protocols from the remote, selecting the first +/// protocol from the given list of supported protocols that is supported +/// by the remote. +/// +/// This strategy may be beneficial if the dialer supports many protocols +/// and it is unclear whether the remote supports one of the first few. +pub fn dialer_select_proto_parallel(inner: R, protocols: I) -> DialerSelectPar +where + R: AsyncRead + AsyncWrite, + I: IntoIterator, + I::Item: AsRef<[u8]> +{ + let protocols = protocols.into_iter(); + DialerSelectPar { + protocols, + state: ParState::SendHeader { + io: MessageIO::new(inner) + } + } +} -/// Future, returned by `dialer_select_proto_serial` which selects a protocol -/// and dialer sequentially. +/// A `Future` returned by [`dialer_select_proto_serial`] which negotiates +/// a protocol iteratively by considering one protocol after the other. pub struct DialerSelectSeq where R: AsyncRead + AsyncWrite, I: Iterator, I::Item: AsRef<[u8]> { - inner: DialerSelectSeqState + // TODO: It would be nice if eventually N = I::Item = Protocol. + protocols: iter::Peekable, + state: SeqState } -enum DialerSelectSeqState +enum SeqState where R: AsyncRead + AsyncWrite, - I: Iterator, - I::Item: AsRef<[u8]> + N: AsRef<[u8]> { - AwaitDialer { - dialer_fut: DialerFuture, - protocols: I - }, - NextProtocol { - dialer: Dialer, - proto_name: I::Item, - protocols: I - }, - FlushProtocol { - dialer: Dialer, - proto_name: I::Item, - protocols: I - }, - AwaitProtocol { - stream: StreamFuture>, - proto_name: I::Item, - protocols: I - }, - Undefined + SendHeader { io: MessageIO, }, + SendProtocol { io: MessageIO, protocol: N }, + FlushProtocol { io: MessageIO, protocol: N }, + AwaitProtocol { io: MessageIO, protocol: N }, + Done } impl Future for DialerSelectSeq where R: AsyncRead + AsyncWrite, I: Iterator, - I::Item: AsRef<[u8]> + Clone + I::Item: AsRef<[u8]> { type Item = (I::Item, Negotiated); - type Error = ProtocolChoiceError; + type Error = NegotiationError; fn poll(&mut self) -> Poll { loop { - match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) { - DialerSelectSeqState::AwaitDialer { mut dialer_fut, mut protocols } => { - let dialer = match dialer_fut.poll()? { - Async::Ready(d) => d, - Async::NotReady => { - self.inner = DialerSelectSeqState::AwaitDialer { dialer_fut, protocols }; - return Ok(Async::NotReady) - } - }; - let proto_name = protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?; - self.inner = DialerSelectSeqState::NextProtocol { - dialer, - protocols, - proto_name + match mem::replace(&mut self.state, SeqState::Done) { + SeqState::SendHeader { mut io } => { + if io.start_send(Message::Header(Version::V1))?.is_not_ready() { + self.state = SeqState::SendHeader { io }; + return Ok(Async::NotReady) } + let protocol = self.protocols.next().ok_or(NegotiationError::Failed)?; + self.state = SeqState::SendProtocol { io, protocol }; } - DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => { - trace!("sending {:?}", proto_name.as_ref()); - let req = Request::Protocol { name: proto_name.clone() }; - match dialer.start_send(req)? { - AsyncSink::Ready => { - self.inner = DialerSelectSeqState::FlushProtocol { - dialer, - proto_name, - protocols - } - } - AsyncSink::NotReady(_) => { - self.inner = DialerSelectSeqState::NextProtocol { - dialer, - protocols, - proto_name - }; - return Ok(Async::NotReady) - } + SeqState::SendProtocol { mut io, protocol } => { + let p = Protocol::try_from(protocol.as_ref())?; + if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() { + self.state = SeqState::SendProtocol { io, protocol }; + return Ok(Async::NotReady) + } + debug!("Dialer: Proposed protocol: {}", p); + if self.protocols.peek().is_some() { + self.state = SeqState::FlushProtocol { io, protocol } + } else { + debug!("Dialer: Expecting proposed protocol: {}", p); + let io = Negotiated::expecting(io.into_reader(), p); + return Ok(Async::Ready((protocol, io))) } } - DialerSelectSeqState::FlushProtocol { mut dialer, proto_name, protocols } => { - match dialer.poll_complete()? { - Async::Ready(()) => { - let stream = dialer.into_future(); - self.inner = DialerSelectSeqState::AwaitProtocol { - stream, - proto_name, - protocols - } - } - Async::NotReady => { - self.inner = DialerSelectSeqState::FlushProtocol { - dialer, - proto_name, - protocols - }; - return Ok(Async::NotReady) - } + SeqState::FlushProtocol { mut io, protocol } => { + if io.poll_complete()?.is_not_ready() { + self.state = SeqState::FlushProtocol { io, protocol }; + return Ok(Async::NotReady) } + self.state = SeqState::AwaitProtocol { io, protocol } } - DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, mut protocols } => { - let (m, r) = match stream.poll() { - Ok(Async::Ready(x)) => x, - Ok(Async::NotReady) => { - self.inner = DialerSelectSeqState::AwaitProtocol { - stream, - proto_name, - protocols - }; + SeqState::AwaitProtocol { mut io, protocol } => { + let msg = match io.poll()? { + Async::NotReady => { + self.state = SeqState::AwaitProtocol { io, protocol }; return Ok(Async::NotReady) } - Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + Async::Ready(None) => + return Err(NegotiationError::from( + io::Error::from(io::ErrorKind::UnexpectedEof))), + Async::Ready(Some(msg)) => msg, }; - trace!("received {:?}", m); - match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? { - Response::Protocol { ref name } - if name.as_ref() == proto_name.as_ref() => - { - return Ok(Async::Ready((proto_name, Negotiated(r.into_inner())))) + + match msg { + Message::Header(Version::V1) => { + self.state = SeqState::AwaitProtocol { io, protocol }; } - Response::ProtocolNotAvailable => { - let proto_name = protocols.next() - .ok_or(ProtocolChoiceError::NoProtocolFound)?; - self.inner = DialerSelectSeqState::NextProtocol { - dialer: r, - protocols, - proto_name - } + Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { + debug!("Dialer: Received confirmation for protocol: {}", p); + let (io, remaining) = io.into_inner(); + let io = Negotiated::completed(io, remaining); + return Ok(Async::Ready((protocol, io))) } - _ => return Err(ProtocolChoiceError::UnexpectedMessage) + Message::NotAvailable => { + debug!("Dialer: Received rejection of protocol: {}", + String::from_utf8_lossy(protocol.as_ref())); + let protocol = self.protocols.next() + .ok_or(NegotiationError::Failed)?; + self.state = SeqState::SendProtocol { io, protocol } + } + _ => return Err(ProtocolError::InvalidMessage.into()) } } - DialerSelectSeqState::Undefined => - panic!("DialerSelectSeqState::poll called after completion") + SeqState::Done => panic!("SeqState::poll called after completion") } } } } -/// Helps selecting a protocol amongst the ones supported. -/// -/// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then -/// chooses the most appropriate one. -pub fn dialer_select_proto_parallel(inner: R, protocols: I) -> DialerSelectPar -where - R: AsyncRead + AsyncWrite, - I: IntoIterator, - I::Item: AsRef<[u8]> -{ - let protocols = protocols.into_iter(); - DialerSelectPar { - inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::dial(inner), protocols } - } -} - -/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in -/// parallel, by first requesting the list of protocols supported by the remote endpoint and -/// then selecting the most appropriate one by applying a match predicate to the result. +/// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates +/// a protocol selectively by considering all supported protocols of the remote +/// "in parallel". pub struct DialerSelectPar where R: AsyncRead + AsyncWrite, I: Iterator, I::Item: AsRef<[u8]> { - inner: DialerSelectParState + protocols: I, + state: ParState } -enum DialerSelectParState +enum ParState where R: AsyncRead + AsyncWrite, - I: Iterator, - I::Item: AsRef<[u8]> + N: AsRef<[u8]> { - AwaitDialer { - dialer_fut: DialerFuture, - protocols: I - }, - ProtocolList { - dialer: Dialer, - protocols: I - }, - FlushListRequest { - dialer: Dialer, - protocols: I - }, - AwaitListResponse { - stream: StreamFuture>, - protocols: I, - }, - Protocol { - dialer: Dialer, - proto_name: I::Item - }, - FlushProtocol { - dialer: Dialer, - proto_name: I::Item - }, - AwaitProtocol { - stream: StreamFuture>, - proto_name: I::Item - }, - Undefined + SendHeader { io: MessageIO }, + SendProtocolsRequest { io: MessageIO }, + Flush { io: MessageIO }, + RecvProtocols { io: MessageIO }, + SendProtocol { io: MessageIO, protocol: N }, + Done } impl Future for DialerSelectPar where R: AsyncRead + AsyncWrite, I: Iterator, - I::Item: AsRef<[u8]> + Clone + I::Item: AsRef<[u8]> { type Item = (I::Item, Negotiated); - type Error = ProtocolChoiceError; + type Error = NegotiationError; fn poll(&mut self) -> Poll { loop { - match mem::replace(&mut self.inner, DialerSelectParState::Undefined) { - DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => { - match dialer_fut.poll()? { - Async::Ready(dialer) => { - self.inner = DialerSelectParState::ProtocolList { dialer, protocols } - } - Async::NotReady => { - self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols }; - return Ok(Async::NotReady) - } + match mem::replace(&mut self.state, ParState::Done) { + ParState::SendHeader { mut io } => { + if io.start_send(Message::Header(Version::V1))?.is_not_ready() { + self.state = ParState::SendHeader { io }; + return Ok(Async::NotReady) } + self.state = ParState::SendProtocolsRequest { io }; } - DialerSelectParState::ProtocolList { mut dialer, protocols } => { - trace!("requesting protocols list"); - match dialer.start_send(Request::ListProtocols)? { - AsyncSink::Ready => { - self.inner = DialerSelectParState::FlushListRequest { - dialer, - protocols - } - } - AsyncSink::NotReady(_) => { - self.inner = DialerSelectParState::ProtocolList { dialer, protocols }; - return Ok(Async::NotReady) - } + ParState::SendProtocolsRequest { mut io } => { + if io.start_send(Message::ListProtocols)?.is_not_ready() { + self.state = ParState::SendProtocolsRequest { io }; + return Ok(Async::NotReady) } + debug!("Dialer: Requested supported protocols."); + self.state = ParState::Flush { io } } - DialerSelectParState::FlushListRequest { mut dialer, protocols } => { - match dialer.poll_complete()? { - Async::Ready(()) => { - self.inner = DialerSelectParState::AwaitListResponse { - stream: dialer.into_future(), - protocols - } - } - Async::NotReady => { - self.inner = DialerSelectParState::FlushListRequest { - dialer, - protocols - }; - return Ok(Async::NotReady) - } + ParState::Flush { mut io } => { + if io.poll_complete()?.is_not_ready() { + self.state = ParState::Flush { io }; + return Ok(Async::NotReady) } + self.state = ParState::RecvProtocols { io } } - DialerSelectParState::AwaitListResponse { mut stream, protocols } => { - let (resp, dialer) = match stream.poll() { - Ok(Async::Ready(x)) => x, - Ok(Async::NotReady) => { - self.inner = DialerSelectParState::AwaitListResponse { stream, protocols }; + ParState::RecvProtocols { mut io } => { + let msg = match io.poll()? { + Async::NotReady => { + self.state = ParState::RecvProtocols { io }; return Ok(Async::NotReady) } - Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + Async::Ready(None) => + return Err(NegotiationError::from( + io::Error::from(io::ErrorKind::UnexpectedEof))), + Async::Ready(Some(msg)) => msg, }; - trace!("protocols list response: {:?}", resp); - let supported = - if let Some(Response::SupportedProtocols { protocols }) = resp { - protocols - } else { - return Err(ProtocolChoiceError::UnexpectedMessage) - }; - let mut found = None; - for local_name in protocols { - for remote_name in &supported { - if remote_name.as_ref() == local_name.as_ref() { - found = Some(local_name); - break; - } - } - if found.is_some() { - break; - } - } - let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?; - self.inner = DialerSelectParState::Protocol { dialer, proto_name } - } - DialerSelectParState::Protocol { mut dialer, proto_name } => { - trace!("Requesting protocol: {:?}", proto_name.as_ref()); - let req = Request::Protocol { name: proto_name.clone() }; - match dialer.start_send(req)? { - AsyncSink::Ready => { - self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name } - } - AsyncSink::NotReady(_) => { - self.inner = DialerSelectParState::Protocol { dialer, proto_name }; - return Ok(Async::NotReady) - } - } - } - DialerSelectParState::FlushProtocol { mut dialer, proto_name } => { - match dialer.poll_complete()? { - Async::Ready(()) => { - self.inner = DialerSelectParState::AwaitProtocol { - stream: dialer.into_future(), - proto_name - } + + match &msg { + Message::Header(Version::V1) => { + self.state = ParState::RecvProtocols { io } } - Async::NotReady => { - self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }; - return Ok(Async::NotReady) + Message::Protocols(supported) => { + let protocol = self.protocols.by_ref() + .find(|p| supported.iter().any(|s| + s.as_ref() == p.as_ref())) + .ok_or(NegotiationError::Failed)?; + debug!("Dialer: Found supported protocol: {}", + String::from_utf8_lossy(protocol.as_ref())); + self.state = ParState::SendProtocol { io, protocol }; } + _ => return Err(ProtocolError::InvalidMessage.into()) } } - DialerSelectParState::AwaitProtocol { mut stream, proto_name } => { - let (resp, dialer) = match stream.poll() { - Ok(Async::Ready(x)) => x, - Ok(Async::NotReady) => { - self.inner = DialerSelectParState::AwaitProtocol { stream, proto_name }; - return Ok(Async::NotReady) - } - Err((e, _)) => return Err(ProtocolChoiceError::from(e)) - }; - trace!("received {:?}", resp); - match resp { - Some(Response::Protocol { ref name }) - if name.as_ref() == proto_name.as_ref() => - { - return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner())))) - } - _ => return Err(ProtocolChoiceError::UnexpectedMessage) + ParState::SendProtocol { mut io, protocol } => { + let p = Protocol::try_from(protocol.as_ref())?; + if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() { + self.state = ParState::SendProtocol { io, protocol }; + return Ok(Async::NotReady) } + debug!("Dialer: Expecting proposed protocol: {}", p); + let io = Negotiated::expecting(io.into_reader(), p); + return Ok(Async::Ready((protocol, io))) } - DialerSelectParState::Undefined => - panic!("DialerSelectParState::poll called after completion") + ParState::Done => panic!("ParState::poll called after completion") } } } diff --git a/misc/multistream-select/src/error.rs b/misc/multistream-select/src/error.rs index 1f72b5c0c8a..4d948de4490 100644 --- a/misc/multistream-select/src/error.rs +++ b/misc/multistream-select/src/error.rs @@ -20,58 +20,8 @@ //! Main `ProtocolChoiceError` error. -use crate::protocol::MultistreamSelectError; +pub use crate::protocol::ProtocolError; + use std::error::Error; use std::{fmt, io}; -/// Error that can happen when negotiating a protocol with the remote. -#[derive(Debug)] -pub enum ProtocolChoiceError { - /// Error in the protocol. - MultistreamSelectError(MultistreamSelectError), - - /// Received a message from the remote that makes no sense in the current context. - UnexpectedMessage, - - /// We don't support any protocol in common with the remote. - NoProtocolFound, -} - -impl From for ProtocolChoiceError { - fn from(err: MultistreamSelectError) -> ProtocolChoiceError { - ProtocolChoiceError::MultistreamSelectError(err) - } -} - -impl From for ProtocolChoiceError { - fn from(err: io::Error) -> ProtocolChoiceError { - MultistreamSelectError::from(err).into() - } -} - -impl Error for ProtocolChoiceError { - fn description(&self) -> &str { - match *self { - ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol", - ProtocolChoiceError::UnexpectedMessage => { - "received a message from the remote that makes no sense in the current context" - } - ProtocolChoiceError::NoProtocolFound => { - "we don't support any protocol in common with the remote" - } - } - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - match *self { - ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err), - _ => None, - } - } -} - -impl fmt::Display for ProtocolChoiceError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(fmt, "{}", Error::description(self)) - } -} diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 4d4f2c02bb4..44dbfe95789 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -28,8 +28,8 @@ const MAX_LEN_BYTES: u16 = 2; const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; const DEFAULT_BUFFER_SIZE: usize = 64; -/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read -/// and write unsigned-varint prefixed frames. +/// A `Stream` and `Sink` for unsigned-varint length-delimited frames, +/// wrapping an underlying `AsyncRead + AsyncWrite` I/O resource. /// /// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint /// frame length). Frames mostly consist in a short protocol name, which is highly @@ -75,20 +75,70 @@ impl LengthDelimited { } } - /// Destroys the `LengthDelimited` and returns the underlying socket. + /// Returns a reference to the underlying I/O stream. + pub fn inner_ref(&self) -> &R { + &self.inner + } + + /// Returns a mutable reference to the underlying I/O stream. + /// + /// > **Note**: Care should be taken to not tamper with the underlying stream of data + /// > coming in, as it may corrupt the stream of frames. + pub fn inner_mut(&mut self) -> &mut R { + &mut self.inner + } + + /// Drops the `LengthDelimited` resource, yielding the underlying I/O stream + /// together with the remaining write buffer containing the uvi-framed data + /// that has not yet been written to the underlying I/O stream. /// - /// This method is guaranteed not to skip any data from the socket. + /// The returned remaining write buffer may be prepended to follow-up + /// protocol data to send with a single `write`. Either way, if non-empty, + /// the write buffer _must_ eventually be written to the I/O stream + /// _before_ any follow-up data, in order to maintain a correct data stream. /// /// # Panic /// - /// Will panic if called while there is data inside the read or write buffer. - /// **This can only happen if you call `poll()` manually**. Using this struct - /// as it is intended to be used (i.e. through the high-level `futures` API) - /// will always leave the object in a state in which `into_inner()` will not panic. - pub fn into_inner(self) -> R { - assert!(self.write_buffer.is_empty()); + /// Will panic if called while there is data in the read buffer. The read buffer is + /// guaranteed to be empty whenever `Stream::poll` yields a new `Bytes` frame. + pub fn into_inner(self) -> (R, BytesMut) { assert!(self.read_buffer.is_empty()); - self.inner + (self.inner, self.write_buffer) + } + + /// Converts the `LengthDelimited` into a `LengthDelimitedReader`, dropping the + /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying + /// I/O stream. + /// + /// This is typically done if further uvi-framed messages are expected to be + /// received but no more such messages are written, allowing the writing of + /// follow-up protocol data to commence. + pub fn into_reader(self) -> LengthDelimitedReader { + LengthDelimitedReader { inner: self } + } + + /// Writes all buffered frame data to the underlying I/O stream, + /// _without flushing it_. + /// + /// After this method returns `Async::Ready`, the write buffer of frames + /// submitted to the `Sink` is guaranteed to be empty. + pub fn poll_write_buffer(&mut self) -> Poll<(), io::Error> + where + R: AsyncWrite + { + while !self.write_buffer.is_empty() { + let n = try_ready!(self.inner.poll_write(&self.write_buffer)); + + if n == 0 { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "Failed to write buffered frame.")) + } + + self.write_buffer.split_to(n); + } + + Ok(Async::Ready(())) } } @@ -204,18 +254,9 @@ where } fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - while !self.write_buffer.is_empty() { - let n = try_ready!(self.inner.poll_write(&self.write_buffer)); - - if n == 0 { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "Failed to write buffered frame.")) - } - - let _ = self.write_buffer.split_to(n); - } - + // Write all buffered frame data to the underlying I/O stream. + try_ready!(self.poll_write_buffer()); + // Flush the underlying I/O stream. try_ready!(self.inner.poll_flush()); return Ok(Async::Ready(())); } @@ -226,6 +267,102 @@ where } } +/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited +/// frames on an underlying I/O resource combined with direct `AsyncWrite` access. +pub struct LengthDelimitedReader { + inner: LengthDelimited +} + +impl LengthDelimitedReader { + /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream. + /// + /// This method is guaranteed not to drop any data read from or not yet + /// submitted to the underlying I/O stream. + /// + /// # Panic + /// + /// Will panic if called while there is data in the read or write buffer. + /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields + /// a new `Message`. The write buffer is guaranteed to be empty whenever + /// [`poll_write_buffer`] yields `Async::Ready` or after the `Sink` has been + /// completely flushed via [`Sink::poll_complete`]. + pub fn into_inner(self) -> (R, BytesMut) { + self.inner.into_inner() + } + + /// Returns a reference to the underlying I/O stream. + pub fn inner_ref(&self) -> &R { + self.inner.inner_ref() + } + + /// Returns a mutable reference to the underlying I/O stream. + /// + /// > **Note**: Care should be taken to not tamper with the underlying stream of data + /// > coming in, as it may corrupt the stream of frames. + pub fn inner_mut(&mut self) -> &mut R { + self.inner.inner_mut() + } +} + +impl Stream for LengthDelimitedReader +where + R: AsyncRead +{ + type Item = Bytes; + type Error = io::Error; + + fn poll(&mut self) -> Poll, Self::Error> { + self.inner.poll() + } +} + +impl io::Write for LengthDelimitedReader +where + R: AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + // Try to drain the write buffer together with writing `buf`. + if !self.inner.write_buffer.is_empty() { + let n = self.inner.write_buffer.len(); + self.inner.write_buffer.extend_from_slice(buf); + let result = self.inner.poll_write_buffer(); + let written = n - self.inner.write_buffer.len(); + if written == 0 { + if let Err(e) = result { + return Err(e) + } + return Err(io::ErrorKind::WouldBlock.into()) + } + if written < buf.len() { + if self.inner.write_buffer.len() > n { + self.inner.write_buffer.split_off(n); // Never grow the buffer. + } + return Ok(written) + } + return Ok(buf.len()) + } + + self.inner_mut().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + match self.inner.poll_complete()? { + Async::Ready(()) => Ok(()), + Async::NotReady => Err(io::ErrorKind::WouldBlock.into()) + } + } +} + +impl AsyncWrite for LengthDelimitedReader +where + R: AsyncWrite +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + try_ready!(self.inner.poll_complete()); + self.inner_mut().shutdown() + } +} + #[cfg(test)] mod tests { use futures::{Future, Stream}; diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index b2bc054bfff..9dd89e3cbe2 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -18,24 +18,59 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! # Multistream-select +//! # Multistream-select Protocol Negotiation //! -//! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p -//! to negotiate which protocol to use with the remote on a connection or substream. +//! This crate implements the `multistream-select` protocol, which is the protocol +//! used by libp2p to negotiate which application-layer protocol to use with the +//! remote on a connection or substream. //! -//! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to -//! > understand it in order to use *libp2p*. +//! > **Note**: This crate is used primarily by core components of *libp2p* and it +//! > is usually not used directly on its own. //! -//! Whenever a new connection or a new multiplexed substream is opened, libp2p uses -//! `multistream-select` to negotiate with the remote which protocol to use. After a protocol has -//! been successfully negotiated, the stream (i.e. the connection or the multiplexed substream) -//! immediately stops using `multistream-select` and starts using the negotiated protocol. +//! ## Roles //! -//! ## Protocol explanation +//! Two peers using the multistream-select negotiation protocol on an I/O stream +//! are distinguished by their role as a _dialer_ (or _initiator_) or as a _listener_ +//! (or _responder_). Thereby the dialer plays the active part, driving the protocol, +//! whereas the listener reacts to the messages received. //! -//! The dialer has two options available: either request the list of protocols that the listener -//! supports, or suggest a protocol. If a protocol is suggested, the listener can either accept (by -//! answering with the same protocol name) or refuse the choice (by answering "not available"). +//! The dialer has two options: it can either pick a protocol from the complete list +//! of protocols that the listener supports, or it can directly suggest a protocol. +//! Either way, a selected protocol is sent to the listener who can either accept (by +//! echoing the same protocol) or reject (by responding with a message stating +//! "not available"). If a suggested protocol is not available, the dialer may +//! suggest another protocol. This process continues until a protocol is agreed upon, +//! yielding a [`Negotiated`](self::Negotiated) stream, or the dialer has run out of +//! alternatives. +//! +//! See [`dialer_select_proto`](self::dialer_select_proto) and +//! [`listener_select_proto`](self::listener_select_proto). +//! +//! ## [`Negotiated`](self::Negotiated) +//! +//! When a dialer or listener participating in a negotiation settles +//! on a protocol to use, the [`DialerSelectFuture`] respectively +//! [`ListenerSelectFuture`] yields a [`Negotiated`](self::Negotiated) +//! I/O stream. +//! +//! Notably, when a `DialerSelectFuture` resolves to a `Negotiated`, it may not yet +//! have written the last negotiation message to the underlying I/O stream and may +//! still be expecting confirmation for that protocol, despite having settled on +//! a protocol to use. +//! +//! Similarly, when a `ListenerSelectFuture` resolves to a `Negotiated`, it may not +//! yet have sent the last negotiation message despite having settled on a protocol +//! proposed by the dialer that it supports. +//! +//! +//! This behaviour allows both the dialer and the listener to send data +//! relating to the negotiated protocol together with the last negotiation +//! message(s), which, in the case of the dialer only supporting a single +//! protocol, results in 0-RTT negotiation. Note, however, that a dialer +//! that performs multiple 0-RTT negotiations in sequence for different +//! protocols layered on top of each other may trigger undesirable behaviour +//! for a listener not supporting one of the intermediate protocols. +//! See [`dialer_select_proto`](self::dialer_select_proto). //! //! ## Examples //! @@ -54,77 +89,28 @@ //! //! let client = TcpStream::connect(&"127.0.0.1:10333".parse().unwrap()) //! .from_err() -//! .and_then(move |connec| { +//! .and_then(move |io| { //! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"]; -//! dialer_select_proto(connec, protos).map(|r| r.0) -//! }); +//! dialer_select_proto(io, protos) // .map(|r| r.0) +//! }) +//! .map(|(protocol, _io)| protocol); //! //! let mut rt = Runtime::new().unwrap(); -//! let negotiated_protocol = rt.block_on(client).expect("failed to find a protocol"); -//! println!("negotiated: {:?}", negotiated_protocol); +//! let protocol = rt.block_on(client).expect("failed to find a protocol"); +//! println!("Negotiated protocol: {:?}", protocol); //! # } //! ``` //! mod dialer_select; -mod error; mod length_delimited; mod listener_select; -mod tests; - +mod negotiated; mod protocol; +mod tests; -use futures::prelude::*; -use std::io; -use tokio_io::{AsyncRead, AsyncWrite}; - +pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; +pub use self::protocol::ProtocolError; pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture}; -pub use self::error::ProtocolChoiceError; pub use self::listener_select::{listener_select_proto, ListenerSelectFuture}; -/// A stream after it has been negotiated. -pub struct Negotiated(pub(crate) TInner); - -impl io::Read for Negotiated -where - TInner: io::Read -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl AsyncRead for Negotiated -where - TInner: AsyncRead -{ - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.0.prepare_uninitialized_buffer(buf) - } - - fn read_buf(&mut self, buf: &mut B) -> Poll { - self.0.read_buf(buf) - } -} - -impl io::Write for Negotiated -where - TInner: io::Write -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.0.flush() - } -} - -impl AsyncWrite for Negotiated -where - TInner: AsyncWrite -{ - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.0.shutdown() - } -} diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 40ed92d057e..a62581158dd 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -18,167 +18,196 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! Contains the `listener_select_proto` code, which allows selecting a protocol thanks to -//! `multistream-select` for the listener. +//! Protocol negotiation strategies for the peer acting as the listener +//! in a multistream-select protocol negotiation. -use futures::{prelude::*, sink, stream::StreamFuture}; -use crate::protocol::{ - Request, - Response, - Listener, - ListenerFuture, -}; -use log::{debug, trace}; -use std::mem; +use futures::prelude::*; +use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version}; +use log::{debug, warn}; +use smallvec::SmallVec; +use std::{io, iter::FromIterator, mem, convert::TryFrom}; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::{Negotiated, ProtocolChoiceError}; +use crate::{Negotiated, NegotiationError}; -/// Helps selecting a protocol amongst the ones supported. +/// Returns a `Future` that negotiates a protocol on the given I/O stream +/// for a peer acting as the _listener_ (or _responder_). /// -/// This function expects a socket and an iterator of the list of supported protocols. The iterator -/// must be clonable (i.e. iterable multiple times), because the list may need to be accessed -/// multiple times. -/// -/// The iterator must produce tuples of the name of the protocol that is advertised to the remote, -/// a function that will check whether a remote protocol matches ours, and an identifier for the -/// protocol of type `P` (you decide what `P` is). The parameters of the function are the name -/// proposed by the remote, and the protocol name that we passed (so that you don't have to clone -/// the name). -/// -/// On success, returns the socket and the identifier of the chosen protocol (of type `P`). The -/// socket now uses this protocol. -pub fn listener_select_proto(inner: R, protocols: I) -> ListenerSelectFuture +/// This function is given an I/O stream and a list of protocols and returns a +/// computation that performs the protocol negotiation with the remote. The +/// returned `Future` resolves with the name of the negotiated protocol and +/// a [`Negotiated`] I/O stream. +pub fn listener_select_proto(inner: R, protocols: I) -> ListenerSelectFuture where R: AsyncRead + AsyncWrite, - for<'r> &'r I: IntoIterator, - X: AsRef<[u8]> + I: IntoIterator, + I::Item: AsRef<[u8]> { + let protocols = protocols.into_iter().filter_map(|n| + match Protocol::try_from(n.as_ref()) { + Ok(p) => Some((n, p)), + Err(e) => { + warn!("Listener: Ignoring invalid protocol: {} due to {}", + String::from_utf8_lossy(n.as_ref()), e); + None + } + }); ListenerSelectFuture { - inner: ListenerSelectState::AwaitListener { - listener_fut: Listener::listen(inner), - protocols + protocols: SmallVec::from_iter(protocols), + state: State::RecvHeader { + io: MessageIO::new(inner) } } } -/// Future, returned by `listener_select_proto` which selects a protocol among the ones supported. -pub struct ListenerSelectFuture +/// The `Future` returned by [`listener_select_proto`] that performs a +/// multistream-select protocol negotiation on an underlying I/O stream. +pub struct ListenerSelectFuture where R: AsyncRead + AsyncWrite, - for<'a> &'a I: IntoIterator, - X: AsRef<[u8]> + N: AsRef<[u8]> { - inner: ListenerSelectState + // TODO: It would be nice if eventually N = Protocol, which has a + // few more implications on the API. + protocols: SmallVec<[(N, Protocol); 8]>, + state: State } -enum ListenerSelectState +enum State where R: AsyncRead + AsyncWrite, - for<'a> &'a I: IntoIterator, - X: AsRef<[u8]> + N: AsRef<[u8]> { - AwaitListener { - listener_fut: ListenerFuture, - protocols: I + RecvHeader { io: MessageIO }, + SendHeader { io: MessageIO }, + RecvMessage { io: MessageIO }, + SendMessage { + io: MessageIO, + message: Message, + protocol: Option }, - Incoming { - stream: StreamFuture>, - protocols: I - }, - Outgoing { - sender: sink::Send>, - protocols: I, - outcome: Option - }, - Undefined + Flush { io: MessageIO }, + Done } -impl Future for ListenerSelectFuture +impl Future for ListenerSelectFuture where R: AsyncRead + AsyncWrite, - for<'a> &'a I: IntoIterator, - X: AsRef<[u8]> + Clone + N: AsRef<[u8]> + Clone { - type Item = (X, Negotiated, I); - type Error = ProtocolChoiceError; + type Item = (N, Negotiated); + type Error = NegotiationError; fn poll(&mut self) -> Poll { loop { - match mem::replace(&mut self.inner, ListenerSelectState::Undefined) { - ListenerSelectState::AwaitListener { mut listener_fut, protocols } => { - let listener = match listener_fut.poll()? { - Async::Ready(l) => l, + match mem::replace(&mut self.state, State::Done) { + State::RecvHeader { mut io } => { + match io.poll()? { + Async::Ready(Some(Message::Header(Version::V1))) => { + self.state = State::SendHeader { io } + } + Async::Ready(Some(Message::Header(Version::V2))) => { + // The V2 protocol is not yet supported and not even + // yet fully specified or implemented anywhere. For + // now we just return 'na' to force any dialer to + // fall back to V1, according to the current plans + // for the "transition period". + // + // See: https://github.com/libp2p/specs/pull/95. + self.state = State::SendMessage { + io, + message: Message::NotAvailable, + protocol: None, + } + } + Async::Ready(Some(_)) => { + return Err(ProtocolError::InvalidMessage.into()) + } + Async::Ready(None) => + return Err(NegotiationError::from( + ProtocolError::IoError( + io::ErrorKind::UnexpectedEof.into()))), Async::NotReady => { - self.inner = ListenerSelectState::AwaitListener { listener_fut, protocols }; + self.state = State::RecvHeader { io }; return Ok(Async::NotReady) } - }; - let stream = listener.into_future(); - self.inner = ListenerSelectState::Incoming { stream, protocols }; + } } - ListenerSelectState::Incoming { mut stream, protocols } => { - let (msg, listener) = match stream.poll() { - Ok(Async::Ready(x)) => x, + State::SendHeader { mut io } => { + if io.start_send(Message::Header(Version::V1))?.is_not_ready() { + return Ok(Async::NotReady) + } + self.state = State::RecvMessage { io }; + } + State::RecvMessage { mut io } => { + let msg = match io.poll() { + Ok(Async::Ready(Some(msg))) => msg, + Ok(Async::Ready(None)) => + return Err(NegotiationError::from( + ProtocolError::IoError( + io::ErrorKind::UnexpectedEof.into()))), Ok(Async::NotReady) => { - self.inner = ListenerSelectState::Incoming { stream, protocols }; + self.state = State::RecvMessage { io }; return Ok(Async::NotReady) } - Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + Err(e) => return Err(e.into()) }; + match msg { - Some(Request::ListProtocols) => { - trace!("protocols list response: {:?}", protocols - .into_iter() - .map(|p| p.as_ref().into()) - .collect::>>()); - let supported = protocols.into_iter().collect(); - let msg = Response::SupportedProtocols { protocols: supported }; - let sender = listener.send(msg); - self.inner = ListenerSelectState::Outgoing { - sender, - protocols, - outcome: None - } + Message::ListProtocols => { + let supported = self.protocols.iter().map(|(_,p)| p).cloned().collect(); + let message = Message::Protocols(supported); + self.state = State::SendMessage { io, message, protocol: None } } - Some(Request::Protocol { name }) => { - let mut outcome = None; - let mut send_back = Response::ProtocolNotAvailable; - for supported in &protocols { - if name.as_ref() == supported.as_ref() { - send_back = Response::Protocol { - name: supported.clone() - }; - outcome = Some(supported); - break; + Message::Protocol(p) => { + let protocol = self.protocols.iter().find_map(|(name, proto)| { + if &p == proto { + Some(name.clone()) + } else { + None } - } - trace!("requested: {:?}, supported: {}", name, outcome.is_some()); - let sender = listener.send(send_back); - self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome } - } - None => { - debug!("no protocol request received"); - return Err(ProtocolChoiceError::NoProtocolFound) + }); + + let message = if protocol.is_some() { + debug!("Listener: confirming protocol: {}", p); + Message::Protocol(p.clone()) + } else { + debug!("Listener: rejecting protocol: {}", + String::from_utf8_lossy(p.as_ref())); + Message::NotAvailable + }; + + self.state = State::SendMessage { io, message, protocol }; } + _ => return Err(ProtocolError::InvalidMessage.into()) } } - ListenerSelectState::Outgoing { mut sender, protocols, outcome } => { - let listener = match sender.poll()? { - Async::Ready(l) => l, - Async::NotReady => { - self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome }; - return Ok(Async::NotReady) + State::SendMessage { mut io, message, protocol } => { + if let AsyncSink::NotReady(message) = io.start_send(message)? { + self.state = State::SendMessage { io, message, protocol }; + return Ok(Async::NotReady) + }; + // If a protocol has been selected, finish negotiation. + // Otherwise flush the sink and expect to receive another + // message. + self.state = match protocol { + Some(protocol) => { + debug!("Listener: sent confirmed protocol: {}", + String::from_utf8_lossy(protocol.as_ref())); + let (io, remaining) = io.into_inner(); + let io = Negotiated::completed(io, remaining); + return Ok(Async::Ready((protocol, io))) } + None => State::Flush { io } }; - if let Some(p) = outcome { - return Ok(Async::Ready((p, Negotiated(listener.into_inner()), protocols))) - } else { - let stream = listener.into_future(); - self.inner = ListenerSelectState::Incoming { stream, protocols } + } + State::Flush { mut io } => { + if io.poll_complete()?.is_not_ready() { + self.state = State::Flush { io }; + return Ok(Async::NotReady) } + self.state = State::RecvMessage { io } } - ListenerSelectState::Undefined => - panic!("ListenerSelectState::poll called after completion") + State::Done => panic!("State::poll called after completion") } } } diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs new file mode 100644 index 00000000000..cb63adcf685 --- /dev/null +++ b/misc/multistream-select/src/negotiated.rs @@ -0,0 +1,338 @@ +// Copyright 2019 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use bytes::BytesMut; +use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError}; +use futures::{prelude::*, Async, try_ready}; +use log::debug; +use tokio_io::{AsyncRead, AsyncWrite}; +use std::{mem, io, fmt, error::Error}; + +/// An I/O stream that has settled on an (application-layer) protocol to use. +/// +/// A `Negotiated` represents an I/O stream that has _settled_ on a protocol +/// to use. In particular, it is not implied that all of the protocol negotiation +/// frames have yet been sent and / or received, just that the selected protocol +/// is fully determined. This is to allow the last protocol negotiation frames +/// sent by a peer to be combined in a single write, possibly piggy-backing +/// data from the negotiated protocol on top. +/// +/// Reading from a `Negotiated` I/O stream that still has pending negotiation +/// protocol data to send implicitly triggers flushing of all yet unsent data. +pub struct Negotiated { + state: State +} + +/// A `Future` that waits on the completion of protocol negotiation. +pub struct NegotiatedComplete { + inner: Option> +} + +impl Future for NegotiatedComplete { + type Item = Negotiated; + type Error = NegotiationError; + + fn poll(&mut self) -> Poll { + let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); + if io.poll()?.is_not_ready() { + self.inner = Some(io); + return Ok(Async::NotReady) + } + return Ok(Async::Ready(io)) + } +} + +impl Negotiated { + /// Creates a `Negotiated` in state [`State::Complete`], possibly + /// with `remaining` data to be sent. + pub(crate) fn completed(io: TInner, remaining: BytesMut) -> Self { + Negotiated { state: State::Completed { io, remaining } } + } + + /// Creates a `Negotiated` in state [`State::Expecting`] that is still + /// expecting confirmation of the given `protocol`. + pub(crate) fn expecting(io: MessageReader, protocol: Protocol) -> Self { + Negotiated { state: State::Expecting { io, protocol } } + } + + /// Polls the `Negotiated` for completion. + fn poll(&mut self) -> Poll<(), NegotiationError> + where + TInner: AsyncRead + AsyncWrite + { + // Flush any pending negotiation data. + match self.poll_flush() { + Ok(Async::Ready(())) => {}, + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(e) => { + // If the remote closed the stream, it is important to still + // continue reading the data that was sent, if any. + if e.kind() != io::ErrorKind::WriteZero { + return Err(e.into()) + } + } + } + + if let State::Completed { remaining, .. } = &mut self.state { + let _ = remaining.take(); // Drop remaining data flushed above. + return Ok(Async::Ready(())) + } + + // Read outstanding protocol negotiation messages. + loop { + match mem::replace(&mut self.state, State::Invalid) { + State::Expecting { mut io, protocol } => { + let msg = match io.poll() { + Ok(Async::Ready(Some(msg))) => msg, + Ok(Async::NotReady) => { + self.state = State::Expecting { io, protocol }; + return Ok(Async::NotReady) + } + Ok(Async::Ready(None)) => { + self.state = State::Expecting { io, protocol }; + return Err(ProtocolError::IoError( + io::ErrorKind::UnexpectedEof.into()).into()) + } + Err(err) => { + self.state = State::Expecting { io, protocol }; + return Err(err.into()) + } + }; + + if let Message::Header(Version::V1) = &msg { + self.state = State::Expecting { io, protocol }; + continue + } + + if let Message::Protocol(p) = &msg { + if p.as_ref() == protocol.as_ref() { + debug!("Negotiated: Received confirmation for protocol: {}", p); + let (io, remaining) = io.into_inner(); + self.state = State::Completed { io, remaining }; + return Ok(Async::Ready(())) + } + } + + return Err(NegotiationError::Failed) + } + + _ => panic!("Negotiated: Invalid state") + } + } + } + + /// Returns a `NegotiatedComplete` future that waits for protocol + /// negotiation to complete. + pub fn complete(self) -> NegotiatedComplete { + NegotiatedComplete { inner: Some(self) } + } +} + +/// The states of a `Negotiated` I/O stream. +enum State { + /// In this state, a `Negotiated` is still expecting to + /// receive confirmation of the protocol it as settled on. + Expecting { io: MessageReader, protocol: Protocol }, + + /// In this state, a protocol has been agreed upon and may + /// only be pending the sending of the final acknowledgement, + /// which is prepended to / combined with the next write for + /// efficiency. + Completed { io: R, remaining: BytesMut }, + + /// Temporary state while moving the `io` resource from + /// `Expecting` to `Completed`. + Invalid, +} + +impl io::Read for Negotiated +where + R: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + loop { + if let State::Completed { io, remaining } = &mut self.state { + // If protocol negotiation is complete and there is no + // remaining data to be flushed, commence with reading. + if remaining.is_empty() { + return io.read(buf) + } + } + + // Poll the `Negotiated`, driving protocol negotiation to completion, + // including flushing of any remaining data. + let result = self.poll(); + + // There is still remaining data to be sent before data relating + // to the negotiated protocol can be read. + if let Ok(Async::NotReady) = result { + return Err(io::ErrorKind::WouldBlock.into()) + } + + if let Err(err) = result { + return Err(err.into()) + } + } + } +} + +impl AsyncRead for Negotiated +where + TInner: AsyncRead + AsyncWrite +{ + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match &self.state { + State::Completed { io, .. } => + io.prepare_uninitialized_buffer(buf), + State::Expecting { io, .. } => + io.inner_ref().prepare_uninitialized_buffer(buf), + State::Invalid => panic!("Negotiated: Invalid state") + } + } +} + +impl io::Write for Negotiated +where + TInner: AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + match &mut self.state { + State::Completed { io, ref mut remaining } => { + if !remaining.is_empty() { + // Try to write `buf` together with `remaining` for efficiency, + // regardless of whether the underlying I/O stream is buffered. + // Every call to `write` may imply a syscall and separate + // network packet. + let remaining_len = remaining.len(); + remaining.extend_from_slice(buf); + match io.write(&remaining) { + Err(e) => { + remaining.split_off(buf.len()); + debug_assert_eq!(remaining.len(), remaining_len); + Err(e) + } + Ok(n) => { + remaining.split_to(n); + if !remaining.is_empty() { + let written = if n < buf.len() { + remaining.split_off(remaining_len); + n + } else { + buf.len() + }; + debug_assert!(remaining.len() <= remaining_len); + Ok(written) + } else { + Ok(buf.len()) + } + } + } + } else { + io.write(buf) + } + }, + State::Expecting { io, .. } => io.write(buf), + State::Invalid => panic!("Negotiated: Invalid state") + } + } + + fn flush(&mut self) -> io::Result<()> { + match &mut self.state { + State::Completed { io, ref mut remaining } => { + while !remaining.is_empty() { + let n = io.write(remaining)?; + if n == 0 { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "Failed to write remaining buffer.")) + } + remaining.split_to(n); + } + io.flush() + }, + State::Expecting { io, .. } => io.flush(), + State::Invalid => panic!("Negotiated: Invalid state") + } + } +} + +impl AsyncWrite for Negotiated +where + TInner: AsyncWrite + AsyncRead +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + // Ensure all data has been flushed and expected negotiation messages + // have been received. + try_ready!(self.poll().map_err(Into::::into)); + // Continue with the shutdown of the underlying I/O stream. + match &mut self.state { + State::Completed { io, .. } => io.shutdown(), + State::Expecting { io, .. } => io.shutdown(), + State::Invalid => panic!("Negotiated: Invalid state") + } + } +} + +/// Error that can happen when negotiating a protocol with the remote. +#[derive(Debug)] +pub enum NegotiationError { + /// A protocol error occurred during the negotiation. + ProtocolError(ProtocolError), + + /// Protocol negotiation failed because no protocol could be agreed upon. + Failed, +} + +impl From for NegotiationError { + fn from(err: ProtocolError) -> NegotiationError { + NegotiationError::ProtocolError(err) + } +} + +impl From for NegotiationError { + fn from(err: io::Error) -> NegotiationError { + ProtocolError::from(err).into() + } +} + +impl Into for NegotiationError { + fn into(self) -> io::Error { + if let NegotiationError::ProtocolError(e) = self { + return e.into() + } + io::Error::new(io::ErrorKind::Other, self) + } +} + +impl Error for NegotiationError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + NegotiationError::ProtocolError(err) => Some(err), + _ => None, + } + } +} + +impl fmt::Display for NegotiationError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(fmt, "{}", Error::description(self)) + } +} + diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs new file mode 100644 index 00000000000..4af91917510 --- /dev/null +++ b/misc/multistream-select/src/protocol.rs @@ -0,0 +1,486 @@ +// Copyright 2017 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Multistream-select protocol messages an I/O operations for +//! constructing protocol negotiation flows. +//! +//! A protocol negotiation flow is constructed by using the +//! `Stream` and `Sink` implementations of `MessageIO` and +//! `MessageReader`. + +use bytes::{Bytes, BytesMut, BufMut}; +use crate::length_delimited::{LengthDelimited, LengthDelimitedReader}; +use futures::{prelude::*, try_ready}; +use log::trace; +use std::{io, fmt, error::Error, convert::TryFrom}; +use tokio_io::{AsyncRead, AsyncWrite}; +use unsigned_varint as uvi; + +/// The maximum number of supported protocols that can be processed. +const MAX_PROTOCOLS: usize = 1000; + +/// The maximum length (in bytes) of a protocol name. +/// +/// This limit is necessary in order to be able to unambiguously parse +/// response messages without knowledge of the corresponding request. +/// 140 comes about from 3 * 47 = 141, where 47 is the ascii/utf8 +/// encoding of the `/` character and an encoded protocol name is +/// at least 3 bytes long (uvi-length followed by `/` and `\n`). +/// Hence a protocol list response message with 47 protocols is at least +/// 141 bytes long and thus such a response cannot be mistaken for a +/// single protocol response. See `Message::decode`. +const MAX_PROTOCOL_LEN: usize = 140; + +/// The encoded form of a multistream-select 1.0.0 header message. +const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n"; +/// The encoded form of a multistream-select 2.0.0 header message. +const MSG_MULTISTREAM_2_0: &[u8] = b"/multistream/2.0.0\n"; +/// The encoded form of a multistream-select 'na' message. +const MSG_PROTOCOL_NA: &[u8] = b"na\n"; +/// The encoded form of a multistream-select 'ls' message. +const MSG_LS: &[u8] = b"ls\n"; + +/// The known multistream-select protocol versions. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Version { + /// The first and currently still the only deployed version + /// of multistream-select. + V1, + /// Draft: https://github.com/libp2p/specs/pull/95 + V2, +} + +/// A protocol (name) exchanged during protocol negotiation. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Protocol(Bytes); + +impl AsRef<[u8]> for Protocol { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl TryFrom for Protocol { + type Error = ProtocolError; + + fn try_from(value: Bytes) -> Result { + if !value.as_ref().starts_with(b"/") || value.len() > MAX_PROTOCOL_LEN { + return Err(ProtocolError::InvalidProtocol) + } + Ok(Protocol(value)) + } +} + +impl TryFrom<&[u8]> for Protocol { + type Error = ProtocolError; + + fn try_from(value: &[u8]) -> Result { + Self::try_from(Bytes::from(value)) + } +} + +impl fmt::Display for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", String::from_utf8_lossy(&self.0)) + } +} + +/// A multistream-select protocol message. +/// +/// Multistream-select protocol messages are exchanged with the goal +/// of agreeing on a application-layer protocol to use on an I/O stream. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Message { + /// A header message identifies the multistream-select protocol + /// that the sender wishes to speak. + Header(Version), + /// A protocol message identifies a protocol request or acknowledgement. + Protocol(Protocol), + /// A message through which a peer requests the complete list of + /// supported protocols from the remote. + ListProtocols, + /// A message listing all supported protocols of a peer. + Protocols(Vec), + /// A message signaling that a requested protocol is not available. + NotAvailable, +} + +impl Message { + /// Encodes a `Message` into its byte representation. + pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> { + match self { + Message::Header(Version::V1) => { + dest.reserve(MSG_MULTISTREAM_1_0.len()); + dest.put(MSG_MULTISTREAM_1_0); + Ok(()) + } + Message::Header(Version::V2) => { + dest.reserve(MSG_MULTISTREAM_2_0.len()); + dest.put(MSG_MULTISTREAM_2_0); + Ok(()) + } + Message::Protocol(p) => { + let len = p.0.as_ref().len() + 1; // + 1 for \n + dest.reserve(len); + dest.put(p.0.as_ref()); + dest.put(&b"\n"[..]); + Ok(()) + } + Message::ListProtocols => { + dest.reserve(MSG_LS.len()); + dest.put(MSG_LS); + Ok(()) + } + Message::Protocols(ps) => { + let mut buf = uvi::encode::usize_buffer(); + let mut out_msg = Vec::from(uvi::encode::usize(ps.len(), &mut buf)); + for p in ps { + out_msg.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); // +1 for '\n' + out_msg.extend_from_slice(p.0.as_ref()); + out_msg.push(b'\n') + } + dest.reserve(out_msg.len()); + dest.put(out_msg); + Ok(()) + } + Message::NotAvailable => { + dest.reserve(MSG_PROTOCOL_NA.len()); + dest.put(MSG_PROTOCOL_NA); + Ok(()) + } + } + } + + /// Decodes a `Message` from its byte representation. + pub fn decode(mut msg: Bytes) -> Result { + if msg == MSG_MULTISTREAM_1_0 { + return Ok(Message::Header(Version::V1)) + } + + if msg == MSG_MULTISTREAM_2_0 { + return Ok(Message::Header(Version::V2)) + } + + if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') && msg.len() <= MAX_PROTOCOL_LEN { + let p = Protocol::try_from(msg.split_to(msg.len() - 1))?; + return Ok(Message::Protocol(p)); + } + + if msg == MSG_PROTOCOL_NA { + return Ok(Message::NotAvailable); + } + + if msg == MSG_LS { + return Ok(Message::ListProtocols) + } + + // At this point, it must be a varint number of protocols, i.e. + // a `Protocols` message. + let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?; + if num_protocols > MAX_PROTOCOLS { + return Err(ProtocolError::TooManyProtocols) + } + let mut protocols = Vec::with_capacity(num_protocols); + for _ in 0 .. num_protocols { + let (len, rem) = uvi::decode::usize(remaining)?; + if len == 0 || len > rem.len() || rem[len - 1] != b'\n' { + return Err(ProtocolError::InvalidMessage) + } + let p = Protocol::try_from(Bytes::from(&rem[.. len - 1]))?; + protocols.push(p); + remaining = &rem[len ..] + } + + return Ok(Message::Protocols(protocols)); + } +} + +/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. +pub struct MessageIO { + inner: LengthDelimited, +} + +impl MessageIO { + /// Constructs a new `MessageIO` resource wrapping the given I/O stream. + pub fn new(inner: R) -> MessageIO + where + R: AsyncRead + AsyncWrite + { + Self { inner: LengthDelimited::new(inner) } + } + + /// Converts the `MessageIO` into a `MessageReader`, dropping the + /// `Message`-oriented `Sink` in favour of direct `AsyncWrite` access + /// to the underlying I/O stream. + /// + /// This is typically done if further negotiation messages are expected to be + /// received but no more messages are written, allowing the writing of + /// follow-up protocol data to commence. + pub fn into_reader(self) -> MessageReader { + MessageReader { inner: self.inner.into_reader() } + } + + /// Drops the `MessageIO` resource, yielding the underlying I/O stream + /// together with the remaining write buffer containing the protocol + /// negotiation frame data that has not yet been written to the I/O stream. + /// + /// The returned remaining write buffer may be prepended to follow-up + /// protocol data to send with a single `write`. Either way, if non-empty, + /// the write buffer _must_ eventually be written to the I/O stream + /// _before_ any follow-up data, in order for protocol negotiation to + /// complete cleanly. + /// + /// # Panics + /// + /// Panics if the read buffer is not empty, meaning that an incoming + /// protocol negotiation frame has been partially read. The read buffer + /// is guaranteed to be empty whenever [`MessageIO::poll`] returned + /// a message. + pub fn into_inner(self) -> (R, BytesMut) { + self.inner.into_inner() + } +} + +impl Sink for MessageIO +where + R: AsyncWrite, +{ + type SinkItem = Message; + type SinkError = ProtocolError; + + fn start_send(&mut self, msg: Self::SinkItem) -> StartSend { + let mut buf = BytesMut::new(); + msg.encode(&mut buf)?; + match self.inner.start_send(buf.freeze())? { + AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(msg)), + AsyncSink::Ready => Ok(AsyncSink::Ready), + } + } + + fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + Ok(self.inner.poll_complete()?) + } + + fn close(&mut self) -> Poll<(), Self::SinkError> { + Ok(self.inner.close()?) + } +} + +impl Stream for MessageIO +where + R: AsyncRead +{ + type Item = Message; + type Error = ProtocolError; + + fn poll(&mut self) -> Poll, Self::Error> { + poll_stream(&mut self.inner) + } +} + +/// A `MessageReader` implements a `Stream` of `Message`s on an underlying +/// I/O resource combined with direct `AsyncWrite` access. +pub struct MessageReader { + inner: LengthDelimitedReader +} + +impl MessageReader { + /// Drops the `MessageReader` resource, yielding the underlying I/O stream + /// together with the remaining write buffer containing the protocol + /// negotiation frame data that has not yet been written to the I/O stream. + /// + /// The returned remaining write buffer may be prepended to follow-up + /// protocol data to send with a single `write`. Either way, if non-empty, + /// the write buffer _must_ eventually be written to the I/O stream + /// _before_ any follow-up data, in order for protocol negotiation to + /// complete cleanly. + /// + /// # Panics + /// + /// Panics if the read buffer is not empty, meaning that an incoming + /// protocol negotiation frame has been partially read. The read buffer + /// is guaranteed to be empty whenever [`MessageReader::poll`] returned + /// a message. + pub fn into_inner(self) -> (R, BytesMut) { + self.inner.into_inner() + } + + /// Returns a reference to the underlying I/O stream. + pub fn inner_ref(&self) -> &R { + self.inner.inner_ref() + } +} + +impl Stream for MessageReader +where + R: AsyncRead +{ + type Item = Message; + type Error = ProtocolError; + + fn poll(&mut self) -> Poll, Self::Error> { + poll_stream(&mut self.inner) + } +} + +impl io::Write for MessageReader +where + R: AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } +} + +impl AsyncWrite for MessageReader +where + TInner: AsyncWrite +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.inner.shutdown() + } +} + +fn poll_stream(stream: &mut S) -> Poll, ProtocolError> +where + S: Stream, +{ + let msg = if let Some(msg) = try_ready!(stream.poll()) { + Message::decode(msg)? + } else { + return Ok(Async::Ready(None)) + }; + + trace!("Received message: {:?}", msg); + + Ok(Async::Ready(Some(msg))) +} + +/// A protocol error. +#[derive(Debug)] +pub enum ProtocolError { + /// I/O error. + IoError(io::Error), + + /// Received an invalid message from the remote. + InvalidMessage, + + /// A protocol (name) is invalid. + InvalidProtocol, + + /// Too many protocols have been returned by the remote. + TooManyProtocols, +} + +impl From for ProtocolError { + fn from(err: io::Error) -> ProtocolError { + ProtocolError::IoError(err) + } +} + +impl Into for ProtocolError { + fn into(self) -> io::Error { + if let ProtocolError::IoError(e) = self { + return e + } + return io::ErrorKind::InvalidData.into() + } +} + +impl From for ProtocolError { + fn from(err: uvi::decode::Error) -> ProtocolError { + Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) + } +} + +impl Error for ProtocolError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match *self { + ProtocolError::IoError(ref err) => Some(err), + _ => None, + } + } +} + +impl fmt::Display for ProtocolError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + ProtocolError::IoError(e) => + write!(fmt, "I/O error: {}", e), + ProtocolError::InvalidMessage => + write!(fmt, "Received an invalid message."), + ProtocolError::InvalidProtocol => + write!(fmt, "A protocol (name) is invalid."), + ProtocolError::TooManyProtocols => + write!(fmt, "Too many protocols received.") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::*; + use rand::Rng; + use rand::distributions::Alphanumeric; + use std::iter; + + impl Arbitrary for Protocol { + fn arbitrary(g: &mut G) -> Protocol { + let n = g.gen_range(1, g.size()); + let p: String = iter::repeat(()) + .map(|()| g.sample(Alphanumeric)) + .take(n) + .collect(); + Protocol(Bytes::from(format!("/{}", p))) + } + } + + impl Arbitrary for Message { + fn arbitrary(g: &mut G) -> Message { + match g.gen_range(0, 5) { + 0 => Message::Header(Version::V1), + 1 => Message::NotAvailable, + 2 => Message::ListProtocols, + 3 => Message::Protocol(Protocol::arbitrary(g)), + 4 => Message::Protocols(Vec::arbitrary(g)), + _ => panic!() + } + } + } + + #[test] + fn encode_decode_message() { + fn prop(msg: Message) { + let mut buf = BytesMut::new(); + msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg)); + match Message::decode(buf.freeze()) { + Ok(m) => assert_eq!(m, msg), + Err(e) => panic!("Decoding failed: {:?}", e) + } + } + quickcheck(prop as fn(_)) + } +} + diff --git a/misc/multistream-select/src/protocol/dialer.rs b/misc/multistream-select/src/protocol/dialer.rs deleted file mode 100644 index 28da191a490..00000000000 --- a/misc/multistream-select/src/protocol/dialer.rs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Contains the `Dialer` wrapper, which allows raw communications with a listener. - -use super::*; - -use bytes::{Bytes, BytesMut}; -use crate::length_delimited::LengthDelimited; -use crate::protocol::{Request, Response, MultistreamSelectError}; -use futures::{prelude::*, sink, Async, StartSend, try_ready}; -use tokio_io::{AsyncRead, AsyncWrite}; -use std::marker; -use unsigned_varint as uvi; - -/// The maximum number of supported protocols that can be processed. -const MAX_PROTOCOLS: usize = 1000; - -/// Wraps around a `AsyncRead+AsyncWrite`. -/// Assumes that we're on the dialer's side. Produces and accepts messages. -pub struct Dialer { - inner: LengthDelimited, - handshake_finished: bool, - _protocol_name: marker::PhantomData, -} - -impl Dialer -where - R: AsyncRead + AsyncWrite, - N: AsRef<[u8]> -{ - pub fn dial(inner: R) -> DialerFuture { - let io = LengthDelimited::new(inner); - let mut buf = BytesMut::new(); - Header::Multistream10.encode(&mut buf); - DialerFuture { - inner: io.send(buf.freeze()), - _protocol_name: marker::PhantomData, - } - } - - /// Grants back the socket. Typically used after a `ProtocolAck` has been received. - pub fn into_inner(self) -> R { - self.inner.into_inner() - } -} - -impl Sink for Dialer -where - R: AsyncRead + AsyncWrite, - N: AsRef<[u8]> -{ - type SinkItem = Request; - type SinkError = MultistreamSelectError; - - fn start_send(&mut self, request: Self::SinkItem) -> StartSend { - let mut msg = BytesMut::new(); - request.encode(&mut msg)?; - match self.inner.start_send(msg.freeze())? { - AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(request)), - AsyncSink::Ready => Ok(AsyncSink::Ready), - } - } - - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - Ok(self.inner.poll_complete()?) - } - - fn close(&mut self) -> Poll<(), Self::SinkError> { - Ok(self.inner.close()?) - } -} - -impl Stream for Dialer -where - R: AsyncRead + AsyncWrite -{ - type Item = Response; - type Error = MultistreamSelectError; - - fn poll(&mut self) -> Poll, Self::Error> { - loop { - let mut msg = match self.inner.poll() { - Ok(Async::Ready(Some(msg))) => msg, - Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(err.into()), - }; - - if !self.handshake_finished { - if msg == MSG_MULTISTREAM_1_0 { - self.handshake_finished = true; - continue; - } else { - return Err(MultistreamSelectError::FailedHandshake); - } - } - - if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') { - let len = msg.len(); - let name = msg.split_to(len - 1); - return Ok(Async::Ready(Some( - Response::Protocol { name } - ))); - } else if msg == MSG_PROTOCOL_NA { - return Ok(Async::Ready(Some(Response::ProtocolNotAvailable))); - } else { - // A varint number of protocols - let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?; - if num_protocols > MAX_PROTOCOLS { // TODO: configurable limit - return Err(MultistreamSelectError::TooManyProtocols) - } - let mut protocols = Vec::with_capacity(num_protocols); - for _ in 0 .. num_protocols { - let (len, rem) = uvi::decode::usize(remaining)?; - if len == 0 || len > rem.len() || rem[len - 1] != b'\n' { - return Err(MultistreamSelectError::UnknownMessage) - } - protocols.push(Bytes::from(&rem[.. len - 1])); - remaining = &rem[len ..] - } - return Ok(Async::Ready(Some( - Response::SupportedProtocols { protocols }, - ))); - } - } - } -} - -/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`. -pub struct DialerFuture> { - inner: sink::Send>, - _protocol_name: marker::PhantomData, -} - -impl> Future for DialerFuture { - type Item = Dialer; - type Error = MultistreamSelectError; - - fn poll(&mut self) -> Poll { - let inner = try_ready!(self.inner.poll()); - Ok(Async::Ready(Dialer { - inner, - handshake_finished: false, - _protocol_name: marker::PhantomData, - })) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::runtime::current_thread::Runtime; - use tokio_tcp::{TcpListener, TcpStream}; - use futures::Future; - use futures::{Sink, Stream}; - - #[test] - fn wrong_proto_name() { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); - let listener_addr = listener.local_addr().unwrap(); - - let server = listener - .incoming() - .into_future() - .map(|_| ()) - .map_err(|(e, _)| e.into()); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |stream| Dialer::dial(stream)) - .and_then(move |dialer| { - let name = b"invalid_name"; - dialer.send(Request::Protocol { name }) - }); - - let mut rt = Runtime::new().unwrap(); - match rt.block_on(server.join(client)) { - Err(MultistreamSelectError::InvalidProtocolName) => (), - _ => panic!(), - } - } -} diff --git a/misc/multistream-select/src/protocol/error.rs b/misc/multistream-select/src/protocol/error.rs deleted file mode 100644 index f6686ee9fbd..00000000000 --- a/misc/multistream-select/src/protocol/error.rs +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Contains the error structs for the low-level protocol handling. - -use std::error::Error; -use std::fmt; -use std::io; -use unsigned_varint::decode; - -/// Error at the multistream-select layer of communication. -#[derive(Debug)] -pub enum MultistreamSelectError { - /// I/O error. - IoError(io::Error), - - /// The remote doesn't use the same multistream-select protocol as we do. - FailedHandshake, - - /// Received an unknown message from the remote. - UnknownMessage, - - /// Protocol names must always start with `/`, otherwise this error is returned. - InvalidProtocolName, - - /// Too many protocols have been returned by the remote. - TooManyProtocols, -} - -impl From for MultistreamSelectError { - fn from(err: io::Error) -> MultistreamSelectError { - MultistreamSelectError::IoError(err) - } -} - -impl From for MultistreamSelectError { - fn from(err: decode::Error) -> MultistreamSelectError { - Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) - } -} - -impl Error for MultistreamSelectError { - fn description(&self) -> &str { - match *self { - MultistreamSelectError::IoError(_) => "I/O error", - MultistreamSelectError::FailedHandshake => { - "the remote doesn't use the same multistream-select protocol as we do" - } - MultistreamSelectError::UnknownMessage => "received an unknown message from the remote", - MultistreamSelectError::InvalidProtocolName => { - "protocol names must always start with `/`, otherwise this error is returned" - } - MultistreamSelectError::TooManyProtocols => - "Too many protocols." - } - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - match *self { - MultistreamSelectError::IoError(ref err) => Some(err), - _ => None, - } - } -} - -impl fmt::Display for MultistreamSelectError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - write!(fmt, "{}", Error::description(self)) - } -} diff --git a/misc/multistream-select/src/protocol/listener.rs b/misc/multistream-select/src/protocol/listener.rs deleted file mode 100644 index 243304edcff..00000000000 --- a/misc/multistream-select/src/protocol/listener.rs +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Contains the `Listener` wrapper, which allows raw communications with a dialer. - -use super::*; - -use bytes::{Bytes, BytesMut}; -use crate::length_delimited::LengthDelimited; -use crate::protocol::{Request, Response, MultistreamSelectError}; -use futures::{prelude::*, sink, stream::StreamFuture}; -use log::{debug, trace}; -use std::{marker, mem}; -use tokio_io::{AsyncRead, AsyncWrite}; - -/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and -/// accepts messages. -pub struct Listener { - inner: LengthDelimited, - _protocol_name: marker::PhantomData, -} - -impl Listener -where - R: AsyncRead + AsyncWrite, - N: AsRef<[u8]> -{ - /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the - /// future returns a `Listener`. - pub fn listen(inner: R) -> ListenerFuture { - let inner = LengthDelimited::new(inner); - ListenerFuture { - inner: ListenerFutureState::Await { inner: inner.into_future() }, - _protocol_name: marker::PhantomData, - } - } - - /// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a - /// `ProtocolAck` has been sent back. - pub fn into_inner(self) -> R { - self.inner.into_inner() - } -} - -impl Sink for Listener -where - R: AsyncRead + AsyncWrite, - N: AsRef<[u8]> -{ - type SinkItem = Response; - type SinkError = MultistreamSelectError; - - fn start_send(&mut self, response: Self::SinkItem) -> StartSend { - let mut msg = BytesMut::new(); - response.encode(&mut msg)?; - match self.inner.start_send(msg.freeze())? { - AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(response)), - AsyncSink::Ready => Ok(AsyncSink::Ready) - } - } - - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - Ok(self.inner.poll_complete()?) - } - - fn close(&mut self) -> Poll<(), Self::SinkError> { - Ok(self.inner.close()?) - } -} - -impl Stream for Listener -where - R: AsyncRead + AsyncWrite, -{ - type Item = Request; - type Error = MultistreamSelectError; - - fn poll(&mut self) -> Poll, Self::Error> { - let mut msg = match self.inner.poll() { - Ok(Async::Ready(Some(msg))) => msg, - Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(err.into()), - }; - - if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') { - let len = msg.len(); - let name = msg.split_to(len - 1); - Ok(Async::Ready(Some( - Request::Protocol { name }, - ))) - } else if msg == MSG_LS { - Ok(Async::Ready(Some( - Request::ListProtocols, - ))) - } else { - Err(MultistreamSelectError::UnknownMessage) - } - } -} - - -/// Future, returned by `Listener::new` which performs the handshake and returns -/// the `Listener` if successful. -pub struct ListenerFuture { - inner: ListenerFutureState, - _protocol_name: marker::PhantomData, -} - -enum ListenerFutureState { - Await { - inner: StreamFuture> - }, - Reply { - sender: sink::Send> - }, - Undefined -} - -impl> Future for ListenerFuture { - type Item = Listener; - type Error = MultistreamSelectError; - - fn poll(&mut self) -> Poll { - loop { - match mem::replace(&mut self.inner, ListenerFutureState::Undefined) { - ListenerFutureState::Await { mut inner } => { - let (msg, socket) = - match inner.poll() { - Ok(Async::Ready(x)) => x, - Ok(Async::NotReady) => { - self.inner = ListenerFutureState::Await { inner }; - return Ok(Async::NotReady) - } - Err((e, _)) => return Err(MultistreamSelectError::from(e)) - }; - if msg.as_ref().map(|b| &b[..]) != Some(MSG_MULTISTREAM_1_0) { - debug!("Unexpected message: {:?}", msg); - return Err(MultistreamSelectError::FailedHandshake) - } - trace!("sending back /multistream/ to finish the handshake"); - let mut frame = BytesMut::new(); - Header::Multistream10.encode(&mut frame); - let sender = socket.send(frame.freeze()); - self.inner = ListenerFutureState::Reply { sender } - } - ListenerFutureState::Reply { mut sender } => { - let listener = match sender.poll()? { - Async::Ready(x) => x, - Async::NotReady => { - self.inner = ListenerFutureState::Reply { sender }; - return Ok(Async::NotReady) - } - }; - return Ok(Async::Ready(Listener { - inner: listener, - _protocol_name: marker::PhantomData - })) - } - ListenerFutureState::Undefined => - panic!("ListenerFutureState::poll called after completion") - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::runtime::current_thread::Runtime; - use tokio_tcp::{TcpListener, TcpStream}; - use bytes::Bytes; - use futures::Future; - use futures::{Sink, Stream}; - - #[test] - fn wrong_proto_name() { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); - let listener_addr = listener.local_addr().unwrap(); - - let server = listener - .incoming() - .into_future() - .map_err(|(e, _)| e.into()) - .and_then(move |(connec, _)| Listener::listen(connec.unwrap())) - .and_then(|listener| { - let name = Bytes::from("invalid-proto"); - listener.send(Response::Protocol { name }) - }); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |stream| Dialer::<_, Bytes>::dial(stream)); - - let mut rt = Runtime::new().unwrap(); - match rt.block_on(server.join(client)) { - Err(MultistreamSelectError::InvalidProtocolName) => (), - _ => panic!(), - } - } -} diff --git a/misc/multistream-select/src/protocol/mod.rs b/misc/multistream-select/src/protocol/mod.rs deleted file mode 100644 index 5b1fca7153b..00000000000 --- a/misc/multistream-select/src/protocol/mod.rs +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Contains lower-level structs to handle the multistream protocol. - -const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n"; -const MSG_PROTOCOL_NA: &[u8] = b"na\n"; -const MSG_LS: &[u8] = b"ls\n"; - -mod dialer; -mod error; -mod listener; - -pub use self::dialer::{Dialer, DialerFuture}; -pub use self::error::MultistreamSelectError; -pub use self::listener::{Listener, ListenerFuture}; - -use bytes::{BytesMut, BufMut}; -use unsigned_varint as uvi; - -pub enum Header { - Multistream10 -} - -impl Header { - fn encode(&self, dest: &mut BytesMut) { - match self { - Header::Multistream10 => { - dest.reserve(MSG_MULTISTREAM_1_0.len()); - dest.put(MSG_MULTISTREAM_1_0); - } - } - } -} - -/// Message sent from the dialer to the listener. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Request { - /// The dialer wants us to use a protocol. - /// - /// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start - /// communicating in the new protocol. - Protocol { - /// Name of the protocol. - name: N - }, - - /// The dialer requested the list of protocols that the listener supports. - ListProtocols, -} - -impl> Request { - fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { - match self { - Request::Protocol { name } => { - if !name.as_ref().starts_with(b"/") { - return Err(MultistreamSelectError::InvalidProtocolName) - } - let len = name.as_ref().len() + 1; // + 1 for \n - dest.reserve(len); - dest.put(name.as_ref()); - dest.put(&b"\n"[..]); - Ok(()) - } - Request::ListProtocols => { - dest.reserve(MSG_LS.len()); - dest.put(MSG_LS); - Ok(()) - } - } - } -} - - -/// Message sent from the listener to the dialer. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Response { - /// The protocol requested by the dialer is accepted. The socket immediately starts using the - /// new protocol. - Protocol { name: N }, - - /// The protocol requested by the dialer is not supported or available. - ProtocolNotAvailable, - - /// Response to the request for the list of protocols. - SupportedProtocols { - /// The list of protocols. - // TODO: use some sort of iterator - protocols: Vec, - }, -} - -impl> Response { - fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> { - match self { - Response::Protocol { name } => { - if !name.as_ref().starts_with(b"/") { - return Err(MultistreamSelectError::InvalidProtocolName) - } - let len = name.as_ref().len() + 1; // + 1 for \n - dest.reserve(len); - dest.put(name.as_ref()); - dest.put(&b"\n"[..]); - Ok(()) - } - Response::SupportedProtocols { protocols } => { - let mut buf = uvi::encode::usize_buffer(); - let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf)); - for p in protocols { - out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n' - out_msg.extend_from_slice(p.as_ref()); - out_msg.push(b'\n') - } - dest.reserve(out_msg.len()); - dest.put(out_msg); - Ok(()) - } - Response::ProtocolNotAvailable => { - dest.reserve(MSG_PROTOCOL_NA.len()); - dest.put(MSG_PROTOCOL_NA); - Ok(()) - } - } - } -} - - diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index dbfc0588d7e..95e7c151849 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -22,65 +22,13 @@ #![cfg(test)] -use crate::ProtocolChoiceError; +use crate::NegotiationError; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; -use crate::protocol::{Dialer, Request, Listener, Response}; use crate::{dialer_select_proto, listener_select_proto}; use futures::prelude::*; use tokio::runtime::current_thread::Runtime; use tokio_tcp::{TcpListener, TcpStream}; - -/// Holds a `Vec` and satifies the iterator requirements of `listener_select_proto`. -struct VecRefIntoIter(Vec); - -impl<'a, T> IntoIterator for &'a VecRefIntoIter -where T: Clone -{ - type Item = T; - type IntoIter = std::vec::IntoIter; - fn into_iter(self) -> Self::IntoIter { - self.0.clone().into_iter() - } -} - -#[test] -fn negotiate_with_self_succeeds() { - let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); - let listener_addr = listener.local_addr().unwrap(); - - let server = listener - .incoming() - .into_future() - .map_err(|(e, _)| e.into()) - .and_then(move |(connec, _)| Listener::listen(connec.unwrap())) - .and_then(|l| l.into_future().map_err(|(e, _)| e)) - .and_then(|(msg, rest)| { - let proto = match msg { - Some(Request::Protocol { name }) => name, - _ => panic!(), - }; - rest.send(Response::Protocol { name: proto }) - }); - - let client = TcpStream::connect(&listener_addr) - .from_err() - .and_then(move |stream| Dialer::dial(stream)) - .and_then(move |dialer| { - let name = b"/hello/1.0.0"; - dialer.send(Request::Protocol { name }) - }) - .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e)) - .and_then(move |(msg, _)| { - let proto = match msg { - Some(Response::Protocol { name }) => name, - _ => panic!(), - }; - assert_eq!(proto, "/hello/1.0.0"); - Ok(()) - }); - let mut rt = Runtime::new().unwrap(); - let _ = rt.block_on(server.join(client)).unwrap(); -} +use tokio_io::io as nio; #[test] fn select_proto_basic() { @@ -94,18 +42,32 @@ fn select_proto_basic() { .map_err(|(e, _)| e.into()) .and_then(move |connec| { let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, VecRefIntoIter(protos)).map(|r| r.0) + listener_select_proto(connec, protos) + }) + .and_then(|(proto, io)| { + nio::write_all(io, b"pong").from_err().map(move |_| proto) }); let client = TcpStream::connect(&listener_addr) .from_err() .and_then(move |connec| { let protos = vec![b"/proto3", b"/proto2"]; - dialer_select_proto(connec, protos).map(|r| r.0) + dialer_select_proto(connec, protos) + }) + .and_then(|(proto, io)| { + nio::write_all(io, b"ping").from_err().map(move |(io, _)| (proto, io)) + }) + .and_then(|(proto, io)| { + nio::read_exact(io, [0; 4]).from_err().map(move |(_, msg)| { + assert_eq!(&msg, b"pong"); + proto + }) }); + let mut rt = Runtime::new().unwrap(); let (dialer_chosen, listener_chosen) = rt.block_on(client.join(server)).unwrap(); + assert_eq!(dialer_chosen, b"/proto2"); assert_eq!(listener_chosen, b"/proto2"); } @@ -122,19 +84,22 @@ fn no_protocol_found() { .map_err(|(e, _)| e.into()) .and_then(move |connec| { let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, VecRefIntoIter(protos)).map(|r| r.0) - }); + listener_select_proto(connec, protos) + }) + .and_then(|(proto, io)| io.complete().map(move |_| proto)); let client = TcpStream::connect(&listener_addr) .from_err() .and_then(move |connec| { let protos = vec![b"/proto3", b"/proto4"]; - dialer_select_proto(connec, protos).map(|r| r.0) - }); + dialer_select_proto(connec, protos) + }) + .and_then(|(proto, io)| io.complete().map(move |_| proto)); + let mut rt = Runtime::new().unwrap(); match rt.block_on(client.join(server)) { - Err(ProtocolChoiceError::NoProtocolFound) => (), - _ => panic!(), + Err(NegotiationError::Failed) => (), + e => panic!("{:?}", e), } } @@ -150,19 +115,22 @@ fn select_proto_parallel() { .map_err(|(e, _)| e.into()) .and_then(move |connec| { let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, VecRefIntoIter(protos)).map(|r| r.0) - }); + listener_select_proto(connec, protos) + }) + .and_then(|(proto, io)| io.complete().map(move |_| proto)); let client = TcpStream::connect(&listener_addr) .from_err() .and_then(move |connec| { let protos = vec![b"/proto3", b"/proto2"]; - dialer_select_proto_parallel(connec, protos.into_iter()).map(|r| r.0) - }); + dialer_select_proto_parallel(connec, protos.into_iter()) + }) + .and_then(|(proto, io)| io.complete().map(move |_| proto)); let mut rt = Runtime::new().unwrap(); let (dialer_chosen, listener_chosen) = rt.block_on(client.join(server)).unwrap(); + assert_eq!(dialer_chosen, b"/proto2"); assert_eq!(listener_chosen, b"/proto2"); } @@ -179,19 +147,22 @@ fn select_proto_serial() { .map_err(|(e, _)| e.into()) .and_then(move |connec| { let protos = vec![b"/proto1", b"/proto2"]; - listener_select_proto(connec, VecRefIntoIter(protos)).map(|r| r.0) - }); + listener_select_proto(connec, protos) + }) + .and_then(|(proto, io)| io.complete().map(move |_| proto)); let client = TcpStream::connect(&listener_addr) .from_err() .and_then(move |connec| { let protos = vec![b"/proto3", b"/proto2"]; - dialer_select_proto_serial(connec, protos.into_iter()).map(|r| r.0) - }); + dialer_select_proto_serial(connec, protos.into_iter()) + }) + .and_then(|(proto, io)| io.complete().map(move |_| proto)); let mut rt = Runtime::new().unwrap(); let (dialer_chosen, listener_chosen) = rt.block_on(client.join(server)).unwrap(); + assert_eq!(dialer_chosen, b"/proto2"); assert_eq!(listener_chosen, b"/proto2"); } diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index a1fe1636d6d..8806b031551 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -27,7 +27,7 @@ use bytes::Bytes; use libp2p_core::{ Endpoint, StreamMuxer, - upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, Negotiated} + upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, Negotiated}, }; use log::{debug, trace}; use parking_lot::Mutex; diff --git a/protocols/deflate/tests/test.rs b/protocols/deflate/tests/test.rs index 0abed2857d6..7ea9b116570 100644 --- a/protocols/deflate/tests/test.rs +++ b/protocols/deflate/tests/test.rs @@ -85,7 +85,7 @@ where .unwrap() .map_err(|e| panic!("client error: {}", e)) .and_then(move |server| { - io::write_all(server, message2).and_then(|(client, _)| io::flush(client)) + io::write_all(server, message2).and_then(|(client, _)| io::shutdown(client)) }) .map(|_| ()); diff --git a/protocols/floodsub/src/protocol.rs b/protocols/floodsub/src/protocol.rs index 532a0f88f80..e6951321dc5 100644 --- a/protocols/floodsub/src/protocol.rs +++ b/protocols/floodsub/src/protocol.rs @@ -49,7 +49,7 @@ impl UpgradeInfo for FloodsubConfig { impl InboundUpgrade for FloodsubConfig where - TSocket: AsyncRead, + TSocket: AsyncRead + AsyncWrite, { type Output = FloodsubRpc; type Error = FloodsubDecodeError; @@ -164,7 +164,7 @@ impl UpgradeInfo for FloodsubRpc { impl OutboundUpgrade for FloodsubRpc where - TSocket: AsyncWrite, + TSocket: AsyncWrite + AsyncRead, { type Output = (); type Error = io::Error; diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 320483cb81a..71e5ab7ef20 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -35,6 +35,7 @@ use libp2p_core::{ either::EitherOutput, upgrade::{self, InboundUpgrade, OutboundUpgrade, Negotiated} }; +use log::trace; use multihash::Multihash; use std::{borrow::Cow, error, fmt, io, time::Duration}; use tokio_io::{AsyncRead, AsyncWrite}; @@ -80,7 +81,6 @@ where KadRequestMsg, Option, ), - /// Waiting to send a message to the remote. /// Waiting to flush the substream so that the data arrives to the remote. OutPendingFlush(KadOutStreamSink, Option), /// Waiting for an answer back from the remote. @@ -830,7 +830,14 @@ where None, false, ), - Ok(Async::Ready(None)) | Err(_) => (None, None, false), + Ok(Async::Ready(None)) => { + trace!("Inbound substream: EOF"); + (None, None, false) + } + Err(e) => { + trace!("Inbound substream error: {:?}", e); + (None, None, false) + }, }, SubstreamState::InWaitingUser(id, substream) => ( Some(SubstreamState::InWaitingUser(id, substream)),