Skip to content

Commit

Permalink
Add TLSAcceptor and Builder
Browse files Browse the repository at this point in the history
Signed-off-by: Heinz N. Gies <heinz@licenser.net>
  • Loading branch information
Licenser committed Feb 3, 2023
1 parent 21c4d37 commit 3881c3b
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 134 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ rustls = { version = "0.20.1", default-features = false }
tokio = "1.0"
tokio-rustls = { version = "0.23", default-features = false }
webpki-roots = { version = "0.22", optional = true }
futures-util = { version = "0.3" }

[dev-dependencies]
futures-util = { version = "0.3.1", default-features = false }
Expand All @@ -31,7 +32,7 @@ http1 = ["hyper/http1"]
http2 = ["hyper/http2"]
webpki-tokio = ["tokio-runtime", "webpki-roots"]
native-tokio = ["tokio-runtime", "rustls-native-certs"]
tokio-runtime = ["hyper/runtime"]
tokio-runtime = ["hyper/runtime"]
tls12 = ["tokio-rustls/tls12", "rustls/tls12"]
logging = ["log", "tokio-rustls/logging", "rustls/logging"]

Expand Down
144 changes: 13 additions & 131 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@
//! Certificate and private key are hardcoded to sample files.
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
//! otherwise HTTP/1.1 will be used.
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use hyper::server::conn::AddrIncoming;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use hyper_rustls::TlsAcceptor;
use std::vec::Vec;
use std::{env, fs, io, sync};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
use std::{env, fs, io};

fn main() {
// Serve an echo service over HTTPS, with proper error handling.
Expand All @@ -39,139 +32,28 @@ async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
};
let addr = format!("127.0.0.1:{}", port).parse()?;

// Load public certificate.
let certs = load_certs("examples/sample.pem")?;
// Load private key.
let key = load_private_key("examples/sample.rsa")?;
// Build TLS configuration.
let tls_cfg = {
// Load public certificate.
let certs = load_certs("examples/sample.pem")?;
// Load private key.
let key = load_private_key("examples/sample.rsa")?;
// Do not use client certificate authentication.
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?;
// Configure ALPN to accept HTTP/2, HTTP/1.1 in that order.
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
sync::Arc::new(cfg)
};

// Create a TCP listener via tokio.
let incoming = AddrIncoming::bind(&addr)?;
let acceptor = TlsAcceptor::builder()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?
.with_http_alpn()
.with_incoming(incoming);
let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) });
let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(service);
let server = Server::builder(acceptor).serve(service);

// Run the future, keep going until an error occurs.
println!("Starting to serve on https://{}.", addr);
server.await?;
Ok(())
}

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}

impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
Expand Down
137 changes: 137 additions & 0 deletions src/acceptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use rustls::ServerConfig;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

mod builder;

pub use builder::AcceptorBuilder;

use self::builder::WantsTlsConfig;
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}

impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

/// A TLS acceptor that can be used with hyper servers.
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

/// A Acceptor for the `https` scheme.
impl TlsAcceptor {
/// Provides a builder for a `TlsAcceptor`.
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
AcceptorBuilder::new()
}
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl<C, I> From<(C, I)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
I: Into<AddrIncoming>,
{
fn from((config, incoming): (C, I)) -> TlsAcceptor {
TlsAcceptor::new(config.into(), incoming.into())
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
Loading

0 comments on commit 3881c3b

Please sign in to comment.