Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tower-batch: wake waiting workers on close to avoid hangs #1908

Merged
merged 2 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tower-batch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ tracing = "0.1.25"
zebra-test = { path = "../zebra-test/" }
tower-fallback = { path = "../tower-fallback/" }
color-eyre = "0.5.10"
tokio-test = "0.4.1"
tower-test = "0.4.0"
35 changes: 29 additions & 6 deletions tower-batch/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,44 @@ where
T::Error: Send + Sync,
Request: Send + 'static,
{
let (batch, worker) = Self::pair(service, max_items, max_latency);
tokio::spawn(worker.run());
batch
}

/// Creates a new `Batch` wrapping `service`, but returns the background worker.
///
/// This is useful if you do not want to spawn directly onto the `tokio`
/// runtime but instead want to use your own executor. This will return the
/// `Batch` and the background `Worker` that you can then spawn.
pub fn pair(
service: T,
max_items: usize,
max_latency: std::time::Duration,
) -> (Self, Worker<T, Request>)
where
T: Send + 'static,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();

// The semaphore bound limits the maximum number of concurrent requests
// (specifically, requests which got a `Ready` from `poll_ready`, but haven't
// used their semaphore reservation in a `call` yet).
// We choose a bound that allows callers to check readiness for every item in
// a batch, then actually submit those items.
let bound = max_items;
let (tx, rx) = mpsc::unbounded_channel();
let (handle, worker) = Worker::new(service, rx, max_items, max_latency);
tokio::spawn(worker.run());
let semaphore = Semaphore::new(bound);
Batch {
let (semaphore, close) = Semaphore::new_with_close(bound);

let (handle, worker) = Worker::new(service, rx, max_items, max_latency, close);
let batch = Batch {
tx,
semaphore,
handle,
}
};

(batch, worker)
}

fn get_worker_error(&self) -> crate::BoxError {
Expand Down
31 changes: 29 additions & 2 deletions tower-batch/src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::sync::{Arc, Mutex};
use std::{
pin::Pin,
sync::{Arc, Mutex},
};

use futures::future::TryFutureExt;
use pin_project::pin_project;
Expand All @@ -10,6 +13,8 @@ use tokio::{
use tower::{Service, ServiceExt};
use tracing_futures::Instrument;

use crate::semaphore;

use super::{
error::{Closed, ServiceError},
message::{self, Message},
Expand All @@ -23,7 +28,7 @@ use super::{
/// as part of the public API. This is the "sealed" pattern to include "private"
/// types in public traits that are not meant for consumers of the library to
/// implement (only call).
#[pin_project]
#[pin_project(PinnedDrop)]
#[derive(Debug)]
pub struct Worker<T, Request>
where
Expand All @@ -36,6 +41,7 @@ where
handle: Handle,
max_items: usize,
max_latency: std::time::Duration,
close: Option<semaphore::Close>,
}

/// Get the error out
Expand All @@ -54,6 +60,7 @@ where
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
max_items: usize,
max_latency: std::time::Duration,
close: semaphore::Close,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
Expand All @@ -66,6 +73,7 @@ where
failed: None,
max_items,
max_latency,
close: Some(close),
};

(handle, worker)
Expand All @@ -88,6 +96,12 @@ where
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));

// Wake any tasks waiting on channel capacity.
if let Some(close) = self.close.take() {
tracing::debug!("waking pending tasks");
close.close();
}
}
}
}
Expand Down Expand Up @@ -221,3 +235,16 @@ impl Clone for Handle {
}
}
}

#[pin_project::pinned_drop]
impl<T, Request> PinnedDrop for Worker<T, Request>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
{
fn drop(mut self: Pin<&mut Self>) {
if let Some(close) = self.as_mut().close.take() {
close.close();
}
}
}
Loading