Skip to content

Commit

Permalink
Add method_not_allowed_fallback to router (#2903)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
  • Loading branch information
Lachstec and jplatte committed Nov 14, 2024
1 parent a59a82c commit c5a3c66
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 1 deletion.
38 changes: 38 additions & 0 deletions axum/src/docs/routing/method_not_allowed_fallback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Add a fallback [`Handler`] for the case where a route exists, but the method of the request is not supported.

Sets a fallback on all previously registered [`MethodRouter`]s,
to be called when no matching method handler is set.

```rust,no_run
use axum::{response::IntoResponse, routing::get, Router};
async fn hello_world() -> impl IntoResponse {
"Hello, world!\n"
}
async fn default_fallback() -> impl IntoResponse {
"Default fallback\n"
}
async fn handle_405() -> impl IntoResponse {
"Method not allowed fallback"
}
#[tokio::main]
async fn main() {
let router = Router::new()
.route("/", get(hello_world))
.fallback(default_fallback)
.method_not_allowed_fallback(handle_405);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();
}
```

The fallback only applies if there is a `MethodRouter` registered for a given path,
but the method used in the request is not specified. In the example, a `GET` on
`http://localhost:3000` causes the `hello_world` handler to react, while issuing a
`POST` triggers `handle_405`. Calling an entirely different route, like `http://localhost:3000/hello`
causes `default_fallback` to run.
13 changes: 13 additions & 0 deletions axum/src/routing/method_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,19 @@ where
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
self
}

/// Add a fallback [`Handler`] if no custom one has been provided.
pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
S: Send + Sync + 'static,
{
match self.fallback {
Fallback::Default(_) => self.fallback(handler),
_ => self,
}
}
}

impl MethodRouter<(), Infallible> {
Expand Down
12 changes: 12 additions & 0 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,18 @@ where
.fallback_endpoint(Endpoint::Route(route))
}

#[doc = include_str!("../docs/routing/method_not_allowed_fallback.md")]
pub fn method_not_allowed_fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
tap_inner!(self, mut this => {
this.path_router
.method_not_allowed_fallback(handler.clone())
})
}

fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint);
Expand Down
17 changes: 16 additions & 1 deletion axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::extract::{nested_path::SetNestedPath, Request};
use crate::{
extract::{nested_path::SetNestedPath, Request},
handler::Handler,
};
use axum_core::response::IntoResponse;
use matchit::MatchError;
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
Expand Down Expand Up @@ -79,6 +82,18 @@ where
Ok(())
}

pub(super) fn method_not_allowed_fallback<H, T>(&mut self, handler: H)
where
H: Handler<T, S>,
T: 'static,
{
for (_, endpoint) in self.routes.iter_mut() {
if let Endpoint::MethodRouter(rt) = endpoint {
*rt = rt.clone().default_fallback(handler.clone());
}
}
}

pub(super) fn route_service<T>(
&mut self,
path: &str,
Expand Down
45 changes: 45 additions & 0 deletions axum/src/routing/tests/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,51 @@ async fn merge_router_with_fallback_into_empty() {
assert_eq!(res.text().await, "outer");
}

#[crate::test]
async fn mna_fallback_with_existing_fallback() {
let app = Router::new()
.route(
"/",
get(|| async { "test" }).fallback(|| async { "index fallback" }),
)
.route("/path", get(|| async { "path" }))
.method_not_allowed_fallback(|| async { "method not allowed fallback" });

let client = TestClient::new(app);
let index_fallback = client.post("/").await;
let method_not_allowed_fallback = client.post("/path").await;

assert_eq!(index_fallback.text().await, "index fallback");
assert_eq!(
method_not_allowed_fallback.text().await,
"method not allowed fallback"
);
}

#[crate::test]
async fn mna_fallback_with_state() {
let app = Router::new()
.route("/", get(|| async { "index" }))
.method_not_allowed_fallback(|State(state): State<&'static str>| async move { state })
.with_state("state");

let client = TestClient::new(app);
let res = client.post("/").await;
assert_eq!(res.text().await, "state");
}

#[crate::test]
async fn mna_fallback_with_unused_state() {
let app = Router::new()
.route("/", get(|| async { "index" }))
.with_state(())
.method_not_allowed_fallback(|| async move { "bla" });

let client = TestClient::new(app);
let res = client.post("/").await;
assert_eq!(res.text().await, "bla");
}

#[crate::test]
async fn state_isnt_cloned_too_much_with_fallback() {
let state = CountingCloneableState::new();
Expand Down

0 comments on commit c5a3c66

Please sign in to comment.