From 892daba8f51a271898752912a1e301cfed3711d2 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Sat, 14 Nov 2020 17:58:48 -0800 Subject: [PATCH] do not let Body read beyond its length --- src/body.rs | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 94 insertions(+), 4 deletions(-) diff --git a/src/body.rs b/src/body.rs index 4a2c42e6..c0a324e1 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,4 +1,4 @@ -use futures_lite::{io, prelude::*}; +use futures_lite::{io, prelude::*, ready}; use serde::{de::DeserializeOwned, Serialize}; use std::fmt::{self, Debug}; @@ -56,6 +56,7 @@ pin_project_lite::pin_project! { reader: Box, mime: Mime, length: Option, + bytes_read: usize } } @@ -78,6 +79,7 @@ impl Body { reader: Box::new(io::empty()), mime: mime::BYTE_STREAM, length: Some(0), + bytes_read: 0, } } @@ -108,6 +110,7 @@ impl Body { reader: Box::new(reader), mime: mime::BYTE_STREAM, length: len, + bytes_read: 0, } } @@ -151,6 +154,7 @@ impl Body { mime: mime::BYTE_STREAM, length: Some(bytes.len()), reader: Box::new(io::Cursor::new(bytes)), + bytes_read: 0, } } @@ -200,6 +204,7 @@ impl Body { mime: mime::PLAIN, length: Some(s.len()), reader: Box::new(io::Cursor::new(s.into_bytes())), + bytes_read: 0, } } @@ -245,6 +250,7 @@ impl Body { length: Some(bytes.len()), reader: Box::new(io::Cursor::new(bytes)), mime: mime::JSON, + bytes_read: 0, }; Ok(body) } @@ -309,6 +315,7 @@ impl Body { length: Some(bytes.len()), reader: Box::new(io::Cursor::new(bytes)), mime: mime::FORM, + bytes_read: 0, }; Ok(body) } @@ -377,6 +384,7 @@ impl Body { mime, length: Some(len as usize), reader: Box::new(io::BufReader::new(file)), + bytes_read: 0, }) } @@ -418,6 +426,7 @@ impl Debug for Body { f.debug_struct("Body") .field("reader", &"") .field("length", &self.length) + .field("bytes_read", &self.bytes_read) .finish() } } @@ -459,15 +468,25 @@ impl AsyncRead for Body { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - Pin::new(&mut self.reader).poll_read(cx, buf) + let mut buf = match self.length { + None => buf, + Some(length) if length == self.bytes_read => return Poll::Ready(Ok(0)), + Some(length) => { + let max_len = (length - self.bytes_read).min(buf.len()); + &mut buf[0..max_len] + } + }; + + let bytes = ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?; + self.bytes_read += bytes; + Poll::Ready(Ok(bytes)) } } impl AsyncBufRead for Body { #[allow(missing_doc_code_examples)] fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - this.reader.poll_fill_buf(cx) + self.project().reader.poll_fill_buf(cx) } fn consume(mut self: Pin<&mut Self>, amt: usize) { @@ -500,6 +519,7 @@ fn guess_ext(path: &std::path::Path) -> Option { #[cfg(test)] mod test { use super::*; + use async_std::io::Cursor; use serde::Deserialize; #[async_std::test] @@ -523,4 +543,74 @@ mod test { let res = body.into_form::().await; assert_eq!(res.unwrap_err().status(), 422); } + + async fn read_with_buffers_of_size(reader: &mut R, size: usize) -> crate::Result + where + R: AsyncRead + Unpin, + { + let mut return_buffer = vec![]; + loop { + let mut buf = vec![0; size]; + match reader.read(&mut buf).await? { + 0 => break Ok(String::from_utf8(return_buffer)?), + bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]), + } + } + } + + #[async_std::test] + async fn attempting_to_read_past_length_with_shorter_buffer() -> crate::Result<()> { + for buf_len in 1..13 { + let mut body = Body::from_reader(Cursor::new("hello world"), Some(5)); + assert_eq!( + read_with_buffers_of_size(&mut body, buf_len).await?, + "hello" + ); + assert_eq!(body.bytes_read, 5); + } + + Ok(()) + } + + #[async_std::test] + async fn attempting_to_read_when_length_is_greater_than_content() -> crate::Result<()> { + for buf_len in 1..13 { + let mut body = Body::from_reader(Cursor::new("hello world"), Some(15)); + assert_eq!( + read_with_buffers_of_size(&mut body, buf_len).await?, + "hello world" + ); + assert_eq!(body.bytes_read, 11); + } + + Ok(()) + } + + #[async_std::test] + async fn attempting_to_read_when_length_is_exactly_right() -> crate::Result<()> { + for buf_len in 1..13 { + let mut body = Body::from_reader(Cursor::new("hello world"), Some(11)); + assert_eq!( + read_with_buffers_of_size(&mut body, buf_len).await?, + "hello world" + ); + assert_eq!(body.bytes_read, 11); + } + + Ok(()) + } + + #[async_std::test] + async fn reading_in_various_buffer_lengths_when_there_is_no_length() -> crate::Result<()> { + for buf_len in 1..13 { + let mut body = Body::from_reader(Cursor::new("hello world"), None); + assert_eq!( + read_with_buffers_of_size(&mut body, buf_len).await?, + "hello world" + ); + assert_eq!(body.bytes_read, 11); + } + + Ok(()) + } }