diff --git a/Cargo.toml b/Cargo.toml index f5b9ddf..e889e3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stubborn-io" -version = "0.3.2" +version = "0.3.3" authors = ["David Raifaizen "] edition = "2021" description = "io traits/structs that automatically recover from potential disconnections/interruptions." diff --git a/src/tokio/io.rs b/src/tokio/io.rs index a3bbe9d..d4d1473 100644 --- a/src/tokio/io.rs +++ b/src/tokio/io.rs @@ -95,12 +95,24 @@ enum Status { FailedAndExhausted, // the way one feels after programming in dynamically typed languages } +#[inline] +fn poll_err( + kind: ErrorKind, + reason: impl Into>, +) -> Poll> { + let io_err = io::Error::new(kind, reason); + Poll::Ready(Err(io_err)) +} + fn exhausted_err() -> Poll> { - let io_err = io::Error::new( + poll_err( ErrorKind::NotConnected, "Disconnected. Connection attempts have been exhausted.", - ); - Poll::Ready(Err(io_err)) + ) +} + +fn disconnected_err() -> Poll> { + poll_err(ErrorKind::NotConnected, "Underlying I/O is disconnected.") } impl Deref for StubbornIo { @@ -381,7 +393,7 @@ where poll } - Status::Disconnected(_) => Poll::Pending, + Status::Disconnected(_) => disconnected_err(), Status::FailedAndExhausted => exhausted_err(), } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs new file mode 100644 index 0000000..7c31127 --- /dev/null +++ b/tests/integration_tests.rs @@ -0,0 +1,28 @@ +use std::time::Duration; + +use stubborn_io::StubbornTcpStream; +use tokio::{io::AsyncWriteExt, sync::oneshot}; + +#[tokio::test] +async fn back_to_back_shutdown_attempts() { + let (port_tx, port_rx) = oneshot::channel(); + tokio::spawn(async move { + let mut streams = Vec::new(); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + port_tx.send(addr).unwrap(); + loop { + let (stream, _addr) = listener.accept().await.unwrap(); + streams.push(stream); + } + }); + let addr = port_rx.await.unwrap(); + let mut connection = StubbornTcpStream::connect(addr).await.unwrap(); + + connection.shutdown().await.unwrap(); + let elapsed = tokio::time::timeout(Duration::from_secs(5), connection.shutdown()).await; + + let result = elapsed.unwrap(); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::NotConnected); +}