Skip to content

Commit

Permalink
Merge pull request #206 from Berrysoft/dev/split
Browse files Browse the repository at this point in the history
feat(io,net): add split
  • Loading branch information
Berrysoft authored Feb 4, 2024
2 parents a26dbe5 + 502685f commit 4cabfec
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 3 deletions.
1 change: 1 addition & 0 deletions compio-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repository = { workspace = true }

[dependencies]
compio-buf = { workspace = true, features = ["arrayvec"] }
futures-util = { workspace = true }
paste = { workspace = true }

[dev-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions compio-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ mod buffer;
#[cfg(feature = "compat")]
pub mod compat;
mod read;
mod split;
pub mod util;
mod write;

pub(crate) type IoResult<T> = std::io::Result<T>;

pub use read::*;
pub use split::*;
pub use util::{copy, null, repeat};
pub use write::*;
75 changes: 75 additions & 0 deletions compio-io/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::sync::Arc;

use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use futures_util::lock::Mutex;

use crate::{AsyncRead, AsyncWrite, IoResult};

/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
/// [`AsyncRead`] and [`AsyncWrite`] handles.
pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
let stream = Arc::new(Mutex::new(stream));
(ReadHalf(stream.clone()), WriteHalf(stream))
}

/// The readable half of a value returned from [`split`].
#[derive(Debug)]
pub struct ReadHalf<T>(Arc<Mutex<T>>);

impl<T: Unpin> ReadHalf<T> {
/// Reunites with a previously split [`WriteHalf`].
///
/// # Panics
///
/// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from
/// the same [`split`] operation this method will panic.
/// This can be checked ahead of time by comparing the stored pointer
/// of the two halves.
#[track_caller]
pub fn unsplit(self, w: WriteHalf<T>) -> T {
if Arc::ptr_eq(&self.0, &w.0) {
drop(w);
let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");
inner.into_inner()
} else {
#[cold]
fn panic_unrelated() -> ! {
panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.")
}

panic_unrelated()
}
}
}

impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.lock().await.read_vectored(buf).await
}
}

/// The writable half of a value returned from [`split`].
#[derive(Debug)]
pub struct WriteHalf<T>(Arc<Mutex<T>>);

impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.lock().await.write_vectored(buf).await
}

async fn flush(&mut self) -> IoResult<()> {
self.0.lock().await.flush().await
}

async fn shutdown(&mut self) -> IoResult<()> {
self.0.lock().await.shutdown().await
}
}
18 changes: 17 additions & 1 deletion compio-io/tests/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::Cursor;

use compio_buf::{arrayvec::ArrayVec, BufResult, IoBuf, IoBufMut};
use compio_io::{
AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt,
split, AsyncRead, AsyncReadAt, AsyncReadAtExt, AsyncReadExt, AsyncWrite, AsyncWriteAt,
AsyncWriteAtExt, AsyncWriteExt,
};

Expand Down Expand Up @@ -355,3 +355,19 @@ async fn read_to_end_at() {
assert_eq!(len, 4);
assert_eq!(buf, [4, 5, 1, 4]);
}

#[tokio::test]
async fn split_unsplit() {
let src = Cursor::new([1, 1, 4, 5, 1, 4]);
let (mut read, mut write) = split(src);

let (len, buf) = read.read([0, 0, 0]).await.unwrap();
assert_eq!(len, 3);
assert_eq!(buf, [1, 1, 4]);

let (len, _) = write.write([2, 2, 2]).await.unwrap();
assert_eq!(len, 3);

let src = read.unsplit(write);
assert_eq!(src.into_inner(), [1, 1, 4, 2, 2, 2]);
}
2 changes: 2 additions & 0 deletions compio-net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

mod resolve;
mod socket;
pub(crate) mod split;
mod tcp;
mod udp;
mod unix;

pub use resolve::ToSocketAddrsAsync;
pub(crate) use resolve::{each_addr, first_addr_buf};
pub(crate) use socket::*;
pub use split::*;
pub use tcp::*;
pub use udp::*;
pub use unix::*;
135 changes: 135 additions & 0 deletions compio-net/src/split.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::{error::Error, fmt, io, ops::Deref, sync::Arc};

use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
use compio_io::{AsyncRead, AsyncWrite};

pub(crate) fn split<T>(stream: &T) -> (ReadHalf<T>, WriteHalf<T>)
where
for<'a> &'a T: AsyncRead + AsyncWrite,
{
(ReadHalf(stream), WriteHalf(stream))
}

/// Borrowed read half.
#[derive(Debug)]
pub struct ReadHalf<'a, T>(&'a T);

