From c82390ba54112c8d1accd372b7fcf94a6ec18afc Mon Sep 17 00:00:00 2001
From: Diggory Blake <diggsey@googlemail.com>
Date: Tue, 2 Feb 2021 01:38:42 +0000
Subject: [PATCH] Implement tests to expose sequencing bugs

---
 src/server/mod.rs   |   2 +
 tests/accept.rs     |  42 +++++++++-
 tests/continue.rs   | 187 +++++++++++++++++++++++++++++++++++++++++++-
 tests/test_utils.rs |   2 +-
 4 files changed, 227 insertions(+), 6 deletions(-)

diff --git a/src/server/mod.rs b/src/server/mod.rs
index 1cfa4e9..71c7787 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -162,6 +162,8 @@ where
         let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
         log::trace!("wrote {} response bytes", bytes_written);
 
+        async_std::task::sleep(Duration::from_millis(1)).await;
+
         let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
         log::trace!(
             "discarded {} unread request body bytes",
diff --git a/tests/accept.rs b/tests/accept.rs
index 92283a8..a57ff0c 100644
--- a/tests/accept.rs
+++ b/tests/accept.rs
@@ -1,7 +1,10 @@
 mod test_utils;
 mod accept {
+    use std::time::Duration;
+
     use super::test_utils::TestServer;
     use async_h1::{client::Encoder, server::ConnectionStatus};
+    use async_std::future::timeout;
     use async_std::io::{self, prelude::WriteExt, Cursor};
     use http_types::{headers::CONNECTION, Body, Request, Response, Result};
 
@@ -17,7 +20,7 @@ mod accept {
         let content_length = 10;
 
         let request_str = format!(
-            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
+            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
             content_length,
             std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
         );
@@ -33,6 +36,39 @@ mod accept {
         Ok(())
     }
 
+    #[async_std::test]
+    async fn pipelined() -> Result<()> {
+        let mut server = TestServer::new(|req| async {
+            let mut response = Response::new(200);
+            let len = req.len();
+            response.set_body(Body::from_reader(req, len));
+            Ok(response)
+        });
+
+        let content_length = 10;
+
+        let request_str = format!(
+            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
+            content_length,
+            std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
+        );
+
+        server.write_all(request_str.as_bytes()).await?;
+        server.write_all(request_str.as_bytes()).await?;
+        assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
+        assert_eq!(
+            timeout(Duration::from_secs(1), server.accept_one()).await??,
+            ConnectionStatus::KeepAlive
+        );
+
+        server.close();
+        assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
+
+        assert!(server.all_read());
+
+        Ok(())
+    }
+
     #[async_std::test]
     async fn request_close() -> Result<()> {
         let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
@@ -74,7 +110,7 @@ mod accept {
         let content_length = 10;
 
         let request_str = format!(
-            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
+            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
             content_length,
             std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
         );
@@ -130,7 +166,7 @@ mod accept {
         let content_length = 10000;
 
         let request_str = format!(
-            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
+            "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
             content_length,
             std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
         );
diff --git a/tests/continue.rs b/tests/continue.rs
index 933fbfe..ad54ea8 100644
--- a/tests/continue.rs
+++ b/tests/continue.rs
@@ -1,9 +1,12 @@
 mod test_utils;
 
+use async_h1::server::ConnectionStatus;
+use async_std::future::timeout;
+use async_std::io::BufReader;
 use async_std::{io, prelude::*, task};
-use http_types::Result;
+use http_types::{Response, Result};
 use std::time::Duration;
-use test_utils::TestIO;
+use test_utils::{TestIO, TestServer};
 
 const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
 Host: example.com\r\n\
@@ -52,3 +55,183 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> {
 
     Ok(())
 }
+
+#[async_std::test]
+async fn test_accept_unread_body() -> Result<()> {
+    let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
+
+    server.write_all(REQUEST_WITH_EXPECT).await?;
+    assert_eq!(
+        timeout(Duration::from_secs(1), server.accept_one()).await??,
+        ConnectionStatus::KeepAlive
+    );
+
+    server.write_all(REQUEST_WITH_EXPECT).await?;
+    assert_eq!(
+        timeout(Duration::from_secs(1), server.accept_one()).await??,
+        ConnectionStatus::KeepAlive
+    );
+
+    server.close();
+    assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
+
+    assert!(server.all_read());
+
+    Ok(())
+}
+
+#[async_std::test]
+async fn test_echo_server() -> Result<()> {
+    let mut server = TestServer::new(|mut req| async move {
+        let mut resp = Response::new(200);
+        resp.set_body(req.take_body());
+        Ok(resp)
+    });
+
+    server.write_all(REQUEST_WITH_EXPECT).await?;
+    server.write_all(b"0123456789").await?;
+    assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
+
+    task::sleep(SLEEP_DURATION).await; // wait for "continue" to be sent
+
+    server.close();
+
+    assert!(server
+        .client
+        .read
+        .to_string()
+        .starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
+
+    assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
+
+    assert!(server.all_read());
+
+    Ok(())
+}
+
+#[async_std::test]
+async fn test_delayed_read() -> Result<()> {
+    let mut server = TestServer::new(|mut req| async move {
+        let mut body = req.take_body();
+        task::spawn(async move {
+            let mut buf = Vec::new();
+            body.read_to_end(&mut buf).await.unwrap();
+        });
+        Ok(Response::new(200))
+    });
+
+    server.write_all(REQUEST_WITH_EXPECT).await?;
+    assert_eq!(
+        timeout(Duration::from_secs(1), server.accept_one()).await??,
+        ConnectionStatus::KeepAlive
+    );
+    server.write_all(b"0123456789").await?;
+
+    server.write_all(REQUEST_WITH_EXPECT).await?;
+    assert_eq!(
+        timeout(Duration::from_secs(1), server.accept_one()).await??,
+        ConnectionStatus::KeepAlive
+    );
+    server.write_all(b"0123456789").await?;
+
+    server.close();
+    assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
+
+    assert!(server.all_read());
+
+    Ok(())
+}
+
+#[async_std::test]
+async fn test_accept_fast_unread_sequential_requests() -> Result<()> {
+    let mut server = TestServer::new(|_| async move { Ok(Response::new(200)) });
+    let mut client = server.client.clone();
+
+    task::spawn(async move {
+        let mut reader = BufReader::new(client.clone());
+        for _ in 0..10 {
+            let mut buf = String::new();
+            client.write_all(REQUEST_WITH_EXPECT).await.unwrap();
+
+            while !buf.ends_with("\r\n\r\n") {
+                reader.read_line(&mut buf).await.unwrap();
+            }
+
+            assert!(buf.starts_with("HTTP/1.1 200 OK\r\n"));
+        }
+        client.close();
+    });
+
+    for _ in 0..10 {
+        assert_eq!(
+            timeout(Duration::from_secs(1), server.accept_one()).await??,
+            ConnectionStatus::KeepAlive
+        );
+    }
+
+    assert_eq!(server.accept_one().await?, ConnectionStatus::Close);
+
+    assert!(server.all_read());
+
+    Ok(())
+}
+
+#[async_std::test]
+async fn test_accept_partial_read_sequential_requests() -> Result<()> {
+    const LARGE_REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
+        Host: example.com\r\n\
+        Content-Length: 1000\r\n\
+        Expect: 100-continue\r\n\r\n";
+
+    let mut server = TestServer::new(|mut req| async move {
+        let mut body = req.take_body();
+        let mut buf = [0];
+        body.read(&mut buf).await.unwrap();
+        Ok(Response::new(200))
+    });
+    let mut client = server.client.clone();
+
+    task::spawn(async move {
+        let mut reader = BufReader::new(client.clone());
+        for _ in 0..10 {
+            let mut buf = String::new();
+            client.write_all(LARGE_REQUEST_WITH_EXPECT).await.unwrap();
+
+            // Wait for body to be requested
+            while !buf.ends_with("\r\n\r\n") {
+                reader.read_line(&mut buf).await.unwrap();
+            }
+            assert!(buf.starts_with("HTTP/1.1 100 Continue\r\n"));
+
+            // Write body
+            for _ in 0..100 {
+                client.write_all(b"0123456789").await.unwrap();
+            }
+
+            // Wait for response
+            buf.clear();
+            while !buf.ends_with("\r\n\r\n") {
+                reader.read_line(&mut buf).await.unwrap();
+            }
+
+            assert!(buf.starts_with("HTTP/1.1 200 OK\r\n"));
+        }
+        client.close();
+    });
+
+    for _ in 0..10 {
+        assert_eq!(
+            timeout(Duration::from_secs(1), server.accept_one()).await??,
+            ConnectionStatus::KeepAlive
+        );
+    }
+
+    assert_eq!(
+        timeout(Duration::from_secs(1), server.accept_one()).await??,
+        ConnectionStatus::Close
+    );
+
+    assert!(server.all_read());
+
+    Ok(())
+}
diff --git a/tests/test_utils.rs b/tests/test_utils.rs
index 8194590..034d4cd 100644
--- a/tests/test_utils.rs
+++ b/tests/test_utils.rs
@@ -19,7 +19,7 @@ use async_dup::Arc;
 pub struct TestServer<F, Fut> {
     server: Server<TestIO, F, Fut>,
     #[pin]
-    client: TestIO,
+    pub(crate) client: TestIO,
 }
 
 impl<F, Fut> TestServer<F, Fut>