diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000000..be98f3607f --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,95 @@ +use std::cmp; +use std::iter; +use std::io::{self, Read, BufRead, Cursor}; + +pub struct BufReader { + buf: Cursor>, + inner: R +} + +const INIT_BUFFER_SIZE: usize = 4096; +const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; + +impl BufReader { + pub fn new(rdr: R) -> BufReader { + BufReader::with_capacity(rdr, INIT_BUFFER_SIZE) + } + + pub fn with_capacity(rdr: R, cap: usize) -> BufReader { + BufReader { + buf: Cursor::new(Vec::with_capacity(cap)), + inner: rdr + } + } + + pub fn get_ref(&self) -> &R { &self.inner } + + pub fn get_mut(&mut self) -> &mut R { &mut self.inner } + + pub fn get_buf(&self) -> &[u8] { + self.buf.get_ref() + } + + pub fn into_inner(self) -> R { self.inner } + + pub fn read_into_buf(&mut self) -> io::Result { + let v = self.buf.get_mut(); + reserve(v); + let inner = &mut self.inner; + with_end_to_cap(v, |b| inner.read(b)) + } +} + +impl Read for BufReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if self.buf.get_ref().len() == self.buf.position() as usize && + buf.len() >= self.buf.get_ref().capacity() { + return self.inner.read(buf); + } + try!(self.fill_buf()); + self.buf.read(buf) + } +} + +impl BufRead for BufReader { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + if self.buf.position() as usize == self.buf.get_ref().len() { + self.buf.set_position(0); + let v = self.buf.get_mut(); + v.truncate(0); + let inner = &mut self.inner; + try!(with_end_to_cap(v, |b| inner.read(b))); + } + self.buf.fill_buf() + } + + fn consume(&mut self, amt: usize) { + self.buf.consume(amt) + } +} + +fn with_end_to_cap(v: &mut Vec, f: F) -> io::Result + where F: FnOnce(&mut [u8]) -> io::Result +{ + let len = v.len(); + let new_area = v.capacity() - len; + v.extend(iter::repeat(0).take(new_area)); + match f(&mut v[len..]) { + Ok(n) => { + v.truncate(len + n); + Ok(n) + } + Err(e) => { + v.truncate(len); + Err(e) + } + } +} + +#[inline] +fn reserve(v: &mut Vec) { + let cap = v.capacity(); + if v.len() == cap { + v.reserve(cmp::min(cap * 4, MAX_BUFFER_SIZE) - cap); + } +} diff --git a/src/client/response.rs b/src/client/response.rs index a450258e56..c5ef7a42de 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,8 +1,9 @@ //! Client Responses -use std::io::{self, Read, BufReader}; +use std::io::{self, Read}; use std::num::FromPrimitive; use std::marker::PhantomData; +use buffer::BufReader; use header; use header::{ContentLength, TransferEncoding}; use header::Encoding::Chunked; @@ -103,9 +104,10 @@ impl Read for Response { mod tests { use std::borrow::Cow::Borrowed; use std::boxed::BoxAny; - use std::io::{self, Read, BufReader}; + use std::io::{self, Read}; use std::marker::PhantomData; + use buffer::BufReader; use header::Headers; use header::TransferEncoding; use header::Encoding; diff --git a/src/http.rs b/src/http.rs index bef6321edb..50648c036c 100644 --- a/src/http.rs +++ b/src/http.rs @@ -5,12 +5,13 @@ use std::io::{self, Read, Write, BufRead}; use httparse; +use buffer::BufReader; use header::Headers; use method::Method; use uri::RequestUri; use version::HttpVersion::{self, Http10, Http11}; use HttpError:: HttpTooLargeError; -use HttpResult; +use {HttpError, HttpResult}; use self::HttpReader::{SizedReader, ChunkedReader, EofReader, EmptyReader}; use self::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter}; @@ -307,56 +308,88 @@ impl Write for HttpWriter { } } +const MAX_HEADERS: usize = 100; + /// Parses a request into an Incoming message head. -pub fn parse_request(buf: &mut T) -> HttpResult> { - let (inc, len) = { - let slice = try!(buf.fill_buf()); - let mut headers = [httparse::Header { name: "", value: b"" }; 64]; - let mut req = httparse::Request::new(&mut headers); - match try!(req.parse(slice)) { +#[inline] +pub fn parse_request(buf: &mut BufReader) -> HttpResult> { + parse::(buf) +} + +/// Parses a response into an Incoming message head. +#[inline] +pub fn parse_response(buf: &mut BufReader) -> HttpResult> { + parse::(buf) +} + +fn parse, I>(rdr: &mut BufReader) -> HttpResult> { + loop { + match try!(try_parse::(rdr)) { + httparse::Status::Complete((inc, len)) => { + rdr.consume(len); + return Ok(inc); + }, + _partial => () + } + match try!(rdr.read_into_buf()) { + 0 => return Err(HttpTooLargeError), + _ => () + } + } +} + +fn try_parse, I>(rdr: &mut BufReader) -> TryParseResult { + let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; + ::try_parse(&mut headers, rdr.get_buf()) +} + +#[doc(hidden)] +trait TryParse { + type Subject; + fn try_parse<'a>(headers: &'a mut [httparse::Header<'a>], buf: &'a [u8]) -> TryParseResult; +} + +type TryParseResult = Result, usize)>, HttpError>; + +impl<'a> TryParse for httparse::Request<'a> { + type Subject = (Method, RequestUri); + + fn try_parse<'b>(headers: &'b mut [httparse::Header<'b>], buf: &'b [u8]) -> TryParseResult<(Method, RequestUri)> { + let mut req = httparse::Request::new(headers); + Ok(match try!(req.parse(buf)) { httparse::Status::Complete(len) => { - (Incoming { + httparse::Status::Complete((Incoming { version: if req.version.unwrap() == 1 { Http11 } else { Http10 }, subject: ( try!(req.method.unwrap().parse()), try!(req.path.unwrap().parse()) ), headers: try!(Headers::from_raw(req.headers)) - }, len) + }, len)) }, - _ => { - // request head is bigger than a BufRead's buffer? 400 that! - return Err(HttpTooLargeError) - } - } - }; - buf.consume(len); - Ok(inc) + httparse::Status::Partial => httparse::Status::Partial + }) + } } -/// Parses a response into an Incoming message head. -pub fn parse_response(buf: &mut T) -> HttpResult> { - let (inc, len) = { - let mut headers = [httparse::Header { name: "", value: b"" }; 64]; - let mut res = httparse::Response::new(&mut headers); - match try!(res.parse(try!(buf.fill_buf()))) { +impl<'a> TryParse for httparse::Response<'a> { + type Subject = RawStatus; + + fn try_parse<'b>(headers: &'b mut [httparse::Header<'b>], buf: &'b [u8]) -> TryParseResult { + let mut res = httparse::Response::new(headers); + Ok(match try!(res.parse(buf)) { httparse::Status::Complete(len) => { - (Incoming { + httparse::Status::Complete((Incoming { version: if res.version.unwrap() == 1 { Http11 } else { Http10 }, subject: RawStatus( res.code.unwrap(), res.reason.unwrap().to_owned().into_cow() ), headers: try!(Headers::from_raw(res.headers)) - }, len) + }, len)) }, - _ => { - // response head is bigger than a BufRead's buffer? - return Err(HttpTooLargeError) - } - } - }; - buf.consume(len); - Ok(inc) + httparse::Status::Partial => httparse::Status::Partial + }) + } } /// An Incoming Message head. Includes request/status line, and headers. @@ -456,19 +489,30 @@ mod tests { read_err("1;no CRLF"); } + #[test] + fn test_parse_incoming() { + use buffer::BufReader; + use mock::MockStream; + + use super::parse_request; + let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let mut buf = BufReader::new(&mut raw); + parse_request(&mut buf).unwrap(); + } + use test::Bencher; #[bench] fn bench_parse_incoming(b: &mut Bencher) { - use std::io::BufReader; + use buffer::BufReader; use mock::MockStream; use super::parse_request; + let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); + let mut buf = BufReader::new(&mut raw); b.iter(|| { - let mut raw = MockStream::with_input(b"GET /echo HTTP/1.1\r\nHost: hyper.rs\r\n\r\n"); - let mut buf = BufReader::new(&mut raw); - parse_request(&mut buf).unwrap(); + buf.get_mut().read.set_position(0); }); } } diff --git a/src/lib.rs b/src/lib.rs index 076b3682ae..d3febe1ba0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,7 +168,8 @@ macro_rules! inspect( #[cfg(test)] #[macro_use] mod mock; - +#[doc(hidden)] +pub mod buffer; pub mod client; pub mod error; pub mod method; diff --git a/src/server/mod.rs b/src/server/mod.rs index f6e1658859..4a78d6ebff 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,5 +1,5 @@ //! HTTP Server -use std::io::{BufReader, BufWriter, Write}; +use std::io::{BufWriter, Write}; use std::marker::PhantomData; use std::net::{SocketAddr, ToSocketAddrs}; use std::path::Path; @@ -14,6 +14,7 @@ pub use net::{Fresh, Streaming}; use HttpError::HttpIoError; use {HttpResult}; +use buffer::BufReader; use header::{Headers, Connection, Expect}; use header::ConnectionOption::{Close, KeepAlive}; use method::Method; @@ -227,6 +228,7 @@ mod tests { Host: example.domain\r\n\ Expect: 100-continue\r\n\ Content-Length: 10\r\n\ + Connection: close\r\n\ \r\n\ 1234567890\ "); diff --git a/src/server/request.rs b/src/server/request.rs index a0fd37daa1..d572cddeb3 100644 --- a/src/server/request.rs +++ b/src/server/request.rs @@ -2,10 +2,11 @@ //! //! These are requests that a `hyper::Server` receives, and include its method, //! target URI, headers, and message body. -use std::io::{self, Read, BufReader}; +use std::io::{self, Read}; use std::net::SocketAddr; use {HttpResult}; +use buffer::BufReader; use net::NetworkStream; use version::{HttpVersion}; use method::Method::{self, Get, Head}; @@ -81,12 +82,13 @@ impl<'a, 'b> Read for Request<'a, 'b> { #[cfg(test)] mod tests { + use buffer::BufReader; use header::{Host, TransferEncoding, Encoding}; use net::NetworkStream; use mock::MockStream; use super::Request; - use std::io::{self, Read, BufReader}; + use std::io::{self, Read}; use std::net::SocketAddr; fn sock(s: &str) -> SocketAddr {