Skip to content

Commit

Permalink
Add Either::as_pin_mut and Either::as_pin_ref (#2691)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaseizinger authored Jan 21, 2023
1 parent 59c5095 commit 41478f5
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions futures-util/src/future/either.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,31 @@ pub enum Either<A, B> {
}

impl<A, B> Either<A, B> {
fn project(self: Pin<&mut Self>) -> Either<Pin<&mut A>, Pin<&mut B>> {
/// Convert `Pin<&Either<A, B>>` to `Either<Pin<&A>, Pin<&B>>`,
/// pinned projections of the inner variants.
pub fn as_pin_ref(self: Pin<&Self>) -> Either<Pin<&A>, Pin<&B>> {
// SAFETY: We can use `new_unchecked` because the `inner` parts are
// guaranteed to be pinned, as they come from `self` which is pinned.
unsafe {
match self.get_unchecked_mut() {
Either::Left(a) => Either::Left(Pin::new_unchecked(a)),
Either::Right(b) => Either::Right(Pin::new_unchecked(b)),
match *Pin::get_ref(self) {
Either::Left(ref inner) => Either::Left(Pin::new_unchecked(inner)),
Either::Right(ref inner) => Either::Right(Pin::new_unchecked(inner)),
}
}
}

/// Convert `Pin<&mut Either<A, B>>` to `Either<Pin<&mut A>, Pin<&mut B>>`,
/// pinned projections of the inner variants.
pub fn as_pin_mut(self: Pin<&mut Self>) -> Either<Pin<&mut A>, Pin<&mut B>> {
// SAFETY: `get_unchecked_mut` is fine because we don't move anything.
// We can use `new_unchecked` because the `inner` parts are guaranteed
// to be pinned, as they come from `self` which is pinned, and we never
// offer an unpinned `&mut A` or `&mut B` through `Pin<&mut Self>`. We
// also don't have an implementation of `Drop`, nor manual `Unpin`.
unsafe {
match *Pin::get_unchecked_mut(self) {
Either::Left(ref mut inner) => Either::Left(Pin::new_unchecked(inner)),
Either::Right(ref mut inner) => Either::Right(Pin::new_unchecked(inner)),
}
}
}
Expand Down Expand Up @@ -85,7 +105,7 @@ where
type Output = A::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll(cx),
Either::Right(x) => x.poll(cx),
}
Expand Down Expand Up @@ -113,7 +133,7 @@ where
type Item = A::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_next(cx),
Either::Right(x) => x.poll_next(cx),
}
Expand Down Expand Up @@ -149,28 +169,28 @@ where
type Error = A::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_ready(cx),
Either::Right(x) => x.poll_ready(cx),
}
}

fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.start_send(item),
Either::Right(x) => x.start_send(item),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_flush(cx),
Either::Right(x) => x.poll_flush(cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_close(cx),
Either::Right(x) => x.poll_close(cx),
}
Expand Down Expand Up @@ -198,7 +218,7 @@ mod if_std {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_read(cx, buf),
Either::Right(x) => x.poll_read(cx, buf),
}
Expand All @@ -209,7 +229,7 @@ mod if_std {
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<Result<usize>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_read_vectored(cx, bufs),
Either::Right(x) => x.poll_read_vectored(cx, bufs),
}
Expand All @@ -226,7 +246,7 @@ mod if_std {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_write(cx, buf),
Either::Right(x) => x.poll_write(cx, buf),
}
Expand All @@ -237,21 +257,21 @@ mod if_std {
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_write_vectored(cx, bufs),
Either::Right(x) => x.poll_write_vectored(cx, bufs),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_flush(cx),
Either::Right(x) => x.poll_flush(cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_close(cx),
Either::Right(x) => x.poll_close(cx),
}
Expand All @@ -268,7 +288,7 @@ mod if_std {
cx: &mut Context<'_>,
pos: SeekFrom,
) -> Poll<Result<u64>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_seek(cx, pos),
Either::Right(x) => x.poll_seek(cx, pos),
}
Expand All @@ -281,14 +301,14 @@ mod if_std {
B: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.poll_fill_buf(cx),
Either::Right(x) => x.poll_fill_buf(cx),
}
}

fn consume(self: Pin<&mut Self>, amt: usize) {
match self.project() {
match self.as_pin_mut() {
Either::Left(x) => x.consume(amt),
Either::Right(x) => x.consume(amt),
}
Expand Down

0 comments on commit 41478f5

Please sign in to comment.