diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 57474c5303..a549c5eaff 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -66,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Implement `IntoResponse` for `(R,) where R: IntoResponse` ([#2143]) - **changed:** For SSE, add space between field and value for compatibility ([#2149]) - **added:** Add `NestedPath` extractor ([#1924]) +- **added:** Add `FixNestedRedirect` middleware [#1664]: https://github.com/tokio-rs/axum/pull/1664 [#1751]: https://github.com/tokio-rs/axum/pull/1751 diff --git a/axum/src/extract/nested_path.rs b/axum/src/extract/nested_path.rs index f31fe3faba..b20614df00 100644 --- a/axum/src/extract/nested_path.rs +++ b/axum/src/extract/nested_path.rs @@ -45,6 +45,13 @@ impl NestedPath { pub fn as_str(&self) -> &str { &self.0 } + + pub(crate) fn extract(parts: &mut Parts) -> Result { + match parts.extensions.get::() { + Some(nested_path) => Ok(nested_path.clone()), + None => Err(NestedPathRejection), + } + } } #[async_trait] @@ -55,10 +62,7 @@ where type Rejection = NestedPathRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - match parts.extensions.get::() { - Some(nested_path) => Ok(nested_path.clone()), - None => Err(NestedPathRejection), - } + Self::extract(parts) } } diff --git a/axum/src/middleware/fix_nested_redirect.rs b/axum/src/middleware/fix_nested_redirect.rs new file mode 100644 index 0000000000..e2667e4963 --- /dev/null +++ b/axum/src/middleware/fix_nested_redirect.rs @@ -0,0 +1,351 @@ +use std::convert::Infallible; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; +use http::header::LOCATION; +use http::{HeaderValue, Request}; +use pin_project_lite::pin_project; +use tower_layer::Layer; +use tower_service::Service; + +use crate::extract::NestedPath; + +/// Middleware for fixing redirects from nested service to include the path they're nested at. +/// +/// # Example +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::FixNestedRedirectLayer, +/// response::Redirect, +/// }; +/// +/// let api = Router::new() +/// // redirect from `/old` to `/new` +/// .route("/old", get(|| async { Redirect::to("/new") })) +/// .route("/new", get(|| async { /* ... */ })); +/// +/// let app = Router::new() +/// .nest( +/// "/api", +/// // make sure the redirects include `/api`, i.e. `location: /api/new` +/// api.layer(FixNestedRedirectLayer::default()), +/// ); +/// # let _: Router = app; +/// ``` +/// +/// # Multiple levels of nesting +/// +/// If you're nesting multiple levels of routers make sure to add `FixNestedRedirectLayer` at the +/// inner most level: +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::FixNestedRedirectLayer, +/// response::Redirect, +/// }; +/// +/// let users_api = Router::new() +/// // redirect from `/old` to `/new` +/// .route("/old", get(|| async { Redirect::to("/new") })) +/// .route("/new", get(|| async { /* ... */ })); +/// +/// let api = Router::new() +/// .nest( +/// "/users", +/// // add the middleware at the inner most level +/// users_api.layer(FixNestedRedirectLayer::default()), +/// ); +/// +/// let app = Router::new() +/// // don't add the middleware here +/// .nest("/api", api); +/// # let _: Router = app; +/// ``` +/// +/// # Opt-out +/// +/// Individual handlers can opt-out by including `FixNestedRedirectOptOut` in the response: +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::{FixNestedRedirectLayer, FixNestedRedirectOptOut}, +/// response::Redirect, +/// }; +/// +/// let api = Router::new() +/// .route("/foo", get(|| async { +/// // this redirect will go to `/somewhere` and not `/api/somewhere` +/// (FixNestedRedirectOptOut, Redirect::to("/somewhere")) +/// })); +/// +/// let app = Router::new() +/// .nest( +/// "/api", +/// api.layer(FixNestedRedirectLayer::default()), +/// ); +/// # let _: Router = app; +/// ``` +/// +/// # Using with `ServeDir` +/// +/// `FixNestedRedirectLayer` can also be used with tower-http's [`ServeDir`]: +/// +/// ``` +/// use axum::{ +/// Router, +/// middleware::FixNestedRedirect, +/// }; +/// use tower_http::services::ServeDir; +/// +/// let app = Router::new().nest_service( +/// "/assets", +/// FixNestedRedirect::new(ServeDir::new("/assets")), +/// ); +/// # let _: Router = app; +/// ``` +/// +/// [`ServeDir`]: tower_http::services::ServeDir +#[derive(Clone, Debug, Default)] +#[non_exhaustive] +pub struct FixNestedRedirectLayer; + +impl Layer for FixNestedRedirectLayer { + type Service = FixNestedRedirect; + + fn layer(&self, inner: S) -> Self::Service { + FixNestedRedirect::new(inner) + } +} + +/// Service for fixing redirects from nested services. +/// +/// See [`FixNestedRedirectLayer`] for more details. +#[derive(Clone, Debug)] +pub struct FixNestedRedirect { + inner: S, +} + +impl FixNestedRedirect { + /// Create a new `FixNestedRedirect`. + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service> for FixNestedRedirect +where + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let (mut parts, body) = req.into_parts(); + let nested_path = NestedPath::extract(&mut parts).ok(); + let req = Request::from_parts(parts, body); + ResponseFuture { + future: self.inner.call(req), + nested_path, + } + } +} + +pin_project! { + /// Response future for [`FixNestedRedirect`]. + /// + /// See [`FixNestedRedirectLayer`] for more details. + pub struct ResponseFuture { + #[pin] + future: F, + nested_path: Option, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match futures_util::ready!(this.future.poll(cx)) { + Ok(res) => { + let (mut parts, body) = res.into_parts(); + if parts.extensions.get::().is_none() { + fix_nested_redirect(&mut parts, this.nested_path.take()); + } + let res = Response::from_parts(parts, body); + Poll::Ready(Ok(res)) + } + Err(err) => Poll::Ready(Err(err)), + } + } +} + +fn fix_nested_redirect( + parts: &mut http::response::Parts, + nested_path: Option, +) -> Option<()> { + if !parts.status.is_redirection() { + return Some(()); + } + + let location = parts.headers.get(LOCATION)?.to_str().ok()?; + + // not sure if there is a more robust way to detect an absolute uri 🤔 + if location.starts_with("https://") + || location.starts_with("http://") + || location.starts_with("//") + { + return Some(()); + } + + let nested_path = nested_path?; + + let new_location = format!("{}{}", nested_path.as_str().trim_end_matches('/'), location); + let new_location = HeaderValue::from_str(&new_location).ok()?; + parts.headers.insert(LOCATION, new_location); + + Some(()) +} + +/// Response extension used to opt-out of [`FixNestedRedirectLayer`] changing the `Location` +/// header. +/// +/// See [`FixNestedRedirectLayer`] for more details. +#[derive(Copy, Clone, Debug)] +pub struct FixNestedRedirectOptOut; + +impl IntoResponseParts for FixNestedRedirectOptOut { + type Error = Infallible; + + fn into_response_parts(self, mut res: ResponseParts) -> Result { + res.extensions_mut().insert(self); + Ok(res) + } +} + +impl IntoResponse for FixNestedRedirectOptOut { + fn into_response(self) -> Response { + (self, ()).into_response() + } +} + +#[cfg(test)] +mod tests { + use http::StatusCode; + use tower_http::services::ServeDir; + + use crate::{ + middleware::{FixNestedRedirect, FixNestedRedirectLayer, FixNestedRedirectOptOut}, + response::Redirect, + routing::get, + test_helpers::TestClient, + Router, + }; + + #[crate::test] + async fn one_level() { + let api = Router::new().route("/old", get(|| async { Redirect::to("/new") })); + let app = Router::new().nest("/api", api.layer(FixNestedRedirectLayer)); + + let client = TestClient::new(app); + + let res = client.get("/api/old").send().await; + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "/api/new"); + } + + #[crate::test] + async fn one_level_with_trailing_slash() { + let api = Router::new().route("/old", get(|| async { Redirect::to("/new") })); + let app = Router::new().nest("/api/", api.layer(FixNestedRedirectLayer)); + + let client = TestClient::new(app); + + let res = client.get("/api/old").send().await; + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "/api/new"); + } + + #[crate::test] + async fn two_levels() { + let users = Router::new().route("/old", get(|| async { Redirect::to("/new") })); + let api = Router::new().nest("/users", users.layer(FixNestedRedirectLayer)); + let app = Router::new().nest("/api", api); + + let client = TestClient::new(app); + + let res = client.get("/api/users/old").send().await; + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "/api/users/new"); + } + + #[crate::test] + async fn opt_out() { + let api = Router::new().route( + "/old", + get(|| async { + ( + FixNestedRedirectOptOut, + Redirect::to("/other/non/api/route"), + ) + }), + ); + let app = Router::new().nest("/api", api.layer(FixNestedRedirectLayer)); + + let client = TestClient::new(app); + + let res = client.get("/api/old").send().await; + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "/other/non/api/route"); + } + + #[crate::test] + async fn absolute_uri() { + let api = Router::new() + .route("/old", get(|| async { Redirect::to("http://example.com") })) + .route("/old2", get(|| async { Redirect::to("//example.com") })); + let app = Router::new().nest("/api", api.layer(FixNestedRedirectLayer)); + + let client = TestClient::new(app); + + let res = client.get("/api/old").send().await; + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "http://example.com"); + + let res = client.get("/api/old2").send().await; + assert_eq!(res.status(), StatusCode::SEE_OTHER); + assert_eq!(res.headers()["location"], "//example.com"); + } + + #[crate::test] + async fn using_serve_dir() { + let app = Router::new().nest_service( + "/public", + FixNestedRedirect::new(ServeDir::new(std::env::var("CARGO_MANIFEST_DIR").unwrap())), + ); + + let client = TestClient::new(app); + + let res = client.get("/public/src").send().await; + assert!(res.status().is_redirection()); + assert_eq!(res.headers()["location"], "/public/src/"); + } +} diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index 22dab1433e..6d70b97fa3 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -2,11 +2,15 @@ //! #![doc = include_str!("../docs/middleware.md")] +mod fix_nested_redirect; mod from_extractor; mod from_fn; mod map_request; mod map_response; +pub use self::fix_nested_redirect::{ + FixNestedRedirect, FixNestedRedirectLayer, FixNestedRedirectOptOut, +}; pub use self::from_extractor::{ from_extractor, from_extractor_with_state, FromExtractor, FromExtractorLayer, }; @@ -22,6 +26,7 @@ pub use crate::extension::AddExtension; pub mod future { //! Future types. + pub use super::fix_nested_redirect::ResponseFuture as FixNestedRedirectFuture; pub use super::from_extractor::ResponseFuture as FromExtractorResponseFuture; pub use super::from_fn::ResponseFuture as FromFnResponseFuture; pub use super::map_request::ResponseFuture as MapRequestResponseFuture;