Skip to content

Commit

Permalink
fix: make poll_shutdown and poll_flush more robust (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac authored Nov 15, 2023
1 parent d9d39f2 commit 535de16
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 23 deletions.
46 changes: 32 additions & 14 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ impl ConnectionStream {
return Poll::Ready(Ok(0));
}

// Writes after shutdown return NotConnected
if self.wants_close_sent {
self.wr_waker.take();
return Poll::Ready(Err(ErrorKind::NotConnected.into()));
}

self.poll_perform_write(cx, false, |tls| {
let n = tls.writer().write(buf).expect("Write will never fail");
trace!("w={n}");
Expand All @@ -346,6 +352,12 @@ impl ConnectionStream {
return Poll::Ready(Ok(0));
}

// Writes after shutdown return NotConnected
if self.wants_close_sent {
self.wr_waker.take();
return Poll::Ready(Err(ErrorKind::NotConnected.into()));
}

self.poll_perform_write(cx, false, |tls| {
// TODO(mmastrac): we should manually write individual bufs here as rustls is not optimal if internal buffers are full
let n = tls
Expand All @@ -366,12 +378,6 @@ impl ConnectionStream {
flushing: bool,
f: impl Fn(&mut Connection) -> T,
) -> Poll<io::Result<T>> {
// Writes after shutdown return NotConnected
if !flushing && self.wants_close_sent {
self.wr_waker.take();
return Poll::Ready(Err(ErrorKind::NotConnected.into()));
}

// First prepare to write
let res = loop {
let write = self.poll_write_only(PollContext::Explicit(cx));
Expand All @@ -388,12 +394,18 @@ impl ConnectionStream {
StreamProgress::NoInterest => {
// Write it
let n = f(&mut self.tls);
// Drain what we can
while self.poll_write_only(PollContext::Explicit(cx))
== StreamProgress::MadeProgress
{}
// And then return what we wrote
break Poll::Ready(Ok(n));
// Drain what we can, and then return what we wrote
break loop {
break match self.poll_write_only(PollContext::Explicit(cx)) {
StreamProgress::MadeProgress => continue,
StreamProgress::Error => {
Poll::Ready(Err(self.wr_error.unwrap().into()))
}
StreamProgress::RegisteredWaker if flushing => Poll::Pending,
StreamProgress::RegisteredWaker => Poll::Ready(Ok(n)),
StreamProgress::NoInterest => Poll::Ready(Ok(n)),
};
};
}
};
};
Expand Down Expand Up @@ -443,11 +455,15 @@ impl ConnectionStream {
// Immediate state change so we can error writes
self.wants_close_sent = true;
if !self.close_sent {
ready!(self.poll_flush(cx))?;
trace!("sending CloseNotify");
self.tls.send_close_notify();
self.close_sent = true;
}
ready!(self.poll_flush(cx))?;

// Don't shut down until we've flushed CloseNotify
ready!(self.poll_flush(cx)?);
debug_assert!(!self.tls.wants_write());

// Note that this is not technically an async call
// TODO(mmastrac): This is currently untested
let tcp_ref: &TcpStream = &self.tcp;
Expand All @@ -456,6 +472,8 @@ impl ConnectionStream {
let mut tcp_ptr = unsafe {
NonNull::new(tcp_ref as *const _ as *mut TcpStream).unwrap_unchecked()
};

trace!("poll_shutdown complete");
// SAFETY: We know that poll_shutdown never uses a mutable reference here
_ = Pin::new(unsafe { tcp_ptr.as_mut() }).poll_shutdown(cx);
Poll::Ready(Ok(()))
Expand Down
2 changes: 1 addition & 1 deletion src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async fn handshake_task_internal(
}
}

// return Err(err);
trace!("handshake error = {err:?}");

// This is a bit of sleight-of-hand: if the handshake fails to write because the other side is gone
// or otherwise errors, _BUT_ writing takes us out of handshaking mode, we treat this as a successful
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ macro_rules! trace {
($($args:expr),+) => {
if cfg!(feature="trace")
{
print!("[{:?}] ", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis());
println!($($args),+);
}
};
Expand Down
71 changes: 63 additions & 8 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::task::ready;
use std::thread::sleep;
use std::time::Duration;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
use tokio::spawn;
use tokio::sync::watch;
use tokio::task::spawn_blocking;
use tokio::task::JoinError;
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -357,6 +360,26 @@ impl TlsStream {
}
}

pub fn linger(&self) -> Result<Option<Duration>, io::Error> {
match &self.state {
TlsStreamState::Open(stm) => stm.tcp_stream().linger(),
TlsStreamState::Handshaking { tcp, .. } => tcp.linger(),
TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
Err(std::io::ErrorKind::NotConnected.into())
}
}
}

pub fn set_linger(&self, dur: Option<Duration>) -> Result<(), io::Error> {
match &self.state {
TlsStreamState::Open(stm) => stm.tcp_stream().set_linger(dur),
TlsStreamState::Handshaking { tcp, .. } => tcp.set_linger(dur),
TlsStreamState::Closed | TlsStreamState::ClosedError(_) => {
Err(std::io::ErrorKind::NotConnected.into())
}
}
}

/// Returns the peer address of this socket.
pub fn peer_addr(&self) -> Result<std::net::SocketAddr, io::Error> {
match &self.state {
Expand Down Expand Up @@ -499,8 +522,8 @@ impl TlsStream {
}

pub async fn close(mut self) -> io::Result<()> {
let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
trace!("closing {self:?}");
let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
match state {
TlsStreamState::Handshaking {
handle,
Expand Down Expand Up @@ -800,18 +823,34 @@ impl Drop for TlsStream {
let state = std::mem::replace(&mut self.state, TlsStreamState::Closed);
match state {
TlsStreamState::Handshaking {
handle, write_buf, ..
handle,
write_buf,
tcp,
..
} => {
spawn(async move {
trace!("in drop task");
match handle.await {
Ok(Ok(result)) => {
drop(tcp);
// TODO(mmastrac): if we split ConnectionStream we can remove this Arc and use reclaim2
let (tcp, tls) = result.into_inner();
let mut stm = ConnectionStream::new(tcp, tls);
stm.write_buf_fully(&write_buf);
let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
trace!("{:?}", res);
trace!("shutdown handshake {:?}", res);
let (tcp, _) = stm.into_inner();
if let Ok(tcp) = tcp.into_std() {
spawn_blocking(move || {
// TODO(mmastrac): this should not be necessary with SO_LINGER but I cannot get that working
trace!("in drop tcp task");
// Drop the TCP stream here just in case close() blocks
_ = tcp.set_nonblocking(false);
sleep(Duration::from_secs(1));
drop(tcp);
trace!("done drop tcp task");
});
}
}
x @ Err(_) => {
trace!("{x:?}");
Expand All @@ -827,7 +866,18 @@ impl Drop for TlsStream {
spawn(async move {
trace!("in drop task");
let res = poll_fn(|cx| stm.poll_shutdown(cx)).await;
trace!("{:?}", res);
trace!("shutdown open {:?}", res);
let (tcp, _) = stm.into_inner();
if let Ok(tcp) = tcp.into_std() {
spawn_blocking(move || {
trace!("in drop tcp task");
// Drop the TCP stream here just in case close() blocks
_ = tcp.set_nonblocking(false);
sleep(Duration::from_secs(1));
drop(tcp);
trace!("done drop tcp task");
});
}
trace!("done drop task");
});
}
Expand Down Expand Up @@ -1817,7 +1867,7 @@ pub(super) mod tests {
#[case] swap: bool,
) -> TestResult {
const BUF_SIZE: usize = 1024;
const BUF_COUNT: usize = 10 * 1024;
const BUF_COUNT: usize = 1 * 1024;

let (server, client) = tls_pair_buffer_size(NonZeroUsize::new(65536)).await;
let (server, client) = if swap {
Expand Down Expand Up @@ -1848,6 +1898,7 @@ pub(super) mod tests {
let n = w.write(&buf).await.unwrap();
w.flush().await.unwrap();
buf = &mut buf[n..];
trace!("[TEST] wrote {n}");
}
w.shutdown().await.unwrap();
barrier2.wait().await;
Expand All @@ -1856,13 +1907,17 @@ pub(super) mod tests {

let r = a.await.unwrap();
let w = b.await.unwrap();
r.unsplit(w).close().await.unwrap();
drop(r.unsplit(w));
});
let b = spawn(async move {
let (mut r, _w) = client.into_split();
let mut buf = vec![0; BUF_SIZE];
for _i in 0..BUF_COUNT {
assert_eq!(BUF_SIZE, r.read_exact(&mut buf).await.unwrap());
for i in 0..BUF_COUNT {
let r = r.read_exact(&mut buf).await;
if let Err(e) = &r {
panic!("Failed to read after {i} of {BUF_COUNT} reads: {e:?}");
};
assert_eq!(BUF_SIZE, r.unwrap());
}
expect_eof_read(&mut r).await;
});
Expand Down

0 comments on commit 535de16

Please sign in to comment.