Skip to content

Commit

Permalink
with_shutdown for Tokio app
Browse files Browse the repository at this point in the history
  • Loading branch information
Grinkers committed Jan 4, 2024
1 parent 426ceee commit 34b7b5c
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 96 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ jobs:
with:
command: check
args: --manifest-path examples/host/Cargo.toml

- name: Check monitor example
if: always()
uses: actions-rs/cargo@v1
Expand Down Expand Up @@ -168,3 +168,10 @@ jobs:
with:
command: check
args: --manifest-path examples/shutdown/Cargo.toml

- name: Check shutdown-tokio example
if: always()
uses: actions-rs/cargo@v1
with:
command: check
args: --manifest-path examples/shutdown-tokio/Cargo.toml
11 changes: 11 additions & 0 deletions examples/shutdown-tokio/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "shutdown-tokio"
version = "0.1.0"
edition = "2021"

[dependencies]
humphrey = { path = "../../humphrey", features = ["tokio"] }
tokio = { version = "1", features = ["full"] }
tokio-util = "0.7"

[workspace]
31 changes: 31 additions & 0 deletions examples/shutdown-tokio/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use humphrey::http::{Request, Response, StatusCode};
use humphrey::App;
use tokio_util::sync::CancellationToken;

use std::error::Error;
use std::thread::{sleep, spawn};
use std::time::Duration;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let cancel = CancellationToken::new();
let cloned_cancel = cancel.clone();
let app: App<()> = App::new()
.with_shutdown(cloned_cancel)
.with_stateless_route("/", hello);

// Shutdown the main app after 5 seconds
spawn(move || {
sleep(Duration::from_secs(5));
cancel.cancel();
});

// Returns after shutdown signal
app.run("0.0.0.0:8080").await?;

Ok(())
}

async fn hello(_: Request) -> Response {
Response::new(StatusCode::OK, "Hello, world! - tokio")
}
2 changes: 1 addition & 1 deletion examples/shutdown/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() -> Result<(), Box<dyn Error>> {
.with_shutdown(app_rx)
.with_stateless_route("/hello", |_| Response::new(StatusCode::OK, "Hello world!"));

// Shutdown both the main app after 5 seconds
// Shutdown the main app after 5 seconds
spawn(move || {
sleep(Duration::from_secs(5));
let _ = shutdown_app.send(());
Expand Down
6 changes: 5 additions & 1 deletion humphrey/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ optional = true
version = "^0.24.1"
optional = true

[dependencies.tokio-util]
version = "0.7"
optional = true

[features]
tls = ["rustls", "rustls-native-certs", "rustls-pemfile"]
tokio = ["dep:tokio", "futures", "tokio-rustls"]
tokio = ["dep:tokio", "futures", "tokio-rustls", "tokio-util"]

[lib]
doctest = false
225 changes: 132 additions & 93 deletions humphrey/src/tokio/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::sync::Arc;

use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_util::sync::CancellationToken;

#[cfg(feature = "tls")]
use rustls::ServerConfig;
Expand All @@ -42,6 +43,7 @@ where
tls_config: Option<Arc<ServerConfig>>,
#[cfg(feature = "tls")]
force_https: bool,
shutdown: Option<CancellationToken>,
}

/// Represents a function able to calculate whether a connection will be accepted.
Expand Down Expand Up @@ -98,6 +100,7 @@ where
tls_config: None,
#[cfg(feature = "tls")]
force_https: false,
shutdown: None,
}
}

Expand All @@ -114,6 +117,7 @@ where
tls_config: None,
#[cfg(feature = "tls")]
force_https: false,
shutdown: None,
}
}

Expand All @@ -129,51 +133,65 @@ where
let error_handler = Arc::new(self.error_handler);

