Skip to content

Commit

Permalink
feat(h2): implement CONNECT support (fixes #2508)
Browse files Browse the repository at this point in the history
  • Loading branch information
nox committed Apr 28, 2021
1 parent d631b8c commit 3d27679
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 23 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ include = [
[lib]
crate-type = ["lib", "staticlib", "cdylib"]

[patch.crates-io]
h2 = { git = "https://github.com/hyperium/h2.git", branch = "master" }

[dependencies]
bytes = "1"
futures-core = { version = "0.3", default-features = false }
Expand All @@ -31,7 +34,7 @@ http = "0.2"
http-body = "0.4"
httpdate = "1.0"
httparse = "1.4"
h2 = { version = "0.3", optional = true }
h2 = { version = "0.3.2", optional = true }
itoa = "0.4.1"
tracing = { version = "0.1", default-features = false, features = ["std"] }
pin-project = "1.0"
Expand Down
148 changes: 137 additions & 11 deletions src/proto/h2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use bytes::Buf;
use h2::SendStream;
use bytes::{Buf, Bytes};
use h2::{RecvStream, SendStream};
use http::header::{
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
TRANSFER_ENCODING, UPGRADE,
};
use http::HeaderMap;
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use std::error::Error as StdError;
use std::io::IoSlice;
use std::io::{self, Cursor, IoSlice};
use std::task::Context;

use crate::body::{DecodedLength, HttpBody};
use crate::common::{task, Future, Pin, Poll};
use crate::headers::content_length_parse_all;
use crate::proto::h2::ping::Recorder;

pub(crate) mod ping;

Expand Down Expand Up @@ -172,7 +175,7 @@ where
is_eos,
);

let buf = SendBuf(Some(chunk));
let buf = SendBuf::Buf(chunk);
me.body_tx
.send_data(buf, is_eos)
.map_err(crate::Error::new_body_write)?;
Expand Down Expand Up @@ -243,32 +246,155 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {

fn send_eos_frame(&mut self) -> crate::Result<()> {
trace!("send body eos");
self.send_data(SendBuf(None), true)
self.send_data(SendBuf::None, true)
.map_err(crate::Error::new_body_write)
}
}

struct SendBuf<B>(Option<B>);
enum SendBuf<B> {
Buf(B),
Cursor(Cursor<Box<[u8]>>),
None,
}

impl<B: Buf> Buf for SendBuf<B> {
#[inline]
fn remaining(&self) -> usize {
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
match *self {
Self::Buf(ref b) => b.remaining(),
Self::Cursor(ref c) => c.remaining(),
Self::None => 0,
}
}

#[inline]
fn chunk(&self) -> &[u8] {
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
match *self {
Self::Buf(ref b) => b.chunk(),
Self::Cursor(ref c) => c.chunk(),
Self::None => &[],
}
}

#[inline]
fn advance(&mut self, cnt: usize) {
if let Some(b) = self.0.as_mut() {
b.advance(cnt)
match *self {
Self::Buf(ref mut b) => b.advance(cnt),
Self::Cursor(ref mut c) => c.advance(cnt),
Self::None => {},
}
}

fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
match *self {
Self::Buf(ref b) => b.chunks_vectored(dst),
Self::Cursor(ref c) => c.chunks_vectored(dst),
Self::None => 0,
}
}
}

// FIXME(nox): Should this type be public? I'm asking this because
// the HTTP/2 RFC says that a proxy that encounters a TCP error with the
// upstream peer should send back to the client a stream error with reason
// CONNECT_ERROR, so we need *something* to send that, but all the user
// gets is a hyper::upgrade::Upgraded, so you can't send anything but a
// data frame back.
struct H2Upgraded<B>
where
B: Buf,
{
ping: Recorder,
send_stream: SendStream<SendBuf<B>>,
recv_stream: RecvStream,
buf: Bytes,
}

impl<B> AsyncRead for H2Upgraded<B>
where
B: Buf,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
read_buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.buf = match ready!(self.recv_stream.poll_data(cx)) {
None => return Poll::Ready(Ok(())),
Some(Ok(buf)) => {
self.ping.record_data(buf.len());
buf
}
Some(Err(e)) => {
return Poll::Ready(Err(h2_to_io_error(e)));
}
};
}
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
read_buf.put_slice(&self.buf[..cnt]);
self.buf.advance(cnt);
let _ = self.recv_stream.flow_control().release_capacity(cnt);
Poll::Ready(Ok(()))
}
}

impl<B> AsyncWrite for H2Upgraded<B>
where
B: Buf,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
// FIXME(nox): PipeToSendStream does some weird stuff, first reserving
// one byte and then polling reset if the capacity is 0, should we do
// that here too? Should we poll reset somewhere?
self.send_stream.reserve_capacity(buf.len());
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
None => Ok(0),
Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt),
Some(Err(e)) => {
// FIXME(nox): Should all H2 errors be returned as is with a
// ErrorKind::Other, or should some be special-cased, say for
// example, CANCEL?
Err(h2_to_io_error(e))
},
})
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(self.write(&[], true))
}
}

