From 34b7b5c516a98fb70effb0b901a67c2028902ae9 Mon Sep 17 00:00:00 2001 From: Kevin Nakamura Date: Thu, 4 Jan 2024 11:57:52 +0900 Subject: [PATCH] with_shutdown for Tokio app --- .github/workflows/examples.yml | 9 +- examples/shutdown-tokio/Cargo.toml | 11 ++ examples/shutdown-tokio/src/main.rs | 31 ++++ examples/shutdown/src/main.rs | 2 +- humphrey/Cargo.toml | 6 +- humphrey/src/tokio/app.rs | 225 ++++++++++++++++------------ 6 files changed, 188 insertions(+), 96 deletions(-) create mode 100644 examples/shutdown-tokio/Cargo.toml create mode 100644 examples/shutdown-tokio/src/main.rs diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 8ad1469..5d88092 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -91,7 +91,7 @@ jobs: with: command: check args: --manifest-path examples/host/Cargo.toml - + - name: Check monitor example if: always() uses: actions-rs/cargo@v1 @@ -168,3 +168,10 @@ jobs: with: command: check args: --manifest-path examples/shutdown/Cargo.toml + + - name: Check shutdown-tokio example + if: always() + uses: actions-rs/cargo@v1 + with: + command: check + args: --manifest-path examples/shutdown-tokio/Cargo.toml diff --git a/examples/shutdown-tokio/Cargo.toml b/examples/shutdown-tokio/Cargo.toml new file mode 100644 index 0000000..c3fd3d8 --- /dev/null +++ b/examples/shutdown-tokio/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "shutdown-tokio" +version = "0.1.0" +edition = "2021" + +[dependencies] +humphrey = { path = "../../humphrey", features = ["tokio"] } +tokio = { version = "1", features = ["full"] } +tokio-util = "0.7" + +[workspace] diff --git a/examples/shutdown-tokio/src/main.rs b/examples/shutdown-tokio/src/main.rs new file mode 100644 index 0000000..654c3ec --- /dev/null +++ b/examples/shutdown-tokio/src/main.rs @@ -0,0 +1,31 @@ +use humphrey::http::{Request, Response, StatusCode}; +use humphrey::App; +use tokio_util::sync::CancellationToken; + +use std::error::Error; +use std::thread::{sleep, spawn}; +use std::time::Duration; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let cancel = CancellationToken::new(); + let cloned_cancel = cancel.clone(); + let app: App<()> = App::new() + .with_shutdown(cloned_cancel) + .with_stateless_route("/", hello); + + // Shutdown the main app after 5 seconds + spawn(move || { + sleep(Duration::from_secs(5)); + cancel.cancel(); + }); + + // Returns after shutdown signal + app.run("0.0.0.0:8080").await?; + + Ok(()) +} + +async fn hello(_: Request) -> Response { + Response::new(StatusCode::OK, "Hello, world! - tokio") +} diff --git a/examples/shutdown/src/main.rs b/examples/shutdown/src/main.rs index 2fe3bae..f07f0ac 100644 --- a/examples/shutdown/src/main.rs +++ b/examples/shutdown/src/main.rs @@ -13,7 +13,7 @@ fn main() -> Result<(), Box> { .with_shutdown(app_rx) .with_stateless_route("/hello", |_| Response::new(StatusCode::OK, "Hello world!")); - // Shutdown both the main app after 5 seconds + // Shutdown the main app after 5 seconds spawn(move || { sleep(Duration::from_secs(5)); let _ = shutdown_app.send(()); diff --git a/humphrey/Cargo.toml b/humphrey/Cargo.toml index d7b27c1..ad2aa2e 100644 --- a/humphrey/Cargo.toml +++ b/humphrey/Cargo.toml @@ -37,9 +37,13 @@ optional = true version = "^0.24.1" optional = true +[dependencies.tokio-util] +version = "0.7" +optional = true + [features] tls = ["rustls", "rustls-native-certs", "rustls-pemfile"] -tokio = ["dep:tokio", "futures", "tokio-rustls"] +tokio = ["dep:tokio", "futures", "tokio-rustls", "tokio-util"] [lib] doctest = false diff --git a/humphrey/src/tokio/app.rs b/humphrey/src/tokio/app.rs index a9e8eeb..ace9d37 100644 --- a/humphrey/src/tokio/app.rs +++ b/humphrey/src/tokio/app.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio_util::sync::CancellationToken; #[cfg(feature = "tls")] use rustls::ServerConfig; @@ -42,6 +43,7 @@ where tls_config: Option>, #[cfg(feature = "tls")] force_https: bool, + shutdown: Option, } /// Represents a function able to calculate whether a connection will be accepted. @@ -98,6 +100,7 @@ where tls_config: None, #[cfg(feature = "tls")] force_https: false, + shutdown: None, } } @@ -114,6 +117,7 @@ where tls_config: None, #[cfg(feature = "tls")] force_https: false, + shutdown: None, } } @@ -129,51 +133,65 @@ where let error_handler = Arc::new(self.error_handler); loop { - match socket.accept().await { - Ok((mut stream, _)) => { - let cloned_state = self.state.clone(); - - // Check that the client is allowed to connect - if (self.connection_condition)(&mut stream, cloned_state) { - let cloned_state = self.state.clone(); - let cloned_monitor = self.monitor.clone(); - let cloned_subapps = subapps.clone(); - let cloned_default_subapp = default_subapp.clone(); - let cloned_error_handler = error_handler.clone(); - - cloned_monitor.send( - Event::new(EventType::ConnectionSuccess) - .with_peer_result(stream.peer_addr()), - ); - - // Spawn a new thread to handle the connection - tokio::spawn(async move { - cloned_monitor.send( - Event::new(EventType::ThreadPoolProcessStarted) - .with_peer_result(stream.peer_addr()), - ); + let shutdown = async { + if let Some(sd) = self.shutdown.clone() { + sd.cancelled().await + } else { + loop { + let _ = tokio::time::sleep(std::time::Duration::from_secs(9999)); + } + } + }; + tokio::select! { + () = shutdown => { break Ok(()); } + s = socket.accept() => { + match s { + Ok((mut stream, _)) => { + let cloned_state = self.state.clone(); + + // Check that the client is allowed to connect + if (self.connection_condition)(&mut stream, cloned_state) { + let cloned_state = self.state.clone(); + let cloned_monitor = self.monitor.clone(); + let cloned_subapps = subapps.clone(); + let cloned_default_subapp = default_subapp.clone(); + let cloned_error_handler = error_handler.clone(); + + cloned_monitor.send( + Event::new(EventType::ConnectionSuccess) + .with_peer_result(stream.peer_addr()), + ); + + // Spawn a new thread to handle the connection + tokio::spawn(async move { + cloned_monitor.send( + Event::new(EventType::ThreadPoolProcessStarted) + .with_peer_result(stream.peer_addr()), + ); - client_handler( - Stream::Tcp(stream), - cloned_subapps, - cloned_default_subapp, - cloned_error_handler, - cloned_state, - cloned_monitor, - ) - .await - }); - } else { - self.monitor.send( - Event::new(EventType::ConnectionDenied) - .with_peer_result(stream.peer_addr()), - ); + client_handler( + Stream::Tcp(stream), + cloned_subapps, + cloned_default_subapp, + cloned_error_handler, + cloned_state, + cloned_monitor, + ) + .await + }); + } else { + self.monitor.send( + Event::new(EventType::ConnectionDenied) + .with_peer_result(stream.peer_addr()), + ); + } + } + Err(e) => self + .monitor + .send(Event::new(EventType::ConnectionError).with_info(e.to_string())), } } - Err(e) => self - .monitor - .send(Event::new(EventType::ConnectionError).with_info(e.to_string())), - } + }; } } @@ -203,60 +221,75 @@ where let acceptor = TlsAcceptor::from(tls_config); loop { - match socket.accept().await { - Ok((mut sock, _)) => { - let cloned_state = self.state.clone(); - - // Check that the client is allowed to connect - if (self.connection_condition)(&mut sock, cloned_state) { - let cloned_state = self.state.clone(); - let cloned_subapps = subapps.clone(); - let cloned_default_subapp = default_subapp.clone(); - let cloned_error_handler = error_handler.clone(); - let cloned_monitor = self.monitor.clone(); - let cloned_acceptor = acceptor.clone(); - - cloned_monitor.send( - Event::new(EventType::ConnectionSuccess) - .with_peer_result(sock.peer_addr()), - ); - - // Spawn a new thread to handle the connection - tokio::spawn(async move { - cloned_monitor.send( - Event::new(EventType::ThreadPoolProcessStarted) - .with_peer_result(sock.peer_addr()), - ); - - match cloned_acceptor.accept(sock).await { - Ok(tls_stream) => { - let stream = Stream::Tls(tls_stream); - - client_handler( - stream, - cloned_subapps, - cloned_default_subapp, - cloned_error_handler, - cloned_state, - cloned_monitor, - ) - .await - } - Err(e) => cloned_monitor.send( - Event::new(EventType::ConnectionError).with_info(e.to_string()), - ), + let shutdown = async { + if let Some(sd) = self.shutdown.clone() { + sd.cancelled().await + } else { + loop { + let _ = tokio::time::sleep(std::time::Duration::from_secs(9999)); + } + } + }; + + tokio::select! { + () = shutdown => { break Ok(()); } + s = socket.accept() => { + match s { + Ok((mut sock, _)) => { + let cloned_state = self.state.clone(); + + // Check that the client is allowed to connect + if (self.connection_condition)(&mut sock, cloned_state) { + let cloned_state = self.state.clone(); + let cloned_subapps = subapps.clone(); + let cloned_default_subapp = default_subapp.clone(); + let cloned_error_handler = error_handler.clone(); + let cloned_monitor = self.monitor.clone(); + let cloned_acceptor = acceptor.clone(); + + cloned_monitor.send( + Event::new(EventType::ConnectionSuccess) + .with_peer_result(sock.peer_addr()), + ); + + // Spawn a new thread to handle the connection + tokio::spawn(async move { + cloned_monitor.send( + Event::new(EventType::ThreadPoolProcessStarted) + .with_peer_result(sock.peer_addr()), + ); + + match cloned_acceptor.accept(sock).await { + Ok(tls_stream) => { + let stream = Stream::Tls(tls_stream); + + client_handler( + stream, + cloned_subapps, + cloned_default_subapp, + cloned_error_handler, + cloned_state, + cloned_monitor, + ) + .await + } + Err(e) => cloned_monitor.send( + Event::new(EventType::ConnectionError).with_info(e.to_string()), + ), + } + }); + } else { + self.monitor.send( + Event::new(EventType::ConnectionDenied) + .with_peer_result(sock.peer_addr()), + ); } - }); - } else { - self.monitor.send( - Event::new(EventType::ConnectionDenied) - .with_peer_result(sock.peer_addr()), - ); + } + Err(e) => self + .monitor + .send(Event::new(EventType::ConnectionError).with_info(e.to_string())), } } - Err(e) => self - .monitor - .send(Event::new(EventType::ConnectionError).with_info(e.to_string())), } } } @@ -417,6 +450,12 @@ where self } + /// Registers a shutdown signal to gracefully shutdown the app, ending the run/run_tls loop. + pub fn with_shutdown(mut self, cancel_token: CancellationToken) -> Self { + self.shutdown = Some(cancel_token); + self + } + /// Gets a reference to the app's state. /// This should only be used in the main thread, as the state is passed to request handlers otherwise. pub fn get_state(&self) -> Arc {