diff --git a/Cargo.toml b/Cargo.toml index c5b44046..64d0a3cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ exclude = ["/.github", "/examples", "/scripts"] [dependencies] tokio = "1.0" -rustls = { version = "0.22", default-features = false } +rustls = { version = "0.23", default-features = false, features = ["std"] } pki-types = { package = "rustls-pki-types", version = "1" } [features] diff --git a/src/common/handshake.rs b/src/common/handshake.rs index ac78165a..d541f992 100644 --- a/src/common/handshake.rs +++ b/src/common/handshake.rs @@ -4,10 +4,11 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; +use rustls::server::AcceptedAlert; use rustls::{ConnectionCommon, SideData}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::common::{Stream, TlsState}; +use crate::common::{Stream, SyncWriteAdapter, TlsState}; pub(crate) trait IoSession { type Io; @@ -21,7 +22,15 @@ pub(crate) trait IoSession { pub(crate) enum MidHandshake { Handshaking(IS), End, - Error { io: IS::Io, error: io::Error }, + SendAlert { + io: IS::Io, + alert: AcceptedAlert, + error: io::Error, + }, + Error { + io: IS::Io, + error: io::Error, + }, } impl Future for MidHandshake @@ -38,6 +47,15 @@ where let mut stream = match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => stream, + MidHandshake::SendAlert { + mut io, + mut alert, + error, + } => { + let mut writer = SyncWriteAdapter { io: &mut io, cx }; + let _ = alert.write(&mut writer); // best effort + return Poll::Ready(Err((error, io))); + } // Starting the handshake returned an error; fail the future immediately. MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), _ => panic!("unexpected polling after handshake"), diff --git a/src/common/mod.rs b/src/common/mod.rs index 664f8f93..442e295d 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -118,41 +118,7 @@ where } pub fn write_io(&mut self, cx: &mut Context) -> Poll> { - struct Writer<'a, 'b, T> { - io: &'a mut T, - cx: &'a mut Context<'b>, - } - - impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> { - #[inline] - fn poll_with( - &mut self, - f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, - ) -> io::Result { - match f(Pin::new(self.io), self.cx) { - Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), - } - } - } - - impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - self.poll_with(|io, cx| io.poll_write(cx, buf)) - } - - #[inline] - fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) - } - - fn flush(&mut self) -> io::Result<()> { - self.poll_with(|io, cx| io.poll_flush(cx)) - } - } - - let mut writer = Writer { io: self.io, cx }; + let mut writer = SyncWriteAdapter { io: self.io, cx }; match self.session.write_tls(&mut writer) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, @@ -360,5 +326,43 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { } } +/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an +/// associated [`Context`]. +/// +/// Turns `Poll::Pending` into `WouldBlock`. +pub struct SyncWriteAdapter<'a, 'b, T> { + pub io: &'a mut T, + pub cx: &'a mut Context<'b>, +} + +impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> { + #[inline] + fn poll_with( + &mut self, + f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, + ) -> io::Result { + match f(Pin::new(self.io), self.cx) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } +} + +impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.poll_with(|io, cx| io.poll_write(cx, buf)) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) + } + + fn flush(&mut self) -> io::Result<()> { + self.poll_with(|io, cx| io.poll_flush(cx)) + } +} + #[cfg(test)] mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index ccf7f7e1..1b3119c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -288,8 +288,10 @@ where return Poll::Ready(Ok(StartHandshake { accepted, io })); } Ok(None) => continue, - Err(err) => { - return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))) + Err((err, mut alert)) => { + let mut writer = common::SyncWriteAdapter { io, cx }; + let _ = alert.write(&mut writer); // best effort + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))); } } } @@ -319,9 +321,10 @@ where { let mut conn = match self.accepted.into_connection(config) { Ok(conn) => conn, - Err(error) => { - return Accept(MidHandshake::Error { + Err((error, alert)) => { + return Accept(MidHandshake::SendAlert { io: self.io, + alert, // TODO(eliza): should this really return an `io::Error`? // Probably not... error: io::Error::new(io::ErrorKind::Other, error), @@ -361,6 +364,7 @@ impl Connect { pub fn get_ref(&self) -> Option<&IO> { match &self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), + MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } @@ -369,6 +373,7 @@ impl Connect { pub fn get_mut(&mut self) -> Option<&mut IO> { match &mut self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), + MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } @@ -384,6 +389,7 @@ impl Accept { pub fn get_ref(&self) -> Option<&IO> { match &self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), + MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } @@ -392,6 +398,7 @@ impl Accept { pub fn get_mut(&mut self) -> Option<&mut IO> { match &mut self.0 { MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), + MidHandshake::SendAlert { io, .. } => Some(io), MidHandshake::Error { io, .. } => Some(io), MidHandshake::End => None, } diff --git a/tests/test.rs b/tests/test.rs index 5598fd46..8d3921b7 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -223,12 +223,16 @@ async fn lazy_config_acceptor_take_io() -> Result<(), rustls::Error> { } let server_msg = b"message from server"; + let fatal_alert_decode_error = b"\x15\x03\x03\x00\x02\x02\x32"; let some_io = acceptor.take_io(); assert!(some_io.is_some(), "Expected Some(io)"); some_io.unwrap().write_all(server_msg).await.unwrap(); - assert_eq!(rx.await.unwrap(), server_msg); + assert_eq!( + rx.await.unwrap(), + [&fatal_alert_decode_error[..], &server_msg[..]].concat() + ); assert!( acceptor.take_io().is_none(),