Skip to content

Commit

Permalink
Forward vectored writes (#45)
Browse files Browse the repository at this point in the history
* Migrate early-data test to rustls

* Replace `match` with `if let` on `TlsState::EarlyData`

* Extract client early data handling

* Forward vectored writes
  • Loading branch information
paolobarbolini authored Mar 17, 2024
1 parent 925a87f commit 3a153ac
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 168 deletions.
180 changes: 120 additions & 60 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use std::pin::Pin;
#[cfg(feature = "early-data")]
use std::task::Waker;
use std::task::{Context, Poll};

use rustls::ClientConnection;
Expand All @@ -20,7 +22,7 @@ pub struct TlsStream<IO> {
pub(crate) state: TlsState,

#[cfg(feature = "early-data")]
pub(crate) early_waker: Option<std::task::Waker>,
pub(crate) early_waker: Option<Waker>,
}

impl<IO> TlsStream<IO> {
Expand Down Expand Up @@ -152,78 +154,70 @@ where
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[allow(clippy::match_single_binding)]
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};
if len != 0 {
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}

stream.as_mut_pin().poll_write(cx, buf)
#[cfg(feature = "early-data")]
{
let bufs = [io::IoSlice::new(buf)];
let written = ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
&bufs
))?;
if written != 0 {
return Poll::Ready(Ok(written));
}
_ => stream.as_mut_pin().poll_write(cx, buf),
}

stream.as_mut_pin().poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[cfg(feature = "early-data")]
{
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}
let written = ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
bufs
))?;
if written != 0 {
return Poll::Ready(Ok(written));
}
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}
stream.as_mut_pin().poll_write_vectored(cx, bufs)
}

this.state = TlsState::Stream;
#[inline]
fn is_write_vectored(&self) -> bool {
true
}

if let Some(waker) = this.early_waker.take() {
waker.wake();
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[cfg(feature = "early-data")]
ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
&[]
))?;

stream.as_mut_pin().poll_flush(cx)
}
Expand All @@ -248,3 +242,69 @@ where
stream.as_mut_pin().poll_shutdown(cx)
}
}

#[cfg(feature = "early-data")]
fn poll_handle_early_data<IO>(
state: &mut TlsState,
stream: &mut Stream<IO, ClientConnection>,
early_waker: &mut Option<Waker>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
if let TlsState::EarlyData(pos, data) = state {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let mut written = 0;

for buf in bufs {
if buf.is_empty() {
continue;
}

let len = match early_data.write(buf) {
Ok(0) => break,
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};

written += len;
data.extend_from_slice(&buf[..len]);

if len < buf.len() {
break;
}
}

if written != 0 {
return Poll::Ready(Ok(written));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
*state = TlsState::Stream;

if let Some(waker) = early_waker.take() {
waker.wake();
}
}

Poll::Ready(Ok(0))
}
37 changes: 37 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,43 @@ where
Poll::Ready(Ok(pos))
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if bufs.iter().all(|buf| buf.is_empty()) {
return Poll::Ready(Ok(0));
}

loop {
let mut would_block = false;
let written = self.session.writer().write_vectored(bufs)?;

while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

return match (written, would_block) {
(0, true) => Poll::Pending,
(0, false) => continue,
(n, _) => Poll::Ready(Ok(n)),
};
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.writer().flush()?;
while self.session.wants_write() {
Expand Down
11 changes: 10 additions & 1 deletion src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ impl AsyncWrite for Expected {

#[tokio::test]
async fn stream_good() -> io::Result<()> {
stream_good_impl(false).await
}

#[tokio::test]
async fn stream_good_vectored() -> io::Result<()> {
stream_good_impl(true).await
}

async fn stream_good_impl(vectored: bool) -> io::Result<()> {
const FILE: &[u8] = include_bytes!("../../README.md");

let (server, mut client) = make_pair();
Expand All @@ -139,7 +148,7 @@ async fn stream_good() -> io::Result<()> {
dbg!(stream.read_to_end(&mut buf).await)?;
assert_eq!(buf, FILE);

dbg!(stream.write_all(b"Hello World!").await)?;
dbg!(utils::write(&mut stream, b"Hello World!", vectored).await)?;
stream.session.send_close_notify();

dbg!(stream.shutdown().await)?;
Expand Down
20 changes: 20 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,26 @@ where
}
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs),
TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs),
}
}

#[inline]
fn is_write_vectored(&self) -> bool {
match self {
TlsStream::Client(x) => x.is_write_vectored(),
TlsStream::Server(x) => x.is_write_vectored(),
}
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Expand Down
18 changes: 18 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ where
stream.as_mut_pin().poll_write(cx, buf)
}

/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_write_vectored(cx, bufs)
}

#[inline]
fn is_write_vectored(&self) -> bool {
true
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Expand Down
Loading

0 comments on commit 3a153ac

Please sign in to comment.