Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NestedPath #1924

Merged
merged 1 commit into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Implement `Handler` for `T: IntoResponse` ([#2140])
- **added:** Implement `IntoResponse` for `(R,) where R: IntoResponse` ([#2143])
- **changed:** For SSE, add space between field and value for compatibility ([#2149])
- **added:** Add `NestedPath` extractor ([#1924])

[#2021]: https://github.com/tokio-rs/axum/pull/2021
[#2014]: https://github.com/tokio-rs/axum/pull/2014
[#2030]: https://github.com/tokio-rs/axum/pull/2030
[#1664]: https://github.com/tokio-rs/axum/pull/1664
[#1751]: https://github.com/tokio-rs/axum/pull/1751
[#1762]: https://github.com/tokio-rs/axum/pull/1762
[#1789]: https://github.com/tokio-rs/axum/pull/1789
[#1835]: https://github.com/tokio-rs/axum/pull/1835
[#1850]: https://github.com/tokio-rs/axum/pull/1850
[#1868]: https://github.com/tokio-rs/axum/pull/1868
[#1924]: https://github.com/tokio-rs/axum/pull/1924
[#1956]: https://github.com/tokio-rs/axum/pull/1956
[#1972]: https://github.com/tokio-rs/axum/pull/1972
[#2014]: https://github.com/tokio-rs/axum/pull/2014
[#2021]: https://github.com/tokio-rs/axum/pull/2021
[#2030]: https://github.com/tokio-rs/axum/pull/2030
[#2058]: https://github.com/tokio-rs/axum/pull/2058
[#2073]: https://github.com/tokio-rs/axum/pull/2073
[#2096]: https://github.com/tokio-rs/axum/pull/2096
Expand Down
2 changes: 2 additions & 0 deletions axum/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod rejection;
pub mod ws;

mod host;
pub(crate) mod nested_path;
mod raw_form;
mod raw_query;
mod request_parts;
Expand All @@ -26,6 +27,7 @@ pub use axum_macros::{FromRef, FromRequest, FromRequestParts};
#[allow(deprecated)]
pub use self::{
host::Host,
nested_path::NestedPath,
path::{Path, RawPathParams},
raw_form::RawForm,
raw_query::RawQuery,
Expand Down
265 changes: 265 additions & 0 deletions axum/src/extract/nested_path.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
use std::{
sync::Arc,
task::{Context, Poll},
};

use crate::extract::Request;
use async_trait::async_trait;
use axum_core::extract::FromRequestParts;
use http::request::Parts;
use tower_layer::{layer_fn, Layer};
use tower_service::Service;

use super::rejection::NestedPathRejection;

/// Access the path the matched the route is nested at.
///
/// This can for example be used when doing redirects.
///
/// # Example
///
/// ```
/// use axum::{
/// Router,
/// extract::NestedPath,
/// routing::get,
/// };
///
/// let api = Router::new().route(
/// "/users",
/// get(|path: NestedPath| async move {
/// // `path` will be "/api" because thats what this
/// // router is nested at when we build `app`
/// let path = path.as_str();
/// })
/// );
///
/// let app = Router::new().nest("/api", api);
/// # let _: Router = app;
/// ```
#[derive(Debug, Clone)]
pub struct NestedPath(Arc<str>);

impl NestedPath {
/// Returns a `str` representation of the path.
pub fn as_str(&self) -> &str {
&self.0
}
}

#[async_trait]
impl<S> FromRequestParts<S> for NestedPath
where
S: Send + Sync,
{
type Rejection = NestedPathRejection;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
match parts.extensions.get::<Self>() {
Some(nested_path) => Ok(nested_path.clone()),
None => Err(NestedPathRejection),
}
}
}

#[derive(Clone)]
pub(crate) struct SetNestedPath<S> {
inner: S,
path: Arc<str>,
}

impl<S> SetNestedPath<S> {
pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
let path = Arc::from(path);
layer_fn(move |inner| Self {
inner,
path: Arc::clone(&path),
})
}
}

impl<S, B> Service<Request<B>> for SetNestedPath<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;

#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<B>) -> Self::Future {
if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
let new_path = if prev.as_str() == "/" {
Arc::clone(&self.path)
} else {
format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
};
prev.0 = new_path;
} else {
req.extensions_mut()
.insert(NestedPath(Arc::clone(&self.path)));
};

self.inner.call(req)
}
}

#[cfg(test)]
mod tests {
use axum_core::response::Response;
use http::StatusCode;

use crate::{
extract::{NestedPath, Request},
middleware::{from_fn, Next},
routing::get,
test_helpers::*,
Router,
};

#[crate::test]
async fn one_level_of_nesting() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api");
async {}
}),
);

let app = Router::new().nest("/api", api);

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn one_level_of_nesting_with_trailing_slash() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api/");
async {}
}),
);

let app = Router::new().nest("/api/", api);

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn two_levels_of_nesting() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api/v2");
async {}
}),
);

let app = Router::new().nest("/api", Router::new().nest("/v2", api));

let client = TestClient::new(app);

let res = client.get("/api/v2/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn two_levels_of_nesting_with_trailing_slash() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api/v2");
async {}
}),
);

let app = Router::new().nest("/api/", Router::new().nest("/v2", api));

let client = TestClient::new(app);

let res = client.get("/api/v2/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn nested_at_root() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/");
async {}
}),
);

let app = Router::new().nest("/", api);

let client = TestClient::new(app);

let res = client.get("/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn deeply_nested_from_root() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api");
async {}
}),
);

let app = Router::new().nest("/", Router::new().nest("/api", api));

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn in_fallbacks() {
let api = Router::new().fallback(get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/api");
async {}
}));

let app = Router::new().nest("/api", api);

let client = TestClient::new(app);

let res = client.get("/api/doesnt-exist").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn in_middleware() {
async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
assert_eq!(nested_path.as_str(), "/api");
next.run(req).await
}

let api = Router::new()
.route("/users", get(|| async {}))
.layer(from_fn(middleware));

let app = Router::new().nest("/api", api);

let client = TestClient::new(app);

let res = client.get("/api/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}
9 changes: 9 additions & 0 deletions axum/src/extract/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,12 @@ composite_rejection! {
MatchedPathMissing,
}
}

define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "The matched route is not nested"]
/// Rejection type for [`NestedPath`](super::NestedPath).
///
/// This rejection is used if the matched route wasn't nested.
pub struct NestedPathRejection;
}
26 changes: 19 additions & 7 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::extract::Request;
use crate::extract::{nested_path::SetNestedPath, Request};
use axum_core::response::IntoResponse;
use matchit::MatchError;
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
Expand Down Expand Up @@ -162,10 +162,10 @@ where

pub(super) fn nest(
&mut self,
path: &str,
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(path);
let prefix = validate_nest_path(path_to_nest_at);

let PathRouter {
routes,
Expand All @@ -181,7 +181,11 @@ where

let path = path_for_nested_route(prefix, inner_path);

match endpoint.layer(StripPrefix::layer(prefix)) {
let layer = (
StripPrefix::layer(prefix),
SetNestedPath::layer(path_to_nest_at),
);
match endpoint.layer(layer) {
Endpoint::MethodRouter(method_router) => {
self.route(&path, method_router)?;
}
Expand All @@ -194,13 +198,17 @@ where
Ok(())
}

pub(super) fn nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>>
pub(super) fn nest_service<T>(
&mut self,
path_to_nest_at: &str,
svc: T,
) -> Result<(), Cow<'static, str>>
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let path = validate_nest_path(path);
let path = validate_nest_path(path_to_nest_at);
let prefix = path;

let path = if path.ends_with('/') {
Expand All @@ -209,7 +217,11 @@ where
format!("{path}/*{NEST_TAIL_PARAM}")
};

let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)));
let layer = (
StripPrefix::layer(prefix),
SetNestedPath::layer(path_to_nest_at),
);
let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));

self.route_endpoint(&path, endpoint.clone())?;

Expand Down
7 changes: 0 additions & 7 deletions axum/src/routing/strip_prefix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ pub(super) struct StripPrefix<S> {
}

impl<S> StripPrefix<S> {
pub(super) fn new(inner: S, prefix: &str) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}

pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
let prefix = Arc::from(prefix);
layer_fn(move |inner| Self {
Expand Down