Skip to content

Commit

Permalink
Implement more control over graceful shutdown
Browse files Browse the repository at this point in the history
This implements the "instant shutdown" flag describe in rwf2#180.

rwf2#180 (comment)
  • Loading branch information
notriddle committed Aug 16, 2020
1 parent 549c924 commit 15546fa
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 20 deletions.
41 changes: 41 additions & 0 deletions core/lib/src/response/response.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{io, fmt, str};
use std::time::Duration;
use std::borrow::Cow;
use std::pin::Pin;

Expand Down Expand Up @@ -235,6 +236,19 @@ impl<'r> ResponseBuilder<'r> {
self
}

/// Sets the finish-on-shutdown delay.
///
/// If this response is set to finish on shutdown, then Rocket will wait
/// until the body is completely served, or the timer runs out. Otherwise,
/// the client may experience a partially-completed download.
///
/// The finish-on-shutdown delay defaults to `0`.
#[inline(always)]
pub fn wait_on_shutdown(&mut self, wait_on_shutdown: Duration) -> &mut ResponseBuilder<'r> {
self.response.set_wait_on_shutdown(wait_on_shutdown);
self
}

/// Sets the status of the `Response` being built to a custom status
/// constructed from the `code` and `reason` phrase.
///
Expand Down Expand Up @@ -598,6 +612,7 @@ pub struct Response<'r> {
status: Option<Status>,
headers: HeaderMap<'r>,
body: Option<ResponseBody<'r>>,
wait_on_shutdown: Duration,
}

impl<'r> Response<'r> {
Expand All @@ -624,6 +639,7 @@ impl<'r> Response<'r> {
status: None,
headers: HeaderMap::new(),
body: None,
wait_on_shutdown: Duration::from_millis(0),
}
}

Expand Down Expand Up @@ -658,6 +674,31 @@ impl<'r> Response<'r> {
ResponseBuilder::new(other)
}

/// Returns the finish-on-shutdown delay.
///
/// If this response is set to a value other than zero, then Rocket
/// will wait until the body is completely served, or the timer
/// runs out. Otherwise, the client may experience a
/// partially-completed download.
///
/// The finish-on-shutdown delay defaults to `0`.
#[inline(always)]
pub fn wait_on_shutdown(&self) -> Duration {
self.wait_on_shutdown
}

/// Sets the finish-on-shutdown delay.
///
/// If this response is set to finish on shutdown, then Rocket will wait
/// until the body is completely served. Otherwise, the client may
/// experience a partially-completed download.
///
/// The finish-on-shutdown delay defaults to `0`.
#[inline(always)]
pub fn set_wait_on_shutdown(&mut self, wait_on_shutdown: Duration) {
self.wait_on_shutdown = wait_on_shutdown;
}

/// Returns the status of `self`.
///
/// # Example
Expand Down
63 changes: 47 additions & 16 deletions core/lib/src/rocket.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::{io, mem};
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::collections::HashMap;
use std::time::Duration;

#[allow(unused_imports)]
use futures::future::FutureExt;
use futures::stream::StreamExt;
use futures::future::{Future, BoxFuture};
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{broadcast, oneshot};
use ref_cast::RefCast;

use yansi::Paint;
Expand Down Expand Up @@ -41,7 +43,7 @@ pub struct Rocket {
default_catcher: Option<Catcher>,
catchers: HashMap<u16, Catcher>,
fairings: Fairings,
shutdown_receiver: Option<mpsc::Receiver<()>>,
shutdown_receiver: Option<broadcast::Receiver<()>>,
pub(crate) shutdown_handle: Shutdown,
}

