Skip to content

Commit

Permalink
Add example of an updating TLS resolver.
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Mar 26, 2024
1 parent e4e46ef commit 3a8305a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ version_check = "0.9.1"
tokio = { version = "1", features = ["macros", "io-std"] }
figment = { version = "0.10", features = ["test"] }
pretty_assertions = "1"
arc-swap = "1.7"
1 change: 0 additions & 1 deletion core/lib/src/listener/tls.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::io;
use std::sync::Arc;

use serde::Deserialize;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::LazyConfigAcceptor;
use rustls::server::{Acceptor, ServerConfig};
Expand Down
47 changes: 46 additions & 1 deletion core/lib/src/tls/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ impl fairing::Fairing for Fairing {

#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::collections::HashMap;
use std::time::UNIX_EPOCH;
use arc_swap::ArcSwap;
use either::Either;
use serde::Deserialize;
use crate::http::uri::Host;
use crate::tls::{TlsConfig, ServerConfig, Resolver, ClientHello};
Expand Down Expand Up @@ -69,10 +74,49 @@ mod tests {
}
}

struct UpdatingResolver {
timestamp: AtomicU64,
tls_config: TlsConfig,
server_config: ArcSwap<ServerConfig>
}

impl TryFrom<TlsConfig> for UpdatingResolver {
type Error = crate::tls::Error;

fn try_from(tls_config: TlsConfig) -> Result<Self, Self::Error> {
Ok(UpdatingResolver {
timestamp: AtomicU64::new(0),
server_config: ArcSwap::new(Arc::new(tls_config.to_server_config()?)),
tls_config,
})
}
}

#[crate::async_trait]
impl Resolver for UpdatingResolver {
async fn resolve(&self, _: ClientHello<'_>) -> Option<Arc<ServerConfig>> {
if let Either::Left(path) = self.tls_config.certs() {
let metadata = tokio::fs::metadata(&path).await.ok()?;
let modtime = metadata.modified().ok()?;
let timestamp = modtime.duration_since(UNIX_EPOCH).ok()?.as_secs();
let old_timestamp = self.timestamp.load(Ordering::Acquire);
if timestamp > old_timestamp {
let new_config = self.tls_config.to_server_config().ok()?;
self.server_config.store(Arc::new(new_config));
self.timestamp.store(timestamp, Ordering::Release);
}
}

Some(self.server_config.load_full())
}
}

#[test]
fn test_config() {
figment::Jail::expect_with(|jail| {
use crate::fs::relative;
use figment::Figment;
use figment::providers::{Toml, Format};

let cert_path = relative!("../../examples/tls/private/rsa_sha256_cert.pem");
let key_path = relative!("../../examples/tls/private/rsa_sha256_key.pem");
Expand All @@ -87,7 +131,8 @@ mod tests {
key = "{key_path}"
"#))?;

let config = crate::Config::figment().extract::<SniConfig>()?;
let toml = Toml::file("Rocket.toml").nested();
let config: SniConfig = Figment::from(toml).extract().unwrap();
assert!(config.sni.contains_key(&Host::parse("api.rocket.rs").unwrap()));
assert!(config.sni.contains_key(&Host::parse("blob.rocket.rs").unwrap()));
Ok(())
Expand Down

0 comments on commit 3a8305a

Please sign in to comment.