diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index b7c619683c..563c2662ce 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -175,6 +175,13 @@ where } } + #[cfg(feature = "server")] + pub(crate) fn has_initial_read_write_state(&self) -> bool { + matches!(self.state.reading, Reading::Init) + && matches!(self.state.writing, Writing::Init) + && self.io.read_buf().is_empty() + } + fn should_error_on_eof(&self) -> bool { // If we're idle, it's probably just the connection closing gracefully. T::should_error_on_parse_eof() && !self.state.is_idle() diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 6141b296f8..32ef001f11 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -82,7 +82,11 @@ where #[cfg(feature = "server")] pub(crate) fn disable_keep_alive(&mut self) { self.conn.disable_keep_alive(); - if self.conn.is_write_closed() { + + // If keep alive has been disabled and no read or write has been seen on + // the connection yet, we must be in a state where the server is being asked to + // shut down before any data has been seen on the connection + if self.conn.is_write_closed() || self.conn.has_initial_read_write_state() { self.close(); } } diff --git a/tests/server.rs b/tests/server.rs index 7a1a5dd430..b412de038d 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -31,6 +31,7 @@ use hyper::body::{Body, Incoming as IncomingBody}; use hyper::server::conn::{http1, http2}; use hyper::service::{service_fn, Service}; use hyper::{Method, Request, Response, StatusCode, Uri, Version}; +use tokio::pin; mod support; @@ -1139,11 +1140,17 @@ async fn disable_keep_alive_mid_request() { let child = thread::spawn(move || { let mut req = connect(&addr); req.write_all(b"GET / HTTP/1.1\r\n").unwrap(); + thread::sleep(Duration::from_millis(10)); tx1.send(()).unwrap(); rx2.recv().unwrap(); req.write_all(b"Host: localhost\r\n\r\n").unwrap(); let mut buf = vec![]; req.read_to_end(&mut buf).unwrap(); + assert!( + buf.starts_with(b"HTTP/1.1 200 OK\r\n"), + "should receive OK response, but buf: {:?}", + buf, + ); }); let (socket, _) = listener.accept().await.unwrap(); @@ -2152,6 +2159,31 @@ async fn max_buf_size() { .expect_err("should TooLarge error"); } +#[cfg(feature = "http1")] +#[tokio::test] +async fn graceful_shutdown_before_first_request_no_block() { + let (listener, addr) = setup_tcp_listener(); + + tokio::spawn(async move { + let socket = listener.accept().await.unwrap().0; + + let future = http1::Builder::new().serve_connection(socket, HelloWorld); + pin!(future); + future.as_mut().graceful_shutdown(); + + future.await.unwrap(); + }); + + let mut stream = TkTcpStream::connect(addr).await.unwrap(); + + let mut buf = vec![]; + + tokio::time::timeout(Duration::from_secs(5), stream.read_to_end(&mut buf)) + .await + .expect("timed out waiting for graceful shutdown") + .expect("error receiving response"); +} + #[test] fn streaming_body() { use futures_util::StreamExt;