Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove trailing slash redirects #1119

Merged
merged 18 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning].

# Unreleased

- None.
- **added:** Add `RouterExt::route_with_tsr` for adding routes with an
additional "trailing slash redirect" route ([#1119])

[#1119]: https://github.com/tokio-rs/axum/pull/1119

# 0.3.5 (27. June, 2022)

Expand Down
89 changes: 88 additions & 1 deletion axum-extra/src/routing/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
//! Additional types for defining routes.

use axum::{handler::Handler, Router};
use axum::{
handler::Handler,
http::Request,
response::{Redirect, Response},
Router,
};
use std::{convert::Infallible, future::ready};
use tower_service::Service;

mod resource;

Expand Down Expand Up @@ -126,6 +133,33 @@ pub trait RouterExt<B>: sealed::Sealed {
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;

/// Add another route to the router with an additional "trailing slash redirect" route.
///
/// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a
/// route for `/foo/` that redirects to `/foo`.
///
/// If you add a route _with_ a trailing slash, such as `/bar/`, this method will also add a
/// route for `/bar` that redirects to `/bar/`.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get};
/// use axum_extra::routing::RouterExt;
///
/// let app = Router::new()
/// // `/foo/` will rediret to `/foo`
/// .route_with_tsr("/foo", get(|| async {}))
/// // `/bar` will rediret to `/bar/`
/// .route_with_tsr("/bar/", get(|| async {}));
/// # let _: Router = app;
/// ```
fn route_with_tsr<T>(self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static,
T::Future: Send + 'static,
Self: Sized;
}

impl<B> RouterExt<B> for Router<B>
Expand Down Expand Up @@ -211,9 +245,62 @@ where
{
self.route(P::PATH, axum::routing::trace(handler))
}

fn route_with_tsr<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static,
T::Future: Send + 'static,
Self: Sized,
{
self = self.route(path, service);

let rediret = Redirect::permanent(path);

if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route(
path_without_trailing_slash,
(move || ready(rediret.clone())).into_service(),
)
} else {
self.route(
&format!("{}/", path),
(move || ready(rediret.clone())).into_service(),
)
}
}
}

mod sealed {
pub trait Sealed {}
impl<B> Sealed for axum::Router<B> {}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{http::StatusCode, routing::get};

#[tokio::test]
async fn test_tsr() {
let app = Router::new()
.route_with_tsr("/foo", get(|| async {}))
.route_with_tsr("/bar/", get(|| async {}));

let client = TestClient::new(app);

let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);

let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/foo");

let res = client.get("/bar/").send().await;
assert_eq!(res.status(), StatusCode::OK);

let res = client.get("/bar").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/bar/");
}
}