diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b904f8e6..4ac92c4e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -72,7 +72,7 @@ jobs: # cargo test --workspace --features all # displayName: TestStable - - job: Doc + - job: Clippy_Doc strategy: matrix: windows: @@ -90,16 +90,7 @@ jobs: cargo +nightly doc --workspace --all-features --no-deps displayName: Build docs - - job: Clippy - pool: - vmImage: ubuntu-latest - - steps: - script: | - rustup toolchain install nightly - rustup +nightly target install x86_64-pc-windows-msvc - rustup +nightly target install x86_64-unknown-linux-gnu - rustup +nightly target install x86_64-apple-darwin rustup +nightly component add clippy - cargo +nightly clippy --target x86_64-pc-windows-msvc --target x86_64-apple-darwin --target x86_64-unknown-linux-gnu --all-features --all-targets -- -Dwarnings - displayName: Run clippy for targets + cargo +nightly clippy --all-features --all-targets -- -Dwarnings + displayName: Run clippy diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml index 2f28f9b2..851ccd2d 100644 --- a/compio-tls/Cargo.toml +++ b/compio-tls/Cargo.toml @@ -14,10 +14,15 @@ compio-buf = { workspace = true } compio-io = { workspace = true } native-tls = { version = "0.2.11", optional = true } +rustls = { version = "0.21.8", optional = true } [dev-dependencies] compio-net = { workspace = true } compio-runtime = { workspace = true } +compio-macros = { workspace = true } + +rustls-native-certs = "0.6.3" [features] default = ["native-tls"] +all = ["native-tls", "rustls"] diff --git a/compio-tls/src/adapter.rs b/compio-tls/src/adapter.rs deleted file mode 100644 index bf523a5c..00000000 --- a/compio-tls/src/adapter.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::io; - -use compio_io::{AsyncRead, AsyncWrite}; -use native_tls::HandshakeError; - -use crate::{wrapper::StreamWrapper, TlsStream}; - -/// A wrapper around a [`native_tls::TlsConnector`], providing an async -/// `connect` method. -/// -/// ```rust -/// use compio_io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; -/// use compio_net::TcpStream; -/// use compio_tls::TlsConnector; -/// -/// # compio_runtime::block_on(async { -/// let connector = TlsConnector::from(native_tls::TlsConnector::new().unwrap()); -/// -/// let stream = TcpStream::connect("www.example.com:443").await.unwrap(); -/// let mut stream = connector.connect("www.example.com", stream).await.unwrap(); -/// -/// stream -/// .write_all("GET / HTTP/1.1\r\nHost:www.example.com\r\nConnection: close\r\n\r\n") -/// .await -/// .unwrap(); -/// stream.flush().await.unwrap(); -/// let (_, res) = stream.read_to_end(vec![]).await.unwrap(); -/// println!("{}", String::from_utf8_lossy(&res)); -/// # }) -/// ``` -#[derive(Debug, Clone)] -pub struct TlsConnector(native_tls::TlsConnector); - -impl From for TlsConnector { - fn from(value: native_tls::TlsConnector) -> Self { - Self(value) - } -} - -impl TlsConnector { - /// Connects the provided stream with this connector, assuming the provided - /// domain. - /// - /// This function will internally call `TlsConnector::connect` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used for clients who have already established, for - /// example, a TCP connection to a remote server. That stream is then - /// provided here to perform the client half of a connection to a - /// TLS-powered server. - pub async fn connect( - &self, - domain: &str, - stream: S, - ) -> io::Result> { - handshake(self.0.connect(domain, StreamWrapper::new(stream))).await - } -} - -/// A wrapper around a [`native_tls::TlsAcceptor`], providing an async `accept` -/// method. -#[derive(Clone)] -pub struct TlsAcceptor(native_tls::TlsAcceptor); - -impl From for TlsAcceptor { - fn from(value: native_tls::TlsAcceptor) -> Self { - Self(value) - } -} - -impl TlsAcceptor { - /// Accepts a new client connection with the provided stream. - /// - /// This function will internally call `TlsAcceptor::accept` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used after a new socket has been accepted from a - /// `TcpListener`. That socket is then passed to this function to perform - /// the server half of accepting a client connection. - pub async fn accept(&self, stream: S) -> io::Result> { - handshake(self.0.accept(StreamWrapper::new(stream))).await - } -} - -async fn handshake( - mut res: Result>, HandshakeError>>, -) -> io::Result> { - loop { - match res { - Ok(mut s) => { - s.get_mut().flush_write_buf().await?; - return Ok(TlsStream::from(s)); - } - Err(e) => match e { - HandshakeError::Failure(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), - HandshakeError::WouldBlock(mut mid_stream) => { - if mid_stream.get_mut().flush_write_buf().await? == 0 { - mid_stream.get_mut().fill_read_buf().await?; - } - res = mid_stream.handshake(); - } - }, - } - } -} diff --git a/compio-tls/src/adapter/mod.rs b/compio-tls/src/adapter/mod.rs new file mode 100644 index 00000000..6912144a --- /dev/null +++ b/compio-tls/src/adapter/mod.rs @@ -0,0 +1,171 @@ +use std::io; + +use compio_io::{AsyncRead, AsyncWrite}; + +use crate::{wrapper::StreamWrapper, TlsStream}; + +#[cfg(feature = "rustls")] +mod rtls; + +#[derive(Debug, Clone)] +enum TlsConnectorInner { + #[cfg(feature = "native-tls")] + NativeTls(native_tls::TlsConnector), + #[cfg(feature = "rustls")] + Rustls(rtls::TlsConnector), +} + +/// A wrapper around a [`native_tls::TlsConnector`] or [`rustls::ClientConfig`], +/// providing an async `connect` method. +#[derive(Debug, Clone)] +pub struct TlsConnector(TlsConnectorInner); + +#[cfg(feature = "native-tls")] +impl From for TlsConnector { + fn from(value: native_tls::TlsConnector) -> Self { + Self(TlsConnectorInner::NativeTls(value)) + } +} + +#[cfg(feature = "rustls")] +impl From> for TlsConnector { + fn from(value: std::sync::Arc) -> Self { + Self(TlsConnectorInner::Rustls(rtls::TlsConnector(value))) + } +} + +impl TlsConnector { + /// Connects the provided stream with this connector, assuming the provided + /// domain. + /// + /// This function will internally call `TlsConnector::connect` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used for clients who have already established, for + /// example, a TCP connection to a remote server. That stream is then + /// provided here to perform the client half of a connection to a + /// TLS-powered server. + pub async fn connect( + &self, + domain: &str, + stream: S, + ) -> io::Result> { + match &self.0 { + #[cfg(feature = "native-tls")] + TlsConnectorInner::NativeTls(c) => { + handshake_native_tls(c.connect(domain, StreamWrapper::new(stream))).await + } + #[cfg(feature = "rustls")] + TlsConnectorInner::Rustls(c) => handshake_rustls(c.connect(domain, stream)).await, + } + } +} + +#[derive(Clone)] +enum TlsAcceptorInner { + #[cfg(feature = "native-tls")] + NativeTls(native_tls::TlsAcceptor), + #[cfg(feature = "rustls")] + Rustls(rtls::TlsAcceptor), +} + +/// A wrapper around a [`native_tls::TlsAcceptor`] or [`rustls::ServerConfig`], +/// providing an async `accept` method. +#[derive(Clone)] +pub struct TlsAcceptor(TlsAcceptorInner); + +#[cfg(feature = "native-tls")] +impl From for TlsAcceptor { + fn from(value: native_tls::TlsAcceptor) -> Self { + Self(TlsAcceptorInner::NativeTls(value)) + } +} + +#[cfg(feature = "rustls")] +impl From> for TlsAcceptor { + fn from(value: std::sync::Arc) -> Self { + Self(TlsAcceptorInner::Rustls(rtls::TlsAcceptor(value))) + } +} + +impl TlsAcceptor { + /// Accepts a new client connection with the provided stream. + /// + /// This function will internally call `TlsAcceptor::accept` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used after a new socket has been accepted from a + /// `TcpListener`. That socket is then passed to this function to perform + /// the server half of accepting a client connection. + pub async fn accept(&self, stream: S) -> io::Result> { + match &self.0 { + #[cfg(feature = "native-tls")] + TlsAcceptorInner::NativeTls(c) => { + handshake_native_tls(c.accept(StreamWrapper::new(stream))).await + } + #[cfg(feature = "rustls")] + TlsAcceptorInner::Rustls(c) => handshake_rustls(c.accept(stream)).await, + } + } +} + +#[cfg(feature = "native-tls")] +async fn handshake_native_tls( + mut res: Result< + native_tls::TlsStream>, + native_tls::HandshakeError>, + >, +) -> io::Result> { + use native_tls::HandshakeError; + + loop { + match res { + Ok(mut s) => { + s.get_mut().flush_write_buf().await?; + return Ok(TlsStream::from(s)); + } + Err(e) => match e { + HandshakeError::Failure(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + HandshakeError::WouldBlock(mut mid_stream) => { + if mid_stream.get_mut().flush_write_buf().await? == 0 { + mid_stream.get_mut().fill_read_buf().await?; + } + res = mid_stream.handshake(); + } + }, + } + } +} + +#[cfg(feature = "rustls")] +async fn handshake_rustls( + mut res: Result, rtls::HandshakeError>, +) -> io::Result> +where + C: std::ops::DerefMut>, +{ + use rtls::HandshakeError; + + loop { + match res { + Ok(mut s) => { + s.flush().await?; + return Ok(s); + } + Err(e) => match e { + HandshakeError::Rustls(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + HandshakeError::System(e) => return Err(e), + HandshakeError::WouldBlock(mut mid_stream) => { + if mid_stream.get_mut().flush_write_buf().await? == 0 { + mid_stream.get_mut().fill_read_buf().await?; + } + res = mid_stream.handshake::(); + } + }, + } + } +} diff --git a/compio-tls/src/adapter/rtls.rs b/compio-tls/src/adapter/rtls.rs new file mode 100644 index 00000000..350a8c9c --- /dev/null +++ b/compio-tls/src/adapter/rtls.rs @@ -0,0 +1,136 @@ +use std::{io, ops::DerefMut, sync::Arc}; + +use compio_io::{AsyncRead, AsyncWrite}; +use rustls::{ + ClientConfig, ClientConnection, ConnectionCommon, Error, ServerConfig, ServerConnection, + ServerName, +}; + +use crate::{wrapper::StreamWrapper, TlsStream}; + +pub enum HandshakeError { + Rustls(Error), + System(io::Error), + WouldBlock(MidStream), +} + +pub struct MidStream { + stream: StreamWrapper, + conn: C, + result_fn: fn(StreamWrapper, C) -> TlsStream, +} + +impl MidStream { + pub fn new( + stream: StreamWrapper, + conn: C, + result_fn: fn(StreamWrapper, C) -> TlsStream, + ) -> Self { + Self { + stream, + conn, + result_fn, + } + } + + pub fn get_mut(&mut self) -> &mut StreamWrapper { + &mut self.stream + } + + pub fn handshake(mut self) -> Result, HandshakeError> + where + C: DerefMut>, + S: AsyncRead + AsyncWrite, + { + loop { + let mut write_would_block = false; + let mut read_would_block = false; + + while self.conn.wants_write() { + match self.conn.write_tls(&mut self.stream) { + Ok(_) => { + write_would_block = true; + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + write_would_block = true; + break; + } + Err(e) => return Err(HandshakeError::System(e)), + } + } + + while !self.stream.is_eof() && self.conn.wants_read() { + match self.conn.read_tls(&mut self.stream) { + Ok(_) => { + self.conn + .process_new_packets() + .map_err(HandshakeError::Rustls)?; + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + read_would_block = true; + break; + } + Err(e) => return Err(HandshakeError::System(e)), + } + } + + return match (self.stream.is_eof(), self.conn.is_handshaking()) { + (true, true) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Err(HandshakeError::System(err)) + } + (_, false) => Ok((self.result_fn)(self.stream, self.conn)), + (_, true) if write_would_block || read_would_block => { + Err(HandshakeError::WouldBlock(self)) + } + _ => continue, + }; + } + } +} + +#[derive(Debug, Clone)] +pub struct TlsConnector(pub Arc); + +impl TlsConnector { + #[allow(clippy::result_large_err)] + pub fn connect( + &self, + domain: &str, + stream: S, + ) -> Result, HandshakeError> { + let conn = ClientConnection::new( + self.0.clone(), + ServerName::try_from(domain) + .map_err(|e| HandshakeError::System(io::Error::new(io::ErrorKind::Other, e)))?, + ) + .map_err(HandshakeError::Rustls)?; + + MidStream::new( + StreamWrapper::new(stream), + conn, + TlsStream::::new_rustls_client, + ) + .handshake() + } +} + +#[derive(Debug, Clone)] +pub struct TlsAcceptor(pub Arc); + +impl TlsAcceptor { + #[allow(clippy::result_large_err)] + pub fn accept( + &self, + stream: S, + ) -> Result, HandshakeError> { + let conn = ServerConnection::new(self.0.clone()).map_err(HandshakeError::Rustls)?; + + MidStream::new( + StreamWrapper::new(stream), + conn, + TlsStream::::new_rustls_server, + ) + .handshake() + } +} diff --git a/compio-tls/src/stream.rs b/compio-tls/src/stream/mod.rs similarity index 51% rename from compio-tls/src/stream.rs rename to compio-tls/src/stream/mod.rs index c648246e..db8afa92 100644 --- a/compio-tls/src/stream.rs +++ b/compio-tls/src/stream/mod.rs @@ -5,6 +5,60 @@ use compio_io::{AsyncRead, AsyncWrite}; use crate::StreamWrapper; +#[cfg(feature = "rustls")] +mod rtls; + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum TlsStreamInner { + #[cfg(feature = "native-tls")] + NativeTls(native_tls::TlsStream>), + #[cfg(feature = "rustls")] + Rustls(rtls::TlsStream>), +} + +impl TlsStreamInner { + fn get_mut(&mut self) -> &mut StreamWrapper { + match self { + #[cfg(feature = "native-tls")] + Self::NativeTls(s) => s.get_mut(), + #[cfg(feature = "rustls")] + Self::Rustls(s) => s.get_mut(), + } + } +} + +impl io::Read for TlsStreamInner { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + #[cfg(feature = "native-tls")] + Self::NativeTls(s) => io::Read::read(s, buf), + #[cfg(feature = "rustls")] + Self::Rustls(s) => io::Read::read(s, buf), + } + } +} + +impl io::Write for TlsStreamInner { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + #[cfg(feature = "native-tls")] + Self::NativeTls(s) => io::Write::write(s, buf), + #[cfg(feature = "rustls")] + Self::Rustls(s) => io::Write::write(s, buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self { + #[cfg(feature = "native-tls")] + Self::NativeTls(s) => io::Write::flush(s), + #[cfg(feature = "rustls")] + Self::Rustls(s) => io::Write::flush(s), + } + } +} + /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. /// @@ -13,15 +67,29 @@ use crate::StreamWrapper; /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written /// to a `TlsStream` are encrypted when passing through to `S`. #[derive(Debug)] -pub struct TlsStream(native_tls::TlsStream>); +pub struct TlsStream(TlsStreamInner); +impl TlsStream { + #[cfg(feature = "rustls")] + pub(crate) fn new_rustls_client(s: StreamWrapper, conn: rustls::ClientConnection) -> Self { + Self(TlsStreamInner::Rustls(rtls::TlsStream::new_client(s, conn))) + } + + #[cfg(feature = "rustls")] + pub(crate) fn new_rustls_server(s: StreamWrapper, conn: rustls::ServerConnection) -> Self { + Self(TlsStreamInner::Rustls(rtls::TlsStream::new_server(s, conn))) + } +} + +#[cfg(feature = "native-tls")] +#[doc(hidden)] impl From>> for TlsStream { fn from(value: native_tls::TlsStream>) -> Self { - Self(value) + Self(TlsStreamInner::NativeTls(value)) } } -impl AsyncRead for TlsStream { +impl AsyncRead for TlsStream { async fn read(&mut self, mut buf: B) -> BufResult { let slice: &mut [MaybeUninit] = buf.as_mut_slice(); slice.fill(MaybeUninit::new(0)); @@ -35,6 +103,10 @@ impl AsyncRead for TlsStream { return BufResult(Ok(res), buf); } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.flush().await { + Ok(()) => {} + Err(e) => return BufResult(Err(e), buf), + } match self.0.get_mut().fill_read_buf().await { Ok(_) => continue, Err(e) => return BufResult(Err(e), buf), diff --git a/compio-tls/src/stream/rtls.rs b/compio-tls/src/stream/rtls.rs new file mode 100644 index 00000000..6b3c2cec --- /dev/null +++ b/compio-tls/src/stream/rtls.rs @@ -0,0 +1,108 @@ +use std::io; + +use rustls::{ClientConnection, Error, IoState, Reader, ServerConnection, Writer}; + +#[derive(Debug)] +enum TlsConnection { + Client(ClientConnection), + Server(ServerConnection), +} + +impl TlsConnection { + pub fn reader(&mut self) -> Reader<'_> { + match self { + Self::Client(c) => c.reader(), + Self::Server(c) => c.reader(), + } + } + + pub fn writer(&mut self) -> Writer<'_> { + match self { + Self::Client(c) => c.writer(), + Self::Server(c) => c.writer(), + } + } + + pub fn process_new_packets(&mut self) -> Result { + match self { + Self::Client(c) => c.process_new_packets(), + Self::Server(c) => c.process_new_packets(), + } + } + + pub fn read_tls(&mut self, rd: &mut dyn io::Read) -> io::Result { + match self { + Self::Client(c) => c.read_tls(rd), + Self::Server(c) => c.read_tls(rd), + } + } + + pub fn write_tls(&mut self, wr: &mut dyn io::Write) -> io::Result { + match self { + Self::Client(c) => c.write_tls(wr), + Self::Server(c) => c.write_tls(wr), + } + } +} + +#[derive(Debug)] +pub struct TlsStream { + inner: S, + conn: TlsConnection, +} + +impl TlsStream { + pub fn new_client(inner: S, conn: ClientConnection) -> Self { + Self { + inner, + conn: TlsConnection::Client(conn), + } + } + + pub fn new_server(inner: S, conn: ServerConnection) -> Self { + Self { + inner, + conn: TlsConnection::Server(conn), + } + } + + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } +} + +impl io::Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + loop { + match self.conn.reader().read(buf) { + Ok(len) => { + return Ok(len); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.conn.read_tls(&mut self.inner)?; + let state = self + .conn + .process_new_packets() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + if state.tls_bytes_to_write() > 0 { + io::Write::flush(self)?; + } + } + Err(e) => return Err(e), + } + } + } +} + +impl io::Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.flush()?; + self.conn.writer().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.conn.write_tls(&mut self.inner)?; + self.inner.flush()?; + Ok(()) + } +} diff --git a/compio-tls/src/wrapper.rs b/compio-tls/src/wrapper.rs index 0e9e2c3f..014b5b0f 100644 --- a/compio-tls/src/wrapper.rs +++ b/compio-tls/src/wrapper.rs @@ -27,6 +27,10 @@ impl StreamWrapper { } } + pub fn is_eof(&self) -> bool { + self.eof + } + pub fn get_ref(&self) -> &S { &self.stream } diff --git a/compio-tls/tests/connect.rs b/compio-tls/tests/connect.rs new file mode 100644 index 00000000..070e2ac2 --- /dev/null +++ b/compio-tls/tests/connect.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use compio_io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use compio_net::TcpStream; +use compio_tls::TlsConnector; + +async fn connect(connector: TlsConnector) { + let stream = TcpStream::connect("www.example.com:443").await.unwrap(); + let mut stream = connector.connect("www.example.com", stream).await.unwrap(); + + stream + .write_all("GET / HTTP/1.1\r\nHost:www.example.com\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + stream.flush().await.unwrap(); + let (_, res) = stream.read_to_end(vec![]).await.unwrap(); + println!("{}", String::from_utf8_lossy(&res)); +} + +#[cfg(feature = "native-tls")] +#[compio_macros::test] +async fn native() { + let connector = TlsConnector::from(native_tls::TlsConnector::new().unwrap()); + + connect(connector).await; +} + +#[cfg(feature = "rustls")] +#[compio_macros::test] +async fn rtls() { + let mut store = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().unwrap() { + store.add(&rustls::Certificate(cert.0)).unwrap(); + } + + let connector = TlsConnector::from(Arc::new( + rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(store) + .with_no_client_auth(), + )); + + connect(connector).await; +}