diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index c74667b986..a40ee222de 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **breaking:** Allow `Error: Into` for `Route::{layer, route_layer}` ([#924]) - **breaking:** Remove `extractor_middleware` which was previously deprecated. Use `axum::middleware::from_extractor` instead ([#1077]) +- **breaking:** `MethodRouter` now panics on overlapping routes ([#1102]) [#1077]: https://github.com/tokio-rs/axum/pull/1077 [#1078]: https://github.com/tokio-rs/axum/pull/1078 @@ -43,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1095]: https://github.com/tokio-rs/axum/pull/1095 [#1098]: https://github.com/tokio-rs/axum/pull/1098 [#924]: https://github.com/tokio-rs/axum/pull/924 +[#1102]: https://github.com/tokio-rs/axum/pull/1102 # 0.5.7 (08. June, 2022) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 264dedaacf..4edb8bdb66 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -912,6 +912,32 @@ impl MethodRouter { S: Service, Response = Response, Error = E> + Clone + Send + 'static, S::Future: Send + 'static, { + macro_rules! set_service { + ( + $filter:ident, + $svc:ident, + $allow_header:ident, + [ + $( + ($out:ident, $variant:ident, [$($method:literal),+]) + ),+ + $(,)? + ] + ) => { + $( + if $filter.contains(MethodFilter::$variant) { + if $out.is_some() { + panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", stringify!($variant)) + } + $out = $svc.clone(); + $( + append_allow_header(&mut $allow_header, $method); + )+ + } + )+ + } + } + // written with a pattern match like this to ensure we update all fields let Self { mut get, @@ -927,39 +953,21 @@ impl MethodRouter { _request_body: _, } = self; let svc = Some(Route::new(svc)); - if filter.contains(MethodFilter::GET) { - get = svc.clone(); - append_allow_header(&mut allow_header, "GET"); - append_allow_header(&mut allow_header, "HEAD"); - } - if filter.contains(MethodFilter::HEAD) { - append_allow_header(&mut allow_header, "HEAD"); - head = svc.clone(); - } - if filter.contains(MethodFilter::DELETE) { - append_allow_header(&mut allow_header, "DELETE"); - delete = svc.clone(); - } - if filter.contains(MethodFilter::OPTIONS) { - append_allow_header(&mut allow_header, "OPTIONS"); - options = svc.clone(); - } - if filter.contains(MethodFilter::PATCH) { - append_allow_header(&mut allow_header, "PATCH"); - patch = svc.clone(); - } - if filter.contains(MethodFilter::POST) { - append_allow_header(&mut allow_header, "POST"); - post = svc.clone(); - } - if filter.contains(MethodFilter::PUT) { - append_allow_header(&mut allow_header, "PUT"); - put = svc.clone(); - } - if filter.contains(MethodFilter::TRACE) { - append_allow_header(&mut allow_header, "TRACE"); - trace = svc; - } + set_service!( + filter, + svc, + allow_header, + [ + (get, GET, ["GET", "HEAD"]), + (head, HEAD, ["HEAD"]), + (delete, DELETE, ["DELETE"]), + (options, OPTIONS, ["OPTIONS"]), + (patch, PATCH, ["PATCH"]), + (post, POST, ["POST"]), + (put, PUT, ["PUT"]), + (trace, TRACE, ["TRACE"]), + ] + ); Self { get, head, @@ -1294,6 +1302,32 @@ mod tests { assert_eq!(headers[ALLOW], "GET,POST"); } + #[tokio::test] + #[should_panic( + expected = "Overlapping method route. Cannot add two method routes that both handle `GET`" + )] + async fn handler_overlaps() { + let _: MethodRouter = get(ok).get(ok); + } + + #[tokio::test] + #[should_panic( + expected = "Overlapping method route. Cannot add two method routes that both handle `POST`" + )] + async fn service_overlaps() { + let _: MethodRouter = post_service(ok.into_service()).post_service(ok.into_service()); + } + + #[tokio::test] + async fn get_head_does_not_overlap() { + let _: MethodRouter = get(ok).head(ok); + } + + #[tokio::test] + async fn head_get_does_not_overlap() { + let _: MethodRouter = head(ok).get(ok); + } + async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service, Response = Response, Error = Infallible>,