Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add layer that limits body size #271

Merged
merged 19 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Added

- Add `Timeout` middleware ([#270])
- Add `RequestBodyLimit` middleware ([#271])

[#270]: https://github.com/tower-rs/tower-http/pull/270
[#271]: https://github.com/tower-rs/tower-http/pull/271

## Changed

Expand Down
4 changes: 3 additions & 1 deletion tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ bytes = "1"
futures-core = "0.3"
futures-util = { version = "0.3.14", default_features = false, features = [] }
http = "0.2.2"
http-body = "0.4.1"
http-body = "0.4.5"
pin-project-lite = "0.2.7"
tower-layer = "0.3"
tower-service = "0.3"
Expand Down Expand Up @@ -62,6 +62,7 @@ full = [
"decompression-full",
"follow-redirect",
"fs",
"limit",
"map-request-body",
"map-response-body",
"metrics",
Expand All @@ -82,6 +83,7 @@ catch-panic = ["tracing", "futures-util/std"]
cors = []
follow-redirect = ["iri-string", "tower/util"]
fs = ["tokio/fs", "tokio-util/io", "tokio/io-util", "mime_guess", "mime", "percent-encoding", "httpdate", "set-status", "futures-util/alloc"]
limit = []
map-request-body = []
map-response-body = []
metrics = ["tokio/time"]
Expand Down
20 changes: 20 additions & 0 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
) -> ServiceBuilder<
Stack<crate::catch_panic::CatchPanicLayer<crate::catch_panic::DefaultResponseForPanic>, L>,
>;

/// Intercept requests with over-sized payloads and convert them into
/// `413 Payload Too Large` responses.
///
/// See [`tower_http::limit`] for more details.
///
/// [`tower_http::limit`]: crate::limit
#[cfg(feature = "limit")]
fn request_body_limit(
self,
limit: usize,
) -> ServiceBuilder<Stack<crate::limit::RequestBodyLimitLayer, L>>;
}

impl<L> crate::sealed::Sealed<L> for ServiceBuilder<L> {}
Expand Down Expand Up @@ -558,4 +570,12 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
> {
self.layer(crate::catch_panic::CatchPanicLayer::new())
}

#[cfg(feature = "limit")]
fn request_body_limit(
self,
limit: usize,
) -> ServiceBuilder<Stack<crate::limit::RequestBodyLimitLayer, L>> {
self.layer(crate::limit::RequestBodyLimitLayer::new(limit))
}
}
3 changes: 3 additions & 0 deletions tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ pub mod trace;
#[cfg(feature = "follow-redirect")]
pub mod follow_redirect;

#[cfg(feature = "limit")]
pub mod limit;

#[cfg(feature = "metrics")]
pub mod metrics;

Expand Down
107 changes: 107 additions & 0 deletions tower-http/src/limit/body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use bytes::Bytes;
use http::{HeaderMap, HeaderValue, Response, StatusCode};
use http_body::{Body, Full, SizeHint};
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Response body for [`RequestBodyLimit`].
///
/// [`RequestBodyLimit`]: super::RequestBodyLimit
pub struct ResponseBody<B> {
#[pin]
inner: ResponseBodyInner<B>
}
}

impl<B> ResponseBody<B> {
fn payload_too_large() -> Self {
Self {
inner: ResponseBodyInner::PayloadTooLarge {
body: Full::from(BODY),
},
}
}

pub(crate) fn new(body: B) -> Self {
Self {
inner: ResponseBodyInner::Body { body },
}
}
}

pin_project! {
#[project = BodyProj]
enum ResponseBodyInner<B> {
PayloadTooLarge {
#[pin]
body: Full<Bytes>,
},
Body {
#[pin]
body: B
}
}
}

impl<B> Body for ResponseBody<B>
where
B: Body<Data = Bytes>,
{
type Data = Bytes;
type Error = B::Error;

fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match self.project().inner.project() {
BodyProj::PayloadTooLarge { body } => body.poll_data(cx).map_err(|err| match err {}),
BodyProj::Body { body } => body.poll_data(cx),
}
}

fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
match self.project().inner.project() {
BodyProj::PayloadTooLarge { body } => {
body.poll_trailers(cx).map_err(|err| match err {})
}
BodyProj::Body { body } => body.poll_trailers(cx),
}
}

fn is_end_stream(&self) -> bool {
match &self.inner {
ResponseBodyInner::PayloadTooLarge { body } => body.is_end_stream(),
ResponseBodyInner::Body { body } => body.is_end_stream(),
}
}

fn size_hint(&self) -> SizeHint {
match &self.inner {
ResponseBodyInner::PayloadTooLarge { body } => body.size_hint(),
ResponseBodyInner::Body { body } => body.size_hint(),
}
}
}

const BODY: &[u8] = b"length limit exceeded";

pub(crate) fn create_error_response<B>() -> Response<ResponseBody<B>>
where
B: Body,
{
let mut res = Response::new(ResponseBody::payload_too_large());
*res.status_mut() = StatusCode::PAYLOAD_TOO_LARGE;

#[allow(clippy::declare_interior_mutable_const)]
const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8");
res.headers_mut()
.insert(http::header::CONTENT_TYPE, TEXT_PLAIN);

res
}
61 changes: 61 additions & 0 deletions tower-http/src/limit/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use super::body::create_error_response;
use super::ResponseBody;
use futures_core::ready;
use http::Response;
use http_body::Body;
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Response future for [`RequestBodyLimit`].
///
/// [`RequestBodyLimit`]: super::RequestBodyLimit
pub struct ResponseFuture<F> {
#[pin]
inner: ResponseFutureInner<F>,
}
}

impl<F> ResponseFuture<F> {
pub(crate) fn payload_too_large() -> Self {
Self {
inner: ResponseFutureInner::PayloadTooLarge,
}
}

pub(crate) fn new(future: F) -> Self {
Self {
inner: ResponseFutureInner::Future { future },
}
}
}

pin_project! {
#[project = ResFutProj]
enum ResponseFutureInner<F> {
PayloadTooLarge,
Future {
#[pin]
future: F,
}
}
}

impl<ResBody, F, E> Future for ResponseFuture<F>
where
ResBody: Body,
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResponseBody<ResBody>>, E>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match self.project().inner.project() {
ResFutProj::PayloadTooLarge => create_error_response(),
ResFutProj::Future { future } => ready!(future.poll(cx))?.map(ResponseBody::new),
};

Poll::Ready(Ok(res))
}
}
32 changes: 32 additions & 0 deletions tower-http/src/limit/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use super::RequestBodyLimit;
use tower_layer::Layer;

/// Layer that applies the [`RequestBodyLimit`] middleware that intercepts requests
/// with body lengths greater than the configured limit and converts them into
/// `413 Payload Too Large` responses.
///
/// See the [module docs](crate::limit) for an example.
///
/// [`RequestBodyLimit`]: super::RequestBodyLimit
#[derive(Clone, Copy, Debug)]
pub struct RequestBodyLimitLayer {
limit: usize,
}

impl RequestBodyLimitLayer {
/// Create a new `RequestBodyLimitLayer` with the given body length limit.
pub fn new(limit: usize) -> Self {
Self { limit }
}
}

impl<S> Layer<S> for RequestBodyLimitLayer {
type Service = RequestBodyLimit<S>;

fn layer(&self, inner: S) -> Self::Service {
RequestBodyLimit {
inner,
limit: self.limit,
}
}
}
Loading