Skip to content

Commit

Permalink
feat(rt): replace IO traits with hyper::rt ones
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar committed May 17, 2023
1 parent 8552543 commit 3674759
Show file tree
Hide file tree
Showing 39 changed files with 861 additions and 197 deletions.
2 changes: 1 addition & 1 deletion benches/support/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod tokiort;
pub use tokiort::{TokioExecutor, TokioTimer};
pub use tokiort::{TokioExecutor, TokioIo, TokioTimer};
146 changes: 146 additions & 0 deletions benches/support/tokiort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,149 @@ impl Future for TokioSleep {
// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html

impl Sleep for TokioSleep {}

pin_project! {
#[derive(Debug)]
pub struct TokioIo<T> {
#[pin]
inner: T,
}
}

impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}

pub fn inner(self) -> T {
self.inner
}
}

impl<T> hyper::rt::AsyncRead for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};

unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}

impl<T> hyper::rt::AsyncWrite for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}

impl<T> tokio::io::AsyncRead for TokioIo<T>
where
T: hyper::rt::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
//let init = tbuf.initialized().len();
let filled = tbuf.filled().len();
let sub_filled = unsafe {
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());

match hyper::rt::AsyncRead::poll_read(self.project().inner, cx, buf.unfilled()) {
Poll::Ready(Ok(())) => buf.filled().len(),
other => return other,
}
};

let n_filled = filled + sub_filled;
// At least sub_filled bytes had to have been initialized.
let n_init = sub_filled;
unsafe {
tbuf.assume_init(n_init);
tbuf.set_filled(n_filled);
}

Poll::Ready(Ok(()))
}
}

impl<T> tokio::io::AsyncWrite for TokioIo<T>
where
T: hyper::rt::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
hyper::rt::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
hyper::rt::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
hyper::rt::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}
7 changes: 6 additions & 1 deletion examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ use hyper::Request;
use tokio::io::{self, AsyncWriteExt as _};
use tokio::net::TcpStream;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// A simple type alias so as to DRY.
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

Expand Down Expand Up @@ -40,8 +44,9 @@ async fn fetch_url(url: hyper::Uri) -> Result<()> {
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
let stream = TcpStream::connect(addr).await?;
let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?;
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand Down
7 changes: 6 additions & 1 deletion examples/client_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use hyper::{body::Buf, Request};
use serde::Deserialize;
use tokio::net::TcpStream;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// A simple type alias so as to DRY.
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

Expand All @@ -29,8 +33,9 @@ async fn fetch_json(url: hyper::Uri) -> Result<Vec<User>> {
let addr = format!("{}:{}", host, port);

let stream = TcpStream::connect(addr).await?;
let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?;
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand Down
7 changes: 6 additions & 1 deletion examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use hyper::service::service_fn;
use hyper::{body::Body, Method, Request, Response, StatusCode};
use tokio::net::TcpListener;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

/// This is our service handler. It receives a Request, routes on its
/// path, and returns a Future of a Response.
async fn echo(
Expand Down Expand Up @@ -92,10 +96,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
println!("Listening on http://{}", addr);
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(echo))
.serve_connection(io, service_fn(echo))
.await
{
println!("Error serving connection: {:?}", err);
Expand Down
14 changes: 8 additions & 6 deletions examples/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use hyper::{server::conn::http1, service::service_fn};
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
pretty_env_logger::init();
Expand All @@ -20,6 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);

// This is the `Service` that will handle the connection.
// `service_fn` is a helper to convert a function that
Expand All @@ -42,9 +47,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

async move {
let client_stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(client_stream);

let (mut sender, conn) =
hyper::client::conn::http1::handshake(client_stream).await?;
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand All @@ -56,10 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service)
.await
{
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
println!("Failed to serve the connection: {:?}", err);
}
});
Expand Down
11 changes: 9 additions & 2 deletions examples/hello.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use hyper::service::service_fn;
use hyper::{Request, Response};
use tokio::net::TcpListener;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// An async function that consumes a request, does nothing with it and returns a
// response.
async fn hello(_: Request<hyper::body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Expand All @@ -35,7 +39,10 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// has work to do. In this case, a connection arrives on the port we are listening on and
// the task is woken up, at which point the task is then put back on a thread, and is
// driven forward by the runtime, eventually yielding a TCP stream.
let (stream, _) = listener.accept().await?;
let (tcp, _) = listener.accept().await?;
// Use an adapter to access something implementing `tokio::io` traits as if they implement
// `hyper::rt` IO traits.
let io = TokioIo::new(tcp);

// Spin up a new task in Tokio so we can continue to listen for new TCP connection on the
// current task without waiting for the processing of the HTTP1 connection we just received
Expand All @@ -44,7 +51,7 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Handle the connection from the client using HTTP1 and pass any
// HTTP requests received on that connection to the `hello` function
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(hello))
.serve_connection(io, service_fn(hello))
.await
{
println!("Error serving connection: {:?}", err);
Expand Down
13 changes: 10 additions & 3 deletions examples/http_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use hyper::{Method, Request, Response};

use tokio::net::{TcpListener, TcpStream};

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// To try this example:
// 1. cargo run --example http_proxy
// 2. config http_proxy in command line
Expand All @@ -28,12 +32,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(stream, service_fn(proxy))
.serve_connection(io, service_fn(proxy))
.with_upgrades()
.await
{
Expand Down Expand Up @@ -88,11 +93,12 @@ async fn proxy(
let addr = format!("{}:{}", host, port);

let stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);

let (mut sender, conn) = Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.handshake(stream)
.handshake(io)
.await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
Expand Down Expand Up @@ -123,9 +129,10 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {

// Create a TCP connection to host:port, build a tunnel between the connection and
// the upgraded connection
async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> {
async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
// Connect to remote server
let mut server = TcpStream::connect(addr).await?;
let mut upgraded = TokioIo::new(upgraded);

// Proxying data
let (from_client, from_server) =
Expand Down
10 changes: 8 additions & 2 deletions examples/multi_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ use hyper::service::service_fn;
use hyper::{Request, Response};
use tokio::net::TcpListener;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

static INDEX1: &[u8] = b"The 1st service!";
static INDEX2: &[u8] = b"The 2nd service!";

Expand All @@ -33,10 +37,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr1).await.unwrap();
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(index1))
.serve_connection(io, service_fn(index1))
.await
{
println!("Error serving connection: {:?}", err);
Expand All @@ -49,10 +54,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = TcpListener::bind(addr2).await.unwrap();
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(index2))
.serve_connection(io, service_fn(index2))
.await
{
println!("Error serving connection: {:?}", err);
Expand Down
Loading

0 comments on commit 3674759

Please sign in to comment.