-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #206 from Berrysoft/dev/split
feat(io,net): add split
- Loading branch information
Showing
9 changed files
with
372 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
use std::sync::Arc; | ||
|
||
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; | ||
use futures_util::lock::Mutex; | ||
|
||
use crate::{AsyncRead, AsyncWrite, IoResult}; | ||
|
||
/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate | ||
/// [`AsyncRead`] and [`AsyncWrite`] handles. | ||
pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) { | ||
let stream = Arc::new(Mutex::new(stream)); | ||
(ReadHalf(stream.clone()), WriteHalf(stream)) | ||
} | ||
|
||
/// The readable half of a value returned from [`split`]. | ||
#[derive(Debug)] | ||
pub struct ReadHalf<T>(Arc<Mutex<T>>); | ||
|
||
impl<T: Unpin> ReadHalf<T> { | ||
/// Reunites with a previously split [`WriteHalf`]. | ||
/// | ||
/// # Panics | ||
/// | ||
/// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from | ||
/// the same [`split`] operation this method will panic. | ||
/// This can be checked ahead of time by comparing the stored pointer | ||
/// of the two halves. | ||
#[track_caller] | ||
pub fn unsplit(self, w: WriteHalf<T>) -> T { | ||
if Arc::ptr_eq(&self.0, &w.0) { | ||
drop(w); | ||
let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed"); | ||
inner.into_inner() | ||
} else { | ||
#[cold] | ||
fn panic_unrelated() -> ! { | ||
panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.") | ||
} | ||
|
||
panic_unrelated() | ||
} | ||
} | ||
} | ||
|
||
impl<T: AsyncRead> AsyncRead for ReadHalf<T> { | ||
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.lock().await.read(buf).await | ||
} | ||
|
||
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> { | ||
self.0.lock().await.read_vectored(buf).await | ||
} | ||
} | ||
|
||
/// The writable half of a value returned from [`split`]. | ||
#[derive(Debug)] | ||
pub struct WriteHalf<T>(Arc<Mutex<T>>); | ||
|
||
impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { | ||
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.lock().await.write(buf).await | ||
} | ||
|
||
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.lock().await.write_vectored(buf).await | ||
} | ||
|
||
async fn flush(&mut self) -> IoResult<()> { | ||
self.0.lock().await.flush().await | ||
} | ||
|
||
async fn shutdown(&mut self) -> IoResult<()> { | ||
self.0.lock().await.shutdown().await | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
use std::{error::Error, fmt, io, ops::Deref, sync::Arc}; | ||
|
||
use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; | ||
use compio_io::{AsyncRead, AsyncWrite}; | ||
|
||
pub(crate) fn split<T>(stream: &T) -> (ReadHalf<T>, WriteHalf<T>) | ||
where | ||
for<'a> &'a T: AsyncRead + AsyncWrite, | ||
{ | ||
(ReadHalf(stream), WriteHalf(stream)) | ||
} | ||
|
||
/// Borrowed read half. | ||
#[derive(Debug)] | ||
pub struct ReadHalf<'a, T>(&'a T); | ||
|
||
impl<T> AsyncRead for ReadHalf<'_, T> | ||
where | ||
for<'a> &'a T: AsyncRead, | ||
{ | ||
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.read(buf).await | ||
} | ||
|
||
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> { | ||
self.0.read_vectored(buf).await | ||
} | ||
} | ||
|
||
/// Borrowed write half. | ||
#[derive(Debug)] | ||
pub struct WriteHalf<'a, T>(&'a T); | ||
|
||
impl<T> AsyncWrite for WriteHalf<'_, T> | ||
where | ||
for<'a> &'a T: AsyncWrite, | ||
{ | ||
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.write(buf).await | ||
} | ||
|
||
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.write_vectored(buf).await | ||
} | ||
|
||
async fn flush(&mut self) -> io::Result<()> { | ||
self.0.flush().await | ||
} | ||
|
||
async fn shutdown(&mut self) -> io::Result<()> { | ||
self.0.shutdown().await | ||
} | ||
} | ||
|
||
pub(crate) fn into_split<T>(stream: T) -> (OwnedReadHalf<T>, OwnedWriteHalf<T>) | ||
where | ||
for<'a> &'a T: AsyncRead + AsyncWrite, | ||
{ | ||
let stream = Arc::new(stream); | ||
(OwnedReadHalf(stream.clone()), OwnedWriteHalf(stream)) | ||
} | ||
|
||
/// Owned read half. | ||
#[derive(Debug)] | ||
pub struct OwnedReadHalf<T>(Arc<T>); | ||
|
||
impl<T: Unpin> OwnedReadHalf<T> { | ||
/// Attempts to put the two halves of a `TcpStream` back together and | ||
/// recover the original socket. Succeeds only if the two halves | ||
/// originated from the same call to `into_split`. | ||
pub fn reunite(self, w: OwnedWriteHalf<T>) -> Result<T, ReuniteError<T>> { | ||
if Arc::ptr_eq(&self.0, &w.0) { | ||
drop(w); | ||
Ok(Arc::try_unwrap(self.0) | ||
.ok() | ||
.expect("`Arc::try_unwrap` failed")) | ||
} else { | ||
Err(ReuniteError(self, w)) | ||
} | ||
} | ||
} | ||
|
||
impl<T> AsyncRead for OwnedReadHalf<T> | ||
where | ||
for<'a> &'a T: AsyncRead, | ||
{ | ||
async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.deref().read(buf).await | ||
} | ||
|
||
async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> { | ||
self.0.deref().read_vectored(buf).await | ||
} | ||
} | ||
|
||
/// Owned write half. | ||
#[derive(Debug)] | ||
pub struct OwnedWriteHalf<T>(Arc<T>); | ||
|
||
impl<T> AsyncWrite for OwnedWriteHalf<T> | ||
where | ||
for<'a> &'a T: AsyncWrite, | ||
{ | ||
async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.deref().write(buf).await | ||
} | ||
|
||
async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> { | ||
self.0.deref().write_vectored(buf).await | ||
} | ||
|
||
async fn flush(&mut self) -> io::Result<()> { | ||
self.0.deref().flush().await | ||
} | ||
|
||
async fn shutdown(&mut self) -> io::Result<()> { | ||
self.0.deref().shutdown().await | ||
} | ||
} | ||
|
||
/// Error indicating that two halves were not from the same socket, and thus | ||
/// could not be reunited. | ||
#[derive(Debug)] | ||
pub struct ReuniteError<T>(pub OwnedReadHalf<T>, pub OwnedWriteHalf<T>); | ||
|
||
impl<T> fmt::Display for ReuniteError<T> { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
write!( | ||
f, | ||
"tried to reunite halves that are not from the same socket" | ||
) | ||
} | ||
} | ||
|
||
impl<T: fmt::Debug> Error for ReuniteError<T> {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.