Skip to content

Commit

Permalink
Add NestedPath (#1924)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Sep 17, 2023
1 parent 449e4c1 commit 20f48af
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 17 deletions.
8 changes: 5 additions & 3 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Implement `Handler` for `T: IntoResponse` ([#2140])
- **added:** Implement `IntoResponse` for `(R,) where R: IntoResponse` ([#2143])
- **changed:** For SSE, add space between field and value for compatibility ([#2149])
- **added:** Add `NestedPath` extractor ([#1924])

[#2021]: https://github.com/tokio-rs/axum/pull/2021
[#2014]: https://github.com/tokio-rs/axum/pull/2014
[#2030]: https://github.com/tokio-rs/axum/pull/2030
[#1664]: https://github.com/tokio-rs/axum/pull/1664
[#1751]: https://github.com/tokio-rs/axum/pull/1751
[#1762]: https://github.com/tokio-rs/axum/pull/1762
[#1789]: https://github.com/tokio-rs/axum/pull/1789
[#1835]: https://github.com/tokio-rs/axum/pull/1835
[#1850]: https://github.com/tokio-rs/axum/pull/1850
[#1868]: https://github.com/tokio-rs/axum/pull/1868
[#1924]: https://github.com/tokio-rs/axum/pull/1924
[#1956]: https://github.com/tokio-rs/axum/pull/1956
[#1972]: https://github.com/tokio-rs/axum/pull/1972
[#2014]: https://github.com/tokio-rs/axum/pull/2014
[#2021]: https://github.com/tokio-rs/axum/pull/2021
[#2030]: https://github.com/tokio-rs/axum/pull/2030
[#2058]: https://github.com/tokio-rs/axum/pull/2058
[#2073]: https://github.com/tokio-rs/axum/pull/2073
[#2096]: https://github.com/tokio-rs/axum/pull/2096
Expand Down
2 changes: 2 additions & 0 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod rejection;
pub mod ws;

mod host;
pub(crate) mod nested_path;
mod raw_form;
mod raw_query;
mod request_parts;
Expand All @@ -26,6 +27,7 @@ pub use axum_macros::{FromRef, FromRequest, FromRequestParts};
#[allow(deprecated)]
pub use self::{
host::Host,
nested_path::NestedPath,
path::{Path, RawPathParams},
raw_form::RawForm,
raw_query::RawQuery,
Expand Down
265 changes: 265 additions & 0 deletions axum/src/extract/nested_path.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
use std::{
sync::Arc,
task::{Context, Poll},
};

use crate::extract::Request;
use async_trait::async_trait;
use axum_core::extract::FromRequestParts;
use http::request::Parts;
use tower_layer::{layer_fn, Layer};
use tower_service::Service;

use super::rejection::NestedPathRejection;

/// Access the path the matched the route is nested at.
///
/// This can for example be used when doing redirects.
///
/// # Example
///
/// ```
/// use axum::{
/// Router,
/// extract::NestedPath,
/// routing::get,
/// };
///
/// let api = Router::new().route(
/// "/users",
/// get(|path: NestedPath| async move {
/// // `path` will be "/api" because thats what this
/// // router is nested at when we build `app`
/// let path = path.as_str();
/// })
/// );
///
/// let app = Router::new().nest("/api", api);
/// # let _: Router = app;
/// ```
#[derive(Debug, Clone)]
pub struct NestedPath(Arc<str>);

impl NestedPath {
/// Returns a `str` representation of the path.
pub fn as_str(&self) -> &str {
&self.0
}
}

#[async_trait]
impl<S> FromRequestParts<S> for NestedPath
where
S: Send + Sync,
{
type Rejection = NestedPathRejection;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
match parts.extensions.get::<Self>() {
Some(nested_path) => Ok(nested_path.clone()),
None => Err(NestedPathRejection),
}
}
}

#[derive(Clone)]
pub(crate) struct SetNestedPath<S> {
inner: S,
path: Arc<str>,
}

impl<S> SetNestedPath<S> {
pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
let path = Arc::from(path);
layer_fn(move |inner| Self {
inner,
path: Arc::clone(&path),
})
}
}

impl<S, B> Service<Request<B>> for SetNestedPath<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<B>) -> Self::Future {
if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
let new_path = if prev.as_str() == "/" {
Arc::clone(&self.path)
} else {
format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
};
prev.0 = new_path;
} else {
req.extensions_mut()
.insert(NestedPath(Arc::clone(&self.path)));
};

self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use axum_core::response::Response;
use http::StatusCode;

use crate::{
extract::{NestedPath, Request},
middleware::{from_fn, Next},
routing::get,
test_helpers::*,
Router,
};

#[crate::test]
async fn one_level_of_nesting() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api");
async {}
}),
);

let app = Router::new().nest("/api", api);

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn one_level_of_nesting_with_trailing_slash() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api/");
async {}
}),
);

let app = Router::new().nest("/api/", api);

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn two_levels_of_nesting() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api/v2");
async {}
}),
);

let app = Router::new().nest("/api", Router::new().nest("/v2", api));

let client = TestClient::new(app);

let res = client.get("/api/v2/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn two_levels_of_nesting_with_trailing_slash() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api/v2");
async {}
}),
);

let app = Router::new().nest("/api/", Router::new().nest("/v2", api));

let client = TestClient::new(app);

let res = client.get("/api/v2/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn nested_at_root() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/");
async {}
}),
);

let app = Router::new().nest("/", api);

let client = TestClient::new(app);

let res = client.get("/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn deeply_nested_from_root() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api");
async {}
}),
);

let app = Router::new().nest("/", Router::new().nest("/api", api));

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn in_fallbacks() {
let api = Router::new().fallback(get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api");
async {}
}));

let app = Router::new().nest("/api", api);

let client = TestClient::new(app);

let res = client.get("/api/doesnt-exist").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn in_middleware() {
async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
assert_eq!(nested_path.as_str(), "/api");
next.run(req).await
}

let api = Router::new()
.route("/users", get(|| async {}))
.layer(from_fn(middleware));

let app = Router::new().nest("/api", api);

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}
9 changes: 9 additions & 0 deletions axum/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,12 @@ composite_rejection! {
MatchedPathMissing,
}
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "The matched route is not nested"]
/// Rejection type for [`NestedPath`](super::NestedPath).
///
/// This rejection is used if the matched route wasn't nested.
pub struct NestedPathRejection;
}
26 changes: 19 additions & 7 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::extract::Request;
use crate::extract::{nested_path::SetNestedPath, Request};
use axum_core::response::IntoResponse;
use matchit::MatchError;
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
Expand Down Expand Up @@ -162,10 +162,10 @@ where

pub(super) fn nest(
&mut self,
path: &str,
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(path);
let prefix = validate_nest_path(path_to_nest_at);

let PathRouter {
routes,
Expand All @@ -181,7 +181,11 @@ where

let path = path_for_nested_route(prefix, inner_path);

match endpoint.layer(StripPrefix::layer(prefix)) {
let layer = (
StripPrefix::layer(prefix),
SetNestedPath::layer(path_to_nest_at),
);
match endpoint.layer(layer) {
Endpoint::MethodRouter(method_router) => {
self.route(&path, method_router)?;
}
Expand All @@ -194,13 +198,17 @@ where
Ok(())
}

pub(super) fn nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>>
pub(super) fn nest_service<T>(
&mut self,
path_to_nest_at: &str,
svc: T,
) -> Result<(), Cow<'static, str>>
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let path = validate_nest_path(path);
let path = validate_nest_path(path_to_nest_at);
let prefix = path;

let path = if path.ends_with('/') {
Expand All @@ -209,7 +217,11 @@ where
format!("{path}/*{NEST_TAIL_PARAM}")
};

let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)));
let layer = (
StripPrefix::layer(prefix),
SetNestedPath::layer(path_to_nest_at),
);
let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));

self.route_endpoint(&path, endpoint.clone())?;

Expand Down
7 changes: 0 additions & 7 deletions axum/src/routing/strip_prefix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ pub(super) struct StripPrefix<S> {
}

impl<S> StripPrefix<S> {
pub(super) fn new(inner: S, prefix: &str) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}

pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
let prefix = Arc::from(prefix);
layer_fn(move |inner| Self {
Expand Down

0 comments on commit 20f48af

Please sign in to comment.