diff --git a/src/proto/conn.rs b/src/proto/conn.rs index c3f8bff0bf..c7999488c2 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -186,7 +186,8 @@ where I: AsyncRead + AsyncWrite, let was_mid_parse = !self.io.read_buf().is_empty(); return if was_mid_parse || must_error { debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); - Err(e) + self.on_parse_error(e) + .map(|()| Async::NotReady) } else { debug!("read eof"); Ok(Async::Ready(None)) @@ -213,7 +214,8 @@ where I: AsyncRead + AsyncWrite, Err(e) => { debug!("decoder error = {:?}", e); self.state.close_read(); - return Err(e); + return self.on_parse_error(e) + .map(|()| Async::NotReady); } }; @@ -548,6 +550,27 @@ where I: AsyncRead + AsyncWrite, Ok(AsyncSink::Ready) } + // When we get a parse error, depending on what side we are, we might be able + // to write a response before closing the connection. + // + // - Client: there is nothing we can do + // - Server: if Response hasn't been written yet, we can send a 4xx response + fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> { + match self.state.writing { + Writing::Init => { + if let Some(msg) = T::on_error(&err) { + self.write_head(msg, false); + self.state.error = Some(err); + return Ok(()); + } + } + _ => (), + } + + // fallback is pass the error back up + Err(err) + } + fn write_queued(&mut self) -> Poll<(), io::Error> { trace!("Conn::write_queued()"); let state = match self.state.writing { diff --git a/src/proto/h1/parse.rs b/src/proto/h1/parse.rs index 0ca640cc68..b0f5df171f 100644 --- a/src/proto/h1/parse.rs +++ b/src/proto/h1/parse.rs @@ -150,6 +150,26 @@ impl Http1Transaction for ServerTransaction { ret } + fn on_error(err: &::Error) -> Option> { + let status = match err { + &::Error::Method | + &::Error::Version | + &::Error::Header | + &::Error::Uri(_) => { + StatusCode::BadRequest + }, + &::Error::TooLarge => { + StatusCode::RequestHeaderFieldsTooLarge + } + _ => return None, + }; + + debug!("sending automatic response ({}) for parse error", status); + let mut msg = MessageHead::default(); + msg.subject = status; + Some(msg) + } + fn should_error_on_parse_eof() -> bool { false } @@ -317,6 +337,11 @@ impl Http1Transaction for ClientTransaction { Ok(body) } + fn on_error(_err: &::Error) -> Option> { + // we can't tell the server about any errors it creates + None + } + fn should_error_on_parse_eof() -> bool { true } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index a562b3804c..54576b44ef 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -149,6 +149,7 @@ pub trait Http1Transaction { fn parse(bytes: &mut BytesMut) -> ParseResult; fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result>; fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> ::Result; + fn on_error(err: &::Error) -> Option>; fn should_error_on_parse_eof() -> bool; fn should_read_first() -> bool; diff --git a/tests/server.rs b/tests/server.rs index 159271802d..d6a70f4590 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -900,6 +900,64 @@ fn returning_1xx_response_is_error() { core.run(fut).unwrap_err(); } +#[test] +fn parse_errors_send_4xx_response() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"GE T / HTTP/1.1\r\n\r\n").unwrap(); + let mut buf = [0; 256]; + tcp.read(&mut buf).unwrap(); + + let expected = "HTTP/1.1 400 "; + assert_eq!(s(&buf[..expected.len()]), expected); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::::new() + .serve_connection(socket, HelloWorld) + .map(|_| ()) + }); + + core.run(fut).unwrap_err(); +} + +#[test] +fn illegal_request_length_returns_400_response() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"POST / HTTP/1.1\r\nContent-Length: foo\r\n\r\n").unwrap(); + let mut buf = [0; 256]; + tcp.read(&mut buf).unwrap(); + + let expected = "HTTP/1.1 400 "; + assert_eq!(s(&buf[..expected.len()]), expected); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::::new() + .serve_connection(socket, HelloWorld) + .map(|_| ()) + }); + + core.run(fut).unwrap_err(); +} + #[test] fn remote_addr() { let server = serve();