Expand Down Expand Up @@ -118,6 +120,7 @@ impl Rocket {

// Create a "dummy" instance of `Rocket` to use while mem-swapping `self`.
fn dummy() -> Rocket {
let (tx, _) = broadcast::channel(1);
Rocket {
manifest: vec![],
config: Config::development(),
Expand All @@ -126,7 +129,7 @@ impl Rocket {
catchers: HashMap::new(),
managed_state: Container::new(),
fairings: Fairings::new(),
shutdown_handle: Shutdown(mpsc::channel(1).0),
shutdown_handle: Shutdown(tx, Arc::new(AtomicBool::new(false))),
shutdown_receiver: None,
}
}
Expand Down Expand Up @@ -189,6 +192,10 @@ async fn hyper_service_fn(
// the response metadata (and a body channel) beforehand.
let (tx, rx) = oneshot::channel();

// The shutdown subscription needs to be opened before dispatching the request.
// Otherwise, if shutdown begins during initial request processing, we would miss it.
let shutdown_receiver = rocket.shutdown_handle.0.subscribe();

tokio::spawn(async move {
// Get all of the information from Hyper.
let (h_parts, h_body) = hyp_req.into_parts();
Expand All @@ -205,7 +212,7 @@ async fn hyper_service_fn(
// handler) instead of doing this.
let dummy = Request::new(&rocket, Method::Get, Origin::dummy());
let r = rocket.handle_error(Status::BadRequest, &dummy).await;
return rocket.issue_response(r, tx).await;
return rocket.issue_response(r, tx, shutdown_receiver).await;
}
};

Expand All @@ -215,7 +222,7 @@ async fn hyper_service_fn(
// Dispatch the request to get a response, then write that response out.
let token = rocket.preprocess_request(&mut req, &mut data).await;
let r = rocket.dispatch(token, &mut req, data).await;
rocket.issue_response(r, tx).await;
rocket.issue_response(r, tx, shutdown_receiver).await;
});

rx.await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
Expand All @@ -227,14 +234,37 @@ impl Rocket {
&self,
response: Response<'_>,
tx: oneshot::Sender<hyper::Response<hyper::Body>>,
mut shutdown_receiver: broadcast::Receiver<()>,
) {
let wait_on_shutdown = response.wait_on_shutdown();
let result = self.write_response(response, tx);
match result.await {
Ok(()) => {
info_!("{}", Paint::green("Response succeeded."));
let mut shutdown_receiver = if wait_on_shutdown != Duration::from_millis(0) {
let (tx, rx) = broadcast::channel(1);
tokio::spawn(async move {
let _ = shutdown_receiver.recv().await;
tokio::time::delay_for(wait_on_shutdown).await;
tx.send(()).expect("there should always be at least one shutdown listener");
});
rx
} else {
shutdown_receiver
};
tokio::select!{
result = result => {
match result {
Ok(()) => {
info_!("{}", Paint::green("Response succeeded."));
}
Err(e) => {
error_!("Failed to write response: {:?}.", e);
}
}
}
Err(e) => {
error_!("Failed to write response: {:?}.", e);
// The error returned by `recv()` is discarded here.
// This is fine, because the only case where it returns that error is
// if the sender is dropped, which would indicate shutdown anyway.
_ = shutdown_receiver.recv() => {
info_!("{}", Paint::red("Response cancelled for shutdown."));
}
}
}
Expand Down Expand Up @@ -517,8 +547,7 @@ impl Rocket {
// listener.set_keepalive(timeout);

// We need to get this before moving `self` into an `Arc`.
let mut shutdown_receiver = self.shutdown_receiver
.take().expect("shutdown receiver has already been used");
let mut shutdown_receiver = self.shutdown_receiver.take().expect("a rocket will listen exactly once");

let rocket = Arc::new(self);
let service = hyper::make_service_fn(move |connection: &<L as Listener>::Connection| {
Expand All @@ -545,7 +574,9 @@ impl Rocket {
hyper::Server::builder(Incoming::from_listener(listener))
.executor(TokioExecutor)
.serve(service)
.with_graceful_shutdown(async move { shutdown_receiver.recv().await; })
// Discarding the error is fine, because it indicates that the sender has been dropped.
// If the sender is dropped, then the Rocket is dropped, and that means we're shutting down.
.with_graceful_shutdown(async move { let _ = shutdown_receiver.recv().await; })
.await
.map_err(|e| crate::error::Error::Run(Box::new(e)))
}
Expand Down Expand Up @@ -661,17 +692,17 @@ impl Rocket {
}

let managed_state = Container::new();
let (shutdown_sender, shutdown_receiver) = mpsc::channel(1);
let (shutdown_sender, shutdown_receiver) = broadcast::channel(1);

Rocket {
config, managed_state,
shutdown_handle: Shutdown(shutdown_sender),
shutdown_handle: Shutdown(shutdown_sender, Arc::new(AtomicBool::new(false))),
shutdown_receiver: Some(shutdown_receiver),
manifest: vec![],
router: Router::new(),
default_catcher: None,
catchers: HashMap::new(),
fairings: Fairings::new(),
shutdown_receiver: Some(shutdown_receiver),
}
}

Expand Down
34 changes: 30 additions & 4 deletions core/lib/src/shutdown.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::request::{FromRequest, Outcome, Request};
use tokio::sync::mpsc;
use tokio::sync::broadcast;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

/// A request guard to gracefully shutdown a Rocket server.
///
Expand Down Expand Up @@ -34,19 +36,43 @@ use tokio::sync::mpsc;
/// }
/// ```
#[derive(Debug, Clone)]
pub struct Shutdown(pub(crate) mpsc::Sender<()>);
pub struct Shutdown(pub(crate) broadcast::Sender<()>, pub(crate) Arc<AtomicBool>);

impl Shutdown {
/// Notify Rocket to shut down gracefully. This function returns
/// immediately; pending requests will continue to run until completion
/// before the actual shutdown occurs.
#[inline]
pub fn shutdown(mut self) {
pub fn shutdown(self) {
self.1.store(true, Ordering::SeqCst);
// Intentionally ignore any error, as the only scenarios this can happen
// is sending too many shutdown requests or we're already shut down.
let _ = self.0.try_send(());
let _ = self.0.send(());
info!("Server shutdown requested, waiting for all pending requests to finish.");
}
pub async fn wait(self) {
// This uses four events:
//
// * the store event
// * the send event
//
// * the subscribe event
// * the load event
//
// Since both pairs of events are happening in parallel threads (potentially), we need to worry about
// all possible interleavings, but events within a single thread cannot be reordered (store comes before
// send, while subscribe comes before load, always). For this to work, either we store before we load,
// or we subscribe before we send.
//
// In a sequential ordering, either the store came first, or the subscribe did. If the store came
// before the subscribe, then it must have also come before the load, which means that the load will
// pick up on the atomic change. If the subscribe came before the store, then it also came before
// the send, meaning that the broadcast channel will kick us out instead.
let mut recv = self.0.subscribe();
if !self.1.load(Ordering::SeqCst) {
let _ = recv.recv().await;
}
}
}

#[crate::async_trait]
Expand Down
51 changes: 51 additions & 0 deletions core/lib/tests/graceful-shutdown-finish.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#[macro_use] extern crate rocket;

use rocket::Shutdown;
use rocket::response::Response;
use tokio::io::AsyncRead;

use std::pin::Pin;
use std::task::{Poll, Context};
use std::time::Duration;

struct AsyncReader(bool);

impl AsyncRead for AsyncReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context,
buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
if self.0 {
Poll::Ready(Ok(0))
} else {
buf[0] = b'a';
Pin::<&mut AsyncReader>::into_inner(self).0 = true;
Poll::Ready(Ok(1))
}
}
}

#[get("/test")]
fn test(shutdown: Shutdown) -> Response<'static> {
shutdown.shutdown();
Response::build()
.chunked_body(AsyncReader(false), 512)
.wait_on_shutdown(Duration::from_millis(u64::MAX))
.finalize()
}

mod tests {
use super::*;
use rocket::local::blocking::Client;

#[test]
fn graceful_shutdown_works() {
let rocket = rocket::ignite()
.mount("/", routes![test]);
let client = Client::new(rocket).unwrap();

let response = client.get("/test").dispatch();
assert_eq!(response.into_string().unwrap(), String::from("a"));
}
}
53 changes: 53 additions & 0 deletions core/lib/tests/graceful-shutdown-wait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#[macro_use] extern crate rocket;

use rocket::Shutdown;
use rocket::response::Response;
use tokio::io::AsyncRead;

use std::pin::Pin;
use std::task::{Poll, Context};

struct AsyncReader;

impl AsyncRead for AsyncReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context,
_buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
Poll::Pending
}
}

#[get("/test-shutdown")]
async fn test(shutdown: Shutdown) -> Response<'static> {
shutdown.shutdown();
Response::build()
.chunked_body(AsyncReader, 512)
.finalize()
}

#[get("/test-wait")]
async fn test2(shutdown: Shutdown) -> Response<'static> {
shutdown.wait().await;
Response::build()
.chunked_body(AsyncReader, 512)
.finalize()
}

mod tests {
use super::*;
use rocket::local::asynchronous::Client;
use futures::join;

#[rocket::async_test]
async fn graceful_shutdown_works() {
let rocket = rocket::ignite()
.mount("/", routes![test, test2]);
let client = Client::new(rocket).await.unwrap();

let shutdown_response = client.get("/test-shutdown").dispatch();
let wait_response = client.get("/test-wait").dispatch();
let _ = join!(shutdown_response, wait_response);
}
}
Loading

0 comments on commit 15546fa

Please sign in to comment.