Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tls): add rustls backend #134

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
# cargo test --workspace --features all
# displayName: TestStable

- job: Doc
- job: Clippy_Doc
strategy:
matrix:
windows:
Expand All @@ -90,16 +90,7 @@ jobs:
cargo +nightly doc --workspace --all-features --no-deps
displayName: Build docs

- job: Clippy
pool:
vmImage: ubuntu-latest

steps:
- script: |
rustup toolchain install nightly
rustup +nightly target install x86_64-pc-windows-msvc
rustup +nightly target install x86_64-unknown-linux-gnu
rustup +nightly target install x86_64-apple-darwin
rustup +nightly component add clippy
cargo +nightly clippy --target x86_64-pc-windows-msvc --target x86_64-apple-darwin --target x86_64-unknown-linux-gnu --all-features --all-targets -- -Dwarnings
displayName: Run clippy for targets
cargo +nightly clippy --all-features --all-targets -- -Dwarnings
displayName: Run clippy
5 changes: 5 additions & 0 deletions compio-tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ compio-buf = { workspace = true }
compio-io = { workspace = true }

native-tls = { version = "0.2.11", optional = true }
rustls = { version = "0.21.8", optional = true }

[dev-dependencies]
compio-net = { workspace = true }
compio-runtime = { workspace = true }
compio-macros = { workspace = true }

rustls-native-certs = "0.6.3"

[features]
default = ["native-tls"]
all = ["native-tls", "rustls"]
109 changes: 0 additions & 109 deletions compio-tls/src/adapter.rs

This file was deleted.

171 changes: 171 additions & 0 deletions compio-tls/src/adapter/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use std::io;

use compio_io::{AsyncRead, AsyncWrite};

use crate::{wrapper::StreamWrapper, TlsStream};

#[cfg(feature = "rustls")]
mod rtls;

#[derive(Debug, Clone)]
enum TlsConnectorInner {
#[cfg(feature = "native-tls")]
NativeTls(native_tls::TlsConnector),
#[cfg(feature = "rustls")]
Rustls(rtls::TlsConnector),
}

/// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`],
/// providing an async `connect` method.
#[derive(Debug, Clone)]
pub struct TlsConnector(TlsConnectorInner);

#[cfg(feature = "native-tls")]
impl From<native_tls::TlsConnector> for TlsConnector {
fn from(value: native_tls::TlsConnector) -> Self {
Self(TlsConnectorInner::NativeTls(value))
}
}

#[cfg(feature = "rustls")]
impl From<std::sync::Arc<rustls::ClientConfig>> for TlsConnector {
fn from(value: std::sync::Arc<rustls::ClientConfig>) -> Self {
Self(TlsConnectorInner::Rustls(rtls::TlsConnector(value)))
}
}

impl TlsConnector {
/// Connects the provided stream with this connector, assuming the provided
/// domain.
///
/// This function will internally call `TlsConnector::connect` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used for clients who have already established, for
/// example, a TCP connection to a remote server. That stream is then
/// provided here to perform the client half of a connection to a
/// TLS-powered server.
pub async fn connect<S: AsyncRead + AsyncWrite>(
&self,
domain: &str,
stream: S,
) -> io::Result<TlsStream<S>> {
match &self.0 {
#[cfg(feature = "native-tls")]
TlsConnectorInner::NativeTls(c) => {
handshake_native_tls(c.connect(domain, StreamWrapper::new(stream))).await
}
#[cfg(feature = "rustls")]
TlsConnectorInner::Rustls(c) => handshake_rustls(c.connect(domain, stream)).await,
}
}
}

#[derive(Clone)]
enum TlsAcceptorInner {
#[cfg(feature = "native-tls")]
NativeTls(native_tls::TlsAcceptor),
#[cfg(feature = "rustls")]
Rustls(rtls::TlsAcceptor),
}

/// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`],
/// providing an async `accept` method.
#[derive(Clone)]
pub struct TlsAcceptor(TlsAcceptorInner);

#[cfg(feature = "native-tls")]
impl From<native_tls::TlsAcceptor> for TlsAcceptor {
fn from(value: native_tls::TlsAcceptor) -> Self {
Self(TlsAcceptorInner::NativeTls(value))
}
}

#[cfg(feature = "rustls")]
impl From<std::sync::Arc<rustls::ServerConfig>> for TlsAcceptor {
fn from(value: std::sync::Arc<rustls::ServerConfig>) -> Self {
Self(TlsAcceptorInner::Rustls(rtls::TlsAcceptor(value)))
}
}

impl TlsAcceptor {
/// Accepts a new client connection with the provided stream.
///
/// This function will internally call `TlsAcceptor::accept` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used after a new socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of accepting a client connection.
pub async fn accept<S: AsyncRead + AsyncWrite>(&self, stream: S) -> io::Result<TlsStream<S>> {
match &self.0 {
#[cfg(feature = "native-tls")]
TlsAcceptorInner::NativeTls(c) => {
handshake_native_tls(c.accept(StreamWrapper::new(stream))).await
}
#[cfg(feature = "rustls")]
TlsAcceptorInner::Rustls(c) => handshake_rustls(c.accept(stream)).await,
}
}
}

#[cfg(feature = "native-tls")]
async fn handshake_native_tls<S: AsyncRead + AsyncWrite>(
mut res: Result<
native_tls::TlsStream<StreamWrapper<S>>,
native_tls::HandshakeError<StreamWrapper<S>>,
>,
) -> io::Result<TlsStream<S>> {
use native_tls::HandshakeError;

loop {
match res {
Ok(mut s) => {
s.get_mut().flush_write_buf().await?;
return Ok(TlsStream::from(s));
}
Err(e) => match e {
HandshakeError::Failure(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
HandshakeError::WouldBlock(mut mid_stream) => {
if mid_stream.get_mut().flush_write_buf().await? == 0 {
mid_stream.get_mut().fill_read_buf().await?;
}
res = mid_stream.handshake();
}
},
}
}
}

#[cfg(feature = "rustls")]
async fn handshake_rustls<S: AsyncRead + AsyncWrite, C, D>(
mut res: Result<TlsStream<S>, rtls::HandshakeError<S, C>>,
) -> io::Result<TlsStream<S>>
where
C: std::ops::DerefMut<Target = rustls::ConnectionCommon<D>>,
{
use rtls::HandshakeError;

loop {
match res {
Ok(mut s) => {
s.flush().await?;
return Ok(s);
}
Err(e) => match e {
HandshakeError::Rustls(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
HandshakeError::System(e) => return Err(e),
HandshakeError::WouldBlock(mut mid_stream) => {
if mid_stream.get_mut().flush_write_buf().await? == 0 {
mid_stream.get_mut().fill_read_buf().await?;
}
res = mid_stream.handshake::<D>();
}
},
}
}
}
Loading