diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 4cb60f6717..57474c5303 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -65,10 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Implement `Handler` for `T: IntoResponse` ([#2140]) - **added:** Implement `IntoResponse` for `(R,) where R: IntoResponse` ([#2143]) - **changed:** For SSE, add space between field and value for compatibility ([#2149]) +- **added:** Add `NestedPath` extractor ([#1924]) -[#2021]: https://github.com/tokio-rs/axum/pull/2021 -[#2014]: https://github.com/tokio-rs/axum/pull/2014 -[#2030]: https://github.com/tokio-rs/axum/pull/2030 [#1664]: https://github.com/tokio-rs/axum/pull/1664 [#1751]: https://github.com/tokio-rs/axum/pull/1751 [#1762]: https://github.com/tokio-rs/axum/pull/1762 @@ -76,8 +74,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1835]: https://github.com/tokio-rs/axum/pull/1835 [#1850]: https://github.com/tokio-rs/axum/pull/1850 [#1868]: https://github.com/tokio-rs/axum/pull/1868 +[#1924]: https://github.com/tokio-rs/axum/pull/1924 [#1956]: https://github.com/tokio-rs/axum/pull/1956 [#1972]: https://github.com/tokio-rs/axum/pull/1972 +[#2014]: https://github.com/tokio-rs/axum/pull/2014 +[#2021]: https://github.com/tokio-rs/axum/pull/2021 +[#2030]: https://github.com/tokio-rs/axum/pull/2030 [#2058]: https://github.com/tokio-rs/axum/pull/2058 [#2073]: https://github.com/tokio-rs/axum/pull/2073 [#2096]: https://github.com/tokio-rs/axum/pull/2096 diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index fce1b01081..719083d11f 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -11,6 +11,7 @@ pub mod rejection; pub mod ws; mod host; +pub(crate) mod nested_path; mod raw_form; mod raw_query; mod request_parts; @@ -26,6 +27,7 @@ pub use axum_macros::{FromRef, FromRequest, FromRequestParts}; #[allow(deprecated)] pub use self::{ host::Host, + nested_path::NestedPath, path::{Path, RawPathParams}, raw_form::RawForm, raw_query::RawQuery, diff --git a/axum/src/extract/nested_path.rs b/axum/src/extract/nested_path.rs new file mode 100644 index 0000000000..f31fe3faba --- /dev/null +++ b/axum/src/extract/nested_path.rs @@ -0,0 +1,265 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use crate::extract::Request; +use async_trait::async_trait; +use axum_core::extract::FromRequestParts; +use http::request::Parts; +use tower_layer::{layer_fn, Layer}; +use tower_service::Service; + +use super::rejection::NestedPathRejection; + +/// Access the path the matched the route is nested at. +/// +/// This can for example be used when doing redirects. +/// +/// # Example +/// +/// ``` +/// use axum::{ +/// Router, +/// extract::NestedPath, +/// routing::get, +/// }; +/// +/// let api = Router::new().route( +/// "/users", +/// get(|path: NestedPath| async move { +/// // `path` will be "/api" because thats what this +/// // router is nested at when we build `app` +/// let path = path.as_str(); +/// }) +/// ); +/// +/// let app = Router::new().nest("/api", api); +/// # let _: Router = app; +/// ``` +#[derive(Debug, Clone)] +pub struct NestedPath(Arc); + +impl NestedPath { + /// Returns a `str` representation of the path. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +#[async_trait] +impl FromRequestParts for NestedPath +where + S: Send + Sync, +{ + type Rejection = NestedPathRejection; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + match parts.extensions.get::() { + Some(nested_path) => Ok(nested_path.clone()), + None => Err(NestedPathRejection), + } + } +} + +#[derive(Clone)] +pub(crate) struct SetNestedPath { + inner: S, + path: Arc, +} + +impl SetNestedPath { + pub(crate) fn layer(path: &str) -> impl Layer + Clone { + let path = Arc::from(path); + layer_fn(move |inner| Self { + inner, + path: Arc::clone(&path), + }) + } +} + +impl Service> for SetNestedPath +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + if let Some(prev) = req.extensions_mut().get_mut::() { + let new_path = if prev.as_str() == "/" { + Arc::clone(&self.path) + } else { + format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into() + }; + prev.0 = new_path; + } else { + req.extensions_mut() + .insert(NestedPath(Arc::clone(&self.path))); + }; + + self.inner.call(req) + } +} + +#[cfg(test)] +mod tests { + use axum_core::response::Response; + use http::StatusCode; + + use crate::{ + extract::{NestedPath, Request}, + middleware::{from_fn, Next}, + routing::get, + test_helpers::*, + Router, + }; + + #[crate::test] + async fn one_level_of_nesting() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/api"); + async {} + }), + ); + + let app = Router::new().nest("/api", api); + + let client = TestClient::new(app); + + let res = client.get("/api/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn one_level_of_nesting_with_trailing_slash() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/api/"); + async {} + }), + ); + + let app = Router::new().nest("/api/", api); + + let client = TestClient::new(app); + + let res = client.get("/api/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn two_levels_of_nesting() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/api/v2"); + async {} + }), + ); + + let app = Router::new().nest("/api", Router::new().nest("/v2", api)); + + let client = TestClient::new(app); + + let res = client.get("/api/v2/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn two_levels_of_nesting_with_trailing_slash() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/api/v2"); + async {} + }), + ); + + let app = Router::new().nest("/api/", Router::new().nest("/v2", api)); + + let client = TestClient::new(app); + + let res = client.get("/api/v2/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn nested_at_root() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/"); + async {} + }), + ); + + let app = Router::new().nest("/", api); + + let client = TestClient::new(app); + + let res = client.get("/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn deeply_nested_from_root() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/api"); + async {} + }), + ); + + let app = Router::new().nest("/", Router::new().nest("/api", api)); + + let client = TestClient::new(app); + + let res = client.get("/api/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn in_fallbacks() { + let api = Router::new().fallback(get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/api"); + async {} + })); + + let app = Router::new().nest("/api", api); + + let client = TestClient::new(app); + + let res = client.get("/api/doesnt-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn in_middleware() { + async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response { + assert_eq!(nested_path.as_str(), "/api"); + next.run(req).await + } + + let api = Router::new() + .route("/users", get(|| async {})) + .layer(from_fn(middleware)); + + let app = Router::new().nest("/api", api); + + let client = TestClient::new(app); + + let res = client.get("/api/users").send().await; + assert_eq!(res.status(), StatusCode::OK); + } +} diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index 2ef75f9761..cba76af054 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -207,3 +207,12 @@ composite_rejection! { MatchedPathMissing, } } + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "The matched route is not nested"] + /// Rejection type for [`NestedPath`](super::NestedPath). + /// + /// This rejection is used if the matched route wasn't nested. + pub struct NestedPathRejection; +} diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index bd6d91ad7d..b4ef4cb412 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -1,4 +1,4 @@ -use crate::extract::Request; +use crate::extract::{nested_path::SetNestedPath, Request}; use axum_core::response::IntoResponse; use matchit::MatchError; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; @@ -162,10 +162,10 @@ where pub(super) fn nest( &mut self, - path: &str, + path_to_nest_at: &str, router: PathRouter, ) -> Result<(), Cow<'static, str>> { - let prefix = validate_nest_path(path); + let prefix = validate_nest_path(path_to_nest_at); let PathRouter { routes, @@ -181,7 +181,11 @@ where let path = path_for_nested_route(prefix, inner_path); - match endpoint.layer(StripPrefix::layer(prefix)) { + let layer = ( + StripPrefix::layer(prefix), + SetNestedPath::layer(path_to_nest_at), + ); + match endpoint.layer(layer) { Endpoint::MethodRouter(method_router) => { self.route(&path, method_router)?; } @@ -194,13 +198,17 @@ where Ok(()) } - pub(super) fn nest_service(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>> + pub(super) fn nest_service( + &mut self, + path_to_nest_at: &str, + svc: T, + ) -> Result<(), Cow<'static, str>> where T: Service + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - let path = validate_nest_path(path); + let path = validate_nest_path(path_to_nest_at); let prefix = path; let path = if path.ends_with('/') { @@ -209,7 +217,11 @@ where format!("{path}/*{NEST_TAIL_PARAM}") }; - let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix))); + let layer = ( + StripPrefix::layer(prefix), + SetNestedPath::layer(path_to_nest_at), + ); + let endpoint = Endpoint::Route(Route::new(layer.layer(svc))); self.route_endpoint(&path, endpoint.clone())?; diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index 671c4de773..0b06db4d28 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -14,13 +14,6 @@ pub(super) struct StripPrefix { } impl StripPrefix { - pub(super) fn new(inner: S, prefix: &str) -> Self { - Self { - inner, - prefix: prefix.into(), - } - } - pub(super) fn layer(prefix: &str) -> impl Layer + Clone { let prefix = Arc::from(prefix); layer_fn(move |inner| Self {