diff --git a/util/BUILD b/util/BUILD index a0cdc228d..18c0127fa 100644 --- a/util/BUILD +++ b/util/BUILD @@ -1,6 +1,6 @@ # Copyright 2020 Nathan (Blaise) Bruer. All rights reserved. -load("@io_bazel_rules_rust//rust:rust.bzl", "rust_library") +load("@io_bazel_rules_rust//rust:rust.bzl", "rust_library", "rust_test") rust_library( name = "error", @@ -37,3 +37,14 @@ rust_library( ], visibility = ["//visibility:public"], ) + +rust_test( + name = "utils_tests", + srcs = ["tests/async_fixed_buffer_tests.rs"], + deps = [ + ":async_fixed_buffer", + ":error", + "//third_party:tokio", + "//third_party:pretty_assertions", + ], +) diff --git a/util/async_fixed_buffer.rs b/util/async_fixed_buffer.rs index d90b3a7d8..9a97caf05 100644 --- a/util/async_fixed_buffer.rs +++ b/util/async_fixed_buffer.rs @@ -17,6 +17,7 @@ #![forbid(unsafe_code)] +use std::sync::Arc; use std::sync::Mutex; use core::pin::Pin; @@ -27,8 +28,8 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; pub struct AsyncFixedBuf { inner: FixedBuf, - waker: Mutex>, - did_shutdown: AtomicBool, + waker: Arc>>, + did_shutdown: Arc, write_amt: AtomicUsize, read_amt: AtomicUsize, } @@ -42,20 +43,37 @@ impl AsyncFixedBuf { pub fn new(mem: T) -> Self { AsyncFixedBuf { inner: FixedBuf::new(mem), - waker: Mutex::new(None), - did_shutdown: AtomicBool::new(false), + waker: Arc::new(Mutex::new(None)), + did_shutdown: Arc::new(AtomicBool::new(false)), write_amt: AtomicUsize::new(0), read_amt: AtomicUsize::new(0), } } + // Utility method that can be used to get a lambda that will close the + // stream. This is useful for the reader to close the stream. + pub fn get_closer(&mut self) -> Box { + let did_shutdown = self.did_shutdown.clone(); + let waker = self.waker.clone(); + Box::new(move || { + if did_shutdown.load(Ordering::Relaxed) { + return; + } + did_shutdown.store(true, Ordering::Relaxed); + let mut waker = waker.lock().unwrap(); + if let Some(w) = waker.take() { + w.wake() + } + }) + } + fn park(&mut self, new_waker: &Waker) { - let waker = self.waker.get_mut().unwrap(); + let mut waker = self.waker.lock().unwrap(); *waker = Some(new_waker.clone()); } fn wake(&mut self) { - let waker = self.waker.get_mut().unwrap(); + let mut waker = self.waker.lock().unwrap(); if let Some(w) = waker.take() { w.wake() } @@ -98,23 +116,23 @@ impl + Unpin> tokio::io::AsyncRead for AsyncFixedBuf { impl> tokio::io::AsyncWrite for AsyncFixedBuf { fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + if self.did_shutdown.load(Ordering::Relaxed) { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "Receiver disconnected", + ))); + } match self.inner.writable() { Some(writable_slice) => { let write_amt = buf.len().min(writable_slice.len()); - let mut result = Ok(write_amt); if write_amt > 0 { writable_slice[..write_amt].clone_from_slice(&buf[..write_amt]); self.inner.wrote(write_amt); - } else if self.did_shutdown.load(Ordering::Relaxed) { - result = Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "Receiver disconnected", - )); } self.wake(); self.write_amt.fetch_add(write_amt, Ordering::Relaxed); - Poll::Ready(result) + Poll::Ready(Ok(write_amt)) } None => { self.park(cx.waker()); diff --git a/util/tests/async_fixed_buffer_tests.rs b/util/tests/async_fixed_buffer_tests.rs new file mode 100644 index 000000000..35cd1baf8 --- /dev/null +++ b/util/tests/async_fixed_buffer_tests.rs @@ -0,0 +1,56 @@ +// Copyright 2020 Nathan (Blaise) Bruer. All rights reserved. + +use async_fixed_buffer::AsyncFixedBuf; +use error::{make_err, Code, Error, ResultExt}; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; + +#[cfg(test)] +mod async_fixed_buffer_tests { + use super::*; + use pretty_assertions::assert_eq; // Must be declared in every module. + + #[tokio::test] + async fn get_closer_closes_read_stream_early() -> Result<(), Error> { + let mut raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 32].into_boxed_slice()); + let mut stream_closer = raw_fixed_buffer.get_closer(); + let (mut rx, mut tx) = tokio::io::split(raw_fixed_buffer); + + tx.write_all(&vec![255u8; 4]).await?; + + let read_spawn = tokio::spawn(async move { + let mut read_buffer = vec![0u8; 5]; + rx.read_exact(&mut read_buffer[..]).await + }); + // Wait a few cycles to ensure we are in a read loop. + for _ in 0..100 { + tokio::task::yield_now().await; + } + stream_closer(); + let read_result = read_spawn.await.err_tip(|| "Failed to join thread")?; + let err: Error = read_result.unwrap_err().into(); + assert_eq!(err, make_err!(Code::Internal, "Sender disconnected")); + Ok(()) + } + + #[tokio::test] + async fn get_closer_closes_write_stream_early() -> Result<(), Error> { + let mut raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 4].into_boxed_slice()); + let mut stream_closer = raw_fixed_buffer.get_closer(); + let (_, mut tx) = tokio::io::split(raw_fixed_buffer); + + let write_spawn = tokio::spawn(async move { + let mut read_buffer = vec![0u8; 5]; + tx.write_all(&mut read_buffer[..]).await + }); + // Wait a few cycles to ensure we are in a read loop. + for _ in 0..100 { + tokio::task::yield_now().await; + } + stream_closer(); + let read_result = write_spawn.await.err_tip(|| "Failed to join thread")?; + let err: Error = read_result.unwrap_err().into(); + assert_eq!(err, make_err!(Code::Internal, "Receiver disconnected")); + Ok(()) + } +}