Skip to content

Commit

Permalink
Fix async_fixed_buffers to add get_closer()
Browse files Browse the repository at this point in the history
  • Loading branch information
allada committed Jan 11, 2021
1 parent 3c147cd commit 9225b1f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 14 deletions.
13 changes: 12 additions & 1 deletion util/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright 2020 Nathan (Blaise) Bruer. All rights reserved.

load("@io_bazel_rules_rust//rust:rust.bzl", "rust_library")
load("@io_bazel_rules_rust//rust:rust.bzl", "rust_library", "rust_test")

rust_library(
name = "error",
Expand Down Expand Up @@ -37,3 +37,14 @@ rust_library(
],
visibility = ["//visibility:public"],
)

rust_test(
name = "utils_tests",
srcs = ["tests/async_fixed_buffer_tests.rs"],
deps = [
":async_fixed_buffer",
":error",
"//third_party:tokio",
"//third_party:pretty_assertions",
],
)
44 changes: 31 additions & 13 deletions util/async_fixed_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#![forbid(unsafe_code)]

use std::sync::Arc;
use std::sync::Mutex;

use core::pin::Pin;
Expand All @@ -27,8 +28,8 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};

pub struct AsyncFixedBuf<T> {
inner: FixedBuf<T>,
waker: Mutex<Option<Waker>>,
did_shutdown: AtomicBool,
waker: Arc<Mutex<Option<Waker>>>,
did_shutdown: Arc<AtomicBool>,
write_amt: AtomicUsize,
read_amt: AtomicUsize,
}
Expand All @@ -42,20 +43,37 @@ impl<T> AsyncFixedBuf<T> {
pub fn new(mem: T) -> Self {
AsyncFixedBuf {
inner: FixedBuf::new(mem),
waker: Mutex::new(None),
did_shutdown: AtomicBool::new(false),
waker: Arc::new(Mutex::new(None)),
did_shutdown: Arc::new(AtomicBool::new(false)),
write_amt: AtomicUsize::new(0),
read_amt: AtomicUsize::new(0),
}
}

// Utility method that can be used to get a lambda that will close the
// stream. This is useful for the reader to close the stream.
pub fn get_closer(&mut self) -> Box<dyn FnMut() + Sync + Send> {
let did_shutdown = self.did_shutdown.clone();
let waker = self.waker.clone();
Box::new(move || {
if did_shutdown.load(Ordering::Relaxed) {
return;
}
did_shutdown.store(true, Ordering::Relaxed);
let mut waker = waker.lock().unwrap();
if let Some(w) = waker.take() {
w.wake()
}
})
}

fn park(&mut self, new_waker: &Waker) {
let waker = self.waker.get_mut().unwrap();
let mut waker = self.waker.lock().unwrap();
*waker = Some(new_waker.clone());
}

fn wake(&mut self) {
let waker = self.waker.get_mut().unwrap();
let mut waker = self.waker.lock().unwrap();
if let Some(w) = waker.take() {
w.wake()
}
Expand Down Expand Up @@ -98,23 +116,23 @@ impl<T: AsRef<[u8]> + Unpin> tokio::io::AsyncRead for AsyncFixedBuf<T> {

impl<T: AsMut<[u8]>> tokio::io::AsyncWrite for AsyncFixedBuf<T> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
if self.did_shutdown.load(Ordering::Relaxed) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Receiver disconnected",
)));
}
match self.inner.writable() {
Some(writable_slice) => {
let write_amt = buf.len().min(writable_slice.len());
let mut result = Ok(write_amt);
if write_amt > 0 {
writable_slice[..write_amt].clone_from_slice(&buf[..write_amt]);
self.inner.wrote(write_amt);
} else if self.did_shutdown.load(Ordering::Relaxed) {
result = Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Receiver disconnected",
));
}

self.wake();
self.write_amt.fetch_add(write_amt, Ordering::Relaxed);
Poll::Ready(result)
Poll::Ready(Ok(write_amt))
}
None => {
self.park(cx.waker());
Expand Down
56 changes: 56 additions & 0 deletions util/tests/async_fixed_buffer_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2020 Nathan (Blaise) Bruer. All rights reserved.

use async_fixed_buffer::AsyncFixedBuf;
use error::{make_err, Code, Error, ResultExt};
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;

#[cfg(test)]
mod async_fixed_buffer_tests {
use super::*;
use pretty_assertions::assert_eq; // Must be declared in every module.

#[tokio::test]
async fn get_closer_closes_read_stream_early() -> Result<(), Error> {
let mut raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 32].into_boxed_slice());
let mut stream_closer = raw_fixed_buffer.get_closer();
let (mut rx, mut tx) = tokio::io::split(raw_fixed_buffer);

tx.write_all(&vec![255u8; 4]).await?;

let read_spawn = tokio::spawn(async move {
let mut read_buffer = vec![0u8; 5];
rx.read_exact(&mut read_buffer[..]).await
});
// Wait a few cycles to ensure we are in a read loop.
for _ in 0..100 {
tokio::task::yield_now().await;
}
stream_closer();
let read_result = read_spawn.await.err_tip(|| "Failed to join thread")?;
let err: Error = read_result.unwrap_err().into();
assert_eq!(err, make_err!(Code::Internal, "Sender disconnected"));
Ok(())
}

#[tokio::test]
async fn get_closer_closes_write_stream_early() -> Result<(), Error> {
let mut raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 4].into_boxed_slice());
let mut stream_closer = raw_fixed_buffer.get_closer();
let (_, mut tx) = tokio::io::split(raw_fixed_buffer);

let write_spawn = tokio::spawn(async move {
let mut read_buffer = vec![0u8; 5];
tx.write_all(&mut read_buffer[..]).await
});
// Wait a few cycles to ensure we are in a read loop.
for _ in 0..100 {
tokio::task::yield_now().await;
}
stream_closer();
let read_result = write_spawn.await.err_tip(|| "Failed to join thread")?;
let err: Error = read_result.unwrap_err().into();
assert_eq!(err, make_err!(Code::Internal, "Receiver disconnected"));
Ok(())
}
}

0 comments on commit 9225b1f

Please sign in to comment.