diff --git a/Cargo.lock b/Cargo.lock index 4a568eb00..409b5711c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -542,6 +542,53 @@ dependencies = [ "zeroize", ] +[[package]] +name = "axum" +version = "0.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper 0.14.23", + "itoa 1.0.4", + "matchit", + "memchr", + "mime 0.3.16", + "percent-encoding 2.2.0", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime 0.3.16", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.66" @@ -1328,6 +1375,7 @@ dependencies = [ "aws-smithy-client", "aws-smithy-http", "aws-smithy-types-convert", + "axum", "backtrace", "base64 0.13.1", "bzip2", @@ -1348,12 +1396,14 @@ dependencies = [ "grass", "hostname", "http", + "hyper 0.14.23", "indoc", "iron", "kuchiki", "log 0.4.17", "lol_html", "memmap2", + "mime 0.3.16", "mime_guess 2.0.4", "mockito", "once_cell", @@ -1376,6 +1426,7 @@ dependencies = [ "sentry", "sentry-anyhow", "sentry-panic", + "sentry-tower", "sentry-tracing", "serde", "serde_cbor", @@ -1395,6 +1446,9 @@ dependencies = [ "time 0.3.17", "tokio", "toml", + "tower", + "tower-http", + "tower-service", "tracing", "tracing-log", "tracing-subscriber", @@ -2472,6 +2526,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -2939,6 +2999,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" +[[package]] +name = "matchit" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" + [[package]] name = "maybe-async" version = "0.2.6" @@ -4528,6 +4594,20 @@ dependencies = [ "sentry-core", ] +[[package]] +name = "sentry-tower" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "359fdd1be4c5ecf03ffa7b21bc865836159590d0f6d114bbabd1e9c9be4c513e" +dependencies = [ + "http", + "pin-project", + "sentry-core", + "tower-layer", + "tower-service", + "url 2.3.1", +] + [[package]] name = "sentry-tracing" version = "0.27.0" @@ -4844,6 +4924,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" + [[package]] name = "synstructure" version = "0.12.6" @@ -5212,6 +5298,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-layer" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index da7250cdd..441073488 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ consistency_check = ["crates-index", "rayon"] sentry = "0.27.0" sentry-panic = "0.27.0" sentry-tracing = "0.27.0" +sentry-tower = { version = "0.27.0", features = ["http"] } sentry-anyhow = { version = "0.27.0", features = ["backtrace"] } log = "0.4" tracing = "0.1.37" @@ -89,6 +90,14 @@ memmap2 = "0.5.0" iron = "0.6" router = "0.6" +# axum dependencies +axum = "0.5.17" +hyper = { version = "0.14.15", default-features = false } +tower = "0.4.11" +tower-service = "0.3.2" +tower-http = { version = "0.3.4", features = ["trace"] } +mime = "0.3.16" + # NOTE: if you change this, also double-check that the comment in `queue_builder::remove_tempdirs` is still accurate. tempfile = "3.1.0" diff --git a/src/bin/cratesfyi.rs b/src/bin/cratesfyi.rs index 61ca3e4f5..654480461 100644 --- a/src/bin/cratesfyi.rs +++ b/src/bin/cratesfyi.rs @@ -13,10 +13,11 @@ use docs_rs::utils::{ get_config, queue_builder, remove_crate_priority, set_crate_priority, ConfigName, }; use docs_rs::{ - BuildQueue, Config, Context, Index, Metrics, PackageKind, RustwideBuilder, Server, Storage, + start_web_server, BuildQueue, Config, Context, Index, Metrics, PackageKind, RustwideBuilder, + Storage, }; use once_cell::sync::OnceCell; -use tokio::runtime::Runtime; +use tokio::runtime::{Builder, Runtime}; use tracing_log::LogTracer; use tracing_subscriber::{filter::Directive, prelude::*, EnvFilter}; @@ -156,10 +157,10 @@ impl CommandLine { } Self::StartWebServer { socket_addr } => { // Blocks indefinitely - let _ = Server::start(Some(&socket_addr), &ctx)?; + start_web_server(Some(&socket_addr), &ctx)?; } Self::Daemon { registry_watcher } => { - docs_rs::utils::start_daemon(&ctx, registry_watcher == Toggle::Enabled)?; + docs_rs::utils::start_daemon(ctx, registry_watcher == Toggle::Enabled)?; } Self::Database { subcommand } => subcommand.handle_args(ctx)?, Self::Queue { subcommand } => subcommand.handle_args(ctx)?, @@ -536,6 +537,7 @@ enum DeleteSubcommand { }, } +#[derive(Clone)] struct BinContext { build_queue: OnceCell>, storage: OnceCell>, @@ -597,7 +599,11 @@ impl Context for BinContext { fn cdn(self) -> CdnBackend = CdnBackend::new(&self.config()?, &self.runtime()?); fn config(self) -> Config = Config::from_env()?; fn metrics(self) -> Metrics = Metrics::new()?; - fn runtime(self) -> Runtime = Runtime::new()?; + fn runtime(self) -> Runtime = { + Builder::new_multi_thread() + .enable_all() + .build()? + }; fn index(self) -> Index = { let config = self.config()?; let path = config.registry_index_path.clone(); diff --git a/src/index/mod.rs b/src/index/mod.rs index 99b120e7b..c98b4917a 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -3,7 +3,6 @@ use std::{path::PathBuf, process::Command}; use anyhow::Context; use crates_index_diff::git; -use tracing::debug; use url::Url; use self::api::Api; @@ -90,6 +89,7 @@ impl Index { #[cfg(feature = "consistency_check")] pub(crate) fn crates(&self) -> Result { + use tracing::debug; // First ensure the index is up to date, peeking will pull the latest changes without // affecting anything else. debug!("Updating index"); diff --git a/src/lib.rs b/src/lib.rs index eda3a5c0a..7ee70ab65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ pub use self::docbuilder::RustwideBuilder; pub use self::index::Index; pub use self::metrics::Metrics; pub use self::storage::Storage; -pub use self::web::Server; +pub use self::web::start_web_server; mod build_queue; pub mod cdn; diff --git a/src/test/mod.rs b/src/test/mod.rs index af15c1a58..2c99b7375 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -6,20 +6,31 @@ use crate::db::{Pool, PoolClient}; use crate::error::Result; use crate::repositories::RepositoryStatsUpdater; use crate::storage::{Storage, StorageKind}; -use crate::web::{cache, Server}; +use crate::web::{ + build_axum_app, build_strangler_service, cache, page::TemplateData, start_iron_server, +}; use crate::{BuildQueue, Config, Context, Index, Metrics}; use anyhow::Context as _; use fn_error_context::context; -use iron::headers::CacheControl; +use iron::{headers::CacheControl, Listening}; use once_cell::unsync::OnceCell; use postgres::Client as Connection; use reqwest::{ blocking::{Client, ClientBuilder, RequestBuilder, Response}, Method, }; -use std::{fs, net::SocketAddr, panic, str::FromStr, sync::Arc, time::Duration}; -use tokio::runtime::Runtime; -use tracing::{debug, error}; +use std::thread::{self, JoinHandle}; +use std::{ + fs, + net::{SocketAddr, TcpListener}, + panic, + str::FromStr, + sync::Arc, + time::Duration, +}; +use tokio::runtime::{Builder, Runtime}; +use tokio::sync::oneshot::Sender; +use tracing::{debug, error, instrument, trace}; #[track_caller] pub(crate) fn wrapper(f: impl FnOnce(&TestEnvironment) -> Result<()>) { @@ -240,7 +251,7 @@ impl TestEnvironment { fn cleanup(self) { if let Some(frontend) = self.frontend.into_inner() { - frontend.server.leak(); + frontend.shutdown(); } if let Some(storage) = self.storage.get() { storage @@ -340,7 +351,14 @@ impl TestEnvironment { } pub(crate) fn runtime(&self) -> Arc { self.runtime - .get_or_init(|| Arc::new(Runtime::new().expect("failed to initialize runtime"))) + .get_or_init(|| { + Arc::new( + Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to initialize runtime"), + ) + }) .clone() } @@ -504,12 +522,16 @@ impl Drop for TestDatabase { } pub(crate) struct TestFrontend { - server: Server, + iron_server: Listening, + axum_server_thread: JoinHandle<()>, + axum_server_shutdown_signal: Sender<()>, + axum_server_address: SocketAddr, pub(crate) client: Client, pub(crate) client_no_redirect: Client, } impl TestFrontend { + #[instrument(skip_all)] fn new(context: &dyn Context) -> Self { fn build(f: impl FnOnce(ClientBuilder) -> ClientBuilder) -> Client { let base = Client::builder() @@ -521,24 +543,92 @@ impl TestFrontend { f(base).build().unwrap() } + debug!("loading template data"); + let template_data = + Arc::new(TemplateData::new(&mut context.pool().unwrap().get().unwrap()).unwrap()); + + debug!("starting iron server"); + let iron_server = start_iron_server(context, template_data.clone(), Some(1)) + .expect("could not start iron server"); + + debug!("binding local TCP port for axum"); + let axum_listener = + TcpListener::bind("127.0.0.1:0".parse::().unwrap()).unwrap(); + + let axum_addr = axum_listener.local_addr().unwrap(); + debug!("bound to local address: {}", axum_addr); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + debug!("building axum app"); + let axum_app = build_axum_app(context, template_data).expect("could not build axum app"); + + let handle = thread::spawn({ + let runtime = context.runtime().unwrap(); + move || { + runtime.block_on(async { + axum::Server::from_tcp(axum_listener) + .unwrap() + .serve( + axum_app + .fallback( + build_strangler_service(iron_server.socket) + .expect("could not build strangler service"), + ) + .into_make_service(), + ) + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .await + .expect("error from axum server") + }) + } + }); + Self { - server: Server::start(Some("127.0.0.1:0"), context) - .expect("failed to start the web server"), + iron_server, + axum_server_address: axum_addr, + axum_server_thread: handle, + axum_server_shutdown_signal: tx, client: build(|b| b), client_no_redirect: build(|b| b.redirect(reqwest::redirect::Policy::none())), } } + #[instrument(skip_all)] + fn shutdown(self) { + trace!("sending axum shutdown signal"); + self.axum_server_shutdown_signal + .send(()) + .expect("could not send shutdown signal"); + + trace!("joining axum server thread"); + self.axum_server_thread + .join() + .expect("could not join axum background thread"); + + trace!("forgetting about iron"); + // Iron is bugged, and it never closes the server even when the listener is dropped. To + // avoid never-ending tests this method forgets about the server, leaking it and allowing the + // program to end. + // + // The OS will then close all the dangling servers once the process exits. + // + // https://docs.rs/iron/0.5/iron/struct.Listening.html#method.close + std::mem::forget(self.iron_server); + } + fn build_url(&self, url: &str) -> String { if url.is_empty() || url.starts_with('/') { - format!("http://{}{}", self.server.addr(), url) + format!("http://{}{}", self.axum_server_address, url) } else { url.to_owned() } } pub(crate) fn server_addr(&self) -> SocketAddr { - self.server.addr() + self.axum_server_address } pub(crate) fn get(&self, url: &str) -> RequestBuilder { diff --git a/src/utils/daemon.rs b/src/utils/daemon.rs index b78246b43..22e616fb4 100644 --- a/src/utils/daemon.rs +++ b/src/utils/daemon.rs @@ -4,6 +4,7 @@ use crate::{ utils::{queue_builder, report_error}, + web::start_web_server, BuildQueue, Config, Context, Index, RustwideBuilder, }; use anyhow::{anyhow, Context as _, Error}; @@ -94,21 +95,26 @@ pub fn start_background_repository_stats_updater(context: &dyn Context) -> Resul Ok(()) } -pub fn start_daemon(context: &dyn Context, enable_registry_watcher: bool) -> Result<(), Error> { +pub fn start_daemon( + context: C, + enable_registry_watcher: bool, +) -> Result<(), Error> { // Start the web server before doing anything more expensive // Please check with an administrator before changing this (see #1172 for context). info!("Starting web server"); - let server = crate::Server::start(None, context)?; - let server_thread = thread::spawn(|| drop(server)); + let webserver_thread = thread::spawn({ + let context = context.clone(); + move || start_web_server(None, &context) + }); if enable_registry_watcher { // check new crates every minute - start_registry_watcher(context)?; + start_registry_watcher(&context)?; } // build new crates every minute let build_queue = context.build_queue()?; - let rustwide_builder = RustwideBuilder::init(context)?; + let rustwide_builder = RustwideBuilder::init(&context)?; thread::Builder::new() .name("build queue reader".to_string()) .spawn(move || { @@ -116,14 +122,13 @@ pub fn start_daemon(context: &dyn Context, enable_registry_watcher: bool) -> Res }) .unwrap(); - start_background_repository_stats_updater(context)?; + start_background_repository_stats_updater(&context)?; - // Never returns; `server` blocks indefinitely when dropped - // NOTE: if a anyhow occurred earlier in `start_daemon`, the server will _not_ be joined - + // NOTE: if a error occurred earlier in `start_daemon`, the server will _not_ be joined - // instead it will get killed when the process exits. - server_thread + webserver_thread .join() - .map_err(|_| anyhow!("web server panicked")) + .map_err(|err| anyhow!("web server panicked: {:?}", err))? } pub(crate) fn cron(name: &'static str, interval: Duration, exec: F) -> Result<(), Error> diff --git a/src/web/cache.rs b/src/web/cache.rs index c01bfbae0..e12b59e9c 100644 --- a/src/web/cache.rs +++ b/src/web/cache.rs @@ -1,9 +1,14 @@ use super::STATIC_FILE_CACHE_DURATION; use crate::config::Config; +use axum::{ + http::Request as AxumHttpRequest, middleware::Next, response::Response as AxumResponse, +}; +use http::header::CACHE_CONTROL; use iron::{ headers::{CacheControl, CacheDirective}, AfterMiddleware, IronResult, Request, Response, }; +use std::sync::Arc; #[cfg(test)] pub const NO_CACHE: &str = "max-age=0"; @@ -112,6 +117,40 @@ impl AfterMiddleware for CacheMiddleware { } } +pub(crate) async fn cache_middleware(req: AxumHttpRequest, next: Next) -> AxumResponse { + let config = req + .extensions() + .get::>() + .cloned() + .expect("missing config extension in request"); + + let mut response = next.run(req).await; + + let cache = response + .extensions() + .get::() + .unwrap_or(&CachePolicy::NoCaching); + + if cfg!(test) { + assert!( + !response.headers().contains_key(CACHE_CONTROL), + "handlers should never set their own caching headers and only use CachePolicy to control caching." + ); + } + + let directives = cache.render(&config); + if !directives.is_empty() { + response.headers_mut().insert( + CACHE_CONTROL, + CacheControl(directives) + .to_string() + .parse() + .expect("cache-control header could not be parsed"), + ); + } + response +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/web/csp.rs b/src/web/csp.rs index 3298dd412..b53590c46 100644 --- a/src/web/csp.rs +++ b/src/web/csp.rs @@ -1,6 +1,9 @@ use crate::config::Config; +use axum::{ + http::Request as AxumHttpRequest, middleware::Next, response::Response as AxumResponse, +}; use iron::{AfterMiddleware, BeforeMiddleware, IronResult, Request, Response}; -use std::fmt::Write; +use std::{fmt::Write, sync::Arc}; pub(super) struct Csp { nonce: String, @@ -137,6 +140,52 @@ impl AfterMiddleware for CspMiddleware { } } +pub(crate) async fn csp_middleware(mut req: AxumHttpRequest, next: Next) -> AxumResponse { + let csp_report_only = req + .extensions() + .get::>() + .expect("missing config extension in request") + .csp_report_only; + + let csp = Arc::new(Csp::new()); + req.extensions_mut().insert(csp.clone()); + + let mut response = next.run(req).await; + + let content_type = response + .headers() + .get("Content-Type") + .map(|header| header.as_bytes()); + + let preset = match content_type { + Some(b"text/html; charset=utf-8") => ContentType::Html, + Some(b"text/svg+xml") => ContentType::Svg, + _ => ContentType::Other, + }; + + let rendered = csp.render(preset); + + if let Some(rendered) = rendered { + let mut headers = response.headers_mut().clone(); + headers.insert( + // The Report-Only header tells the browser to just log CSP failures instead of + // actually enforcing them. This is useful to check if the CSP works without + // impacting production traffic. + if csp_report_only { + http::header::CONTENT_SECURITY_POLICY_REPORT_ONLY + } else { + http::header::CONTENT_SECURITY_POLICY + }, + rendered + .parse() + .expect("rendered CSP could not be parsed into header value"), + ); + *response.headers_mut() = headers; + } + + response +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/web/error.rs b/src/web/error.rs index 048d8b7ef..8014c1346 100644 --- a/src/web/error.rs +++ b/src/web/error.rs @@ -1,6 +1,12 @@ +use std::borrow::Cow; + use crate::{ db::PoolError, - web::{page::WebPage, releases::Search, ErrorPage}, + web::{page::WebPage, releases::Search, AxumErrorPage, ErrorPage}, +}; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response as AxumResponse}, }; use iron::{status::Status, Handler, IronError, IronResult, Request, Response}; @@ -42,7 +48,7 @@ impl From for IronError { impl Handler for Nope { fn handle(&self, req: &mut Request) -> IronResult { - match *self { + match self { Nope::ResourceNotFound => { // user tried to navigate to a resource (doc page/file) that doesn't exist // TODO: Display the attempted page @@ -132,6 +138,102 @@ impl From for IronError { } } +#[derive(Debug, thiserror::Error)] +#[allow(dead_code)] // FIXME: remove after iron is gone +pub enum AxumNope { + #[error("Requested resource not found")] + ResourceNotFound, + #[error("Requested build not found")] + BuildNotFound, + #[error("Requested crate not found")] + CrateNotFound, + #[error("Requested owner not found")] + OwnerNotFound, + #[error("Requested crate does not have specified version")] + VersionNotFound, + // #[error("Search yielded no results")] + // NoResults, + #[error("Internal server error")] + InternalServerError, + #[error("internal error")] + InternalError(#[from] anyhow::Error), +} + +impl IntoResponse for AxumNope { + fn into_response(self) -> AxumResponse { + match self { + AxumNope::ResourceNotFound => { + // user tried to navigate to a resource (doc page/file) that doesn't exist + AxumErrorPage { + title: "Requested resource does not exist", + message: "no such resource".into(), + status: StatusCode::NOT_FOUND, + } + .into_response() + } + + AxumNope::BuildNotFound => AxumErrorPage { + title: "The requested build does not exist", + message: "no such build".into(), + status: StatusCode::NOT_FOUND, + } + .into_response(), + + AxumNope::CrateNotFound => { + // user tried to navigate to a crate that doesn't exist + // TODO: Display the attempted crate and a link to a search for said crate + AxumErrorPage { + title: "The requested crate does not exist", + message: "no such crate".into(), + status: StatusCode::NOT_FOUND, + } + .into_response() + } + + AxumNope::OwnerNotFound => AxumErrorPage { + title: "The requested owner does not exist", + message: "no such owner".into(), + status: StatusCode::NOT_FOUND, + } + .into_response(), + + AxumNope::VersionNotFound => { + // user tried to navigate to a crate with a version that does not exist + // TODO: Display the attempted crate and version + AxumErrorPage { + title: "The requested version does not exist", + message: "no such version for this crate".into(), + status: StatusCode::NOT_FOUND, + } + .into_response() + } + // AxumNope::NoResults => { + // todo!("to be implemented when search-handler is migrated to axum") + // } + AxumNope::InternalServerError => { + // something went wrong, details should have been logged + AxumErrorPage { + title: "Internal server error", + message: "internal server error".into(), + status: StatusCode::INTERNAL_SERVER_ERROR, + } + .into_response() + } + AxumNope::InternalError(source) => { + let web_error = crate::web::AxumErrorPage { + title: "Internal Server Error", + message: Cow::Owned(source.to_string()), + status: StatusCode::INTERNAL_SERVER_ERROR, + }; + + crate::utils::report_error(&source); + + web_error.into_response() + } + } + } +} + #[cfg(test)] mod tests { use crate::test::wrapper; diff --git a/src/web/metrics.rs b/src/web/metrics.rs index ab6d80731..899aec7cf 100644 --- a/src/web/metrics.rs +++ b/src/web/metrics.rs @@ -1,11 +1,18 @@ use crate::db::Pool; use crate::BuildQueue; use crate::Metrics; +use axum::{ + extract::MatchedPath, http::Request as AxumRequest, middleware::Next, response::IntoResponse, +}; use iron::headers::ContentType; use iron::prelude::*; use iron::status::Status; use prometheus::{Encoder, HistogramVec, TextEncoder}; -use std::time::{Duration, Instant}; +use std::{ + borrow::Cow, + sync::Arc, + time::{Duration, Instant}, +}; #[cfg(test)] use tracing::debug; @@ -32,6 +39,52 @@ fn duration_to_seconds(d: Duration) -> f64 { d.as_secs() as f64 + nanos } +/// Request recorder middleware +/// +/// Looks similar, but *is not* a usable middleware / layer +/// since we need the route-name. +/// +/// Can be used like: +/// ```ignore +/// get(handler).route_layer(middleware::from_fn(|request, next| async { +/// request_recorder(request, next, Some("static resource")).await +/// })) +/// ``` +pub(crate) async fn request_recorder( + request: AxumRequest, + next: Next, + route_name: Option<&str>, +) -> impl IntoResponse { + let route_name = if let Some(rn) = route_name { + Cow::Borrowed(rn) + } else if let Some(path) = request.extensions().get::() { + Cow::Owned(path.as_str().to_string()) + } else { + Cow::Owned(request.uri().path().to_string()) + }; + + let metrics = request + .extensions() + .get::>() + .expect("metrics missing in request extensions") + .clone(); + + let start = Instant::now(); + let result = next.run(request).await; + let resp_time = duration_to_seconds(start.elapsed()); + + metrics + .routes_visited + .with_label_values(&[&route_name]) + .inc(); + metrics + .response_time + .with_label_values(&[&route_name]) + .observe(resp_time); + + result +} + pub(super) struct RequestRecorder { handler: Box, route_name: String, diff --git a/src/web/mod.rs b/src/web/mod.rs index 282c87e24..a9d0b4336 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,12 +1,12 @@ //! Web interface of docs.rs -pub(crate) mod page; +pub mod page; use crate::utils::get_correct_docsrs_style_file; use crate::utils::report_error; use anyhow::anyhow; use serde_json::Value; -use tracing::info; +use tracing::{info, instrument}; /// ctry! (cratesfyitry) is extremely similar to try! and itry! /// except it returns an error page response instead of plain Err. @@ -90,9 +90,15 @@ mod rustdoc; mod sitemap; mod source; mod statics; +mod strangler; -use crate::{impl_webpage, Context}; +use crate::{impl_axum_webpage, impl_webpage, Context}; use anyhow::Error; +use axum::{ + extract::Extension, + http::{uri::Authority, StatusCode}, + middleware, Router as AxumRouter, +}; use chrono::{DateTime, Utc}; use csp::CspMiddleware; use error::Nope; @@ -112,13 +118,16 @@ use semver::{Version, VersionReq}; use serde::Serialize; use std::borrow::Borrow; use std::{borrow::Cow, net::SocketAddr, sync::Arc}; +use strangler::StranglerService; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; /// Duration of static files for staticfile and DatabaseFileHandler (in seconds) const STATIC_FILE_CACHE_DURATION: u64 = 60 * 60 * 24 * 30 * 12; // 12 months const DEFAULT_BIND: &str = "0.0.0.0:3000"; -struct MainHandler { +pub(crate) struct MainHandler { shared_resource_handler: Box, router_handler: Box, inject_extensions: InjectExtensions, @@ -136,7 +145,10 @@ impl MainHandler { chain } - fn new(template_data: Arc, context: &dyn Context) -> Result { + pub(crate) fn new( + template_data: Arc, + context: &dyn Context, + ) -> Result { let inject_extensions = InjectExtensions::new(context, template_data)?; let routes = routes::build_routes(); @@ -224,9 +236,7 @@ impl Handler for MainHandler { .or_else(|e| { let err = if let Some(err) = e.error.downcast_ref::() { *err - } else if e.error.downcast_ref::().is_some() - || e.response.status == Some(status::NotFound) - { + } else if e.error.is::() || e.response.status == Some(status::NotFound) { error::Nope::ResourceNotFound } else if e.response.status == Some(status::InternalServerError) { report_error(&anyhow!("internal server error: {}", e.error)); @@ -416,51 +426,72 @@ fn match_version( Err(Nope::VersionNotFound) } -#[must_use = "`Server` blocks indefinitely when dropped"] -pub struct Server { - inner: Listening, +#[instrument(skip_all)] +pub(crate) fn build_axum_app( + context: &dyn Context, + template_data: Arc, +) -> Result { + Ok(routes::build_axum_routes().layer( + // It’s recommended to use tower::ServiceBuilder to apply multiple middleware at once, + // instead of calling Router::layer repeatedly: + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(sentry_tower::NewSentryLayer::new_from_top()) + .layer(sentry_tower::SentryHttpLayer::with_transaction()) + .layer(Extension(context.pool()?)) + .layer(Extension(context.build_queue()?)) + .layer(Extension(context.metrics()?)) + .layer(Extension(context.config()?)) + .layer(Extension(context.storage()?)) + .layer(Extension(template_data)) + .layer(middleware::from_fn(csp::csp_middleware)) + .layer(middleware::from_fn( + page::web_page::render_templates_middleware, + )) + .layer(middleware::from_fn(cache::cache_middleware)), + )) } -impl Server { - pub fn start(addr: Option<&str>, context: &dyn Context) -> Result { - // Initialize templates - let template_data = Arc::new(TemplateData::new(&mut *context.pool()?.get()?)?); - let server = Self::start_inner(addr.unwrap_or(DEFAULT_BIND), template_data, context)?; - info!("Running docs.rs web server on http://{}", server.addr()); - Ok(server) +#[instrument(skip_all)] +pub(crate) fn build_strangler_service(addr: SocketAddr) -> Result { + Ok(StranglerService::new(Authority::try_from( + addr.to_string(), + )?)) +} + +#[instrument(skip_all)] +pub(crate) fn start_iron_server( + context: &dyn Context, + template_data: Arc, + threads: Option, +) -> Result { + let mut iron = Iron::new(MainHandler::new(template_data, context)?); + if let Some(threads) = threads { + iron.threads = threads; } + Ok(iron.http("127.0.0.1:0")?) +} - fn start_inner( - addr: &str, - template_data: Arc, - context: &dyn Context, - ) -> Result { - let mut iron = Iron::new(MainHandler::new(template_data, context)?); - if cfg!(test) { - iron.threads = 1; - } - let inner = iron - .http(addr) - .unwrap_or_else(|_| panic!("Failed to bind to socket on {}", addr)); +#[instrument(skip_all)] +pub fn start_web_server(addr: Option<&str>, context: &dyn Context) -> Result<(), Error> { + let template_data = Arc::new(TemplateData::new(&mut *context.pool()?.get()?)?); - Ok(Server { inner }) - } + let iron_server = start_iron_server(context, template_data.clone(), None)?; - pub(crate) fn addr(&self) -> SocketAddr { - self.inner.socket - } + let axum_addr: SocketAddr = addr.unwrap_or(DEFAULT_BIND).parse()?; - /// Iron is bugged, and it never closes the server even when the listener is dropped. To - /// avoid never-ending tests this method forgets about the server, leaking it and allowing the - /// program to end. - /// - /// The OS will then close all the dangling servers once the process exits. - /// - /// https://docs.rs/iron/0.5/iron/struct.Listening.html#method.close - #[cfg(test)] - pub(crate) fn leak(self) { - std::mem::forget(self.inner); - } + context.runtime()?.block_on(async { + axum::Server::bind(&axum_addr) + .serve( + build_axum_app(context, template_data)? + .fallback(build_strangler_service(iron_server.socket)?) + .into_make_service(), + ) + .await?; + Ok::<(), Error>(()) + })?; + + Ok(()) } /// Converts Timespec to nice readable relative time string @@ -632,6 +663,21 @@ impl_webpage! { status = |err| err.status, } +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct AxumErrorPage { + /// The title of the page + pub title: &'static str, + /// The error message, displayed as a description + pub message: Cow<'static, str>, + #[serde(skip)] + pub status: StatusCode, +} + +impl_axum_webpage! { + AxumErrorPage = "error.html", + status = |err| err.status, +} + #[cfg(test)] mod test { use super::*; diff --git a/src/web/page/mod.rs b/src/web/page/mod.rs index e09529675..5d61cf81f 100644 --- a/src/web/page/mod.rs +++ b/src/web/page/mod.rs @@ -1,5 +1,5 @@ mod templates; -mod web_page; +pub(crate) mod web_page; pub(crate) use templates::TemplateData; pub(super) use web_page::WebPage; diff --git a/src/web/page/templates.rs b/src/web/page/templates.rs index 6b42bc019..d0b662a97 100644 --- a/src/web/page/templates.rs +++ b/src/web/page/templates.rs @@ -42,7 +42,7 @@ fn load_rustc_resource_suffix(conn: &mut Client) -> Result { anyhow::bail!("failed to parse the rustc version"); } -pub(super) fn load_templates(conn: &mut Client) -> Result { +fn load_templates(conn: &mut Client) -> Result { // This uses a custom function to find the templates in the filesystem instead of Tera's // builtin way (passing a glob expression to Tera::new), speeding up the startup of the // application and running the tests. diff --git a/src/web/page/web_page.rs b/src/web/page/web_page.rs index 1a7f0e098..113bc1a47 100644 --- a/src/web/page/web_page.rs +++ b/src/web/page/web_page.rs @@ -1,12 +1,20 @@ use super::TemplateData; use crate::{ ctry, - web::{cache::CachePolicy, csp::Csp}, + web::{cache::CachePolicy, csp::Csp, error::AxumNope}, }; +use anyhow::anyhow; +use axum::{ + body::{boxed, Body}, + http::Request as AxumRequest, + middleware::Next, + response::{IntoResponse, Response as AxumResponse}, +}; +use http::header::CONTENT_LENGTH; use iron::{headers::ContentType, response::Response, status::Status, IronResult, Request}; use serde::Serialize; -use std::borrow::Cow; -use tera::Context; +use std::{borrow::Cow, sync::Arc}; +use tera::{Context, Tera}; /// When making using a custom status, use a closure that coerces to a `fn(&Self) -> Status` #[macro_export] @@ -38,6 +46,50 @@ macro_rules! impl_webpage { }; } +#[macro_export] +macro_rules! impl_axum_webpage { + ($page:ty = $template:literal $(, status = $status:expr)? $(, content_type = $content_type:expr)? $(,)?) => { + $crate::impl_axum_webpage!($page = |_| ::std::borrow::Cow::Borrowed($template) $(, status = $status)? $(, content_type = $content_type)?); + }; + + ($page:ty = $template:expr $(, status = $status:expr)? $(, content_type = $content_type:expr)? $(,)?) => { + impl axum::response::IntoResponse for $page + { + fn into_response(self) -> ::axum::response::Response { + // set a default content type, eventually override from the page + #[allow(unused_mut, unused_assignments)] + let mut ct: &'static str = ::mime::TEXT_HTML_UTF_8.as_ref(); + $( + ct = $content_type; + )? + + let mut response = ::axum::http::Response::builder() + .header(::axum::http::header::CONTENT_TYPE, ct) + $( + .status({ + let status: fn(&$page) -> ::axum::http::StatusCode = $status; + (status)(&self) + }) + )? + // this empty body will be replaced in `render_templates_middleware` using + // the data from `DelayedTemplateRender` below. + .body(::axum::body::boxed(::axum::body::Body::empty())) + .unwrap(); + + response.extensions_mut().insert($crate::web::page::web_page::DelayedTemplateRender { + context: ::tera::Context::from_serialize(&self) + .expect("could not create tera context from web-page"), + template: { + let template: fn(&Self) -> ::std::borrow::Cow<'static, str> = $template; + template(&self).to_string() + }, + }); + response + } + } + }; +} + #[derive(Serialize)] struct TemplateContext<'a, T> { csp_nonce: &'a str, @@ -105,3 +157,62 @@ pub trait WebPage: Serialize + Sized { None } } + +/// adding this to the axum response extensions will lead +/// to the template being rendered, adding the csp_nonce to +/// the context. +pub(crate) struct DelayedTemplateRender { + pub template: String, + pub context: Context, +} + +fn render_response(mut response: AxumResponse, templates: &Tera, csp_nonce: &str) -> AxumResponse { + if let Some(render) = response.extensions().get::() { + let mut context = render.context.clone(); + context.insert("csp_nonce", &csp_nonce); + + let rendered = match templates.render(&render.template, &context) { + Ok(content) => content, + Err(err) => { + if response.status().is_server_error() { + // avoid infinite loop if error.html somehow fails to load + panic!("error while serving error page: {:?}", err); + } else { + return render_response( + AxumNope::InternalError(anyhow!(err)).into_response(), + templates, + csp_nonce, + ); + } + } + }; + let content_length = rendered.len(); + *response.body_mut() = boxed(Body::from(rendered)); + response + .headers_mut() + .insert(CONTENT_LENGTH, content_length.into()); + response + } else { + response + } +} + +pub(crate) async fn render_templates_middleware( + req: AxumRequest, + next: Next, +) -> AxumResponse { + let templates: Arc = req + .extensions() + .get::>() + .expect("template data request extension not found") + .clone(); + + let csp_nonce = req + .extensions() + .get::>() + .expect("csp request extension not found") + .nonce() + .to_owned(); + + render_response(next.run(req).await, &templates.templates, &csp_nonce) +} diff --git a/src/web/routes.rs b/src/web/routes.rs index f0ed7c42f..be316e8fa 100644 --- a/src/web/routes.rs +++ b/src/web/routes.rs @@ -1,27 +1,72 @@ use crate::web::page::WebPage; +use super::metrics::request_recorder; use super::{cache::CachePolicy, metrics::RequestRecorder}; -use ::std::borrow::Cow; +use axum::{ + handler::Handler as AxumHandler, middleware, response::Redirect, routing::get, + routing::MethodRouter, Router as AxumRouter, +}; use iron::middleware::Handler; -use router::Router; -use std::collections::HashSet; +use router::Router as IronRouter; +use std::{borrow::Cow, collections::HashSet, convert::Infallible}; +use tracing::instrument; + +#[instrument(skip_all)] +fn get_static(handler: H) -> MethodRouter +where + H: AxumHandler, + B: Send + 'static + hyper::body::HttpBody, + T: 'static, +{ + get(handler).route_layer(middleware::from_fn(|request, next| async { + request_recorder(request, next, Some("static resource")).await + })) +} + +#[instrument(skip_all)] +fn get_internal(handler: H) -> MethodRouter +where + H: AxumHandler, + B: Send + 'static + hyper::body::HttpBody, + T: 'static, +{ + get(handler).route_layer(middleware::from_fn(|request, next| async { + request_recorder(request, next, None).await + })) +} + +pub(super) fn build_axum_routes() -> AxumRouter { + AxumRouter::new() + // Well known resources, robots.txt and favicon.ico support redirection, the sitemap.xml + // must live at the site root: + // https://developers.google.com/search/reference/robots_txt#handling-http-result-codes + // https://support.google.com/webmasters/answer/183668?hl=en + .route( + "/robots.txt", + get_static(|| async { Redirect::permanent("/-/static/robots.txt") }), + ) + .route( + "/favicon.ico", + get_static(|| async { Redirect::permanent("/-/static/favicon.ico") }), + ) + .route( + "/sitemap.xml", + get_internal(super::sitemap::sitemapindex_handler), + ) + .route( + "/-/sitemap/:letter/sitemap.xml", + get_internal(super::sitemap::sitemap_handler), + ) + .route( + "/about/builds", + get_internal(super::sitemap::about_builds_handler), + ) +} // REFACTOR: Break this into smaller initialization functions pub(super) fn build_routes() -> Routes { let mut routes = Routes::new(); - // Well known resources, robots.txt and favicon.ico support redirection, the sitemap.xml - // must live at the site root: - // https://developers.google.com/search/reference/robots_txt#handling-http-result-codes - // https://support.google.com/webmasters/answer/183668?hl=en - routes.static_resource("/robots.txt", PermanentRedirect("/-/static/robots.txt")); - routes.static_resource("/favicon.ico", PermanentRedirect("/-/static/favicon.ico")); - routes.internal_page("/sitemap.xml", super::sitemap::sitemapindex_handler); - routes.internal_page( - "/-/sitemap/:letter/sitemap.xml", - super::sitemap::sitemap_handler, - ); - // This should not need to be served from the root as we reference the inner path in links, // but clients might have cached the url and need to update it. routes.static_resource( @@ -58,7 +103,6 @@ pub(super) fn build_routes() -> Routes { routes.internal_page("/about", super::sitemap::about_handler); routes.internal_page("/about/metrics", super::metrics::metrics_handler); - routes.internal_page("/about/builds", super::sitemap::about_builds_handler); routes.internal_page("/about/:subpage", super::sitemap::about_handler); routes.internal_page("/releases", super::releases::recent_releases_handler); @@ -201,8 +245,8 @@ impl Routes { self.page_prefixes.clone() } - pub(super) fn iron_router(mut self) -> Router { - let mut router = Router::new(); + pub(super) fn iron_router(mut self) -> IronRouter { + let mut router = IronRouter::new(); for (pattern, handler) in self.get.drain(..) { router.get(&pattern, handler, calculate_id(&pattern)); } diff --git a/src/web/sitemap.rs b/src/web/sitemap.rs index 8cec32d11..7a5b7b1df 100644 --- a/src/web/sitemap.rs +++ b/src/web/sitemap.rs @@ -1,19 +1,19 @@ use crate::{ db::Pool, docbuilder::Limits, - impl_webpage, + impl_axum_webpage, impl_webpage, utils::{get_config, ConfigName}, - web::error::Nope, - web::page::WebPage, + web::{error::AxumNope, page::WebPage}, }; -use chrono::{DateTime, TimeZone, Utc}; -use iron::{ - headers::ContentType, - mime::{Mime, SubLevel, TopLevel}, - IronResult, Request, Response, +use anyhow::Context; +use axum::{ + extract::{Extension, Path}, + response::IntoResponse, }; -use router::Router; +use chrono::{DateTime, TimeZone, Utc}; +use iron::{IronResult, Request as IronRequest, Response as IronResponse}; use serde::Serialize; +use tokio::task::spawn_blocking; /// sitemap index #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -21,15 +21,15 @@ struct SitemapIndexXml { sitemaps: Vec, } -impl_webpage! { +impl_axum_webpage! { SitemapIndexXml = "core/sitemapindex.xml", - content_type = ContentType(Mime(TopLevel::Application, SubLevel::Xml, vec![])), + content_type = "application/xml", } -pub fn sitemapindex_handler(req: &mut Request) -> IronResult { +pub(crate) async fn sitemapindex_handler() -> impl IntoResponse { let sitemaps: Vec = ('a'..='z').collect(); - SitemapIndexXml { sitemaps }.into_response(req) + SitemapIndexXml { sitemaps } } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -45,26 +45,25 @@ struct SitemapXml { releases: Vec, } -impl_webpage! { +impl_axum_webpage! { SitemapXml = "core/sitemap.xml", - content_type = ContentType(Mime(TopLevel::Application, SubLevel::Xml, vec![])), + content_type = "application/xml", } -pub fn sitemap_handler(req: &mut Request) -> IronResult { - let router = extension!(req, Router); - let letter = cexpect!(req, router.find("letter")); - +pub(crate) async fn sitemap_handler( + Path(letter): Path, + Extension(pool): Extension, +) -> Result { if letter.len() != 1 { - return Err(Nope::ResourceNotFound.into()); + return Err(AxumNope::ResourceNotFound); } else if let Some(ch) = letter.chars().next() { if !(ch.is_ascii_lowercase()) { - return Err(Nope::ResourceNotFound.into()); + return Err(AxumNope::ResourceNotFound); } } - - let mut conn = extension!(req, Pool).get()?; - let query = conn - .query( + let releases = spawn_blocking(move || -> anyhow::Result<_> { + let mut conn = pool.get()?; + let query = conn.query( "SELECT crates.name, releases.target_name, MAX(releases.release_time) as release_time @@ -76,25 +75,27 @@ pub fn sitemap_handler(req: &mut Request) -> IronResult { GROUP BY crates.name, releases.target_name ", &[&format!("{}%", letter)], - ) - .unwrap(); - - let releases = query - .into_iter() - .map(|row| SitemapRow { - crate_name: row.get("name"), - target_name: row.get("target_name"), - last_modified: row - .get::<_, DateTime>("release_time") - // On Aug 27 2022 we added `` to all pages, - // so they should all get recrawled if they haven't been since then. - .max(Utc.ymd(2022, 8, 28).and_hms(0, 0, 0)) - .format("%+") - .to_string(), - }) - .collect(); - - SitemapXml { releases }.into_response(req) + )?; + + Ok(query + .into_iter() + .map(|row| SitemapRow { + crate_name: row.get("name"), + target_name: row.get("target_name"), + last_modified: row + .get::<_, DateTime>("release_time") + // On Aug 27 2022 we added `` to all pages, + // so they should all get recrawled if they haven't been since then. + .max(Utc.ymd(2022, 8, 28).and_hms(0, 0, 0)) + .format("%+") + .to_string(), + }) + .collect()) + }) + .await + .context("failed to join thread")??; + + Ok(SitemapXml { releases }) } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -107,22 +108,23 @@ struct AboutBuilds { active_tab: &'static str, } -impl_webpage!(AboutBuilds = "core/about/builds.html"); - -pub fn about_builds_handler(req: &mut Request) -> IronResult { - let mut conn = extension!(req, Pool).get()?; +impl_axum_webpage!(AboutBuilds = "core/about/builds.html"); - let rustc_version = ctry!( - req, +pub(crate) async fn about_builds_handler( + Extension(pool): Extension, +) -> Result { + let rustc_version = spawn_blocking(move || -> anyhow::Result<_> { + let mut conn = pool.get()?; get_config::(&mut conn, ConfigName::RustcVersion) - ); + }) + .await + .context("failed to join thread")??; - AboutBuilds { + Ok(AboutBuilds { rustc_version, limits: Limits::default(), active_tab: "builds", - } - .into_response(req) + }) } #[derive(Serialize)] @@ -134,7 +136,7 @@ struct AboutPage<'a> { impl_webpage!(AboutPage<'_> = |this: &AboutPage| this.template.clone().into()); -pub fn about_handler(req: &mut Request) -> IronResult { +pub fn about_handler(req: &mut IronRequest) -> IronResult { use super::ErrorPage; use iron::status::Status; diff --git a/src/web/strangler.rs b/src/web/strangler.rs new file mode 100644 index 000000000..f2484eddf --- /dev/null +++ b/src/web/strangler.rs @@ -0,0 +1,258 @@ +/// This implements a simplified strangler-service, +/// using code from +/// https://github.com/MidasLamb/axum-strangler/ +/// +/// because +/// * axum-strangler breaks redirects in our current implementation: +/// https://github.com/MidasLamb/axum-strangler/issues/4 +/// * it adds quite some dependencies it only needs for supporting WebSockets (which we don't +/// need). +/// * it has more dependencies itself doesn't need (reqwest) +/// +/// We might be able to switch back to using the library when +/// * the host/redirect problem is fixed +/// * websocket suport is hidden behind a feature +use std::{ + convert::Infallible, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use axum::{ + extract::RequestParts, + http::{uri::Authority, Uri}, +}; +use tower_service::Service; + +/// Service that forwards all requests to another service +/// ```ignore +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let strangler_svc = StranglerService::new( +/// axum::http::uri::Authority::from_static("127.0.0.1:3333"), +/// ); +/// let router = axum::Router::new().fallback(strangler_svc); +/// axum::Server::bind(&"127.0.0.1:0".parse()?) +/// .serve(router.into_make_service()) +/// # .with_graceful_shutdown(async { +/// # // Shut down immediately +/// # }) +/// .await?; +/// Ok(()) +/// } +/// ``` +#[derive(Clone)] +pub struct StranglerService { + http_client: hyper::Client, + inner: Arc, +} + +impl StranglerService { + /// Construct a new `StranglerService`. + /// The `strangled_authority` is the host & port of the service to be strangled. + pub fn new(strangled_authority: Authority) -> Self { + Self { + http_client: hyper::Client::new(), + inner: Arc::new(InnerStranglerService { + strangled_authority, + }), + } + } +} + +struct InnerStranglerService { + strangled_authority: axum::http::uri::Authority, +} + +impl Service> for StranglerService { + type Response = axum::response::Response; + type Error = Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: axum::http::Request) -> Self::Future { + let http_client = self.http_client.clone(); + let inner = self.inner.clone(); + + let fut = forward_call_to_strangled(http_client, inner, req); + Box::pin(fut) + } +} + +#[tracing::instrument(skip_all, fields(req.path = %req.uri()))] // Note that we set the path to the + // "full" uri, as host etc gets + // removed by axum already. +async fn forward_call_to_strangled( + http_client: hyper::Client, + inner: Arc, + req: axum::http::Request, +) -> Result { + tracing::debug!("handling a request"); + let mut request_parts = RequestParts::new(req); + let req: Result, _> = request_parts.extract().await; + let mut req = req.unwrap(); + + let uri: Uri = { + // Not really anything to do, because this could just not be a websocket + // request. + let strangled_authority = inner.strangled_authority.clone(); + let strangled_scheme = axum::http::uri::Scheme::HTTP; + Uri::builder() + .authority(strangled_authority) + .scheme(strangled_scheme) + .path_and_query(req.uri().path_and_query().cloned().unwrap()) + .build() + .unwrap() + }; + + *req.uri_mut() = uri; + + let r = http_client.request(req).await.unwrap(); + + let mut response_builder = axum::response::Response::builder(); + response_builder = response_builder.status(r.status()); + + if let Some(headers) = response_builder.headers_mut() { + *headers = r.headers().clone(); + } + + let response = response_builder + .body(axum::body::boxed(r)) + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR); + + match response { + Ok(response) => Ok(response), + Err(_) => todo!(), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use axum::{http::uri::Authority, routing::get, Extension, Router}; + + /// Create a mock service that's not connecting to anything. + fn make_svc() -> StranglerService { + StranglerService::new(Authority::from_static("127.0.0.1:0")) + } + + #[tokio::test] + async fn can_be_used_as_fallback() { + let router = Router::new().fallback(make_svc()); + axum::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(router.into_make_service()); + } + + #[tokio::test] + async fn can_be_used_for_a_route() { + let router = Router::new().route("/api", make_svc()); + axum::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(router.into_make_service()); + } + + #[derive(Clone)] + struct StopChannel(Arc>); + + struct StartupHelper { + strangler_port: u16, + strangler_joinhandle: tokio::task::JoinHandle<()>, + stranglee_joinhandle: tokio::task::JoinHandle<()>, + } + + async fn start_up_strangler_and_strangled(strangled_router: Router) -> StartupHelper { + let (tx, mut rx_1) = tokio::sync::broadcast::channel::<()>(1); + let mut rx_2 = tx.subscribe(); + let tx_arc = Arc::new(tx); + let stop_channel = StopChannel(tx_arc); + + let stranglee_tcp = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let stranglee_port = stranglee_tcp.local_addr().unwrap().port(); + + let strangler_tcp = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let strangler_port = strangler_tcp.local_addr().unwrap().port(); + + let client = hyper::Client::new(); + let strangler_svc = StranglerService { + http_client: client, + inner: Arc::new(InnerStranglerService { + strangled_authority: axum::http::uri::Authority::try_from(format!( + "127.0.0.1:{}", + stranglee_port + )) + .unwrap(), + }), + }; + + let background_stranglee_handle = tokio::spawn(async move { + axum::Server::from_tcp(stranglee_tcp) + .unwrap() + .serve( + strangled_router + .layer(axum::Extension(stop_channel)) + .into_make_service(), + ) + .with_graceful_shutdown(async { + rx_1.recv().await.ok(); + }) + .await + .unwrap(); + }); + + let background_strangler_handle = tokio::spawn(async move { + let router = Router::new().fallback(strangler_svc); + axum::Server::from_tcp(strangler_tcp) + .unwrap() + .serve(router.into_make_service()) + .with_graceful_shutdown(async { + rx_2.recv().await.ok(); + }) + .await + .unwrap(); + }); + + StartupHelper { + strangler_port, + strangler_joinhandle: background_strangler_handle, + stranglee_joinhandle: background_stranglee_handle, + } + } + + #[tokio::test] + async fn proxies_strangled_http_service() { + let router = Router::new().route( + "/api/something", + get( + |Extension(StopChannel(tx_arc)): Extension| async move { + tx_arc.send(()).unwrap(); + "I'm being strangled" + }, + ), + ); + + let StartupHelper { + strangler_port, + strangler_joinhandle, + stranglee_joinhandle, + } = start_up_strangler_and_strangled(router).await; + + let c = reqwest::Client::new(); + let r = c + .get(format!("http://127.0.0.1:{}/api/something", strangler_port)) + .send() + .await + .unwrap() + .text() + .await + .unwrap(); + + assert_eq!(r, "I'm being strangled"); + + stranglee_joinhandle.await.unwrap(); + strangler_joinhandle.await.unwrap(); + } +}