Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support peer_addr and local_addr #9

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn trace_error(error: io::Error) -> io::Error {
error
}

pub struct ImplementReadTrait<'a, T>(pub &'a mut T);
pub struct ImplementReadTrait<'a, T>(pub &'a T);

impl Read for ImplementReadTrait<'_, TcpStream> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
Expand All @@ -36,7 +36,7 @@ impl Read for ImplementReadTrait<'_, TcpStream> {
}
}

pub struct ImplementWriteTrait<'a, T>(pub &'a mut T);
pub struct ImplementWriteTrait<'a, T>(pub &'a T);

impl Write for ImplementWriteTrait<'_, TcpStream> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Expand All @@ -54,18 +54,12 @@ impl Write for ImplementWriteTrait<'_, TcpStream> {
}
}

pub fn read_tls(
tcp: &mut TcpStream,
tls: &mut Connection,
) -> io::Result<usize> {
pub fn read_tls(tcp: &TcpStream, tls: &mut Connection) -> io::Result<usize> {
let mut read = ImplementReadTrait(tcp);
tls.read_tls(&mut read)
}

pub fn write_tls(
tcp: &mut TcpStream,
tls: &mut Connection,
) -> io::Result<usize> {
pub fn write_tls(tcp: &TcpStream, tls: &mut Connection) -> io::Result<usize> {
let mut write = ImplementWriteTrait(tcp);
tls.write_tls(&mut write)
}
16 changes: 10 additions & 6 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ impl ConnectionStream {
(self.tcp, self.tls)
}

pub(crate) fn tcp_stream(&self) -> &TcpStream {
&self.tcp
}

#[cfg(test)]
pub fn plaintext_bytes_to_read(&self) -> usize {
self
Expand All @@ -123,7 +127,7 @@ impl ConnectionStream {
StreamProgress::Error
} else if self.tls.wants_read() {
loop {
match read_tls(&mut self.tcp, &mut self.tls) {
match read_tls(&self.tcp, &mut self.tls) {
Ok(n) => {
if n == 0 {
self.rd_error = Some(ErrorKind::UnexpectedEof);
Expand Down Expand Up @@ -169,7 +173,7 @@ impl ConnectionStream {
} else if self.tls.wants_write() {
loop {
debug_assert!(self.tls.wants_write());
match write_tls(&mut self.tcp, &mut self.tls) {
match write_tls(&self.tcp, &mut self.tls) {
Ok(n) => {
assert!(n > 0);
break StreamProgress::MadeProgress;
Expand Down Expand Up @@ -580,10 +584,10 @@ mod tests {
ClientConnection::new(client_config().into(), server_name())
.unwrap()
.into();
let server = spawn(handshake_task(server, tls_server));
let client = spawn(handshake_task(client, tls_client));
let (tcp_client, tls_client) = client.await.unwrap().unwrap();
let (tcp_server, tls_server) = server.await.unwrap().unwrap();
let server = spawn(handshake_task(server.into(), tls_server));
let client = spawn(handshake_task(client.into(), tls_client));
let (tcp_client, tls_client) = client.await.unwrap().unwrap().reclaim();
let (tcp_server, tls_server) = server.await.unwrap().unwrap().reclaim();
assert!(!tls_client.is_handshaking());
assert!(!tls_server.is_handshaking());

Expand Down
55 changes: 39 additions & 16 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use rustls::Connection;
use std::io;
use std::io::ErrorKind;
use std::sync::Arc;

use tokio::net::TcpStream;

Expand All @@ -10,7 +11,7 @@ use crate::adapter::write_tls;
use crate::TestOptions;

async fn try_read<'a, 'b>(
tcp: &'a mut TcpStream,
tcp: &'a TcpStream,
tls: &'b mut Connection,
) -> io::Result<()> {
match read_tls(tcp, tls) {
Expand All @@ -37,7 +38,7 @@ async fn try_read<'a, 'b>(
}

async fn try_write<'a, 'b>(
tcp: &'a mut TcpStream,
tcp: &'a TcpStream,
tls: &'b mut Connection,
) -> io::Result<()> {
match write_tls(tcp, tls) {
Expand All @@ -52,19 +53,41 @@ async fn try_write<'a, 'b>(
Ok(())
}

#[derive(Debug)]
pub(crate) struct HandshakeResult(Arc<TcpStream>, pub Connection);

impl HandshakeResult {
#[cfg(test)]
pub fn reclaim(self) -> (TcpStream, Connection) {
(
Arc::into_inner(self.0).expect("Failed to reclaim TCP"),
self.1,
)
}

pub fn reclaim2(self, tcp: Arc<TcpStream>) -> (TcpStream, Connection) {
drop(tcp);
(
Arc::into_inner(self.0).expect("Failed to reclaim TCP"),
self.1,
)
}
}

/// Performs a handshake and returns the [`TcpStream`]/[`Connection`] pair if successful.
pub async fn handshake_task(
tcp: TcpStream,
#[cfg(test)]
pub(crate) async fn handshake_task(
tcp: Arc<TcpStream>,
tls: Connection,
) -> io::Result<(TcpStream, Connection)> {
) -> io::Result<HandshakeResult> {
handshake_task_internal(tcp, tls, TestOptions::default()).await
}

pub(crate) async fn handshake_task_internal(
mut tcp: TcpStream,
tcp: Arc<TcpStream>,
mut tls: Connection,
test_options: TestOptions,
) -> io::Result<(TcpStream, Connection)> {
) -> io::Result<HandshakeResult> {
#[cfg(not(test))]
{
_ = test_options;
Expand All @@ -83,7 +106,7 @@ pub(crate) async fn handshake_task_internal(
if test_options.slow_handshake_write {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
match try_write(&mut tcp, &mut tls).await {
match try_write(&tcp, &mut tls).await {
Ok(()) => {}
Err(err) => {
struct WriteSink();
Expand Down Expand Up @@ -112,7 +135,7 @@ pub(crate) async fn handshake_task_internal(
return Err(err);
} else {
// Not handshaking, no write interest, pretend we succeeded and pick up the error later.
return Ok((tcp, tls));
return Ok(HandshakeResult(tcp, tls));
}
}
}
Expand All @@ -129,10 +152,10 @@ pub(crate) async fn handshake_task_internal(
if test_options.slow_handshake_read {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
try_read(&mut tcp, &mut tls).await?;
try_read(&tcp, &mut tls).await?;
}
}
Ok((tcp, tls))
Ok(HandshakeResult(tcp, tls))
}

#[cfg(test)]
Expand All @@ -144,7 +167,7 @@ mod tests {
use crate::tests::TestResult;
use rustls::ClientConnection;
use rustls::ServerConnection;
use tokio::spawn;
use tokio::task::spawn;

#[tokio::test]
async fn test_handshake() -> TestResult {
Expand All @@ -156,10 +179,10 @@ mod tests {
ClientConnection::new(client_config().into(), server_name())
.unwrap()
.into();
let server = spawn(handshake_task(server, tls_server));
let client = spawn(handshake_task(client, tls_client));
let (tcp_client, tls_client) = client.await.unwrap().unwrap();
let (tcp_server, tls_server) = server.await.unwrap().unwrap();
let server = spawn(handshake_task(server.into(), tls_server));
let client = spawn(handshake_task(client.into(), tls_client));
let (tcp_client, tls_client) = client.await.unwrap().unwrap().reclaim();
let (tcp_server, tls_server) = server.await.unwrap().unwrap().reclaim();
assert!(!tls_client.is_handshaking());
assert!(!tls_server.is_handshaking());
// Don't let these drop until the handshake finishes on both sides
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ mod stream;
#[cfg(test)]
mod system_test;

pub use handshake::handshake_task;
pub use stream::TlsHandshake;
pub use stream::TlsStream;

Expand Down
Loading