From 3674759837a4c6e0286a8af8dd5172d8a2d97650 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 17 May 2023 17:18:30 -0400 Subject: [PATCH] feat(rt): replace IO traits with hyper::rt ones --- benches/support/mod.rs | 2 +- benches/support/tokiort.rs | 146 ++++++++++++++++ examples/client.rs | 7 +- examples/client_json.rs | 7 +- examples/echo.rs | 7 +- examples/gateway.rs | 14 +- examples/hello.rs | 11 +- examples/http_proxy.rs | 13 +- examples/multi_server.rs | 10 +- examples/params.rs | 7 +- examples/send_file.rs | 7 +- examples/service_struct_impl.rs | 7 +- examples/single_threaded.rs | 7 +- examples/state.rs | 10 +- examples/upgrades.rs | 16 +- examples/web_api.rs | 13 +- src/client/conn/http1.rs | 2 +- src/client/conn/http2.rs | 2 +- src/client/conn/mod.rs | 68 ++++---- src/common/io/compat.rs | 150 +++++++++++++++++ src/common/io/mod.rs | 4 + src/common/io/rewind.rs | 13 +- src/ffi/io.rs | 4 +- src/proto/h1/conn.rs | 2 +- src/proto/h1/decode.rs | 8 +- src/proto/h1/dispatch.rs | 9 +- src/proto/h1/io.rs | 13 +- src/proto/h2/client.rs | 4 +- src/proto/h2/mod.rs | 26 +-- src/proto/h2/server.rs | 8 +- src/rt/io.rs | 285 ++++++++++++++++++++++++++++++++ src/rt/mod.rs | 4 + src/server/conn/http1.rs | 4 +- src/server/conn/http2.rs | 2 +- src/server/conn/mod.rs | 37 ----- src/upgrade.rs | 10 +- tests/client.rs | 32 ++-- tests/server.rs | 65 +++++--- tests/support/mod.rs | 22 ++- 39 files changed, 861 insertions(+), 197 deletions(-) create mode 100644 src/common/io/compat.rs create mode 100644 src/rt/io.rs diff --git a/benches/support/mod.rs b/benches/support/mod.rs index 48e8048e8b..85cb67fd33 100644 --- a/benches/support/mod.rs +++ b/benches/support/mod.rs @@ -1,2 +1,2 @@ mod tokiort; -pub use tokiort::{TokioExecutor, TokioTimer}; +pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; diff --git a/benches/support/tokiort.rs b/benches/support/tokiort.rs index 67ae3a91aa..1bb27e17e6 100644 --- a/benches/support/tokiort.rs +++ b/benches/support/tokiort.rs @@ -79,3 +79,149 @@ impl Future for TokioSleep { // see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html impl Sleep for TokioSleep {} + +pin_project! { + #[derive(Debug)] + pub struct TokioIo { + #[pin] + inner: T, + } +} + +impl TokioIo { + pub fn new(inner: T) -> Self { + Self { inner } + } + + pub fn inner(self) -> T { + self.inner + } +} + +impl hyper::rt::AsyncRead for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::AsyncWrite for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl tokio::io::AsyncRead for TokioIo +where + T: hyper::rt::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::AsyncRead::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for TokioIo +where + T: hyper::rt::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/examples/client.rs b/examples/client.rs index ffcc026719..046f59de02 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -8,6 +8,10 @@ use hyper::Request; use tokio::io::{self, AsyncWriteExt as _}; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; @@ -40,8 +44,9 @@ async fn fetch_url(url: hyper::Uri) -> Result<()> { let port = url.port_u16().unwrap_or(80); let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/client_json.rs b/examples/client_json.rs index 4ba6787a6e..6a6753528c 100644 --- a/examples/client_json.rs +++ b/examples/client_json.rs @@ -7,6 +7,10 @@ use hyper::{body::Buf, Request}; use serde::Deserialize; use tokio::net::TcpStream; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; @@ -29,8 +33,9 @@ async fn fetch_json(url: hyper::Uri) -> Result> { let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); diff --git a/examples/echo.rs b/examples/echo.rs index 7d3478a666..60d03b368d 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -10,6 +10,10 @@ use hyper::service::service_fn; use hyper::{body::Body, Method, Request, Response, StatusCode}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + /// This is our service handler. It receives a Request, routes on its /// path, and returns a Future of a Response. async fn echo( @@ -92,10 +96,11 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(echo)) + .serve_connection(io, service_fn(echo)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/gateway.rs b/examples/gateway.rs index 907f2fdba2..e0e3e053d0 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -4,6 +4,10 @@ use hyper::{server::conn::http1, service::service_fn}; use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + #[tokio::main] async fn main() -> Result<(), Box> { pretty_env_logger::init(); @@ -20,6 +24,7 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // This is the `Service` that will handle the connection. // `service_fn` is a helper to convert a function that @@ -42,9 +47,9 @@ async fn main() -> Result<(), Box> { async move { let client_stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(client_stream); - let (mut sender, conn) = - hyper::client::conn::http1::handshake(client_stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { println!("Connection failed: {:?}", err); @@ -56,10 +61,7 @@ async fn main() -> Result<(), Box> { }); tokio::task::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Failed to serve the connection: {:?}", err); } }); diff --git a/examples/hello.rs b/examples/hello.rs index a11199adb8..d9d6b8c4c7 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -10,6 +10,10 @@ use hyper::service::service_fn; use hyper::{Request, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // An async function that consumes a request, does nothing with it and returns a // response. async fn hello(_: Request) -> Result>, Infallible> { @@ -35,7 +39,10 @@ pub async fn main() -> Result<(), Box> { // has work to do. In this case, a connection arrives on the port we are listening on and // the task is woken up, at which point the task is then put back on a thread, and is // driven forward by the runtime, eventually yielding a TCP stream. - let (stream, _) = listener.accept().await?; + let (tcp, _) = listener.accept().await?; + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(tcp); // Spin up a new task in Tokio so we can continue to listen for new TCP connection on the // current task without waiting for the processing of the HTTP1 connection we just received @@ -44,7 +51,7 @@ pub async fn main() -> Result<(), Box> { // Handle the connection from the client using HTTP1 and pass any // HTTP requests received on that connection to the `hello` function if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(hello)) + .serve_connection(io, service_fn(hello)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/http_proxy.rs b/examples/http_proxy.rs index 0b4a6818b8..c36cc23778 100644 --- a/examples/http_proxy.rs +++ b/examples/http_proxy.rs @@ -12,6 +12,10 @@ use hyper::{Method, Request, Response}; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // To try this example: // 1. cargo run --example http_proxy // 2. config http_proxy in command line @@ -28,12 +32,13 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection(stream, service_fn(proxy)) + .serve_connection(io, service_fn(proxy)) .with_upgrades() .await { @@ -88,11 +93,12 @@ async fn proxy( let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); let (mut sender, conn) = Builder::new() .preserve_header_case(true) .title_case_headers(true) - .handshake(stream) + .handshake(io) .await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -123,9 +129,10 @@ fn full>(chunk: T) -> BoxBody { // Create a TCP connection to host:port, build a tunnel between the connection and // the upgraded connection -async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> { +async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { // Connect to remote server let mut server = TcpStream::connect(addr).await?; + let mut upgraded = TokioIo::new(upgraded); // Proxying data let (from_client, from_server) = diff --git a/examples/multi_server.rs b/examples/multi_server.rs index 5eb520dbdb..51e6c39ca7 100644 --- a/examples/multi_server.rs +++ b/examples/multi_server.rs @@ -11,6 +11,10 @@ use hyper::service::service_fn; use hyper::{Request, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX1: &[u8] = b"The 1st service!"; static INDEX2: &[u8] = b"The 2nd service!"; @@ -33,10 +37,11 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::bind(addr1).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(index1)) + .serve_connection(io, service_fn(index1)) .await { println!("Error serving connection: {:?}", err); @@ -49,10 +54,11 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::bind(addr2).await.unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(index2)) + .serve_connection(io, service_fn(index2)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/params.rs b/examples/params.rs index a902867f2e..3ba39326a1 100644 --- a/examples/params.rs +++ b/examples/params.rs @@ -13,6 +13,10 @@ use std::convert::Infallible; use std::net::SocketAddr; use url::form_urlencoded; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX: &[u8] = b"
Name:
Number:
"; static MISSING: &[u8] = b"Missing field"; static NOTNUMERIC: &[u8] = b"Number field is not numeric"; @@ -124,10 +128,11 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(param_example)) + .serve_connection(io, service_fn(param_example)) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/send_file.rs b/examples/send_file.rs index a4514eb52b..ec489ec34f 100644 --- a/examples/send_file.rs +++ b/examples/send_file.rs @@ -10,6 +10,10 @@ use http_body_util::Full; use hyper::service::service_fn; use hyper::{Method, Request, Response, Result, StatusCode}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + static INDEX: &str = "examples/send_file_index.html"; static NOTFOUND: &[u8] = b"Not Found"; @@ -24,10 +28,11 @@ async fn main() -> std::result::Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() - .serve_connection(stream, service_fn(response_examples)) + .serve_connection(io, service_fn(response_examples)) .await { println!("Failed to serve connection: {:?}", err); diff --git a/examples/service_struct_impl.rs b/examples/service_struct_impl.rs index 22cc2407dd..fc0f79356c 100644 --- a/examples/service_struct_impl.rs +++ b/examples/service_struct_impl.rs @@ -10,6 +10,10 @@ use std::net::SocketAddr; use std::pin::Pin; use std::sync::Mutex; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + type Counter = i32; #[tokio::main] @@ -21,11 +25,12 @@ async fn main() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() .serve_connection( - stream, + io, Svc { counter: Mutex::new(81818), }, diff --git a/examples/single_threaded.rs b/examples/single_threaded.rs index ee109d54fa..6757d294a7 100644 --- a/examples/single_threaded.rs +++ b/examples/single_threaded.rs @@ -13,6 +13,10 @@ use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + struct Body { // Our Body type is !Send and !Sync: _marker: PhantomData<*const ()>, @@ -64,6 +68,7 @@ async fn run() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // For each connection, clone the counter to use in our service... let cnt = counter.clone(); @@ -77,7 +82,7 @@ async fn run() -> Result<(), Box> { tokio::task::spawn_local(async move { if let Err(err) = http2::Builder::new(LocalExec) - .serve_connection(stream, service) + .serve_connection(io, service) .await { println!("Error serving connection: {:?}", err); diff --git a/examples/state.rs b/examples/state.rs index 7d060efe1d..5263efdadc 100644 --- a/examples/state.rs +++ b/examples/state.rs @@ -12,6 +12,10 @@ use hyper::{server::conn::http1, service::service_fn}; use hyper::{Error, Response}; use tokio::net::TcpListener; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + #[tokio::main] async fn main() -> Result<(), Box> { pretty_env_logger::init(); @@ -26,6 +30,7 @@ async fn main() -> Result<(), Box> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); // Each connection could send multiple requests, so // the `Service` needs a clone to handle later requests. @@ -46,10 +51,7 @@ async fn main() -> Result<(), Box> { } }); - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Error serving connection: {:?}", err); } } diff --git a/examples/upgrades.rs b/examples/upgrades.rs index 92a80d7567..f9754e5d49 100644 --- a/examples/upgrades.rs +++ b/examples/upgrades.rs @@ -16,11 +16,16 @@ use hyper::service::service_fn; use hyper::upgrade::Upgraded; use hyper::{Request, Response, StatusCode}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + // A simple type alias so as to DRY. type Result = std::result::Result>; /// Handle server-side I/O after HTTP upgraded. -async fn server_upgraded_io(mut upgraded: Upgraded) -> Result<()> { +async fn server_upgraded_io(upgraded: Upgraded) -> Result<()> { + let mut upgraded = TokioIo::new(upgraded); // we have an upgraded connection that we can read and // write on directly. // @@ -75,7 +80,8 @@ async fn server_upgrade(mut req: Request) -> Result Result<()> { +async fn client_upgraded_io(upgraded: Upgraded) -> Result<()> { + let mut upgraded = TokioIo::new(upgraded); // We've gotten an upgraded connection that we can read // and write directly on. Let's start out 'foobar' protocol. upgraded.write_all(b"foo=bar").await?; @@ -97,7 +103,8 @@ async fn client_upgrade_request(addr: SocketAddr) -> Result<()> { .unwrap(); let stream = TcpStream::connect(addr).await?; - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -146,10 +153,11 @@ async fn main() { tokio::select! { res = listener.accept() => { let (stream, _) = res.expect("Failed to accept"); + let io = TokioIo::new(stream); let mut rx = rx.clone(); tokio::task::spawn(async move { - let conn = http1::Builder::new().serve_connection(stream, service_fn(server_upgrade)); + let conn = http1::Builder::new().serve_connection(io, service_fn(server_upgrade)); // Don't forget to enable upgrades on the connection. let mut conn = conn.with_upgrades(); diff --git a/examples/web_api.rs b/examples/web_api.rs index 79834a0acd..91d9e9b72f 100644 --- a/examples/web_api.rs +++ b/examples/web_api.rs @@ -9,6 +9,10 @@ use hyper::service::service_fn; use hyper::{body::Incoming as IncomingBody, header, Method, Request, Response, StatusCode}; use tokio::net::{TcpListener, TcpStream}; +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + type GenericError = Box; type Result = std::result::Result; type BoxBody = http_body_util::combinators::BoxBody; @@ -30,8 +34,9 @@ async fn client_request_response() -> Result> { let host = req.uri().host().expect("uri has no host"); let port = req.uri().port_u16().expect("uri has no port"); let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?; + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -109,14 +114,12 @@ async fn main() -> Result<()> { println!("Listening on http://{}", addr); loop { let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); tokio::task::spawn(async move { let service = service_fn(move |req| response_examples(req)); - if let Err(err) = http1::Builder::new() - .serve_connection(stream, service) - .await - { + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { println!("Failed to serve connection: {:?}", err); } }); diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index ed87a991f9..0f72545228 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -6,7 +6,7 @@ use std::fmt; use bytes::Bytes; use http::{Request, Response}; use httparse::ParserConfig; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; diff --git a/src/client/conn/http2.rs b/src/client/conn/http2.rs index c45b67dffd..c5e5fe4d9c 100644 --- a/src/client/conn/http2.rs +++ b/src/client/conn/http2.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use std::time::Duration; use http::{Request, Response}; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; diff --git a/src/client/conn/mod.rs b/src/client/conn/mod.rs index f60bce4080..d099c8ae27 100644 --- a/src/client/conn/mod.rs +++ b/src/client/conn/mod.rs @@ -9,7 +9,9 @@ //! higher-level [Client](super) API. //! //! ## Example -//! A simple example that uses the `SendRequest` struct to talk HTTP over a Tokio TCP stream +//! +//! A simple example that uses the `SendRequest` struct to talk HTTP over some TCP stream. +//! //! ```no_run //! # #[cfg(all(feature = "client", feature = "http1"))] //! # mod rt { @@ -17,38 +19,38 @@ //! use http::{Request, StatusCode}; //! use http_body_util::Empty; //! use hyper::client::conn; -//! use tokio::net::TcpStream; -//! -//! #[tokio::main] -//! async fn main() -> Result<(), Box> { -//! let target_stream = TcpStream::connect("example.com:80").await?; -//! -//! let (mut request_sender, connection) = conn::http1::handshake(target_stream).await?; -//! -//! // spawn a task to poll the connection and drive the HTTP state -//! tokio::spawn(async move { -//! if let Err(e) = connection.await { -//! eprintln!("Error in connection: {}", e); -//! } -//! }); -//! -//! let request = Request::builder() -//! // We need to manually add the host header because SendRequest does not -//! .header("Host", "example.com") -//! .method("GET") -//! .body(Empty::::new())?; -//! let response = request_sender.send_request(request).await?; -//! assert!(response.status() == StatusCode::OK); -//! -//! let request = Request::builder() -//! .header("Host", "example.com") -//! .method("GET") -//! .body(Empty::::new())?; -//! let response = request_sender.send_request(request).await?; -//! assert!(response.status() == StatusCode::OK); -//! Ok(()) -//! } -//! +//! # use hyper::rt::{AsyncRead, AsyncWrite}; +//! # async fn run(tcp: I) -> Result<(), Box> +//! # where +//! # I: AsyncRead + AsyncWrite + Unpin + Send + 'static, +//! # { +//! let (mut request_sender, connection) = conn::http1::handshake(tcp).await?; +//! +//! // spawn a task to poll the connection and drive the HTTP state +//! tokio::spawn(async move { +//! if let Err(e) = connection.await { +//! eprintln!("Error in connection: {}", e); +//! } +//! }); +//! +//! let request = Request::builder() +//! // We need to manually add the host header because SendRequest does not +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Empty::::new())?; +//! +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! +//! let request = Request::builder() +//! .header("Host", "example.com") +//! .method("GET") +//! .body(Empty::::new())?; +//! +//! let response = request_sender.send_request(request).await?; +//! assert!(response.status() == StatusCode::OK); +//! # Ok(()) +//! # } //! # } //! ``` diff --git a/src/common/io/compat.rs b/src/common/io/compat.rs new file mode 100644 index 0000000000..f83e87647a --- /dev/null +++ b/src/common/io/compat.rs @@ -0,0 +1,150 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// This adapts from `hyper` IO traits to the ones in Tokio. +/// +/// This is currently used by `h2`, and by hyper internal unit tests. +#[derive(Debug)] +pub(crate) struct Compat(pub(crate) T); + +pub(crate) fn compat(io: T) -> Compat { + Compat(io) +} + +impl Compat { + fn p(self: Pin<&mut Self>) -> Pin<&mut T> { + // SAFETY: The simplest of projections. This is just + // a wrapper, we don't do anything that would undo the projection. + unsafe { self.map_unchecked_mut(|me| &mut me.0) } + } +} + +impl tokio::io::AsyncRead for Compat +where + T: crate::rt::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let (new_init, new_filled) = unsafe { + let mut buf = crate::rt::ReadBuf::uninit(tbuf.inner_mut()); + buf.set_init(init); + buf.set_filled(filled); + + match crate::rt::AsyncRead::poll_read(self.p(), cx, buf.unfilled()) { + Poll::Ready(Ok(())) => (buf.init_len(), buf.len()), + other => return other, + } + }; + + let n_init = new_init - init; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(new_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for Compat +where + T: crate::rt::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + crate::rt::AsyncWrite::poll_write(self.p(), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + crate::rt::AsyncWrite::poll_flush(self.p(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + crate::rt::AsyncWrite::poll_shutdown(self.p(), cx) + } + + fn is_write_vectored(&self) -> bool { + crate::rt::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + crate::rt::AsyncWrite::poll_write_vectored(self.p(), cx, bufs) + } +} + +#[cfg(test)] +impl crate::rt::AsyncRead for Compat +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: crate::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.p(), cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +impl crate::rt::AsyncWrite for Compat +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.p(), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.p(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.p(), cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.p(), cx, bufs) + } +} diff --git a/src/common/io/mod.rs b/src/common/io/mod.rs index 2e6d506153..6ad07bb771 100644 --- a/src/common/io/mod.rs +++ b/src/common/io/mod.rs @@ -1,3 +1,7 @@ +#[cfg(any(feature = "http2", test))] +mod compat; mod rewind; +#[cfg(any(feature = "http2", test))] +pub(crate) use self::compat::{compat, Compat}; pub(crate) use self::rewind::Rewind; diff --git a/src/common/io/rewind.rs b/src/common/io/rewind.rs index 5642d897d1..f341cfb8af 100644 --- a/src/common/io/rewind.rs +++ b/src/common/io/rewind.rs @@ -2,9 +2,9 @@ use std::marker::Unpin; use std::{cmp, io}; use bytes::{Buf, Bytes}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::common::{task, Pin, Poll}; +use crate::rt::{AsyncRead, AsyncWrite, ReadBufCursor}; /// Combine a buffer with an IO, rewinding reads to use the buffer. #[derive(Debug)] @@ -51,7 +51,7 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + mut buf: ReadBufCursor<'_>, ) -> Poll> { if let Some(mut prefix) = self.pre.take() { // If there are no remaining bytes, let the bytes get dropped. @@ -109,6 +109,7 @@ where mod tests { // FIXME: re-implement tests with `async/await`, this import should // trigger a warning to remind us + use super::super::compat; use super::Rewind; use bytes::Bytes; use tokio::io::AsyncReadExt; @@ -120,14 +121,14 @@ mod tests { let mock = tokio_test::io::Builder::new().read(&underlying).build(); - let mut stream = Rewind::new(mock); + let mut stream = compat(Rewind::new(compat(mock))); // Read off some bytes, ensure we filled o1 let mut buf = [0; 2]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. - stream.rewind(Bytes::copy_from_slice(&buf[..])); + stream.0.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); @@ -143,13 +144,13 @@ mod tests { let mock = tokio_test::io::Builder::new().read(&underlying).build(); - let mut stream = Rewind::new(mock); + let mut stream = compat(Rewind::new(compat(mock))); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); // Rewind the stream so that it is as if we never read in the first place. - stream.rewind(Bytes::copy_from_slice(&buf[..])); + stream.0.rewind(Bytes::copy_from_slice(&buf[..])); let mut buf = [0; 5]; stream.read_exact(&mut buf).await.expect("read1"); diff --git a/src/ffi/io.rs b/src/ffi/io.rs index bff666dbcf..365a3d63ee 100644 --- a/src/ffi/io.rs +++ b/src/ffi/io.rs @@ -2,8 +2,8 @@ use std::ffi::c_void; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::rt::{AsyncRead, AsyncWrite}; use libc::size_t; -use tokio::io::{AsyncRead, AsyncWrite}; use super::task::hyper_context; @@ -124,7 +124,7 @@ impl AsyncRead for hyper_io { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, + buf: &mut crate::rt::ReadBuf<'_>, ) -> Poll> { let buf_ptr = unsafe { buf.unfilled_mut() }.as_mut_ptr() as *mut u8; let buf_len = buf.remaining(); diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index b7c619683c..16755b281a 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,7 +8,7 @@ use bytes::{Buf, Bytes}; use http::header::{HeaderValue, CONNECTION}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use tracing::{debug, error, trace}; use super::io::Buffered; diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 4077b22062..84ebf5af47 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -430,7 +430,7 @@ mod tests { use super::*; use std::pin::Pin; use std::time::Duration; - use tokio::io::{AsyncRead, ReadBuf}; + use crate::rt::{AsyncRead, ReadBuf}; impl<'a> MemRead for &'a [u8] { fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll> { @@ -450,7 +450,7 @@ mod tests { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { let mut v = vec![0; len]; let mut buf = ReadBuf::new(&mut v); - ready!(Pin::new(self).poll_read(cx, &mut buf)?); + ready!(Pin::new(self).poll_read(cx, buf.unfilled())?); Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled()))) } } @@ -629,7 +629,7 @@ mod tests { async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { let mut outs = Vec::new(); - let mut ins = if block_at == 0 { + let mut ins = crate::common::io::compat(if block_at == 0 { tokio_test::io::Builder::new() .wait(Duration::from_millis(10)) .read(content) @@ -640,7 +640,7 @@ mod tests { .wait(Duration::from_millis(10)) .read(&content[block_at..]) .build() - }; + }); let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index cd494581b9..18cec05e92 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -2,7 +2,7 @@ use std::error::Error as StdError; use bytes::{Buf, Bytes}; use http::Request; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use tracing::{debug, trace}; use super::{Http1Transaction, Wants}; @@ -649,6 +649,7 @@ cfg_client! { #[cfg(test)] mod tests { use super::*; + use crate::common::io::compat; use crate::proto::h1::ClientTransaction; use std::time::Duration; @@ -662,7 +663,7 @@ mod tests { // Block at 0 for now, but we will release this response before // the request is ready to write later... let (mut tx, rx) = crate::client::dispatch::channel(); - let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); let mut dispatcher = Dispatcher::new(Client::new(rx), conn); // First poll is needed to allow tx to send... @@ -699,7 +700,7 @@ mod tests { .build_with_handle(); let (mut tx, rx) = crate::client::dispatch::channel(); - let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); conn.set_write_strategy_queue(); let dispatcher = Dispatcher::new(Client::new(rx), conn); @@ -730,7 +731,7 @@ mod tests { .build(); let (mut tx, rx) = crate::client::dispatch::channel(); - let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io); + let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(compat(io)); let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn)); // First poll is needed to allow tx to send... diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index da4101b6fb..5fc1afeecf 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -7,7 +7,7 @@ use std::marker::Unpin; use std::mem::MaybeUninit; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::rt::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::{debug, trace}; use super::{Http1Transaction, ParseContext, ParsedMessage}; @@ -251,7 +251,7 @@ where let dst = self.read_buf.chunk_mut(); let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; let mut buf = ReadBuf::uninit(dst); - match Pin::new(&mut self.io).poll_read(cx, &mut buf) { + match Pin::new(&mut self.io).poll_read(cx, buf.unfilled()) { Poll::Ready(Ok(_)) => { let n = buf.filled().len(); trace!("received {} bytes", n); @@ -662,6 +662,7 @@ enum WriteStrategy { #[cfg(test)] mod tests { + use crate::common::io::compat; use crate::common::time::Time; use super::*; @@ -717,7 +718,7 @@ mod tests { .wait(Duration::from_secs(1)) .build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); // We expect a `parse` to be not ready, and so can't await it directly. // Rather, this `poll_fn` will wrap the `Poll` result. @@ -862,7 +863,7 @@ mod tests { #[cfg(debug_assertions)] // needs to trigger a debug_assert fn write_buf_requires_non_empty_bufs() { let mock = Mock::new().build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.buffer(Cursor::new(Vec::new())); } @@ -897,7 +898,7 @@ mod tests { let mock = Mock::new().write(b"hello world, it's hyper!").build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.write_buf.set_strategy(WriteStrategy::Flatten); buffered.headers_buf().extend(b"hello "); @@ -956,7 +957,7 @@ mod tests { .write(b"hyper!") .build(); - let mut buffered = Buffered::<_, Cursor>>::new(mock); + let mut buffered = Buffered::<_, Cursor>>::new(compat(mock)); buffered.write_buf.set_strategy(WriteStrategy::Queue); // we have 4 buffers, and vec IO disabled, but explicitly said diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 121e24dd84..6bea8538e4 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -8,7 +8,7 @@ use futures_util::stream::StreamExt as _; use h2::client::{Builder, SendRequest}; use h2::SendStream; use http::{Method, StatusCode}; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use tracing::{debug, trace, warn}; use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; @@ -111,7 +111,7 @@ where B::Data: Send + 'static, { let (h2_tx, mut conn) = new_builder(config) - .handshake::<_, SendBuf>(io) + .handshake::<_, SendBuf>(crate::common::io::compat(io)) .await .map_err(crate::Error::new_h2)?; diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index 8def873cfc..ff978308e6 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -4,10 +4,10 @@ use http::header::{HeaderName, CONNECTION, TE, TRAILER, TRANSFER_ENCODING, UPGRA use http::HeaderMap; use pin_project_lite::pin_project; use std::error::Error as StdError; -use std::io::{self, Cursor, IoSlice}; +use std::io::{Cursor, IoSlice}; use std::mem; use std::task::Context; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::rt::{AsyncRead, AsyncWrite, ReadBufCursor}; use tracing::{debug, trace, warn}; use crate::body::Body; @@ -282,8 +282,8 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - read_buf: &mut ReadBuf<'_>, - ) -> Poll> { + mut read_buf: ReadBufCursor<'_>, + ) -> Poll> { if self.buf.is_empty() { self.buf = loop { match ready!(self.recv_stream.poll_data(cx)) { @@ -299,7 +299,7 @@ where return Poll::Ready(match e.reason() { Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), Some(Reason::STREAM_CLOSED) => { - Err(io::Error::new(io::ErrorKind::BrokenPipe, e)) + Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) } _ => Err(h2_to_io_error(e)), }) @@ -323,7 +323,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { + ) -> Poll> { if buf.is_empty() { return Poll::Ready(Ok(0)); } @@ -348,7 +348,7 @@ where Poll::Ready(Err(h2_to_io_error( match ready!(self.send_stream.poll_reset(cx)) { Ok(Reason::NO_ERROR) | Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Ok(reason) => reason.into(), Err(e) => e, @@ -356,14 +356,14 @@ where ))) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll> { if self.send_stream.write(&[], true).is_ok() { return Poll::Ready(Ok(())) } @@ -374,7 +374,7 @@ where return Poll::Ready(Ok(())) } Ok(Reason::CANCEL) | Ok(Reason::STREAM_CLOSED) => { - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Ok(reason) => reason.into(), Err(e) => e, @@ -383,11 +383,11 @@ where } } -fn h2_to_io_error(e: h2::Error) -> io::Error { +fn h2_to_io_error(e: h2::Error) -> std::io::Error { if e.is_io() { e.into_io().unwrap() } else { - io::Error::new(io::ErrorKind::Other, e) + std::io::Error::new(std::io::ErrorKind::Other, e) } } @@ -414,7 +414,7 @@ where unsafe { self.as_inner_unchecked().poll_reset(cx) } } - fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), std::io::Error> { let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); unsafe { self.as_inner_unchecked() diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index a5bd75f92c..cae5d8446f 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -8,7 +8,7 @@ use h2::server::{Connection, Handshake, SendResponse}; use h2::{Reason, RecvStream}; use http::{Method, Request}; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use tracing::{debug, trace, warn}; use super::{ping, PipeToSendStream, SendBuf}; @@ -89,7 +89,7 @@ where { Handshaking { ping_config: ping::Config, - hs: Handshake>, + hs: Handshake, SendBuf>, }, Serving(Serving), Closed, @@ -100,7 +100,7 @@ where B: Body, { ping: Option<(ping::Recorder, ping::Ponger)>, - conn: Connection>, + conn: Connection, SendBuf>, closing: Option, } @@ -132,7 +132,7 @@ where if config.enable_connect_protocol { builder.enable_connect_protocol(); } - let handshake = builder.handshake(io); + let handshake = builder.handshake(crate::common::io::compat(io)); let bdp = if config.adaptive_window { Some(config.initial_stream_window_size) diff --git a/src/rt/io.rs b/src/rt/io.rs new file mode 100644 index 0000000000..63100d6424 --- /dev/null +++ b/src/rt/io.rs @@ -0,0 +1,285 @@ +use std::fmt; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +// New IO traits? What?! Why, are you bonkers? +// +// I mean, yes, probably. But, here's the goals: +// +// 1. Supports poll-based IO operations. +// 2. Opt-in vectored IO. +// 3. Can use an optional buffer pool. +// 4. Able to add completion-based (uring) IO eventually. +// +// Frankly, the last point is the entire reason we're doing this. We want to +// have forwards-compatibility with an eventually stable io-uring runtime. We +// don't need that to work right away. But it must be possible to add in here +// without breaking hyper 1.0. +// +// While in here, if there's small tweaks to poll_read or poll_write that would +// allow even the "slow" path to be faster, such as if someone didn't remember +// to forward along an `is_completion` call. + +/// dox +pub trait Read { + /// dox + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll>; +} + +/// dox +pub trait Write { + /// dox + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll>; + + /// dox + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + /// dox + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; + + /// dox + fn is_write_vectored(&self) -> bool { + false + } + + /// dox + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } +} + +/// dox +pub struct ReadBuf<'a> { + raw: &'a mut [MaybeUninit], + filled: usize, + init: usize, +} + +/// dox +#[derive(Debug)] +pub struct ReadBufCursor<'a> { + buf: &'a mut ReadBuf<'a>, +} + +impl<'data> ReadBuf<'data> { + /// dox + #[inline] + #[cfg(test)] + pub(crate) fn new(raw: &'data mut [u8]) -> Self { + let len = raw.len(); + Self { + // SAFETY: We never de-init the bytes ourselves. + raw: unsafe { &mut *(raw as *mut [u8] as *mut [MaybeUninit]) }, + filled: 0, + init: len, + } + } + + /// dox + #[inline] + pub fn uninit(raw: &'data mut [MaybeUninit]) -> Self { + Self { + raw, + filled: 0, + init: 0, + } + } + + /// dox + #[inline] + pub fn filled(&self) -> &[u8] { + // SAFETY: We only slice the filled part of the buffer, which is always valid + unsafe { &*(&self.raw[0..self.filled] as *const [MaybeUninit] as *const [u8]) } + } + + /// dox + #[inline] + pub fn unfilled<'cursor>(&'cursor mut self) -> ReadBufCursor<'cursor> { + ReadBufCursor { + // SAFETY: self.buf is never re-assigned, so its safe to narrow + // the lifetime. + buf: unsafe { + std::mem::transmute::<&'cursor mut ReadBuf<'data>, &'cursor mut ReadBuf<'cursor>>( + self, + ) + }, + } + } + + #[inline] + pub(crate) unsafe fn set_init(&mut self, n: usize) { + self.init = self.init.max(n); + } + + #[inline] + pub(crate) unsafe fn set_filled(&mut self, n: usize) { + self.filled = self.filled.max(n); + } + + #[inline] + pub(crate) fn len(&self) -> usize { + self.filled + } + + #[inline] + pub(crate) fn init_len(&self) -> usize { + self.init + } + + #[inline] + fn remaining(&self) -> usize { + self.capacity() - self.filled + } + + #[inline] + fn capacity(&self) -> usize { + self.raw.len() + } +} + +impl<'data> fmt::Debug for ReadBuf<'data> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadBuf") + .field("filled", &self.filled) + .field("init", &self.init) + .field("capacity", &self.capacity()) + .finish() + } +} + +impl<'data> ReadBufCursor<'data> { + /// Access the unfilled part of the buffer. + /// + /// # Safety + /// + /// The caller must not uninitialize any bytes that may have been + /// initialized before. + #[inline] + pub unsafe fn as_mut(&mut self) -> &mut [MaybeUninit] { + &mut self.buf.raw[self.buf.filled..] + } + + /// Advance the `filled` cursor by `n` bytes. + /// + /// # Safety + /// + /// The caller must take care that `n` more bytes have been initialized. + #[inline] + pub unsafe fn advance(&mut self, n: usize) { + self.buf.filled = self.buf.filled.checked_add(n).expect("overflow"); + self.buf.init = self.buf.filled.max(self.buf.init); + } + + #[inline] + pub(crate) fn remaining(&self) -> usize { + self.buf.remaining() + } + + #[inline] + pub(crate) fn put_slice(&mut self, buf: &[u8]) { + assert!( + self.buf.remaining() >= buf.len(), + "buf.len() must fit in remaining()" + ); + + let amt = buf.len(); + // Cannot overflow, asserted above + let end = self.buf.filled + amt; + + // Safety: the length is asserted above + unsafe { + self.buf.raw[self.buf.filled..end] + .as_mut_ptr() + .cast::() + .copy_from_nonoverlapping(buf.as_ptr(), amt); + } + + if self.buf.init < end { + self.buf.init = end; + } + self.buf.filled = end; + } +} + +macro_rules! deref_async_read { + () => { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll> { + Pin::new(&mut **self).poll_read(cx, buf) + } + }; +} + +impl Read for Box { + deref_async_read!(); +} + +impl Read for &mut T { + deref_async_read!(); +} + +macro_rules! deref_async_write { + () => { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + (**self).is_write_vectored() + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut **self).poll_shutdown(cx) + } + }; +} + +impl Write for Box { + deref_async_write!(); +} + +impl Write for &mut T { + deref_async_write!(); +} diff --git a/src/rt/mod.rs b/src/rt/mod.rs index 803d010f40..cecd622cb1 100644 --- a/src/rt/mod.rs +++ b/src/rt/mod.rs @@ -6,6 +6,10 @@ //! to plug in other runtimes. pub mod bounds; +mod io; + +//pub use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +pub use self::io::{Read as AsyncRead, ReadBuf, ReadBufCursor, Write as AsyncWrite}; use std::{ future::Future, diff --git a/src/server/conn/http1.rs b/src/server/conn/http1.rs index 9ee2fe159f..76a9556e1a 100644 --- a/src/server/conn/http1.rs +++ b/src/server/conn/http1.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use std::time::Duration; use bytes::Bytes; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::{task, Future, Pin, Poll, Unpin}; @@ -334,7 +334,7 @@ impl Builder { /// # use hyper::{body::Incoming, Request, Response}; /// # use hyper::service::Service; /// # use hyper::server::conn::http1::Builder; - /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use hyper::rt::{AsyncRead, AsyncWrite}; /// # async fn run(some_io: I, some_service: S) /// # where /// # I: AsyncRead + AsyncWrite + Unpin + Send + 'static, diff --git a/src/server/conn/http2.rs b/src/server/conn/http2.rs index 45e0760956..13c560304b 100644 --- a/src/server/conn/http2.rs +++ b/src/server/conn/http2.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use std::time::Duration; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite}; +use crate::rt::{AsyncRead, AsyncWrite}; use crate::body::{Body, Incoming as IncomingBody}; use crate::common::{task, Future, Pin, Poll, Unpin}; diff --git a/src/server/conn/mod.rs b/src/server/conn/mod.rs index 2e7157c5b8..ae69194fe8 100644 --- a/src/server/conn/mod.rs +++ b/src/server/conn/mod.rs @@ -7,43 +7,6 @@ //! //! This module is split by HTTP version. Both work similarly, but do have //! specific options on each builder. -//! -//! ## Example -//! -//! A simple example that prepares an HTTP/1 connection over a Tokio TCP stream. -//! -//! ```no_run -//! # #[cfg(feature = "http1")] -//! # mod rt { -//! use http::{Request, Response, StatusCode}; -//! use http_body_util::Full; -//! use hyper::{server::conn::http1, service::service_fn, body, body::Bytes}; -//! use std::{net::SocketAddr, convert::Infallible}; -//! use tokio::net::TcpListener; -//! -//! #[tokio::main] -//! async fn main() -> Result<(), Box> { -//! let addr: SocketAddr = ([127, 0, 0, 1], 8080).into(); -//! -//! let mut tcp_listener = TcpListener::bind(addr).await?; -//! loop { -//! let (tcp_stream, _) = tcp_listener.accept().await?; -//! tokio::task::spawn(async move { -//! if let Err(http_err) = http1::Builder::new() -//! .keep_alive(true) -//! .serve_connection(tcp_stream, service_fn(hello)) -//! .await { -//! eprintln!("Error while serving HTTP connection: {}", http_err); -//! } -//! }); -//! } -//! } -//! -//! async fn hello(_req: Request) -> Result>, Infallible> { -//! Ok(Response::new(Full::new(Bytes::from("Hello World!")))) -//! } -//! # } -//! ``` #[cfg(feature = "http1")] pub mod http1; diff --git a/src/upgrade.rs b/src/upgrade.rs index 1c7b5b01cd..3eab7e6632 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -45,8 +45,8 @@ use std::fmt; use std::io; use std::marker::Unpin; +use crate::rt::{AsyncRead, AsyncWrite, ReadBufCursor}; use bytes::Bytes; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::oneshot; #[cfg(any(feature = "http1", feature = "http2"))] use tracing::trace; @@ -152,7 +152,7 @@ impl AsyncRead for Upgraded { fn poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, + buf: ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.io).poll_read(cx, buf) } @@ -340,7 +340,9 @@ mod tests { fn upgraded_downcast() { let upgraded = Upgraded::new(Mock, Bytes::new()); - let upgraded = upgraded.downcast::>>().unwrap_err(); + let upgraded = upgraded + .downcast::>>>() + .unwrap_err(); upgraded.downcast::().unwrap(); } @@ -352,7 +354,7 @@ mod tests { fn poll_read( self: Pin<&mut Self>, _cx: &mut task::Context<'_>, - _buf: &mut ReadBuf<'_>, + _buf: ReadBufCursor<'_>, ) -> Poll> { unreachable!("Mock::poll_read") } diff --git a/tests/client.rs b/tests/client.rs index 842282c5bb..7b4444f11f 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -22,6 +22,7 @@ use hyper::{Method, Request, StatusCode, Uri, Version}; use bytes::Bytes; use futures_channel::oneshot; use futures_util::future::{self, FutureExt, TryFuture, TryFutureExt}; +use support::TokioIo; use tokio::net::TcpStream; mod support; @@ -36,8 +37,8 @@ where b.collect().await.map(|c| c.to_bytes()) } -fn tcp_connect(addr: &SocketAddr) -> impl Future> { - TcpStream::connect(*addr) +async fn tcp_connect(addr: &SocketAddr) -> std::io::Result> { + TcpStream::connect(*addr).await.map(TokioIo::new) } struct HttpInfo { @@ -312,7 +313,7 @@ macro_rules! test { req.headers_mut().append("Host", HeaderValue::from_str(&host).unwrap()); } - let (mut sender, conn) = builder.handshake(stream).await?; + let (mut sender, conn) = builder.handshake(TokioIo::new(stream)).await?; tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -1339,7 +1340,7 @@ mod conn { use futures_util::future::{self, poll_fn, FutureExt, TryFutureExt}; use http_body_util::{BodyExt, Empty, StreamBody}; use hyper::rt::Timer; - use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf}; + use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; use tokio::net::{TcpListener as TkTcpListener, TcpStream}; use hyper::body::{Body, Frame}; @@ -1349,7 +1350,7 @@ mod conn { use super::{concat, s, support, tcp_connect, FutureHyperExt}; - use support::{TokioExecutor, TokioTimer}; + use support::{TokioExecutor, TokioIo, TokioTimer}; fn setup_logger() { let _ = pretty_env_logger::try_init(); @@ -1773,7 +1774,7 @@ mod conn { } let parts = conn.into_parts(); - let mut io = parts.io; + let io = parts.io; let buf = parts.read_buf; assert_eq!(buf, b"foobar=ready"[..]); @@ -1785,6 +1786,7 @@ mod conn { })) .unwrap_err(); + let mut io = io.tcp.inner(); let mut vec = vec![]; rt.block_on(io.write_all(b"foo=bar")).unwrap(); rt.block_on(io.read_to_end(&mut vec)).unwrap(); @@ -1861,7 +1863,7 @@ mod conn { } let parts = conn.into_parts(); - let mut io = parts.io; + let io = parts.io; let buf = parts.read_buf; assert_eq!(buf, b"foobar=ready"[..]); @@ -1874,6 +1876,7 @@ mod conn { })) .unwrap_err(); + let mut io = io.tcp.inner(); let mut vec = vec![]; rt.block_on(io.write_all(b"foo=bar")).unwrap(); rt.block_on(io.read_to_end(&mut vec)).unwrap(); @@ -1895,6 +1898,7 @@ mod conn { tokio::select! { res = listener.accept() => { let (stream, _) = res.unwrap(); + let stream = TokioIo::new(stream); let service = service_fn(|_:Request| future::ok::<_, hyper::Error>(Response::new(Empty::::new()))); @@ -2077,7 +2081,7 @@ mod conn { // Spawn an HTTP2 server that reads the whole body and responds tokio::spawn(async move { - let sock = listener.accept().await.unwrap().0; + let sock = TokioIo::new(listener.accept().await.unwrap().0); hyper::server::conn::http2::Builder::new(TokioExecutor) .timer(TokioTimer) .serve_connection( @@ -2166,7 +2170,7 @@ mod conn { let res = client.send_request(req).await.expect("send_request"); assert_eq!(res.status(), StatusCode::OK); - let mut upgraded = hyper::upgrade::on(res).await.unwrap(); + let mut upgraded = TokioIo::new(hyper::upgrade::on(res).await.unwrap()); let mut vec = vec![]; upgraded.read_to_end(&mut vec).await.unwrap(); @@ -2264,7 +2268,7 @@ mod conn { ); } - async fn drain_til_eof(mut sock: T) -> io::Result<()> { + async fn drain_til_eof(mut sock: T) -> io::Result<()> { let mut buf = [0u8; 1024]; loop { let n = sock.read(&mut buf).await?; @@ -2276,11 +2280,11 @@ mod conn { } struct DebugStream { - tcp: TcpStream, + tcp: TokioIo, shutdown_called: bool, } - impl AsyncWrite for DebugStream { + impl hyper::rt::AsyncWrite for DebugStream { fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -2305,11 +2309,11 @@ mod conn { } } - impl AsyncRead for DebugStream { + impl hyper::rt::AsyncRead for DebugStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.tcp).poll_read(cx, buf) } diff --git a/tests/server.rs b/tests/server.rs index 8561ab487c..05595776ed 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -22,8 +22,8 @@ use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::rt::Timer; -use support::{TokioExecutor, TokioTimer}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use hyper::rt::{AsyncRead, AsyncWrite}; +use support::{TokioExecutor, TokioIo, TokioTimer}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener as TkTcpListener, TcpListener, TcpStream as TkTcpStream}; @@ -958,6 +958,7 @@ async fn expect_continue_waits_for_body_poll() { }); let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection( @@ -1131,6 +1132,7 @@ async fn disable_keep_alive_mid_request() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let srv = http1::Builder::new().serve_connection(socket, HelloWorld); future::try_select(srv, rx1) .then(|r| match r { @@ -1178,7 +1180,7 @@ async fn disable_keep_alive_post_request() { let dropped2 = dropped.clone(); let (socket, _) = listener.accept().await.unwrap(); let transport = DebugStream { - stream: socket, + stream: TokioIo::new(socket), _debug: dropped2, }; let server = http1::Builder::new().serve_connection(transport, HelloWorld); @@ -1206,6 +1208,7 @@ async fn empty_parse_eof_does_not_return_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -1222,6 +1225,7 @@ async fn nonempty_parse_eof_returns_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -1245,6 +1249,7 @@ async fn http1_allow_half_close() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .half_close(true) .serve_connection( @@ -1272,6 +1277,7 @@ async fn disconnect_after_reading_request_before_responding() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .half_close(false) .serve_connection( @@ -1303,6 +1309,7 @@ async fn returning_1xx_response_is_error() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection( socket, @@ -1367,6 +1374,7 @@ async fn header_read_timeout_slow_writes() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new() .timer(TokioTimer) .header_read_timeout(Duration::from_secs(5)) @@ -1442,6 +1450,7 @@ async fn header_read_timeout_slow_writes_multiple_requests() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new() .timer(TokioTimer) .header_read_timeout(Duration::from_secs(5)) @@ -1488,6 +1497,7 @@ async fn upgrades() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new().serve_connection( socket, service_fn(|_| { @@ -1506,7 +1516,7 @@ async fn upgrades() { // wait so that we don't write until other side saw 101 response rx.await.unwrap(); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1541,6 +1551,7 @@ async fn http_connect() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); let conn = http1::Builder::new().serve_connection( socket, service_fn(|_| { @@ -1558,7 +1569,7 @@ async fn http_connect() { // wait so that we don't write until other side saw 101 response rx.await.unwrap(); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1611,6 +1622,7 @@ async fn upgrades_new() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, svc) .with_upgrades() @@ -1623,10 +1635,10 @@ async fn upgrades_new() { read_101_rx.await.unwrap(); let upgraded = on_upgrade.await.expect("on_upgrade"); - let parts = upgraded.downcast::().unwrap(); + let parts = upgraded.downcast::>().unwrap(); assert_eq!(parts.read_buf, "eagerly optimistic"); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1645,6 +1657,7 @@ async fn upgrades_ignored() { loop { let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); tokio::task::spawn(async move { http1::Builder::new() .serve_connection(socket, svc) @@ -1715,6 +1728,7 @@ async fn http_connect_new() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, svc) .with_upgrades() @@ -1727,10 +1741,10 @@ async fn http_connect_new() { read_200_rx.await.unwrap(); let upgraded = on_upgrade.await.expect("on_upgrade"); - let parts = upgraded.downcast::().unwrap(); + let parts = upgraded.downcast::>().unwrap(); assert_eq!(parts.read_buf, "eagerly optimistic"); - let mut io = parts.io; + let mut io = parts.io.inner(); io.write_all(b"foo=bar").await.unwrap(); let mut vec = vec![]; io.read_to_end(&mut vec).await.unwrap(); @@ -1776,7 +1790,7 @@ async fn h2_connect() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -1795,6 +1809,7 @@ async fn h2_connect() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -1868,7 +1883,7 @@ async fn h2_connect_multiplex() { assert!(upgrade_res.expect_err("upgrade cancelled").is_canceled()); return; } - let mut upgraded = upgrade_res.expect("upgrade successful"); + let mut upgraded = TokioIo::new(upgrade_res.expect("upgrade successful")); upgraded.write_all(b"Bread?").await.unwrap(); @@ -1904,6 +1919,7 @@ async fn h2_connect_multiplex() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -1955,7 +1971,7 @@ async fn h2_connect_large_body() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -1976,6 +1992,7 @@ async fn h2_connect_large_body() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -2026,7 +2043,7 @@ async fn h2_connect_empty_frames() { let on_upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - let mut upgraded = on_upgrade.await.expect("on_upgrade"); + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); upgraded.write_all(b"Bread?").await.unwrap(); let mut vec = vec![]; @@ -2045,6 +2062,7 @@ async fn h2_connect_empty_frames() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .serve_connection(socket, svc) //.with_upgrades() @@ -2067,6 +2085,7 @@ async fn parse_errors_send_4xx_response() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -2089,6 +2108,7 @@ async fn illegal_request_length_returns_400_response() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .serve_connection(socket, HelloWorld) .await @@ -2129,6 +2149,7 @@ async fn max_buf_size() { }); let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); http1::Builder::new() .max_buf_size(MAX) .serve_connection(socket, HelloWorld) @@ -2359,6 +2380,7 @@ async fn http2_keep_alive_detects_unresponsive_client() { }); let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); let err = http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2377,6 +2399,7 @@ async fn http2_keep_alive_with_responsive_client() { tokio::spawn(async move { let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2387,7 +2410,7 @@ async fn http2_keep_alive_with_responsive_client() { .expect("serve_connection"); }); - let tcp = connect_async(addr).await; + let tcp = TokioIo::new(connect_async(addr).await); let (mut client, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) .handshake(tcp) .await @@ -2440,6 +2463,7 @@ async fn http2_keep_alive_count_server_pings() { tokio::spawn(async move { let (socket, _) = listener.accept().await.expect("accept"); + let socket = TokioIo::new(socket); http2::Builder::new(TokioExecutor) .timer(TokioTimer) @@ -2823,6 +2847,7 @@ impl ServeOptions { tokio::select! { res = listener.accept() => { let (stream, _) = res.unwrap(); + let stream = TokioIo::new(stream); tokio::task::spawn(async move { let msg_tx = msg_tx.clone(); @@ -2874,7 +2899,7 @@ fn has_header(msg: &str, name: &str) -> bool { msg[..n].contains(name) } -fn tcp_bind(addr: &SocketAddr) -> ::tokio::io::Result { +fn tcp_bind(addr: &SocketAddr) -> std::io::Result { let std_listener = StdTcpListener::bind(addr).unwrap(); std_listener.set_nonblocking(true).unwrap(); TcpListener::from_std(std_listener) @@ -2953,7 +2978,7 @@ impl AsyncRead for DebugStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: hyper::rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.stream).poll_read(cx, buf) } @@ -3010,9 +3035,11 @@ impl TestClient { let host = req.uri().host().expect("uri has no host"); let port = req.uri().port_u16().expect("uri has no port"); - let stream = TkTcpStream::connect(format!("{}:{}", host, port)) - .await - .unwrap(); + let stream = TokioIo::new( + TkTcpStream::connect(format!("{}:{}", host, port)) + .await + .unwrap(), + ); if self.http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) diff --git a/tests/support/mod.rs b/tests/support/mod.rs index e7e1e8c6bd..c46eff89ea 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -21,7 +21,7 @@ pub use hyper::{HeaderMap, StatusCode}; pub use std::net::SocketAddr; mod tokiort; -pub use tokiort::{TokioExecutor, TokioTimer}; +pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; #[allow(unused_macros)] macro_rules! t { @@ -357,6 +357,7 @@ async fn async_test(cfg: __TestConfig) { loop { let (stream, _) = listener.accept().await.expect("server error"); + let io = TokioIo::new(stream); // Move a clone into the service_fn let serve_handles = serve_handles.clone(); @@ -386,12 +387,12 @@ async fn async_test(cfg: __TestConfig) { tokio::task::spawn(async move { if http2_only { server::conn::http2::Builder::new(TokioExecutor) - .serve_connection(stream, service) + .serve_connection(io, service) .await .expect("server error"); } else { server::conn::http1::Builder::new() - .serve_connection(stream, service) + .serve_connection(io, service) .await .expect("server error"); } @@ -425,10 +426,11 @@ async fn async_test(cfg: __TestConfig) { async move { let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); let res = if http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -440,7 +442,7 @@ async fn async_test(cfg: __TestConfig) { sender.send_request(req).await.unwrap() } else { let (mut sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -508,6 +510,7 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) loop { let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); let service = service_fn(move |mut req| { async move { @@ -523,11 +526,12 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) let stream = TcpStream::connect(format!("{}:{}", uri, port)) .await .unwrap(); + let io = TokioIo::new(stream); let resp = if http2_only { let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake(stream) + .handshake(io) .await .unwrap(); @@ -540,7 +544,7 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) sender.send_request(req).await? } else { let builder = hyper::client::conn::http1::Builder::new(); - let (mut sender, conn) = builder.handshake(stream).await.unwrap(); + let (mut sender, conn) = builder.handshake(io).await.unwrap(); tokio::task::spawn(async move { if let Err(err) = conn.await { @@ -569,12 +573,12 @@ async fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) if http2_only { server::conn::http2::Builder::new(TokioExecutor) - .serve_connection(stream, service) + .serve_connection(io, service) .await .unwrap(); } else { server::conn::http1::Builder::new() - .serve_connection(stream, service) + .serve_connection(io, service) .await .unwrap(); }