diff --git a/src/client/mod.rs b/src/client/mod.rs index c3efaff415..1dc3ddfea5 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -249,8 +249,12 @@ where C: Connect, let pool_key = Rc::new(domain.to_string()); self.connector.connect(url) .map(move |io| { - let (tx, rx) = mpsc::channel(1); - let pooled = pool.pooled(pool_key, RefCell::new(tx)); + let (tx, rx) = mpsc::channel(0); + let tx = HyperClient { + tx: RefCell::new(tx), + should_close: true, + }; + let pooled = pool.pooled(pool_key, tx); let conn = proto::Conn::<_, _, proto::ClientTransaction, _>::new(io, pooled.clone()); let dispatch = proto::dispatch::Dispatcher::new(proto::dispatch::Client::new(rx), conn); handle.spawn(dispatch.map_err(|err| error!("no_proto error: {}", err))); @@ -269,9 +273,10 @@ where C: Connect, e.into() }); - let resp = race.and_then(move |client| { + let resp = race.and_then(move |mut client| { let (callback, rx) = oneshot::channel(); - client.borrow_mut().start_send((head, body, callback)).unwrap(); + client.tx.borrow_mut().start_send(proto::dispatch::ClientMsg::Request(head, body, callback)).unwrap(); + client.should_close = false; rx.then(|res| { match res { Ok(Ok(res)) => Ok(res), @@ -309,7 +314,29 @@ impl fmt::Debug for Client { } type ProtoClient = ClientProxy, Message, ::Error>; -type HyperClient = RefCell<::futures::sync::mpsc::Sender<(RequestHead, Option, ::futures::sync::oneshot::Sender<::Result<::Response>>)>>; + +struct HyperClient { + tx: RefCell<::futures::sync::mpsc::Sender>>, + should_close: bool, +} + +impl Clone for HyperClient { + fn clone(&self) -> HyperClient { + HyperClient { + tx: self.tx.clone(), + should_close: self.should_close, + } + } +} + +impl Drop for HyperClient { + fn drop(&mut self) { + if self.should_close { + self.should_close = false; + let _ = self.tx.borrow_mut().try_send(proto::dispatch::ClientMsg::Close); + } + } +} enum Dispatch { Proto(Pool>), diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 205334c7e5..2bdcef758e 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -511,6 +511,11 @@ where I: AsyncRead + AsyncWrite, } + pub fn close_and_shutdown(&mut self) -> Poll<(), io::Error> { + try_ready!(self.flush()); + self.shutdown() + } + pub fn shutdown(&mut self) -> Poll<(), io::Error> { match self.io.io_mut().shutdown() { Ok(Async::NotReady) => Ok(Async::NotReady), @@ -625,8 +630,7 @@ where I: AsyncRead + AsyncWrite, #[inline] fn close(&mut self) -> Poll<(), Self::SinkError> { - try_ready!(self.poll_complete()); - self.shutdown() + self.close_and_shutdown() } } diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs index b7de2f0858..0133f7989d 100644 --- a/src/proto/dispatch.rs +++ b/src/proto/dispatch.rs @@ -13,6 +13,7 @@ pub struct Dispatcher { dispatch: D, body_tx: Option, body_rx: Option, + is_closing: bool, } pub trait Dispatch { @@ -34,7 +35,12 @@ pub struct Client { rx: ClientRx, } -type ClientRx = mpsc::Receiver<(RequestHead, Option, oneshot::Sender<::Result<::Response>>)>; +pub enum ClientMsg { + Request(RequestHead, Option, oneshot::Sender<::Result<::Response>>), + Close, +} + +type ClientRx = mpsc::Receiver>; impl Dispatcher where @@ -51,6 +57,7 @@ where dispatch: dispatch, body_tx: None, body_rx: None, + is_closing: false, } } @@ -60,7 +67,9 @@ where fn poll_read(&mut self) -> Poll<(), ::Error> { loop { - if self.conn.can_read_head() { + if self.is_closing { + return Ok(Async::Ready(())); + } else if self.conn.can_read_head() { match self.conn.read_head() { Ok(Async::Ready(Some((head, has_body)))) => { let body = if has_body { @@ -149,12 +158,16 @@ where fn poll_write(&mut self) -> Poll<(), ::Error> { loop { - if self.body_rx.is_none() && self.dispatch.should_poll() { + if self.is_closing { + return Ok(Async::Ready(())); + } else if self.body_rx.is_none() && self.dispatch.should_poll() { if let Some((head, body)) = try_ready!(self.dispatch.poll_msg()) { self.conn.write_head(head, body.is_some()); self.body_rx = body; } else { - self.conn.close_write(); + self.is_closing = true; + //self.conn.close_read(); + //self.conn.close_write(); return Ok(Async::Ready(())); } } else if self.conn.has_queued_body() { @@ -190,6 +203,16 @@ where }) } + fn poll_close(&mut self) -> Poll<(), ::Error> { + debug_assert!(self.is_closing); + + try_ready!(self.conn.close_and_shutdown()); + self.conn.close_read(); + self.conn.close_write(); + self.is_closing = false; + Ok(Async::Ready(())) + } + fn is_done(&self) -> bool { let read_done = self.conn.is_read_closed(); @@ -224,6 +247,10 @@ where self.poll_write()?; self.poll_flush()?; + if self.is_closing { + self.poll_close()?; + } + if self.is_done() { try_ready!(self.conn.shutdown()); trace!("Dispatch::poll done"); @@ -285,6 +312,7 @@ where // ===== impl Client ===== + impl Client { pub fn new(rx: ClientRx) -> Client { Client { @@ -305,11 +333,13 @@ where fn poll_msg(&mut self) -> Poll)>, ::Error> { match self.rx.poll() { - Ok(Async::Ready(Some((head, body, cb)))) => { + Ok(Async::Ready(Some(ClientMsg::Request(head, body, cb)))) => { self.callback = Some(cb); Ok(Async::Ready(Some((head, body)))) }, + Ok(Async::Ready(Some(ClientMsg::Close))) | Ok(Async::Ready(None)) => { + trace!("client tx closed"); // user has dropped sender handle Ok(Async::Ready(None)) }, diff --git a/tests/client.rs b/tests/client.rs index 17db0748b7..f86b8f78f0 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -607,7 +607,7 @@ mod dispatch_impl { } #[test] - fn drop_client_closes_connection() { + fn dropped_client_closes_connection() { // https://github.com/hyperium/hyper/issues/1353 let _ = pretty_env_logger::init(); @@ -653,6 +653,57 @@ mod dispatch_impl { assert_eq!(closes.load(Ordering::Relaxed), 1); } + + #[test] + fn drop_client_closes_idle_connections() { + let _ = pretty_env_logger::init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + + let (tx1, rx1) = oneshot::channel(); + let (_client_drop_tx, client_drop_rx) = oneshot::channel::<()>(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + let body =[b'x'; 64]; + write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head"); + let _ = sock.write_all(&body); + let _ = tx1.send(()); + + // prevent this thread from closing until end of test, so the connection + // stays open and idle until Client is dropped + let _ = client_drop_rx.wait(); + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone())) + .no_proto() + .build(&handle); + let res = client.get(uri).and_then(move |res| { + assert_eq!(res.status(), hyper::StatusCode::Ok); + res.body().concat2() + }); + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + + // not closed yet, just idle + assert_eq!(closes.load(Ordering::Relaxed), 0); + drop(client); + core.run(Timeout::new(Duration::from_millis(100), &handle).unwrap()).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + } + #[test] fn no_keep_alive_closes_connection() { // https://github.com/hyperium/hyper/issues/1383