Skip to content
This repository has been archived by the owner on Nov 15, 2023. It is now read-only.

network: Detect early that NotificationOutSubstream was closed by the remote #13396

Merged
merged 5 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions client/network/src/protocol/notifications/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,9 @@ impl ConnectionHandler for NotifsHandler {
// performed before the code paths that can produce `Ready` (with some rare exceptions).
// Importantly, however, the flush is performed *after* notifications are queued with
// `Sink::start_send`.
// Note that we must call `poll_flush` on all substreams and not only on those we
// have called `Sink::start_send` on, because `NotificationsOutSubstream::poll_flush`
// also reports the substream termination (even if no data was written into it).
for protocol_index in 0..self.protocols.len() {
match &mut self.protocols[protocol_index].state {
State::Open { out_substream: out_substream @ Some(_), .. } => {
Expand Down Expand Up @@ -824,7 +827,7 @@ impl ConnectionHandler for NotifsHandler {
State::OpenDesiredByRemote { in_substream, pending_opening } =>
match NotificationsInSubstream::poll_process(Pin::new(in_substream), cx) {
Poll::Pending => {},
Poll::Ready(Ok(void)) => match void {},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => {
self.protocols[protocol_index].state =
State::Closed { pending_opening: *pending_opening };
Expand All @@ -840,7 +843,7 @@ impl ConnectionHandler for NotifsHandler {
cx,
) {
Poll::Pending => {},
Poll::Ready(Ok(void)) => match void {},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => *in_substream = None,
},
}
Expand Down
132 changes: 122 additions & 10 deletions client/network/src/protocol/notifications/upgrade/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use log::{error, warn};
use sc_network_common::protocol::ProtocolName;
use std::{
convert::Infallible,
io, mem,
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -221,10 +220,7 @@ where

/// Equivalent to `Stream::poll_next`, except that it only drives the handshake and is
/// guaranteed to not generate any notification.
pub fn poll_process(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<Infallible, io::Error>> {
pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let mut this = self.project();

loop {
Expand All @@ -246,8 +242,10 @@ where
},
NotificationsInSubstreamHandshake::Flush => {
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Ready(()) => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Ready(Ok(()))
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
Expand All @@ -260,7 +258,7 @@ where
st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
*this.handshake = st;
return Poll::Pending
return Poll::Ready(Ok(()))
},
}
}
Expand Down Expand Up @@ -443,6 +441,21 @@ where

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();

// `Sink::poll_flush` does not expose stream closed error until we write something into
// the stream, so the code below makes sure we detect that the substream was closed
// even if we don't write anything into it.
match Stream::poll_next(this.socket.as_mut(), cx) {
Poll::Pending => {},
Poll::Ready(Some(_)) => {
error!(
target: "sub-libp2p",
"Unexpected incoming data in `NotificationsOutSubstream`",
);
},
Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Terminated)),
}

Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}

Expand Down Expand Up @@ -492,13 +505,21 @@ pub enum NotificationsOutError {
/// I/O error on the substream.
#[error(transparent)]
Io(#[from] io::Error),

/// End of incoming data detected on out substream.
#[error("substream was closed/reset")]
Terminated,
}

#[cfg(test)]
mod tests {
use super::{NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutOpen};
use futures::{channel::oneshot, prelude::*};
use super::{
NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutError,
NotificationsOutOpen,
};
use futures::{channel::oneshot, future, prelude::*};
use libp2p::core::upgrade;
use std::{pin::Pin, task::Poll};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::TokioAsyncReadCompatExt;

Expand Down Expand Up @@ -691,4 +712,95 @@ mod tests {

client.await.unwrap();
}

#[tokio::test]
async fn send_handshake_without_polling_for_incoming_data() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();

let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, .. } = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await
.unwrap();

assert_eq!(handshake, b"hello world");
});

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();

let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();

assert_eq!(handshake, b"initial message");
substream.send_handshake(&b"hello world"[..]);

// Actually send the handshake.
future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();

client.await.unwrap();
}

#[tokio::test]
async fn can_detect_dropped_out_substream_without_writing_data() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();

let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, mut substream, .. } = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await
.unwrap();

assert_eq!(handshake, b"hello world");

future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
cx.waker().wake_by_ref();
Poll::Pending
},
Poll::Ready(Err(e)) => {
assert!(matches!(e, NotificationsOutError::Terminated));
Poll::Ready(())
},
})
.await;
});

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();

let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();

assert_eq!(handshake, b"initial message");

// Send the handhsake.
substream.send_handshake(&b"hello world"[..]);
future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();

drop(substream);

client.await.unwrap();
}
}