diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 8fc9a1d70f..6b89611422 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -18,7 +18,7 @@ use std::{ task::{Context, Poll}, }; use tower::{ - util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot}, + util::{BoxCloneService, MapErrLayer, MapResponseLayer, Oneshot}, ServiceExt, }; use tower_layer::Layer; @@ -73,7 +73,6 @@ impl Route { NewError: 'static, { let layer = ( - MapRequestLayer::new(|req: Request<_>| req.map(Body::new)), MapErrLayer::new(Into::into), MapResponseLayer::new(IntoResponse::into_response), layer, @@ -113,7 +112,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { let req = req.map(Body::new); - RouteFuture::from_future(self.oneshot_inner(req)) + RouteFuture::from_future(self.oneshot_inner(req)).not_top_level() } } @@ -124,6 +123,7 @@ pin_project! { kind: RouteFutureKind, strip_body: bool, allow_header: Option, + top_level: bool, } } @@ -151,6 +151,7 @@ impl RouteFuture { kind: RouteFutureKind::Future { future }, strip_body: false, allow_header: None, + top_level: true, } } @@ -163,6 +164,11 @@ impl RouteFuture { self.allow_header = Some(allow_header); self } + + pub(crate) fn not_top_level(mut self) -> Self { + self.top_level = false; + self + } } impl Future for RouteFuture { @@ -183,16 +189,16 @@ impl Future for RouteFuture { } }; - set_allow_header(res.headers_mut(), this.allow_header); + if *this.top_level { + set_allow_header(res.headers_mut(), this.allow_header); - // make sure to set content-length before removing the body - set_content_length(res.size_hint(), res.headers_mut()); + // make sure to set content-length before removing the body + set_content_length(res.size_hint(), res.headers_mut()); - let res = if *this.strip_body { - res.map(|_| Body::empty()) - } else { - res - }; + if *this.strip_body { + *res.body_mut() = Body::empty(); + } + } Poll::Ready(Ok(res)) } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index c7ae1f7040..7fbffbb8a3 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1087,3 +1087,22 @@ async fn locks_mutex_very_little() { assert_eq!(num, 1); } } + +#[crate::test] +async fn middleware_adding_body() { + let app = Router::new() + .route("/", get(())) + .layer(MapResponseLayer::new(|mut res: Response| -> Response { + // If there is a content-length header, its value will be zero and axum will avoid + // overwriting it. But this means our content-length doesn't match the length of the + // body, which leads to panics in Hyper. Thus we have to ensure that axum doesn't add + // on content-length headers until after middleware has been run. + assert!(!res.headers().contains_key("content-length")); + *res.body_mut() = "…".into(); + res + })); + + let client = TestClient::new(app); + let res = client.get("/").await; + assert_eq!(res.text().await, "…"); +}