Skip to content

Commit

Permalink
Take rustls 0.23
Browse files Browse the repository at this point in the history
- track new alert-sending API for Acceptor.
  • Loading branch information
ctz authored and djc committed Mar 5, 2024
1 parent 096b161 commit 0587801
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ exclude = ["/.github", "/examples", "/scripts"]

[dependencies]
tokio = "1.0"
rustls = { version = "0.22", default-features = false }
rustls = { version = "0.23", default-features = false, features = ["std"] }
pki-types = { package = "rustls-pki-types", version = "1" }

[features]
Expand Down
22 changes: 20 additions & 2 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};

use rustls::server::AcceptedAlert;
use rustls::{ConnectionCommon, SideData};
use tokio::io::{AsyncRead, AsyncWrite};

use crate::common::{Stream, TlsState};
use crate::common::{Stream, SyncWriteAdapter, TlsState};

pub(crate) trait IoSession {
type Io;
Expand All @@ -21,7 +22,15 @@ pub(crate) trait IoSession {
pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
End,
Error { io: IS::Io, error: io::Error },
SendAlert {
io: IS::Io,
alert: AcceptedAlert,
error: io::Error,
},
Error {
io: IS::Io,
error: io::Error,
},
}

impl<IS, SD> Future for MidHandshake<IS>
Expand All @@ -38,6 +47,15 @@ where

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
MidHandshake::SendAlert {
mut io,
mut alert,
error,
} => {
let mut writer = SyncWriteAdapter { io: &mut io, cx };
let _ = alert.write(&mut writer); // best effort
return Poll::Ready(Err((error, io)));
}
// Starting the handshake returned an error; fail the future immediately.
MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))),
_ => panic!("unexpected polling after handshake"),
Expand Down
15 changes: 11 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,10 @@ where
return Poll::Ready(Ok(StartHandshake { accepted, io }));
}
Ok(None) => continue,
Err(err) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
Err((err, mut alert)) => {
let mut writer = common::SyncWriteAdapter { io, cx };
let _ = alert.write(&mut writer); // best effort
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)));
}
}
}
Expand Down Expand Up @@ -319,9 +321,10 @@ where
{
let mut conn = match self.accepted.into_connection(config) {
Ok(conn) => conn,
Err(error) => {
return Accept(MidHandshake::Error {
Err((error, alert)) => {
return Accept(MidHandshake::SendAlert {
io: self.io,
alert,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
Expand Down Expand Up @@ -361,6 +364,7 @@ impl<IO> Connect<IO> {
pub fn get_ref(&self) -> Option<&IO> {
match &self.0 {
MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
MidHandshake::SendAlert { io, .. } => Some(io),
MidHandshake::Error { io, .. } => Some(io),
MidHandshake::End => None,
}
Expand All @@ -369,6 +373,7 @@ impl<IO> Connect<IO> {
pub fn get_mut(&mut self) -> Option<&mut IO> {
match &mut self.0 {
MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
MidHandshake::SendAlert { io, .. } => Some(io),
MidHandshake::Error { io, .. } => Some(io),
MidHandshake::End => None,
}
Expand All @@ -384,6 +389,7 @@ impl<IO> Accept<IO> {
pub fn get_ref(&self) -> Option<&IO> {
match &self.0 {
MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
MidHandshake::SendAlert { io, .. } => Some(io),
MidHandshake::Error { io, .. } => Some(io),
MidHandshake::End => None,
}
Expand All @@ -392,6 +398,7 @@ impl<IO> Accept<IO> {
pub fn get_mut(&mut self) -> Option<&mut IO> {
match &mut self.0 {
MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
MidHandshake::SendAlert { io, .. } => Some(io),
MidHandshake::Error { io, .. } => Some(io),
MidHandshake::End => None,
}
Expand Down
6 changes: 5 additions & 1 deletion tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,16 @@ async fn lazy_config_acceptor_take_io() -> Result<(), rustls::Error> {
}

let server_msg = b"message from server";
let fatal_alert_decode_error = b"\x15\x03\x03\x00\x02\x02\x32";

let some_io = acceptor.take_io();
assert!(some_io.is_some(), "Expected Some(io)");
some_io.unwrap().write_all(server_msg).await.unwrap();

assert_eq!(rx.await.unwrap(), server_msg);
assert_eq!(
rx.await.unwrap(),
[&fatal_alert_decode_error[..], &server_msg[..]].concat()
);

assert!(
acceptor.take_io().is_none(),
Expand Down

0 comments on commit 0587801

Please sign in to comment.