Skip to content

Commit

Permalink
Remove unsafe from proto::h2
Browse files Browse the repository at this point in the history
Back in #2523, @nox introduced the notion of an UpgradedSendStream, to
support the CONNECT method of HTTP/2. This used `unsafe {}` to support
`http_body::Body`, where `Body::Data` did not implement `Send`, since
the `Data` type wouldn't be sent across the stream once upgraded.

Unfortunately, according to this [thread], I think this may be undefined
behavior, because this relies on us requiring the transmute to execute.

This patch fixes the potential UB by adding the unncessary `Send`
constraints. It appears that all the internal users of
`UpgradeSendStream` already work with `http_body::Body` types that have
`Send`-able `Data` constraints. We can add this constraint without
breaking any external APIs, which lets us remove the `unsafe {}` blocks.

[thread]: https://users.rust-lang.org/t/is-a-reference-to-impossible-value-considered-ub/31383
  • Loading branch information
erickt committed May 10, 2022
1 parent faf24c6 commit 0397af1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/proto/h2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ where
let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
send_stream: UpgradedSendStream::new(send_stream),
recv_stream,
buf: Bytes::new(),
};
Expand Down
50 changes: 9 additions & 41 deletions src/proto/h2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use http::HeaderMap;
use pin_project_lite::pin_project;
use std::error::Error as StdError;
use std::io::{self, Cursor, IoSlice};
use std::mem;
use std::task::Context;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::{debug, trace, warn};
Expand Down Expand Up @@ -409,63 +408,32 @@ fn h2_to_io_error(e: h2::Error) -> io::Error {
}
}

struct UpgradedSendStream<B>(SendStream<SendBuf<Neutered<B>>>);
struct UpgradedSendStream<B: Buf>(SendStream<SendBuf<B>>);

impl<B> UpgradedSendStream<B>
where
B: Buf,
{
unsafe fn new(inner: SendStream<SendBuf<B>>) -> Self {
assert_eq!(mem::size_of::<B>(), mem::size_of::<Neutered<B>>());
Self(mem::transmute(inner))
fn new(inner: SendStream<SendBuf<B>>) -> Self {
Self(inner)
}

fn reserve_capacity(&mut self, cnt: usize) {
unsafe { self.as_inner_unchecked().reserve_capacity(cnt) }
self.0.reserve_capacity(cnt)
}

fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<usize, h2::Error>>> {
unsafe { self.as_inner_unchecked().poll_capacity(cx) }
self.0.poll_capacity(cx)
}

fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<Result<h2::Reason, h2::Error>> {
unsafe { self.as_inner_unchecked().poll_reset(cx) }
self.0.poll_reset(cx)
}

fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
unsafe {
self.as_inner_unchecked()
.send_data(send_buf, end_of_stream)
.map_err(h2_to_io_error)
}
}

unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream<SendBuf<B>> {
&mut *(&mut self.0 as *mut _ as *mut _)
}
}

#[repr(transparent)]
struct Neutered<B> {
_inner: B,
impossible: Impossible,
}

enum Impossible {}

unsafe impl<B> Send for Neutered<B> {}

impl<B> Buf for Neutered<B> {
fn remaining(&self) -> usize {
match self.impossible {}
}

fn chunk(&self) -> &[u8] {
match self.impossible {}
}

fn advance(&mut self, _cnt: usize) {
match self.impossible {}
self.0
.send_data(send_buf, end_of_stream)
.map_err(h2_to_io_error)
}
}
6 changes: 3 additions & 3 deletions src/proto/h2/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ impl<F, B, E> H2Stream<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
B: HttpBody,
B::Data: 'static,
B::Data: Send + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
Expand Down Expand Up @@ -489,7 +489,7 @@ where
H2Upgraded {
ping: connect_parts.ping,
recv_stream: connect_parts.recv_stream,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
send_stream: UpgradedSendStream::new(send_stream),
buf: Bytes::new(),
},
Bytes::new(),
Expand Down Expand Up @@ -527,7 +527,7 @@ impl<F, B, E> Future for H2Stream<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
B: HttpBody,
B::Data: 'static,
B::Data: Send + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
Expand Down

0 comments on commit 0397af1

Please sign in to comment.