From 988e0936187931822f05defd4ebedf2f908d8599 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 22 May 2022 10:43:35 +0200 Subject: [PATCH] Automatically handle `http_body::LengthLimitError` --- axum-core/CHANGELOG.md | 6 ++ axum-core/Cargo.toml | 2 +- axum-core/src/extract/rejection.rs | 73 ++++++++++++++++++++++-- axum/Cargo.toml | 2 +- axum/src/extract/content_length_limit.rs | 5 ++ axum/src/routing/tests/mod.rs | 17 ++++++ 6 files changed, 97 insertions(+), 8 deletions(-) diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index 94b260945bc..1ccd3c36106 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,8 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **added:** Automatically handle `http_body::LengthLimitError` in `FailedToBufferBody` and map + such errors to `413 Payload Too Large` ([#1048]) +- **added:** `FailedToBufferBody::is_length_limit_error` to check if the underlying error is + `http_body::LengthLimitError`. Its source error can also be downcast to + `http_body::LengthLimitError` ([#1048]) - **fixed:** Use `impl IntoResponse` less in docs ([#1049]) +[#1048]: https://github.com/tokio-rs/axum/pull/1048 [#1049]: https://github.com/tokio-rs/axum/pull/1049 # 0.2.4 (02. May, 2022) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index 9ce2a6423a4..2b817302c7b 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -15,7 +15,7 @@ async-trait = "0.1" bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2.7" -http-body = "0.4" +http-body = "0.4.5" mime = "0.3.16" [dev-dependencies] diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs index fc81eda7db8..068799e554e 100644 --- a/axum-core/src/extract/rejection.rs +++ b/axum-core/src/extract/rejection.rs @@ -2,6 +2,7 @@ use crate::response::{IntoResponse, Response}; use http::StatusCode; +use http_body::LengthLimitError; use std::fmt; /// Rejection type used if you try and extract the request body more than @@ -28,12 +29,72 @@ impl fmt::Display for BodyAlreadyExtracted { impl std::error::Error for BodyAlreadyExtracted {} -define_rejection! { - #[status = BAD_REQUEST] - #[body = "Failed to buffer the request body"] - /// Rejection type for extractors that buffer the request body. Used if the - /// request body cannot be buffered due to an error. - pub struct FailedToBufferBody(Error); +/// Rejection type for extractors that buffer the request body. Used if the +/// request body cannot be buffered due to an error. +// TODO: in next major for axum-core make this a #[non_exhaustive] enum so we don't need the +// additional indirection +#[derive(Debug)] +pub struct FailedToBufferBody(FailedToBufferBodyInner); + +impl FailedToBufferBody { + /// Check if the body failed to be buffered because a length limit was hit. + /// + /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or + /// otherwise wrapping request bodies in [`http_body::Limited`]. + pub fn is_length_limit_error(&self) -> bool { + matches!(self.0, FailedToBufferBodyInner::LengthLimitError(_)) + } +} + +#[derive(Debug)] +enum FailedToBufferBodyInner { + Unknown(crate::Error), + LengthLimitError(LengthLimitError), +} + +impl FailedToBufferBody { + pub(crate) fn from_err(err: E) -> Self + where + E: Into, + { + let err = err.into(); + match err.downcast::() { + Ok(err) => Self(FailedToBufferBodyInner::LengthLimitError(*err)), + Err(err) => Self(FailedToBufferBodyInner::Unknown(crate::Error::new(err))), + } + } +} + +impl crate::response::IntoResponse for FailedToBufferBody { + fn into_response(self) -> crate::response::Response { + match self.0 { + FailedToBufferBodyInner::Unknown(err) => ( + http::StatusCode::BAD_REQUEST, + format!(concat!("Failed to buffer the request body", ": {}"), err), + ) + .into_response(), + FailedToBufferBodyInner::LengthLimitError(err) => ( + StatusCode::PAYLOAD_TOO_LARGE, + format!(concat!("Failed to buffer the request body", ": {}"), err), + ) + .into_response(), + } + } +} + +impl std::fmt::Display for FailedToBufferBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to buffer the request body") + } +} + +impl std::error::Error for FailedToBufferBody { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.0 { + FailedToBufferBodyInner::Unknown(err) => Some(err), + FailedToBufferBodyInner::LengthLimitError(err) => Some(err), + } + } } define_rejection! { diff --git a/axum/Cargo.toml b/axum/Cargo.toml index e575543d88c..b5b7347c470 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -81,7 +81,7 @@ features = [ ] [dev-dependencies.tower-http] -version = "0.3.0" +version = "0.3.4" features = ["full"] [package.metadata.docs.rs] diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 6412324b2c5..ce2272df311 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -29,6 +29,11 @@ use std::ops::Deref; /// ``` /// /// This requires the request to have a `Content-Length` header. +/// +/// If you want to limit the size of request bodies without requiring a `Content-Length` header +/// consider using [`tower_http::limit::RequestBodyLimit`]. +/// +/// [`tower_http::limit::RequestBodyLimit`]: https://docs.rs/tower-http/latest/tower_http/limit/struct.RequestBodyLimit.html #[derive(Debug, Clone)] pub struct ContentLengthLimit(pub T); diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 768ffe0261c..d112ae7af6e 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -699,3 +699,20 @@ async fn routes_must_start_with_slash() { let app = Router::new().route(":foo", get(|| async {})); TestClient::new(app); } + +#[tokio::test] +async fn limited_body() { + const LIMIT: usize = 3; + + let app = Router::new() + .route("/", post(|_: Bytes| async {})) + .layer(tower_http::limit::RequestBodyLimitLayer::new(LIMIT)); + + let client = TestClient::new(app); + + let res = client.post("/").body("a".repeat(LIMIT)).send().await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await; + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); +}