impl<B> H2Upgraded<B>
where
B: Buf,
{
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
self.send_stream
.send_data(send_buf, end_of_stream)
.map_err(h2_to_io_error)
}
}

fn h2_to_io_error(e: h2::Error) -> io::Error {
if e.is_io() {
e.into_io().unwrap()
} else {
io::Error::new(io::ErrorKind::Other, e)
}
}
58 changes: 51 additions & 7 deletions src/proto/h2/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use std::marker::Unpin;
#[cfg(feature = "runtime")]
use std::time::Duration;

use bytes::Bytes;
use h2::server::{Connection, Handshake, SendResponse};
use h2::Reason;
use h2::{Reason, RecvStream};
use http::{Method, Request};
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};

Expand All @@ -13,9 +15,12 @@ use crate::body::HttpBody;
use crate::common::exec::ConnStreamExec;
use crate::common::{date, task, Future, Pin, Poll};
use crate::headers;
use crate::proto::h2::ping::Recorder;
use crate::proto::h2::H2Upgraded;
use crate::proto::Dispatched;
use crate::service::HttpService;

use crate::upgrade::{OnUpgrade, Pending, Upgraded};
use crate::{Body, Response};

// Our defaults are chosen for the "majority" case, which usually are not
Expand Down Expand Up @@ -269,8 +274,28 @@ where
// Record the headers received
ping.record_non_data();

let req = req.map(|stream| crate::Body::h2(stream, content_length, ping));
let fut = H2Stream::new(service.call(req), respond);
let is_connect = req.method() == Method::CONNECT;
let (mut parts, stream) = req.into_parts();
let (req, connect_parts) = if !is_connect {
(
Request::from_parts(
parts,
crate::Body::h2(stream, content_length, ping),
),
None,
)
} else {
// FIXME(nox): What happens to the request body? Should we check `content_length`?
let (pending, upgrade) = crate::upgrade::pending();
debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
parts.extensions.insert(upgrade);
(
Request::from_parts(parts, crate::Body::empty()),
Some((pending, ping, stream)),
)
};

let fut = H2Stream::new(service.call(req), connect_parts, respond);
exec.execute_h2stream(fut);
}
Some(Err(e)) => {
Expand Down Expand Up @@ -333,18 +358,22 @@ enum H2StreamState<F, B>
where
B: HttpBody,
{
Service(#[pin] F),
Service(#[pin] F, Option<(Pending, Recorder, RecvStream)>),
Body(#[pin] PipeToSendStream<B>),
}

impl<F, B> H2Stream<F, B>
where
B: HttpBody,
{
fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> {
fn new(
fut: F,
connect_parts: Option<(Pending, Recorder, RecvStream)>,
respond: SendResponse<SendBuf<B::Data>>,
) -> H2Stream<F, B> {
H2Stream {
reply: respond,
state: H2StreamState::Service(fut),
state: H2StreamState::Service(fut, connect_parts),
}
}
}
Expand Down Expand Up @@ -374,7 +403,7 @@ where
let mut me = self.project();
loop {
let next = match me.state.as_mut().project() {
H2StreamStateProj::Service(h) => {
H2StreamStateProj::Service(h, connect_parts) => {
let res = match h.poll(cx) {
Poll::Ready(Ok(r)) => r,
Poll::Pending => {
Expand Down Expand Up @@ -405,6 +434,21 @@ where
.entry(::http::header::DATE)
.or_insert_with(date::update_and_header_value);

if let Some((pending, ping, recv_stream)) = connect_parts.take() {
// FIXME(nox): What do we do about the response body? AFAIK h1 returns an error.
let send_stream = reply!(me, res, false);
pending.fulfill(Upgraded::new(
H2Upgraded {
ping,
recv_stream,
send_stream,
buf: Bytes::new(),
},
Bytes::new(),
));
return Poll::Ready(Ok(()));
}

// automatically set Content-Length from body...
if let Some(len) = body.size_hint().exact() {
headers::set_content_length_if_missing(res.headers_mut(), len);
Expand Down
9 changes: 5 additions & 4 deletions src/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
msg.on_upgrade()
}

#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
pub(super) struct Pending {
tx: oneshot::Sender<crate::Result<Upgraded>>,
}

#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
pub(super) fn pending() -> (Pending, OnUpgrade) {
let (tx, rx) = oneshot::channel();
(Pending { tx }, OnUpgrade { rx: Some(rx) })
Expand All @@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) {
// ===== impl Upgraded =====

impl Upgraded {
#[cfg(any(feature = "http1", test))]
#[cfg(any(feature = "http1", feature = "http2", test))]
pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
Expand Down Expand Up @@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade {

// ===== impl Pending =====

#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
impl Pending {
pub(super) fn fulfill(self, upgraded: Upgraded) {
trace!("pending upgrade fulfill");
let _ = self.tx.send(Ok(upgraded));
}

#[cfg(feature = "http1")]
/// Don't fulfill the pending Upgrade, but instead signal that
/// upgrades are handled manually.
pub(super) fn manual(self) {
Expand Down
Loading

0 comments on commit 3d27679

Please sign in to comment.