diff --git a/src/client/request.rs b/src/client/request.rs index 3f84dda41b..2811f0dd7c 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -101,11 +101,17 @@ impl Request { /// Consume a Fresh Request, writing the headers and method, /// returning a Streaming Request. pub fn start(mut self) -> ::Result> { - let head = try!(self.message.set_outgoing(RequestHead { + let head = match self.message.set_outgoing(RequestHead { headers: self.headers, method: self.method, url: self.url, - })); + }) { + Ok(head) => head, + Err(e) => { + let _ = self.message.close_connection(); + return Err(From::from(e)); + } + }; Ok(Request { method: head.method, @@ -134,17 +140,30 @@ impl Request { impl Write for Request { #[inline] fn write(&mut self, msg: &[u8]) -> io::Result { - self.message.write(msg) + match self.message.write(msg) { + Ok(n) => Ok(n), + Err(e) => { + let _ = self.message.close_connection(); + Err(e) + } + } } #[inline] fn flush(&mut self) -> io::Result<()> { - self.message.flush() + match self.message.flush() { + Ok(r) => Ok(r), + Err(e) => { + let _ = self.message.close_connection(); + Err(e) + } + } } } #[cfg(test)] mod tests { + use std::io::Write; use std::str::from_utf8; use url::Url; use method::Method::{Get, Head, Post}; @@ -237,4 +256,24 @@ mod tests { assert!(!s.contains("Content-Length:")); assert!(s.contains("Transfer-Encoding:")); } + + #[test] + fn test_write_error_closes() { + let url = Url::parse("http://hyper.rs").unwrap(); + let req = Request::with_connector( + Get, url, &mut MockConnector + ).unwrap(); + let mut req = req.start().unwrap(); + + req.message.downcast_mut::().unwrap() + .get_mut().downcast_mut::().unwrap() + .error_on_write = true; + + req.write(b"foo").unwrap(); + assert!(req.flush().is_err()); + + assert!(req.message.downcast_ref::().unwrap() + .get_ref().downcast_ref::().unwrap() + .is_closed); + } } diff --git a/src/client/response.rs b/src/client/response.rs index 46882d42b0..961a84b790 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -37,7 +37,13 @@ impl Response { /// Creates a new response received from the server on the given `HttpMessage`. pub fn with_message(url: Url, mut message: Box) -> ::Result { trace!("Response::with_message"); - let ResponseHead { headers, raw_status, version } = try!(message.get_incoming()); + let ResponseHead { headers, raw_status, version } = match message.get_incoming() { + Ok(head) => head, + Err(e) => { + let _ = message.close_connection(); + return Err(From::from(e)); + } + }; let status = status::StatusCode::from_u16(raw_status.0); debug!("version={:?}, status={:?}", version, status); debug!("headers={:?}", headers); @@ -54,6 +60,7 @@ impl Response { } /// Get the raw status code and reason. + #[inline] pub fn status_raw(&self) -> &RawStatus { &self.status_raw } @@ -68,6 +75,10 @@ impl Read for Response { self.is_drained = true; Ok(0) }, + Err(e) => { + let _ = self.message.close_connection(); + Err(e) + } r => r } } diff --git a/src/mock.rs b/src/mock.rs index 25ae9e8c7c..25f2ca0340 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -2,7 +2,7 @@ use std::fmt; use std::ascii::AsciiExt; use std::io::{self, Read, Write, Cursor}; use std::cell::RefCell; -use std::net::SocketAddr; +use std::net::{SocketAddr, Shutdown}; use std::sync::{Arc, Mutex}; #[cfg(feature = "timeouts")] use std::time::Duration; @@ -21,10 +21,13 @@ use net::{NetworkStream, NetworkConnector}; pub struct MockStream { pub read: Cursor>, pub write: Vec, + pub is_closed: bool, + pub error_on_write: bool, + pub error_on_read: bool, #[cfg(feature = "timeouts")] pub read_timeout: Cell>, #[cfg(feature = "timeouts")] - pub write_timeout: Cell> + pub write_timeout: Cell>, } impl fmt::Debug for MockStream { @@ -48,7 +51,10 @@ impl MockStream { pub fn with_input(input: &[u8]) -> MockStream { MockStream { read: Cursor::new(input.to_vec()), - write: vec![] + write: vec![], + is_closed: false, + error_on_write: false, + error_on_read: false, } } @@ -57,6 +63,9 @@ impl MockStream { MockStream { read: Cursor::new(input.to_vec()), write: vec![], + is_closed: false, + error_on_write: false, + error_on_read: false, read_timeout: Cell::new(None), write_timeout: Cell::new(None), } @@ -65,13 +74,21 @@ impl MockStream { impl Read for MockStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.read.read(buf) + if self.error_on_read { + Err(io::Error::new(io::ErrorKind::Other, "mock error")) + } else { + self.read.read(buf) + } } } impl Write for MockStream { fn write(&mut self, msg: &[u8]) -> io::Result { - Write::write(&mut self.write, msg) + if self.error_on_write { + Err(io::Error::new(io::ErrorKind::Other, "mock error")) + } else { + Write::write(&mut self.write, msg) + } } fn flush(&mut self) -> io::Result<()> { @@ -95,6 +112,11 @@ impl NetworkStream for MockStream { self.write_timeout.set(dur); Ok(()) } + + fn close(&mut self, _how: Shutdown) -> io::Result<()> { + self.is_closed = true; + Ok(()) + } } /// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the @@ -144,6 +166,10 @@ impl NetworkStream for CloneableMockStream { fn set_write_timeout(&self, dur: Option) -> io::Result<()> { self.inner.lock().unwrap().set_write_timeout(dur) } + + fn close(&mut self, how: Shutdown) -> io::Result<()> { + NetworkStream::close(&mut *self.inner.lock().unwrap(), how) + } } impl CloneableMockStream {