From c4bb4db5c219459b37d796f9aa2b3cdc93325621 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 28 Jan 2020 16:23:03 -0800 Subject: [PATCH] fix(http1): only send `100 Continue` if request body is polled Before, if a client request included an `Expect: 100-continue` header, the `100 Continue` response was sent immediately. However, this is problematic if the service is going to reply with some 4xx status code and reject the body. This change delays the automatic sending of the `100 Continue` status until the service has call `poll_data` on the request body once. --- src/body/body.rs | 173 ++++++++++++++++++++++++++++++++++----- src/common/mod.rs | 1 + src/common/watch.rs | 73 +++++++++++++++++ src/proto/h1/conn.rs | 46 ++++++++--- src/proto/h1/dispatch.rs | 8 +- src/proto/h1/mod.rs | 19 +++++ tests/server.rs | 51 ++++++++++++ 7 files changed, 332 insertions(+), 39 deletions(-) create mode 100644 src/common/watch.rs diff --git a/src/body/body.rs b/src/body/body.rs index 3308d3b3bd..939b4f5689 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -11,7 +11,7 @@ use futures_util::TryStreamExt; use http::HeaderMap; use http_body::{Body as HttpBody, SizeHint}; -use crate::common::{task, Future, Never, Pin, Poll}; +use crate::common::{task, watch, Future, Never, Pin, Poll}; use crate::proto::DecodedLength; use crate::upgrade::OnUpgrade; @@ -33,7 +33,7 @@ enum Kind { Once(Option), Chan { content_length: DecodedLength, - abort_rx: oneshot::Receiver<()>, + want_tx: watch::Sender, rx: mpsc::Receiver>, }, H2 { @@ -79,12 +79,14 @@ enum DelayEof { /// Useful when wanting to stream chunks from another thread. See /// [`Body::channel`](Body::channel) for more. #[must_use = "Sender does nothing unless sent on"] -#[derive(Debug)] pub struct Sender { - abort_tx: oneshot::Sender<()>, + want_rx: watch::Receiver, tx: BodySender, } +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + impl Body { /// Create an empty `Body` stream. /// @@ -106,17 +108,22 @@ impl Body { /// Useful when wanting to stream chunks from another thread. #[inline] pub fn channel() -> (Sender, Body) { - Self::new_channel(DecodedLength::CHUNKED) + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) } - pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) { + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) { let (tx, rx) = mpsc::channel(0); - let (abort_tx, abort_rx) = oneshot::channel(); - let tx = Sender { abort_tx, tx }; + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { want_rx, tx }; let rx = Body::new(Kind::Chan { content_length, - abort_rx, + want_tx, rx, }); @@ -236,11 +243,9 @@ impl Body { Kind::Chan { content_length: ref mut len, ref mut rx, - ref mut abort_rx, + ref mut want_tx, } => { - if let Poll::Ready(Ok(())) = Pin::new(abort_rx).poll(cx) { - return Poll::Ready(Some(Err(crate::Error::new_body_write_aborted()))); - } + want_tx.send(WANT_READY); match ready!(Pin::new(rx).poll_next(cx)?) { Some(chunk) => { @@ -460,19 +465,29 @@ impl From> for Body { impl Sender { /// Check to see if this `Sender` can send more data. pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - match self.abort_tx.poll_canceled(cx) { - Poll::Ready(()) => return Poll::Ready(Err(crate::Error::new_closed())), - Poll::Pending => (), // fallthrough - } - + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); self.tx .poll_ready(cx) .map_err(|_| crate::Error::new_closed()) } + fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> crate::Result<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + /// Send data on this channel when it is ready. pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> { - futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await?; + self.ready().await?; self.tx .try_send(Ok(chunk)) .map_err(|_| crate::Error::new_closed()) @@ -498,8 +513,11 @@ impl Sender { /// Aborts the body in an abnormal fashion. pub fn abort(self) { - // TODO(sean): this can just be `self.tx.clone().try_send()` - let _ = self.abort_tx.send(()); + let _ = self + .tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(crate::Error::new_body_write_aborted())); } pub(crate) fn send_error(&mut self, err: crate::Error) { @@ -507,11 +525,29 @@ impl Sender { } } +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + struct Open; + #[derive(Debug)] + struct Closed; + + let mut builder = f.debug_tuple("Sender"); + match self.want_rx.peek() { + watch::CLOSED => builder.field(&Closed), + _ => builder.field(&Open), + }; + + builder.finish() + } +} + #[cfg(test)] mod tests { use std::mem; + use std::task::Poll; - use super::{Body, Sender}; + use super::{Body, DecodedLength, HttpBody, Sender}; #[test] fn test_size_of() { @@ -541,4 +577,97 @@ mod tests { "Option" ); } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = Body::channel(); + + tx.abort(); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.data().await.expect("item 1").expect("chunk 1"); + assert_eq!(chunk1, "chunk 1"); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = Body::channel(); + + assert!(rx.data().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.data()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!( + tx_ready.poll().is_ready(), + "tx is ready after rx has been polled" + ); + } + + #[test] + fn channel_notices_closure() { + let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(ref e)) if e.is_closed() => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 394e549895..3716a56c67 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,6 +14,7 @@ pub(crate) mod io; mod lazy; mod never; pub(crate) mod task; +pub(crate) mod watch; pub use self::exec::Executor; pub(crate) use self::exec::{BoxSendFuture, Exec}; diff --git a/src/common/watch.rs b/src/common/watch.rs new file mode 100644 index 0000000000..ba17d551cb --- /dev/null +++ b/src/common/watch.rs @@ -0,0 +1,73 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(crate) const CLOSED: usize = 0; + +pub(crate) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!( + initial != CLOSED, + "watch::channel initial state of 0 is reserved" + ); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + ( + Sender { + shared: shared.clone(), + }, + Receiver { shared }, + ) +} + +pub(crate) struct Sender { + shared: Arc, +} + +pub(crate) struct Receiver { + shared: Arc, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(crate) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 0575536827..8f3532b9d4 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,7 +8,7 @@ use http::{HeaderMap, Method, Version}; use tokio::io::{AsyncRead, AsyncWrite}; use super::io::Buffered; -use super::{/*Decode,*/ Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext,}; +use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; use crate::common::{task, Pin, Poll, Unpin}; use crate::headers::connection_keep_alive; use crate::proto::{BodyLength, DecodedLength, MessageHead}; @@ -114,7 +114,7 @@ where pub fn can_read_body(&self) -> bool { match self.state.reading { - Reading::Body(..) => true, + Reading::Body(..) | Reading::Continue(..) => true, _ => false, } } @@ -129,10 +129,10 @@ where read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE } - pub fn poll_read_head( + pub(super) fn poll_read_head( &mut self, cx: &mut task::Context<'_>, - ) -> Poll, DecodedLength, bool)>>> { + ) -> Poll, DecodedLength, Wants)>>> { debug_assert!(self.can_read_head()); trace!("Conn::read_head"); @@ -156,23 +156,28 @@ where self.state.keep_alive &= msg.keep_alive; self.state.version = msg.head.version; + let mut wants = if msg.wants_upgrade { + Wants::UPGRADE + } else { + Wants::EMPTY + }; + if msg.decode == DecodedLength::ZERO { - if log_enabled!(log::Level::Debug) && msg.expect_continue { + if msg.expect_continue { debug!("ignoring expect-continue since body is empty"); } self.state.reading = Reading::KeepAlive; if !T::should_read_first() { self.try_keep_alive(cx); } + } else if msg.expect_continue { + self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + wants = wants.add(Wants::EXPECT); } else { - if msg.expect_continue { - let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; - self.io.headers_buf().extend_from_slice(cont); - } self.state.reading = Reading::Body(Decoder::new(msg.decode)); - }; + } - Poll::Ready(Some(Ok((msg.head, msg.decode, msg.wants_upgrade)))) + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) } fn on_read_head_error(&mut self, e: crate::Error) -> Poll>> { @@ -239,7 +244,19 @@ where } } } - _ => unreachable!("read_body invalid state: {:?}", self.state.reading), + Reading::Continue(ref decoder) => { + // Write the 100 Continue if not already responded... + if let Writing::Init = self.state.writing { + trace!("automatically sending 100 Continue"); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.headers_buf().extend_from_slice(cont); + } + + // And now recurse once in the Reading::Body state... + self.state.reading = Reading::Body(decoder.clone()); + return self.poll_read_body(cx); + } + _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), }; self.state.reading = reading; @@ -346,7 +363,9 @@ where // would finish. match self.state.reading { - Reading::Body(..) | Reading::KeepAlive | Reading::Closed => return, + Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { + return + } Reading::Init => (), }; @@ -711,6 +730,7 @@ struct State { #[derive(Debug)] enum Reading { Init, + Continue(Decoder), Body(Decoder), KeepAlive, Closed, diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 07d05fca16..ff5bf01832 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -4,7 +4,7 @@ use bytes::{Buf, Bytes}; use http::{Request, Response, StatusCode}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::Http1Transaction; +use super::{Http1Transaction, Wants}; use crate::body::{Body, Payload}; use crate::common::{task, Future, Never, Pin, Poll, Unpin}; use crate::proto::{ @@ -235,16 +235,16 @@ where } // dispatch is ready for a message, try to read one match ready!(self.conn.poll_read_head(cx)) { - Some(Ok((head, body_len, wants_upgrade))) => { + Some(Ok((head, body_len, wants))) => { let mut body = match body_len { DecodedLength::ZERO => Body::empty(), other => { - let (tx, rx) = Body::new_channel(other); + let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); self.body_tx = Some(tx); rx } }; - if wants_upgrade { + if wants.contains(Wants::UPGRADE) { body.set_on_upgrade(self.conn.on_upgrade()); } self.dispatch.recv_msg(Ok((head, body)))?; diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 39efb8e7b8..2d0bf39bc9 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -74,3 +74,22 @@ pub(crate) struct Encode<'a, T> { req_method: &'a mut Option, title_case_headers: bool, } + +/// Extra flags that a request "wants", like expect-continue or upgrades. +#[derive(Clone, Copy, Debug)] +struct Wants(u8); + +impl Wants { + const EMPTY: Wants = Wants(0b00); + const EXPECT: Wants = Wants(0b01); + const UPGRADE: Wants = Wants(0b10); + + #[must_use] + fn add(self, other: Wants) -> Wants { + Wants(self.0 | other.0) + } + + fn contains(&self, other: Wants) -> bool { + (self.0 & other.0) == other.0 + } +} diff --git a/tests/server.rs b/tests/server.rs index 054dfc8aaf..59ef0f6fee 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -785,6 +785,57 @@ fn expect_continue_but_no_body_is_ignored() { assert_eq!(&resp[..expected.len()], expected); } +#[tokio::test] +async fn expect_continue_waits_for_body_poll() { + let _ = pretty_env_logger::try_init(); + let mut listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + let child = thread::spawn(move || { + let mut tcp = connect(&addr); + + tcp.write_all( + b"\ + POST /foo HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 100\r\n\ + Connection: Close\r\n\ + \r\n\ + ", + ) + .expect("write"); + + let expected = "HTTP/1.1 400 Bad Request\r\n"; + let mut resp = String::new(); + tcp.read_to_string(&mut resp).expect("read"); + + assert_eq!(&resp[..expected.len()], expected); + }); + + let (socket, _) = listener.accept().await.expect("accept"); + + Http::new() + .serve_connection( + socket, + service_fn(|req| { + assert_eq!(req.headers()["expect"], "100-continue"); + // But! We're never going to poll the body! + tokio::time::delay_for(Duration::from_millis(50)).map(move |_| { + // Move and drop the req, so we don't auto-close + drop(req); + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::empty()) + }) + }), + ) + .await + .expect("serve_connection"); + + child.join().expect("client thread"); +} + #[test] fn pipeline_disabled() { let server = serve();