diff --git a/Cargo.toml b/Cargo.toml index 85b70f4..ad1b618 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ hyper = "=1.0.0-rc.4" futures-channel = "0.3" futures-util = { version = "0.3", default-features = false } http = "0.2" +http-body = "1.0.0-rc.2" +bytes = "1" once_cell = "1.14" @@ -30,9 +32,11 @@ tower-service = "0.3" tower = { version = "0.4", features = ["make", "util"] } [dev-dependencies] +hyper = { version = "1.0.0-rc.3", features = ["full"] } bytes = "1" http-body-util = "0.1.0-rc.3" tokio = { version = "1", features = ["macros", "test-util"] } +tokio-test = "0.4" [target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] pnet_datalink = "0.27.2" @@ -50,6 +54,7 @@ http1 = ["hyper/http1"] http2 = ["hyper/http2"] tcp = [] +auto = ["hyper/server", "http1", "http2"] runtime = [] # internal features used in CI diff --git a/src/common/mod.rs b/src/common/mod.rs index 6eeabaf..9ba4244 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,6 +14,7 @@ pub mod exec; #[cfg(feature = "client")] mod lazy; pub(crate) mod never; +pub(crate) mod rewind; #[cfg(feature = "client")] mod sync; diff --git a/src/common/rewind.rs b/src/common/rewind.rs new file mode 100644 index 0000000..18d8f58 --- /dev/null +++ b/src/common/rewind.rs @@ -0,0 +1,161 @@ +use std::marker::Unpin; +use std::{cmp, io}; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::{ + pin::Pin, + task::{self, Poll}, +}; + +/// Combine a buffer with an IO, rewinding reads to use the buffer. +#[derive(Debug)] +pub(crate) struct Rewind { + pre: Option, + inner: T, +} + +impl Rewind { + #[cfg(test)] + pub(crate) fn new(io: T) -> Self { + Rewind { + pre: None, + inner: io, + } + } + + #[allow(dead_code)] + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, + } + } + + #[cfg(test)] + pub(crate) fn rewind(&mut self, bs: Bytes) { + debug_assert!(self.pre.is_none()); + self.pre = Some(bs); + } + + // pub(crate) fn into_inner(self) -> (T, Bytes) { + // (self.inner, self.pre.unwrap_or_else(Bytes::new)) + // } + + // pub(crate) fn get_mut(&mut self) -> &mut T { + // &mut self.inner + // } +} + +impl AsyncRead for Rewind +where + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(mut prefix) = self.pre.take() { + // If there are no remaining bytes, let the bytes get dropped. + if !prefix.is_empty() { + let copy_len = cmp::min(prefix.len(), buf.remaining()); + // TODO: There should be a way to do following two lines cleaner... + buf.put_slice(&prefix[..copy_len]); + prefix.advance(copy_len); + // Put back what's left + if !prefix.is_empty() { + self.pre = Some(prefix); + } + + return Poll::Ready(Ok(())); + } + } + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for Rewind +where + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +#[cfg(test)] +mod tests { + // FIXME: re-implement tests with `async/await`, this import should + // trigger a warning to remind us + use super::Rewind; + use bytes::Bytes; + use tokio::io::AsyncReadExt; + + #[cfg(not(miri))] + #[tokio::test] + async fn partial_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + // Read off some bytes, ensure we filled o1 + let mut buf = [0; 2]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // At this point we should have read everything that was in the MockStream + assert_eq!(&buf, &underlying); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn full_rewind() { + let underlying = [104, 101, 108, 108, 111]; + + let mock = tokio_test::io::Builder::new().read(&underlying).build(); + + let mut stream = Rewind::new(mock); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + + // Rewind the stream so that it is as if we never read in the first place. + stream.rewind(Bytes::copy_from_slice(&buf[..])); + + let mut buf = [0; 5]; + stream.read_exact(&mut buf).await.expect("read1"); + } +} diff --git a/src/lib.rs b/src/lib.rs index acaf64d..85f3d3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![deny(missing_docs)] +#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] //! hyper-util @@ -6,3 +7,6 @@ pub mod client; mod common; pub mod rt; +pub mod server; + +mod error; diff --git a/src/server/conn/auto.rs b/src/server/conn/auto.rs new file mode 100644 index 0000000..fe5f525 --- /dev/null +++ b/src/server/conn/auto.rs @@ -0,0 +1,506 @@ +//! Http1 or Http2 connection. + +use crate::{common::rewind::Rewind, rt::TokioIo}; +use bytes::Bytes; +use http::{Request, Response}; +use http_body::Body; +use hyper::{ + body::Incoming, + rt::{bounds::Http2ConnExec, Timer}, + server::conn::{http1, http2}, + service::Service, +}; +use std::{error::Error as StdError, marker::Unpin, time::Duration}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +type Result = std::result::Result>; + +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +/// Http1 or Http2 connection builder. +pub struct Builder { + http1: http1::Builder, + http2: http2::Builder, +} + +impl Builder { + /// Create a new auto connection builder. + /// + /// `executor` parameter should be a type that implements + /// [`Executor`](hyper::rt::Executor) trait. + /// + /// # Example + /// + /// ``` + /// use hyper_util::{ + /// rt::tokio_executor::TokioExecutor, + /// server::conn::auto, + /// }; + /// + /// auto::Builder::new(TokioExecutor::new()); + /// ``` + pub fn new(executor: E) -> Self { + Self { + http1: http1::Builder::new(), + http2: http2::Builder::new(executor), + } + } + + /// Http1 configuration. + pub fn http1(&mut self) -> Http1Builder<'_, E> { + Http1Builder { inner: self } + } + + /// Http2 configuration. + pub fn http2(&mut self) -> Http2Builder<'_, E> { + Http2Builder { inner: self } + } + + /// Bind a connection together with a [`Service`]. + pub async fn serve_connection(&self, mut io: I, service: S) -> Result<()> + where + S: Service, Response = Response> + Send, + S::Future: Send + 'static, + S::Error: Into>, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, + I: AsyncRead + AsyncWrite + Unpin + 'static, + E: Http2ConnExec, + { + enum Protocol { + H1, + H2, + } + + let mut buf = Vec::new(); + + let protocol = loop { + if buf.len() < 24 { + io.read_buf(&mut buf).await?; + + let len = buf.len().min(H2_PREFACE.len()); + + if buf[0..len] != H2_PREFACE[0..len] { + break Protocol::H1; + } + } else { + break Protocol::H2; + } + }; + + let io = TokioIo::new(Rewind::new_buffered(io, Bytes::from(buf))); + + match protocol { + Protocol::H1 => self.http1.serve_connection(io, service).await?, + Protocol::H2 => self.http2.serve_connection(io, service).await?, + } + + Ok(()) + } +} + +/// Http1 part of builder. +pub struct Http1Builder<'a, E> { + inner: &'a mut Builder, +} + +impl Http1Builder<'_, E> { + /// Http2 configuration. + pub fn http2(&mut self) -> Http2Builder<'_, E> { + Http2Builder { + inner: &mut self.inner, + } + } + + /// Set whether HTTP/1 connections should support half-closures. + /// + /// Clients can chose to shutdown their write-side while waiting + /// for the server to respond. Setting this to `true` will + /// prevent closing the connection immediately if `read` + /// detects an EOF in the middle of a request. + /// + /// Default is `false`. + pub fn half_close(&mut self, val: bool) -> &mut Self { + self.inner.http1.half_close(val); + self + } + + /// Enables or disables HTTP/1 keep-alive. + /// + /// Default is true. + pub fn keep_alive(&mut self, val: bool) -> &mut Self { + self.inner.http1.keep_alive(val); + self + } + + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.title_case_headers(enabled); + self + } + + /// Set whether to support preserving original header cases. + /// + /// Currently, this will record the original cases received, and store them + /// in a private extension on the `Request`. It will also look for and use + /// such an extension in any provided `Response`. + /// + /// Since the relevant extension is still private, there is no way to + /// interact with the original cases. The only effect this can have now is + /// to forward the cases in a proxy-like fashion. + /// + /// Note that this setting does not affect HTTP/2. + /// + /// Default is false. + pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.preserve_header_case(enabled); + self + } + + /// Set a timeout for reading client request headers. If a client does not + /// transmit the entire header within this time, the connection is closed. + /// + /// Default is None. + pub fn header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self { + self.inner.http1.header_read_timeout(read_timeout); + self + } + + /// Set whether HTTP/1 connections should try to use vectored writes, + /// or always flatten into a single buffer. + /// + /// Note that setting this to false may mean more copies of body data, + /// but may also improve performance when an IO transport doesn't + /// support vectored writes well, such as most TLS implementations. + /// + /// Setting this to true will force hyper to use queued strategy + /// which may eliminate unnecessary cloning on some TLS backends + /// + /// Default is `auto`. In this mode hyper will try to guess which + /// mode to use + pub fn writev(&mut self, val: bool) -> &mut Self { + self.inner.http1.writev(val); + self + } + + /// Set the maximum buffer size for the connection. + /// + /// Default is ~400kb. + /// + /// # Panics + /// + /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. + pub fn max_buf_size(&mut self, max: usize) -> &mut Self { + self.inner.http1.max_buf_size(max); + self + } + + /// Aggregates flushes to better support pipelined responses. + /// + /// Experimental, may have bugs. + /// + /// Default is false. + pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { + self.inner.http1.pipeline_flush(enabled); + self + } + + /// Set the timer used in background tasks. + pub fn timer(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.inner.http1.timer(timer); + self + } + + /// Bind a connection together with a [`Service`]. + pub async fn serve_connection(&self, io: I, service: S) -> Result<()> + where + S: Service, Response = Response> + Send, + S::Future: Send + 'static, + S::Error: Into>, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, + I: AsyncRead + AsyncWrite + Unpin + 'static, + E: Http2ConnExec, + { + self.inner.serve_connection(io, service).await + } +} + +/// Http2 part of builder. +pub struct Http2Builder<'a, E> { + inner: &'a mut Builder, +} + +impl Http2Builder<'_, E> { + /// Http1 configuration. + pub fn http1(&mut self) -> Http1Builder<'_, E> { + Http1Builder { + inner: &mut self.inner, + } + } + + /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 + /// stream-level flow control. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE + pub fn initial_stream_window_size(&mut self, sz: impl Into>) -> &mut Self { + self.inner.http2.initial_stream_window_size(sz); + self + } + + /// Sets the max connection-level flow control for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn initial_connection_window_size(&mut self, sz: impl Into>) -> &mut Self { + self.inner.http2.initial_connection_window_size(sz); + self + } + + /// Sets whether to use an adaptive flow control. + /// + /// Enabling this will override the limits set in + /// `http2_initial_stream_window_size` and + /// `http2_initial_connection_window_size`. + pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self { + self.inner.http2.adaptive_window(enabled); + self + } + + /// Sets the maximum frame size to use for HTTP2. + /// + /// Passing `None` will do nothing. + /// + /// If not set, hyper will use a default. + pub fn max_frame_size(&mut self, sz: impl Into>) -> &mut Self { + self.inner.http2.max_frame_size(sz); + self + } + + /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 + /// connections. + /// + /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. + /// + /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS + pub fn max_concurrent_streams(&mut self, max: impl Into>) -> &mut Self { + self.inner.http2.max_concurrent_streams(max); + self + } + + /// Sets an interval for HTTP2 Ping frames should be sent to keep a + /// connection alive. + /// + /// Pass `None` to disable HTTP2 keep-alive. + /// + /// Default is currently disabled. + /// + /// # Cargo Feature + /// + pub fn keep_alive_interval(&mut self, interval: impl Into>) -> &mut Self { + self.inner.http2.keep_alive_interval(interval); + self + } + + /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. + /// + /// If the ping is not acknowledged within the timeout, the connection will + /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. + /// + /// Default is 20 seconds. + /// + /// # Cargo Feature + /// + pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { + self.inner.http2.keep_alive_timeout(timeout); + self + } + + /// Set the maximum write buffer size for each HTTP/2 stream. + /// + /// Default is currently ~400KB, but may change. + /// + /// # Panics + /// + /// The value must be no larger than `u32::MAX`. + pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self { + self.inner.http2.max_send_buf_size(max); + self + } + + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + pub fn enable_connect_protocol(&mut self) -> &mut Self { + self.inner.http2.enable_connect_protocol(); + self + } + + /// Sets the max size of received header frames. + /// + /// Default is currently ~16MB, but may change. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.inner.http2.max_header_list_size(max); + self + } + + /// Set the timer used in background tasks. + pub fn timer(&mut self, timer: M) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.inner.http2.timer(timer); + self + } + + /// Bind a connection together with a [`Service`]. + pub async fn serve_connection(&self, io: I, service: S) -> Result<()> + where + S: Service, Response = Response> + Send, + S::Future: Send + 'static, + S::Error: Into>, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, + I: AsyncRead + AsyncWrite + Unpin + 'static, + E: Http2ConnExec, + { + self.inner.serve_connection(io, service).await + } +} + +#[cfg(test)] +mod tests { + use crate::{ + rt::{tokio_executor::TokioExecutor, TokioIo}, + server::conn::auto, + }; + use http::{Request, Response}; + use http_body::Body; + use http_body_util::{BodyExt, Empty, Full}; + use hyper::{body, body::Bytes, client, service::service_fn}; + use std::{convert::Infallible, error::Error as StdError, net::SocketAddr}; + use tokio::net::{TcpListener, TcpStream}; + + const BODY: &[u8] = b"Hello, world!"; + + #[test] + fn configuration() { + // One liner. + auto::Builder::new(TokioExecutor::new()) + .http1() + .keep_alive(true) + .http2() + .keep_alive_interval(None); + // .serve_connection(io, service); + + // Using variable. + let mut builder = auto::Builder::new(TokioExecutor::new()); + + builder.http1().keep_alive(true); + builder.http2().keep_alive_interval(None); + // builder.serve_connection(io, service); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1() { + let addr = start_server().await; + let mut sender = connect_h1(addr).await; + + let response = sender + .send_request(Request::new(Empty::::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http2() { + let addr = start_server().await; + let mut sender = connect_h2(addr).await; + + let response = sender + .send_request(Request::new(Empty::::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + async fn connect_h1(addr: SocketAddr) -> client::conn::http1::SendRequest + where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into>, + { + let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (sender, connection) = client::conn::http1::handshake(stream).await.unwrap(); + + tokio::spawn(connection); + + sender + } + + async fn connect_h2(addr: SocketAddr) -> client::conn::http2::SendRequest + where + B: Body + Unpin + Send + 'static, + B::Data: Send, + B::Error: Into>, + { + let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (sender, connection) = client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(stream) + .await + .unwrap(); + + tokio::spawn(connection); + + sender + } + + async fn start_server() -> SocketAddr { + let addr: SocketAddr = ([127, 0, 0, 1], 0).into(); + let listener = TcpListener::bind(addr).await.unwrap(); + + let local_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + tokio::task::spawn(async move { + let _ = auto::Builder::new(TokioExecutor::new()) + .serve_connection(stream, service_fn(hello)) + .await; + }); + } + }); + + local_addr + } + + async fn hello(_req: Request) -> Result>, Infallible> { + Ok(Response::new(Full::new(Bytes::from(BODY)))) + } +} diff --git a/src/server/conn/mod.rs b/src/server/conn/mod.rs new file mode 100644 index 0000000..70057c8 --- /dev/null +++ b/src/server/conn/mod.rs @@ -0,0 +1,4 @@ +//! Connection utilities. + +#[cfg(feature = "auto")] +pub mod auto; diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..7b4515c --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,3 @@ +//! Server utilities. + +pub mod conn;