impl<T> AsyncRead for ReadHalf<'_, T>
where
for<'a> &'a T: AsyncRead,
{
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.read_vectored(buf).await
}
}

/// Borrowed write half.
#[derive(Debug)]
pub struct WriteHalf<'a, T>(&'a T);

impl<T> AsyncWrite for WriteHalf<'_, T>
where
for<'a> &'a T: AsyncWrite,
{
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.write_vectored(buf).await
}

async fn flush(&mut self) -> io::Result<()> {
self.0.flush().await
}

async fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown().await
}
}

pub(crate) fn into_split<T>(stream: T) -> (OwnedReadHalf<T>, OwnedWriteHalf<T>)
where
for<'a> &'a T: AsyncRead + AsyncWrite,
{
let stream = Arc::new(stream);
(OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream))
}

/// Owned read half.
#[derive(Debug)]
pub struct OwnedReadHalf<T>(Arc<T>);

impl<T: Unpin> OwnedReadHalf<T> {
/// Attempts to put the two halves of a `TcpStream` back together and
/// recover the original socket. Succeeds only if the two halves
/// originated from the same call to `into_split`.
pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> {
if Arc::ptr_eq(&self.0, &w.0) {
drop(w);
Ok(Arc::try_unwrap(self.0)
.ok()
.expect("`Arc::try_unwrap` failed"))
} else {
Err(ReuniteError(self, w))
}
}
}

impl<T> AsyncRead for OwnedReadHalf<T>
where
for<'a> &'a T: AsyncRead,
{
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.deref().read(buf).await
}

async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
self.0.deref().read_vectored(buf).await
}
}

/// Owned write half.
#[derive(Debug)]
pub struct OwnedWriteHalf<T>(Arc<T>);

impl<T> AsyncWrite for OwnedWriteHalf<T>
where
for<'a> &'a T: AsyncWrite,
{
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.deref().write(buf).await
}

async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
self.0.deref().write_vectored(buf).await
}

async fn flush(&mut self) -> io::Result<()> {
self.0.deref().flush().await
}

async fn shutdown(&mut self) -> io::Result<()> {
self.0.deref().shutdown().await
}
}

/// Error indicating that two halves were not from the same socket, and thus
/// could not be reunited.
#[derive(Debug)]
pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>);

impl<T> fmt::Display for ReuniteError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}

impl<T: fmt::Debug> Error for ReuniteError<T> {}
21 changes: 20 additions & 1 deletion compio-net/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite};
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
use socket2::{Protocol, SockAddr, Type};

use crate::{Socket, ToSocketAddrsAsync};
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, ToSocketAddrsAsync, WriteHalf};

/// A TCP socket server, listening for connections.
///
Expand Down Expand Up @@ -203,6 +203,25 @@ impl TcpStream {
.local_addr()
.map(|addr| addr.as_socket().expect("should be SocketAddr"))
}

/// Splits a [`TcpStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// This method is more efficient than
/// [`into_split`](TcpStream::into_split), but the halves cannot
/// be moved into independently spawned tasks.
pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
crate::split(self)
}

/// Splits a [`TcpStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// Unlike [`split`](TcpStream::split), the owned halves can be moved to
/// separate tasks, however this comes at the cost of a heap allocation.
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
}

impl AsyncRead for TcpStream {
Expand Down
21 changes: 20 additions & 1 deletion compio-net/src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use compio_io::{AsyncRead, AsyncWrite};
use compio_runtime::{impl_attachable, impl_try_as_raw_fd};
use socket2::{Domain, SockAddr, Type};

use crate::Socket;
use crate::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, Socket, WriteHalf};

/// A Unix socket server, listening for connections.
///
Expand Down Expand Up @@ -159,6 +159,25 @@ impl UnixStream {
pub fn local_addr(&self) -> io::Result<SockAddr> {
self.inner.local_addr()
}

/// Splits a [`UnixStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// This method is more efficient than
/// [`into_split`](UnixStream::into_split), but the halves cannot
/// be moved into independently spawned tasks.
pub fn split(&self) -> (ReadHalf<Self>, WriteHalf<Self>) {
crate::split(self)
}

/// Splits a [`UnixStream`] into a read half and a write half, which can be
/// used to read and write the stream concurrently.
///
/// Unlike [`split`](UnixStream::split), the owned halves can be moved to
/// separate tasks, however this comes at the cost of a heap allocation.
pub fn into_split(self) -> (OwnedReadHalf<Self>, OwnedWriteHalf<Self>) {
crate::into_split(self)
}
}

impl AsyncRead for UnixStream {
Expand Down
Loading

0 comments on commit 4cabfec

Please sign in to comment.