Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TLSAcceptor and Builder #186

Merged
merged 9 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ rustls = { version = "0.21.0", default-features = false }
tokio = "1.0"
tokio-rustls = { version = "0.24.0", default-features = false }
webpki-roots = { version = "0.23", optional = true }
futures-util = { version = "0.3" }

[dev-dependencies]
futures-util = { version = "0.3.1", default-features = false }
Expand All @@ -28,7 +29,8 @@ rustls-pemfile = "1.0.0"
tokio = { version = "1.0", features = ["io-std", "macros", "net", "rt-multi-thread"] }

[features]
default = ["native-tokio", "http1", "tls12", "logging"]
default = ["native-tokio", "http1", "tls12", "logging", "acceptor"]
acceptor = ["hyper/server", "tokio-runtime"]
http1 = ["hyper/http1"]
http2 = ["hyper/http2"]
webpki-tokio = ["tokio-runtime", "webpki-roots"]
Expand Down
148 changes: 17 additions & 131 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
//! Certificate and private key are hardcoded to sample files.
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
//! otherwise HTTP/1.1 will be used.
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};

#![cfg(feature = "acceptor")]

use hyper::server::conn::AddrIncoming;
Licenser marked this conversation as resolved.
Show resolved Hide resolved
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{env, fs, io, sync};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
use std::{env, fs, io};

fn main() {
// Serve an echo service over HTTPS, with proper error handling.
Expand All @@ -30,148 +25,39 @@ fn error(err: String) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}

#[cfg(feature = "acceptor")]
Licenser marked this conversation as resolved.
Show resolved Hide resolved
#[tokio::main]
async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use hyper_rustls::TlsAcceptor;
Licenser marked this conversation as resolved.
Show resolved Hide resolved
// First parameter is port number (optional, defaults to 1337)
let port = match env::args().nth(1) {
Some(ref p) => p.to_owned(),
None => "1337".to_owned(),
};
let addr = format!("127.0.0.1:{}", port).parse()?;

// Load public certificate.
let certs = load_certs("examples/sample.pem")?;
// Load private key.
let key = load_private_key("examples/sample.rsa")?;
// Build TLS configuration.
let tls_cfg = {
// Load public certificate.
let certs = load_certs("examples/sample.pem")?;
// Load private key.
let key = load_private_key("examples/sample.rsa")?;
// Do not use client certificate authentication.
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?;
// Configure ALPN to accept HTTP/2, HTTP/1.1, and HTTP/1.0 in that order.
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
sync::Arc::new(cfg)
};

// Create a TCP listener via tokio.
let incoming = AddrIncoming::bind(&addr)?;
let acceptor = TlsAcceptor::builder()
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?
.with_all_versions_alpn()
.with_incoming(incoming);
let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) });
let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(service);
let server = Server::builder(acceptor).serve(service);

// Run the future, keep going until an error occurs.
println!("Starting to serve on https://{}.", addr);
server.await?;
Ok(())
}

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}

impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}

// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
Expand Down
139 changes: 139 additions & 0 deletions src/acceptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use core::task::{Context, Poll};
djc marked this conversation as resolved.
Show resolved Hide resolved
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;

use futures_util::ready;
Licenser marked this conversation as resolved.
Show resolved Hide resolved
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use rustls::ServerConfig;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

mod builder;
pub use builder::AcceptorBuilder;
use builder::WantsTlsConfig;

enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
Licenser marked this conversation as resolved.
Show resolved Hide resolved
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
Licenser marked this conversation as resolved.
Show resolved Hide resolved
pub struct TlsStream {
state: State,
}

impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
}

impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}

/// A TLS acceptor that can be used with hyper servers.
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}

/// A Acceptor for the `https` scheme.
Licenser marked this conversation as resolved.
Show resolved Hide resolved
impl TlsAcceptor {
/// Provides a builder for a `TlsAcceptor`.
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
AcceptorBuilder::new()
}
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}

impl<C, I> From<(C, I)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
I: Into<AddrIncoming>,
{
fn from((config, incoming): (C, I)) -> TlsAcceptor {
TlsAcceptor::new(config.into(), incoming.into())
}
}

impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
Loading