Skip to content

Commit

Permalink
do not let Body read beyond its length
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Nov 15, 2020
1 parent e88d8fc commit 892daba
Showing 1 changed file with 94 additions and 4 deletions.
98 changes: 94 additions & 4 deletions src/body.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures_lite::{io, prelude::*};
use futures_lite::{io, prelude::*, ready};
use serde::{de::DeserializeOwned, Serialize};

use std::fmt::{self, Debug};
Expand Down Expand Up @@ -56,6 +56,7 @@ pin_project_lite::pin_project! {
reader: Box<dyn AsyncBufRead + Unpin + Send + Sync + 'static>,
mime: Mime,
length: Option<usize>,
bytes_read: usize
}
}

Expand All @@ -78,6 +79,7 @@ impl Body {
reader: Box::new(io::empty()),
mime: mime::BYTE_STREAM,
length: Some(0),
bytes_read: 0,
}
}

Expand Down Expand Up @@ -108,6 +110,7 @@ impl Body {
reader: Box::new(reader),
mime: mime::BYTE_STREAM,
length: len,
bytes_read: 0,
}
}

Expand Down Expand Up @@ -151,6 +154,7 @@ impl Body {
mime: mime::BYTE_STREAM,
length: Some(bytes.len()),
reader: Box::new(io::Cursor::new(bytes)),
bytes_read: 0,
}
}

Expand Down Expand Up @@ -200,6 +204,7 @@ impl Body {
mime: mime::PLAIN,
length: Some(s.len()),
reader: Box::new(io::Cursor::new(s.into_bytes())),
bytes_read: 0,
}
}

Expand Down Expand Up @@ -245,6 +250,7 @@ impl Body {
length: Some(bytes.len()),
reader: Box::new(io::Cursor::new(bytes)),
mime: mime::JSON,
bytes_read: 0,
};
Ok(body)
}
Expand Down Expand Up @@ -309,6 +315,7 @@ impl Body {
length: Some(bytes.len()),
reader: Box::new(io::Cursor::new(bytes)),
mime: mime::FORM,
bytes_read: 0,
};
Ok(body)
}
Expand Down Expand Up @@ -377,6 +384,7 @@ impl Body {
mime,
length: Some(len as usize),
reader: Box::new(io::BufReader::new(file)),
bytes_read: 0,
})
}

Expand Down Expand Up @@ -418,6 +426,7 @@ impl Debug for Body {
f.debug_struct("Body")
.field("reader", &"<hidden>")
.field("length", &self.length)
.field("bytes_read", &self.bytes_read)
.finish()
}
}
Expand Down Expand Up @@ -459,15 +468,25 @@ impl AsyncRead for Body {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.reader).poll_read(cx, buf)
let mut buf = match self.length {
None => buf,
Some(length) if length == self.bytes_read => return Poll::Ready(Ok(0)),
Some(length) => {
let max_len = (length - self.bytes_read).min(buf.len());
&mut buf[0..max_len]
}
};

let bytes = ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf))?;
self.bytes_read += bytes;
Poll::Ready(Ok(bytes))
}
}

impl AsyncBufRead for Body {
#[allow(missing_doc_code_examples)]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
let this = self.project();
this.reader.poll_fill_buf(cx)
self.project().reader.poll_fill_buf(cx)
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
Expand Down Expand Up @@ -500,6 +519,7 @@ fn guess_ext(path: &std::path::Path) -> Option<Mime> {
#[cfg(test)]
mod test {
use super::*;
use async_std::io::Cursor;
use serde::Deserialize;

#[async_std::test]
Expand All @@ -523,4 +543,74 @@ mod test {
let res = body.into_form::<Foo>().await;
assert_eq!(res.unwrap_err().status(), 422);
}

async fn read_with_buffers_of_size<R>(reader: &mut R, size: usize) -> crate::Result<String>
where
R: AsyncRead + Unpin,
{
let mut return_buffer = vec![];
loop {
let mut buf = vec![0; size];
match reader.read(&mut buf).await? {
0 => break Ok(String::from_utf8(return_buffer)?),
bytes_read => return_buffer.extend_from_slice(&buf[..bytes_read]),
}
}
}

#[async_std::test]
async fn attempting_to_read_past_length_with_shorter_buffer() -> crate::Result<()> {
for buf_len in 1..13 {
let mut body = Body::from_reader(Cursor::new("hello world"), Some(5));
assert_eq!(
read_with_buffers_of_size(&mut body, buf_len).await?,
"hello"
);
assert_eq!(body.bytes_read, 5);
}

Ok(())
}

#[async_std::test]
async fn attempting_to_read_when_length_is_greater_than_content() -> crate::Result<()> {
for buf_len in 1..13 {
let mut body = Body::from_reader(Cursor::new("hello world"), Some(15));
assert_eq!(
read_with_buffers_of_size(&mut body, buf_len).await?,
"hello world"
);
assert_eq!(body.bytes_read, 11);
}

Ok(())
}

#[async_std::test]
async fn attempting_to_read_when_length_is_exactly_right() -> crate::Result<()> {
for buf_len in 1..13 {
let mut body = Body::from_reader(Cursor::new("hello world"), Some(11));
assert_eq!(
read_with_buffers_of_size(&mut body, buf_len).await?,
"hello world"
);
assert_eq!(body.bytes_read, 11);
}

Ok(())
}

#[async_std::test]
async fn reading_in_various_buffer_lengths_when_there_is_no_length() -> crate::Result<()> {
for buf_len in 1..13 {
let mut body = Body::from_reader(Cursor::new("hello world"), None);
assert_eq!(
read_with_buffers_of_size(&mut body, buf_len).await?,
"hello world"
);
assert_eq!(body.bytes_read, 11);
}

Ok(())
}
}

0 comments on commit 892daba

Please sign in to comment.