diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index ac8336715d..cf40715bd4 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -118,7 +118,7 @@ where (self.into_route)(self.router, state) } - fn call_with_state(mut self: Box, request: Request, state: S) -> RouteFuture { + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture { self.router.call_with_state(request, state) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index d9e84376b7..5a83be11b8 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1023,7 +1023,7 @@ where self } - pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { macro_rules! call { ( $req:expr, @@ -1039,7 +1039,7 @@ where .strip_body($method == Method::HEAD); } MethodEndpoint::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); + let route = handler.clone().into_route(state); return RouteFuture::from_future(route.oneshot_inner($req)) .strip_body($method == Method::HEAD); } @@ -1220,7 +1220,7 @@ where { type Future = InfallibleRouteFuture; - fn call(mut self, req: Request, state: S) -> Self::Future { + fn call(self, req: Request, state: S) -> Self::Future { InfallibleRouteFuture::new(self.call_with_state(req, state)) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index bcc4de1fda..fdaac0eab3 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -17,6 +17,7 @@ use std::{ convert::Infallible, fmt, marker::PhantomData, + sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; @@ -59,23 +60,24 @@ pub(crate) struct RouteId(u32); /// The router type for composing handlers and services. #[must_use] pub struct Router { - path_router: PathRouter, - fallback_router: PathRouter, - default_fallback: bool, - catch_all_fallback: Fallback, + inner: Arc>, } impl Clone for Router { fn clone(&self) -> Self { Self { - path_router: self.path_router.clone(), - fallback_router: self.fallback_router.clone(), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.clone(), + inner: Arc::clone(&self.inner), } } } +struct RouterInner { + path_router: PathRouter, + fallback_router: PathRouter, + default_fallback: bool, + catch_all_fallback: Fallback, +} + impl Default for Router where S: Clone + Send + Sync + 'static, @@ -88,10 +90,10 @@ where impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") - .field("path_router", &self.path_router) - .field("fallback_router", &self.fallback_router) - .field("default_fallback", &self.default_fallback) - .field("catch_all_fallback", &self.catch_all_fallback) + .field("path_router", &self.inner.path_router) + .field("fallback_router", &self.inner.fallback_router) + .field("default_fallback", &self.inner.default_fallback) + .field("catch_all_fallback", &self.inner.catch_all_fallback) .finish() } } @@ -111,22 +113,57 @@ where /// all requests. pub fn new() -> Self { Self { - path_router: Default::default(), - fallback_router: PathRouter::new_fallback(), - default_fallback: true, - catch_all_fallback: Fallback::Default(Route::new(NotFound)), + inner: Arc::new(RouterInner { + path_router: Default::default(), + fallback_router: PathRouter::new_fallback(), + default_fallback: true, + catch_all_fallback: Fallback::Default(Route::new(NotFound)), + }), + } + } + + fn map_inner(self, f: F) -> Router + where + F: FnOnce(RouterInner) -> RouterInner, + { + Router { + inner: Arc::new(f(self.into_inner())), + } + } + + fn tap_inner_mut(self, f: F) -> Self + where + F: FnOnce(&mut RouterInner), + { + let mut inner = self.into_inner(); + f(&mut inner); + Router { + inner: Arc::new(inner), + } + } + + fn into_inner(self) -> RouterInner { + match Arc::try_unwrap(self.inner) { + Ok(inner) => inner, + Err(arc) => RouterInner { + path_router: arc.path_router.clone(), + fallback_router: arc.fallback_router.clone(), + default_fallback: arc.default_fallback, + catch_all_fallback: arc.catch_all_fallback.clone(), + }, } } #[doc = include_str!("../docs/routing/route.md")] #[track_caller] - pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { - panic_on_err!(self.path_router.route(path, method_router)); - self + pub fn route(self, path: &str, method_router: MethodRouter) -> Self { + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.route(path, method_router)); + }) } #[doc = include_str!("../docs/routing/route_service.md")] - pub fn route_service(mut self, path: &str, service: T) -> Self + pub fn route_service(self, path: &str, service: T) -> Self where T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, @@ -142,14 +179,15 @@ where Err(service) => service, }; - panic_on_err!(self.path_router.route_service(path, service)); - self + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.route_service(path, service)); + }) } #[doc = include_str!("../docs/routing/nest.md")] #[track_caller] - pub fn nest(mut self, path: &str, router: Router) -> Self { - let Router { + pub fn nest(self, path: &str, router: Router) -> Self { + let RouterInner { path_router, fallback_router, default_fallback, @@ -157,76 +195,80 @@ where // requests with an empty path. If we were to inherit the catch-all fallback // it would end up matching `/{path}/*` which doesn't match empty paths. catch_all_fallback: _, - } = router; + } = router.into_inner(); - panic_on_err!(self.path_router.nest(path, path_router)); + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.nest(path, path_router)); - if !default_fallback { - panic_on_err!(self.fallback_router.nest(path, fallback_router)); - } - - self + if !default_fallback { + panic_on_err!(this.fallback_router.nest(path, fallback_router)); + } + }) } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. #[track_caller] - pub fn nest_service(mut self, path: &str, service: T) -> Self + pub fn nest_service(self, path: &str, service: T) -> Self where T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - panic_on_err!(self.path_router.nest_service(path, service)); - self + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.nest_service(path, service)); + }) } #[doc = include_str!("../docs/routing/merge.md")] #[track_caller] - pub fn merge(mut self, other: R) -> Self + pub fn merge(self, other: R) -> Self where R: Into>, { const PANIC_MSG: &str = "Failed to merge fallbacks. This is a bug in axum. Please file an issue"; - let Router { + let other: Router = other.into(); + let RouterInner { path_router, fallback_router: mut other_fallback, default_fallback, catch_all_fallback, - } = other.into(); - - panic_on_err!(self.path_router.merge(path_router)); - - match (self.default_fallback, default_fallback) { - // both have the default fallback - // use the one from other - (true, true) => { - self.fallback_router.merge(other_fallback).expect(PANIC_MSG); - } - // self has default fallback, other has a custom fallback - (true, false) => { - self.fallback_router.merge(other_fallback).expect(PANIC_MSG); - self.default_fallback = false; - } - // self has a custom fallback, other has a default - (false, true) => { - let fallback_router = std::mem::take(&mut self.fallback_router); - other_fallback.merge(fallback_router).expect(PANIC_MSG); - self.fallback_router = other_fallback; - } - // both have a custom fallback, not allowed - (false, false) => { - panic!("Cannot merge two `Router`s that both have a fallback") - } - }; - - self.catch_all_fallback = self - .catch_all_fallback - .merge(catch_all_fallback) - .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); - - self + } = other.into_inner(); + + self.map_inner(|mut this| { + panic_on_err!(this.path_router.merge(path_router)); + + match (this.default_fallback, default_fallback) { + // both have the default fallback + // use the one from other + (true, true) => { + this.fallback_router.merge(other_fallback).expect(PANIC_MSG); + } + // this has default fallback, other has a custom fallback + (true, false) => { + this.fallback_router.merge(other_fallback).expect(PANIC_MSG); + this.default_fallback = false; + } + // this has a custom fallback, other has a default + (false, true) => { + let fallback_router = std::mem::take(&mut this.fallback_router); + other_fallback.merge(fallback_router).expect(PANIC_MSG); + this.fallback_router = other_fallback; + } + // both have a custom fallback, not allowed + (false, false) => { + panic!("Cannot merge two `Router`s that both have a fallback") + } + }; + + this.catch_all_fallback = this + .catch_all_fallback + .merge(catch_all_fallback) + .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); + + this + }) } #[doc = include_str!("../docs/routing/layer.md")] @@ -238,12 +280,12 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - Router { - path_router: self.path_router.layer(layer.clone()), - fallback_router: self.fallback_router.layer(layer.clone()), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), - } + self.map_inner(|this| RouterInner { + path_router: this.path_router.layer(layer.clone()), + fallback_router: this.fallback_router.layer(layer.clone()), + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)), + }) } #[doc = include_str!("../docs/routing/route_layer.md")] @@ -256,68 +298,73 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - Router { - path_router: self.path_router.route_layer(layer), - fallback_router: self.fallback_router, - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback, - } + self.map_inner(|this| RouterInner { + path_router: this.path_router.route_layer(layer), + fallback_router: this.fallback_router, + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback, + }) } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(mut self, handler: H) -> Self + pub fn fallback(self, handler: H) -> Self where H: Handler, T: 'static, { - self.catch_all_fallback = - Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); - self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) + self.tap_inner_mut(|this| { + this.catch_all_fallback = + Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); + }) + .fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service(mut self, service: T) -> Self + pub fn fallback_service(self, service: T) -> Self where T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let route = Route::new(service); - self.catch_all_fallback = Fallback::Service(route.clone()); - self.fallback_endpoint(Endpoint::Route(route)) + self.tap_inner_mut(|this| { + this.catch_all_fallback = Fallback::Service(route.clone()); + }) + .fallback_endpoint(Endpoint::Route(route)) } - fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { - self.fallback_router.set_fallback(endpoint); - self.default_fallback = false; - self + fn fallback_endpoint(self, endpoint: Endpoint) -> Self { + self.tap_inner_mut(|this| { + this.fallback_router.set_fallback(endpoint); + this.default_fallback = false; + }) } #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { - Router { - path_router: self.path_router.with_state(state.clone()), - fallback_router: self.fallback_router.with_state(state.clone()), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.with_state(state), - } + self.map_inner(|this| RouterInner { + path_router: this.path_router.with_state(state.clone()), + fallback_router: this.fallback_router.with_state(state.clone()), + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback.with_state(state), + }) } - pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { - let (req, state) = match self.path_router.call_with_state(req, state) { + pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { + let (req, state) = match self.inner.path_router.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => (req, state), }; - let (req, state) = match self.fallback_router.call_with_state(req, state) { + let (req, state) = match self.inner.fallback_router.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => (req, state), }; - self.catch_all_fallback.call_with_state(req, state) + self.inner.catch_all_fallback.call_with_state(req, state) } /// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type @@ -598,13 +645,13 @@ where } } - fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + fn call_with_state(&self, req: Request, state: S) -> RouteFuture { match self { Fallback::Default(route) | Fallback::Service(route) => { RouteFuture::from_future(route.oneshot_inner(req)) } Fallback::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); + let route = handler.clone().into_route(state); RouteFuture::from_future(route.oneshot_inner(req)) } } diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 1f3f96e106..5b8a3b7b86 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -316,7 +316,7 @@ where } pub(super) fn call_with_state( - &mut self, + &self, mut req: Request, state: S, ) -> Result, (Request, S)> { @@ -349,7 +349,7 @@ where let endpoint = self .routes - .get_mut(&id) + .get(&id) .expect("no route for id. This is a bug in axum. Please file an issue"); match endpoint { diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index d0a7ae3bf8..e93d7f8d77 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -43,7 +43,7 @@ impl Route { } pub(crate) fn oneshot_inner( - &mut self, + &self, req: Request, ) -> Oneshot, Request> { self.0.clone().oneshot(req) diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 0621156b33..94cf7a0682 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -951,7 +951,7 @@ async fn state_isnt_cloned_too_much() { client.get("/").await; - assert_eq!(COUNT.load(Ordering::SeqCst), 5); + assert_eq!(COUNT.load(Ordering::SeqCst), 3); } #[crate::test]