Skip to content

Commit

Permalink
fix(server): send 400 responses on parse errors before closing connec…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
seanmonstar committed Jan 23, 2018
1 parent 44c34ce commit 7cb72d2
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/proto/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ where I: AsyncRead + AsyncWrite,
let was_mid_parse = !self.io.read_buf().is_empty();
return if was_mid_parse || must_error {
debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
Err(e)
self.on_parse_error(e)
.map(|()| Async::NotReady)
} else {
debug!("read eof");
Ok(Async::Ready(None))
Expand All @@ -213,7 +214,8 @@ where I: AsyncRead + AsyncWrite,
Err(e) => {
debug!("decoder error = {:?}", e);
self.state.close_read();
return Err(e);
return self.on_parse_error(e)
.map(|()| Async::NotReady);
}
};

Expand Down Expand Up @@ -548,6 +550,27 @@ where I: AsyncRead + AsyncWrite,
Ok(AsyncSink::Ready)
}

// When we get a parse error, depending on what side we are, we might be able
// to write a response before closing the connection.
//
// - Client: there is nothing we can do
// - Server: if Response hasn't been written yet, we can send a 4xx response
fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> {
match self.state.writing {
Writing::Init => {
if let Some(msg) = T::on_error(&err) {
self.write_head(msg, false);
self.state.error = Some(err);
return Ok(());
}
}
_ => (),
}

// fallback is pass the error back up
Err(err)
}

fn write_queued(&mut self) -> Poll<(), io::Error> {
trace!("Conn::write_queued()");
let state = match self.state.writing {
Expand Down
25 changes: 25 additions & 0 deletions src/proto/h1/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,26 @@ impl Http1Transaction for ServerTransaction {
ret
}

fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
let status = match err {
&::Error::Method |
&::Error::Version |
&::Error::Header |
&::Error::Uri(_) => {
StatusCode::BadRequest
},
&::Error::TooLarge => {
StatusCode::RequestHeaderFieldsTooLarge
}
_ => return None,
};

debug!("sending automatic response ({}) for parse error", status);
let mut msg = MessageHead::default();
msg.subject = status;
Some(msg)
}

fn should_error_on_parse_eof() -> bool {
false
}
Expand Down Expand Up @@ -317,6 +337,11 @@ impl Http1Transaction for ClientTransaction {
Ok(body)
}

fn on_error(_err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
// we can't tell the server about any errors it creates
None
}

fn should_error_on_parse_eof() -> bool {
true
}
Expand Down
1 change: 1 addition & 0 deletions src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ pub trait Http1Transaction {
fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>;
fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>;
fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>;
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>>;

fn should_error_on_parse_eof() -> bool;
fn should_read_first() -> bool;
Expand Down
58 changes: 58 additions & 0 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,64 @@ fn returning_1xx_response_is_error() {
core.run(fut).unwrap_err();
}

#[test]
fn parse_errors_send_4xx_response() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(b"GE T / HTTP/1.1\r\n\r\n").unwrap();
let mut buf = [0; 256];
tcp.read(&mut buf).unwrap();

let expected = "HTTP/1.1 400 ";
assert_eq!(s(&buf[..expected.len()]), expected);
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new()
.serve_connection(socket, HelloWorld)
.map(|_| ())
});

core.run(fut).unwrap_err();
}

#[test]
fn illegal_request_length_returns_400_response() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(b"POST / HTTP/1.1\r\nContent-Length: foo\r\n\r\n").unwrap();
let mut buf = [0; 256];
tcp.read(&mut buf).unwrap();

let expected = "HTTP/1.1 400 ";
assert_eq!(s(&buf[..expected.len()]), expected);
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new()
.serve_connection(socket, HelloWorld)
.map(|_| ())
});

core.run(fut).unwrap_err();
}

#[test]
fn remote_addr() {
let server = serve();
Expand Down

0 comments on commit 7cb72d2

Please sign in to comment.