loop {
match socket.accept().await {
Ok((mut stream, _)) => {
let cloned_state = self.state.clone();

// Check that the client is allowed to connect
if (self.connection_condition)(&mut stream, cloned_state) {
let cloned_state = self.state.clone();
let cloned_monitor = self.monitor.clone();
let cloned_subapps = subapps.clone();
let cloned_default_subapp = default_subapp.clone();
let cloned_error_handler = error_handler.clone();

cloned_monitor.send(
Event::new(EventType::ConnectionSuccess)
.with_peer_result(stream.peer_addr()),
);

// Spawn a new thread to handle the connection
tokio::spawn(async move {
cloned_monitor.send(
Event::new(EventType::ThreadPoolProcessStarted)
.with_peer_result(stream.peer_addr()),
);
let shutdown = async {
if let Some(sd) = self.shutdown.clone() {
sd.cancelled().await
} else {
loop {
let _ = tokio::time::sleep(std::time::Duration::from_secs(9999));
}
}
};
tokio::select! {
() = shutdown => { break Ok(()); }
s = socket.accept() => {
match s {
Ok((mut stream, _)) => {
let cloned_state = self.state.clone();

// Check that the client is allowed to connect
if (self.connection_condition)(&mut stream, cloned_state) {
let cloned_state = self.state.clone();
let cloned_monitor = self.monitor.clone();
let cloned_subapps = subapps.clone();
let cloned_default_subapp = default_subapp.clone();
let cloned_error_handler = error_handler.clone();

cloned_monitor.send(
Event::new(EventType::ConnectionSuccess)
.with_peer_result(stream.peer_addr()),
);

// Spawn a new thread to handle the connection
tokio::spawn(async move {
cloned_monitor.send(
Event::new(EventType::ThreadPoolProcessStarted)
.with_peer_result(stream.peer_addr()),
);

client_handler(
Stream::Tcp(stream),
cloned_subapps,
cloned_default_subapp,
cloned_error_handler,
cloned_state,
cloned_monitor,
)
.await
});
} else {
self.monitor.send(
Event::new(EventType::ConnectionDenied)
.with_peer_result(stream.peer_addr()),
);
client_handler(
Stream::Tcp(stream),
cloned_subapps,
cloned_default_subapp,
cloned_error_handler,
cloned_state,
cloned_monitor,
)
.await
});
} else {
self.monitor.send(
Event::new(EventType::ConnectionDenied)
.with_peer_result(stream.peer_addr()),
);
}
}
Err(e) => self
.monitor
.send(Event::new(EventType::ConnectionError).with_info(e.to_string())),
}
}
Err(e) => self
.monitor
.send(Event::new(EventType::ConnectionError).with_info(e.to_string())),
}
};
}
}

Expand Down Expand Up @@ -203,60 +221,75 @@ where
let acceptor = TlsAcceptor::from(tls_config);

loop {
match socket.accept().await {
Ok((mut sock, _)) => {
let cloned_state = self.state.clone();

// Check that the client is allowed to connect
if (self.connection_condition)(&mut sock, cloned_state) {
let cloned_state = self.state.clone();
let cloned_subapps = subapps.clone();
let cloned_default_subapp = default_subapp.clone();
let cloned_error_handler = error_handler.clone();
let cloned_monitor = self.monitor.clone();
let cloned_acceptor = acceptor.clone();

cloned_monitor.send(
Event::new(EventType::ConnectionSuccess)
.with_peer_result(sock.peer_addr()),
);

// Spawn a new thread to handle the connection
tokio::spawn(async move {
cloned_monitor.send(
Event::new(EventType::ThreadPoolProcessStarted)
.with_peer_result(sock.peer_addr()),
);

match cloned_acceptor.accept(sock).await {
Ok(tls_stream) => {
let stream = Stream::Tls(tls_stream);

client_handler(
stream,
cloned_subapps,
cloned_default_subapp,
cloned_error_handler,
cloned_state,
cloned_monitor,
)
.await
}
Err(e) => cloned_monitor.send(
Event::new(EventType::ConnectionError).with_info(e.to_string()),
),
let shutdown = async {
if let Some(sd) = self.shutdown.clone() {
sd.cancelled().await
} else {
loop {
let _ = tokio::time::sleep(std::time::Duration::from_secs(9999));
}
}
};

tokio::select! {
() = shutdown => { break Ok(()); }
s = socket.accept() => {
match s {
Ok((mut sock, _)) => {
let cloned_state = self.state.clone();

// Check that the client is allowed to connect
if (self.connection_condition)(&mut sock, cloned_state) {
let cloned_state = self.state.clone();
let cloned_subapps = subapps.clone();
let cloned_default_subapp = default_subapp.clone();
let cloned_error_handler = error_handler.clone();
let cloned_monitor = self.monitor.clone();
let cloned_acceptor = acceptor.clone();

cloned_monitor.send(
Event::new(EventType::ConnectionSuccess)
.with_peer_result(sock.peer_addr()),
);

// Spawn a new thread to handle the connection
tokio::spawn(async move {
cloned_monitor.send(
Event::new(EventType::ThreadPoolProcessStarted)
.with_peer_result(sock.peer_addr()),
);

match cloned_acceptor.accept(sock).await {
Ok(tls_stream) => {
let stream = Stream::Tls(tls_stream);

client_handler(
stream,
cloned_subapps,
cloned_default_subapp,
cloned_error_handler,
cloned_state,
cloned_monitor,
)
.await
}
Err(e) => cloned_monitor.send(
Event::new(EventType::ConnectionError).with_info(e.to_string()),
),
}
});
} else {
self.monitor.send(
Event::new(EventType::ConnectionDenied)
.with_peer_result(sock.peer_addr()),
);
}
});
} else {
self.monitor.send(
Event::new(EventType::ConnectionDenied)
.with_peer_result(sock.peer_addr()),
);
}
Err(e) => self
.monitor
.send(Event::new(EventType::ConnectionError).with_info(e.to_string())),
}
}
Err(e) => self
.monitor
.send(Event::new(EventType::ConnectionError).with_info(e.to_string())),
}
}
}
Expand Down Expand Up @@ -417,6 +450,12 @@ where
self
}

/// Registers a shutdown signal to gracefully shutdown the app, ending the run/run_tls loop.
pub fn with_shutdown(mut self, cancel_token: CancellationToken) -> Self {
self.shutdown = Some(cancel_token);
self
}

/// Gets a reference to the app's state.
/// This should only be used in the main thread, as the state is passed to request handlers otherwise.
pub fn get_state(&self) -> Arc<State> {
Expand Down

0 comments on commit 34b7b5c

Please sign in to comment.