Skip to content

Commit

Permalink
Add TLSAcceptor and Builder (#186)
Browse files Browse the repository at this point in the history
Signed-off-by: Heinz N. Gies <heinz@licenser.net>
Co-authored-by: Dirkjan Ochtman <dirkjan@ochtman.nl>
Co-authored-by: Daniel McCarney <daniel@binaryparadox.net>
  • Loading branch information
3 people authored May 26, 2023
1 parent a0dd811 commit 286e1fa
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 142 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ rustls = { version = "0.21.0", default-features = false }
tokio = "1.0"
tokio-rustls = { version = "0.24.0", default-features = false }
webpki-roots = { version = "0.23", optional = true }
futures-util = { version = "0.3" }

[dev-dependencies]
futures-util = { version = "0.3.1", default-features = false }
Expand All @@ -28,7 +29,8 @@ rustls-pemfile = "1.0.0"
tokio = { version = "1.0", features = ["io-std", "macros", "net", "rt-multi-thread"] }

[features]
default = ["native-tokio", "http1", "tls12", "logging"]
default = ["native-tokio", "http1", "tls12", "logging", "acceptor"]
acceptor = ["hyper/server", "tokio-runtime"]
http1 = ["hyper/http1"]
http2 = ["hyper/http2"]
webpki-tokio = ["tokio-runtime", "webpki-roots"]
Expand All @@ -45,7 +47,7 @@ required-features = ["native-tokio", "http1"]
[[example]]
name = "server"
path = "examples/server.rs"
required-features = ["tokio-runtime"]
required-features = ["tokio-runtime", "acceptor"]

[package.metadata.docs.rs]
all-features = true
Expand Down
150 changes: 18 additions & 132 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,16 @@
//! Certificate and private key are hardcoded to sample files.
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
//! otherwise HTTP/1.1 will be used.
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};

#![cfg(feature = "acceptor")]

use std::vec::Vec;
use std::{env, fs, io};

use hyper::server::conn::AddrIncoming;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{env, fs, io, sync};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
use hyper_rustls::TlsAcceptor;

fn main() {
// Serve an echo service over HTTPS, with proper error handling.
Expand All @@ -39,139 +36,28 @@ async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
};
let addr = format!("127.0.0.1:{}", port).parse()?;

// Load public certificate.
let certs = load_certs("examples/sample.pem")?;
// Load private key.
let key = load_private_key("examples/sample.rsa")?;
// Build TLS configuration.
let tls_cfg = {
// Load public certificate.
let certs = load_certs("examples/sample.pem")?;
// Load private key.
let key = load_private_key("examples/sample.rsa")?;
// Do not use client certificate authentication.
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?;
// Configure ALPN to accept HTTP/2, HTTP/1.1, and HTTP/1.0 in that order.
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
sync::Arc::new(cfg)
};

// Create a TCP listener via tokio.
let incoming = AddrIncoming::bind(&addr)?;
let acceptor = TlsAcceptor::builder()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?
.with_all_versions_alpn()
.with_incoming(incoming);
let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) });
let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(service);
let server = Server::builder(acceptor).serve(service);

// Run the future, keep going until an error occurs.
println!("Starting to serve on https://{}.", addr);
server.await?;
Ok(())
}

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}

impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
Expand Down
139 changes: 139 additions & 0 deletions src/acceptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use core::task::{Context, Poll};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;

use futures_util::ready;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use rustls::ServerConfig;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

mod builder;
pub use builder::AcceptorBuilder;
use builder::WantsTlsConfig;

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}

impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

/// A TLS acceptor that can be used with hyper servers.
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

/// An Acceptor for the `https` scheme.
impl TlsAcceptor {
/// Provides a builder for a `TlsAcceptor`.
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
AcceptorBuilder::new()
}
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl<C, I> From<(C, I)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
I: Into<AddrIncoming>,
{
fn from((config, incoming): (C, I)) -> TlsAcceptor {
TlsAcceptor::new(config.into(), incoming.into())
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
Loading

0 comments on commit 286e1fa

Please sign in to comment.