diff --git a/Cargo.toml b/Cargo.toml
index dac533117a..81b49cbfdd 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -29,7 +29,7 @@ iovec = "0.1"
log = "0.4"
net2 = { version = "0.2.32", optional = true }
time = "0.1"
-tokio = { version = "0.1.5", optional = true }
+tokio = { version = "0.1.7", optional = true }
tokio-executor = { version = "0.1.0", optional = true }
tokio-io = "0.1"
tokio-reactor = { version = "0.1", optional = true }
@@ -101,6 +101,12 @@ name = "send_file"
path = "examples/send_file.rs"
required-features = ["runtime"]
+[[example]]
+name = "upgrades"
+path = "examples/upgrades.rs"
+required-features = ["runtime"]
+
+
[[example]]
name = "web_api"
path = "examples/web_api.rs"
diff --git a/examples/README.md b/examples/README.md
index 90a425235f..b877ef84a8 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -16,4 +16,6 @@ Run examples with `cargo run --example example_name`.
* [`send_file`](send_file.rs) - A server that sends back content of files using tokio_fs to read the files asynchronously.
+* [`upgrades`](upgrades.rs) - A server and client demonstrating how to do HTTP upgrades (such as WebSockets or `CONNECT` tunneling).
+
* [`web_api`](web_api.rs) - A server consisting in a service that returns incoming POST request's content in the response in uppercase and a service that call that call the first service and includes the first service response in its own response.
diff --git a/examples/upgrades.rs b/examples/upgrades.rs
new file mode 100644
index 0000000000..3cee4940a9
--- /dev/null
+++ b/examples/upgrades.rs
@@ -0,0 +1,127 @@
+// Note: `hyper::upgrade` docs link to this upgrade.
+extern crate futures;
+extern crate hyper;
+extern crate tokio;
+
+use std::str;
+
+use futures::sync::oneshot;
+
+use hyper::{Body, Client, Request, Response, Server, StatusCode};
+use hyper::header::{UPGRADE, HeaderValue};
+use hyper::rt::{self, Future};
+use hyper::service::service_fn_ok;
+
+/// Our server HTTP handler to initiate HTTP upgrades.
+fn server_upgrade(req: Request
) -> Response {
+ let mut res = Response::new(Body::empty());
+
+ // Send a 400 to any request that doesn't have
+ // an `Upgrade` header.
+ if !req.headers().contains_key(UPGRADE) {
+ *res.status_mut() = StatusCode::BAD_REQUEST;
+ return res;
+ }
+
+ // Setup a future that will eventually receive the upgraded
+ // connection and talk a new protocol, and spawn the future
+ // into the runtime.
+ //
+ // Note: This can't possibly be fulfilled until the 101 response
+ // is returned below, so it's better to spawn this future instead
+ // waiting for it to complete to then return a response.
+ let on_upgrade = req
+ .into_body()
+ .on_upgrade()
+ .map_err(|err| eprintln!("upgrade error: {}", err))
+ .and_then(|upgraded| {
+ // We have an upgraded connection that we can read and
+ // write on directly.
+ //
+ // Since we completely control this example, we know exactly
+ // how many bytes the client will write, so just read exact...
+ tokio::io::read_exact(upgraded, vec![0; 7])
+ .and_then(|(upgraded, vec)| {
+ println!("server[foobar] recv: {:?}", str::from_utf8(&vec));
+
+ // And now write back the server 'foobar' protocol's
+ // response...
+ tokio::io::write_all(upgraded, b"bar=foo")
+ })
+ .map(|_| println!("server[foobar] sent"))
+ .map_err(|e| eprintln!("server foobar io error: {}", e))
+ });
+
+ rt::spawn(on_upgrade);
+
+
+ // Now return a 101 Response saying we agree to the upgrade to some
+ // made-up 'foobar' protocol.
+ *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
+ res.headers_mut().insert(UPGRADE, HeaderValue::from_static("foobar"));
+ res
+}
+
+fn main() {
+ // For this example, we just make a server and our own client to talk to
+ // it, so the exact port isn't important. Instead, let the OS give us an
+ // unused port.
+ let addr = ([127, 0, 0, 1], 0).into();
+
+ let server = Server::bind(&addr)
+ .serve(|| service_fn_ok(server_upgrade));
+
+ // We need the assigned address for the client to send it messages.
+ let addr = server.local_addr();
+
+
+ // For this example, a oneshot is used to signal that after 1 request,
+ // the server should be shutdown.
+ let (tx, rx) = oneshot::channel();
+
+ let server = server
+ .map_err(|e| eprintln!("server error: {}", e))
+ .select2(rx)
+ .then(|_| Ok(()));
+
+ rt::run(rt::lazy(move || {
+ rt::spawn(server);
+
+ let req = Request::builder()
+ .uri(format!("http://{}/", addr))
+ .header(UPGRADE, "foobar")
+ .body(Body::empty())
+ .unwrap();
+
+ Client::new()
+ .request(req)
+ .and_then(|res| {
+ if res.status() != StatusCode::SWITCHING_PROTOCOLS {
+ panic!("Our server didn't upgrade: {}", res.status());
+ }
+
+ res
+ .into_body()
+ .on_upgrade()
+ })
+ .map_err(|e| eprintln!("client error: {}", e))
+ .and_then(|upgraded| {
+ // We've gotten an upgraded connection that we can read
+ // and write directly on. Let's start out 'foobar' protocol.
+ tokio::io::write_all(upgraded, b"foo=bar")
+ .and_then(|(upgraded, _)| {
+ println!("client[foobar] sent");
+ tokio::io::read_to_end(upgraded, Vec::new())
+ })
+ .map(|(_upgraded, vec)| {
+ println!("client[foobar] recv: {:?}", str::from_utf8(&vec));
+
+
+ // Complete the oneshot so that the server stops
+ // listening and the process can close down.
+ let _ = tx.send(());
+ })
+ .map_err(|e| eprintln!("client foobar io error: {}", e))
+ })
+ }));
+}
diff --git a/src/body/body.rs b/src/body/body.rs
index 7be740bc9f..09f36b9b45 100644
--- a/src/body/body.rs
+++ b/src/body/body.rs
@@ -10,6 +10,7 @@ use http::HeaderMap;
use common::Never;
use super::{Chunk, Payload};
use super::internal::{FullDataArg, FullDataRet};
+use upgrade::OnUpgrade;
type BodySender = mpsc::Sender>;
@@ -21,15 +22,9 @@ type BodySender = mpsc::Sender>;
#[must_use = "streams do nothing unless polled"]
pub struct Body {
kind: Kind,
- /// Allow the client to pass a future to delay the `Body` from returning
- /// EOF. This allows the `Client` to try to put the idle connection
- /// back into the pool before the body is "finished".
- ///
- /// The reason for this is so that creating a new request after finishing
- /// streaming the body of a response could sometimes result in creating
- /// a brand new connection, since the pool didn't know about the idle
- /// connection yet.
- delayed_eof: Option,
+ /// Keep the extra bits in an `Option>`, so that
+ /// Body stays small in the common case (no extras needed).
+ extra: Option>,
}
enum Kind {
@@ -43,6 +38,19 @@ enum Kind {
Wrapped(Box> + Send>),
}
+struct Extra {
+ /// Allow the client to pass a future to delay the `Body` from returning
+ /// EOF. This allows the `Client` to try to put the idle connection
+ /// back into the pool before the body is "finished".
+ ///
+ /// The reason for this is so that creating a new request after finishing
+ /// streaming the body of a response could sometimes result in creating
+ /// a brand new connection, since the pool didn't know about the idle
+ /// connection yet.
+ delayed_eof: Option,
+ on_upgrade: OnUpgrade,
+}
+
type DelayEofUntil = oneshot::Receiver;
enum DelayEof {
@@ -89,7 +97,6 @@ impl Body {
Self::new_channel(None)
}
- #[inline]
pub(crate) fn new_channel(content_length: Option) -> (Sender, Body) {
let (tx, rx) = mpsc::channel(0);
let (abort_tx, abort_rx) = oneshot::channel();
@@ -139,10 +146,20 @@ impl Body {
Body::new(Kind::Wrapped(Box::new(mapped)))
}
+ /// Converts this `Body` into a `Future` of a pending HTTP upgrade.
+ ///
+ /// See [the `upgrade` module](::upgrade) for more.
+ pub fn on_upgrade(self) -> OnUpgrade {
+ self
+ .extra
+ .map(|ex| ex.on_upgrade)
+ .unwrap_or_else(OnUpgrade::none)
+ }
+
fn new(kind: Kind) -> Body {
Body {
kind: kind,
- delayed_eof: None,
+ extra: None,
}
}
@@ -150,23 +167,46 @@ impl Body {
Body::new(Kind::H2(recv))
}
+ pub(crate) fn set_on_upgrade(&mut self, upgrade: OnUpgrade) {
+ debug_assert!(!upgrade.is_none(), "set_on_upgrade with empty upgrade");
+ let extra = self.extra_mut();
+ debug_assert!(extra.on_upgrade.is_none(), "set_on_upgrade twice");
+ extra.on_upgrade = upgrade;
+ }
+
pub(crate) fn delayed_eof(&mut self, fut: DelayEofUntil) {
- self.delayed_eof = Some(DelayEof::NotEof(fut));
+ self.extra_mut().delayed_eof = Some(DelayEof::NotEof(fut));
+ }
+
+ fn take_delayed_eof(&mut self) -> Option {
+ self
+ .extra
+ .as_mut()
+ .and_then(|extra| extra.delayed_eof.take())
+ }
+
+ fn extra_mut(&mut self) -> &mut Extra {
+ self
+ .extra
+ .get_or_insert_with(|| Box::new(Extra {
+ delayed_eof: None,
+ on_upgrade: OnUpgrade::none(),
+ }))
}
fn poll_eof(&mut self) -> Poll, ::Error> {
- match self.delayed_eof.take() {
+ match self.take_delayed_eof() {
Some(DelayEof::NotEof(mut delay)) => {
match self.poll_inner() {
ok @ Ok(Async::Ready(Some(..))) |
ok @ Ok(Async::NotReady) => {
- self.delayed_eof = Some(DelayEof::NotEof(delay));
+ self.extra_mut().delayed_eof = Some(DelayEof::NotEof(delay));
ok
},
Ok(Async::Ready(None)) => match delay.poll() {
Ok(Async::Ready(never)) => match never {},
Ok(Async::NotReady) => {
- self.delayed_eof = Some(DelayEof::Eof(delay));
+ self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay));
Ok(Async::NotReady)
},
Err(_done) => {
@@ -180,7 +220,7 @@ impl Body {
match delay.poll() {
Ok(Async::Ready(never)) => match never {},
Ok(Async::NotReady) => {
- self.delayed_eof = Some(DelayEof::Eof(delay));
+ self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay));
Ok(Async::NotReady)
},
Err(_done) => {
diff --git a/src/client/conn.rs b/src/client/conn.rs
index e35d435c6e..e136fda350 100644
--- a/src/client/conn.rs
+++ b/src/client/conn.rs
@@ -9,6 +9,7 @@
//! higher-level [Client](super) API.
use std::fmt;
use std::marker::PhantomData;
+use std::mem;
use bytes::Bytes;
use futures::{Async, Future, Poll};
@@ -17,9 +18,21 @@ use tokio_io::{AsyncRead, AsyncWrite};
use body::Payload;
use common::Exec;
+use upgrade::Upgraded;
use proto;
use super::dispatch;
-use {Body, Request, Response, StatusCode};
+use {Body, Request, Response};
+
+type Http1Dispatcher = proto::dispatch::Dispatcher<
+ proto::dispatch::Client,
+ B,
+ T,
+ R,
+>;
+type ConnEither = Either<
+ Http1Dispatcher,
+ proto::h2::Client,
+>;
/// Returns a `Handshake` future over some IO.
///
@@ -48,15 +61,7 @@ where
T: AsyncRead + AsyncWrite + Send + 'static,
B: Payload + 'static,
{
- inner: Either<
- proto::dispatch::Dispatcher<
- proto::dispatch::Client,
- B,
- T,
- proto::ClientUpgradeTransaction,
- >,
- proto::h2::Client,
- >,
+ inner: Option>,
}
@@ -76,7 +81,9 @@ pub struct Builder {
/// If successful, yields a `(SendRequest, Connection)` pair.
#[must_use = "futures do nothing unless polled"]
pub struct Handshake {
- inner: HandshakeInner,
+ builder: Builder,
+ io: Option,
+ _marker: PhantomData,
}
/// A future returned by `SendRequest::send_request`.
@@ -112,27 +119,18 @@ pub struct Parts {
// ========== internal client api
/// A `Future` for when `SendRequest::poll_ready()` is ready.
+#[must_use = "futures do nothing unless polled"]
pub(super) struct WhenReady {
tx: Option>,
}
// A `SendRequest` that can be cloned to send HTTP2 requests.
// private for now, probably not a great idea of a type...
+#[must_use = "futures do nothing unless polled"]
pub(super) struct Http2SendRequest {
dispatch: dispatch::UnboundedSender, Response>,
}
-#[must_use = "futures do nothing unless polled"]
-pub(super) struct HandshakeNoUpgrades {
- inner: HandshakeInner,
-}
-
-struct HandshakeInner {
- builder: Builder,
- io: Option,
- _marker: PhantomData<(B, R)>,
-}
-
// ===== impl SendRequest
impl SendRequest
@@ -354,7 +352,7 @@ where
///
/// Only works for HTTP/1 connections. HTTP/2 connections will panic.
pub fn into_parts(self) -> Parts {
- let (io, read_buf, _) = match self.inner {
+ let (io, read_buf, _) = match self.inner.expect("already upgraded") {
Either::A(h1) => h1.into_inner(),
Either::B(_h2) => {
panic!("http2 cannot into_inner");
@@ -376,12 +374,12 @@ where
/// but it is not desired to actally shutdown the IO object. Instead you
/// would take it back using `into_parts`.
pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> {
- match self.inner {
- Either::A(ref mut h1) => {
+ match self.inner.as_mut().expect("already upgraded") {
+ &mut Either::A(ref mut h1) => {
h1.poll_without_shutdown()
},
- Either::B(ref mut h2) => {
- h2.poll()
+ &mut Either::B(ref mut h2) => {
+ h2.poll().map(|x| x.map(|_| ()))
}
}
}
@@ -396,7 +394,22 @@ where
type Error = ::Error;
fn poll(&mut self) -> Poll {
- self.inner.poll()
+ match try_ready!(self.inner.poll()) {
+ Some(proto::Dispatched::Shutdown) |
+ None => {
+ Ok(Async::Ready(()))
+ },
+ Some(proto::Dispatched::Upgrade(pending)) => {
+ let h1 = match mem::replace(&mut self.inner, None) {
+ Some(Either::A(h1)) => h1,
+ _ => unreachable!("Upgrade expects h1"),
+ };
+
+ let (io, buf, _) = h1.into_inner();
+ pending.fulfill(Upgraded::new(Box::new(io), buf));
+ Ok(Async::Ready(()))
+ }
+ }
}
}
@@ -456,25 +469,9 @@ impl Builder {
B: Payload + 'static,
{
Handshake {
- inner: HandshakeInner {
- builder: self.clone(),
- io: Some(io),
- _marker: PhantomData,
- }
- }
- }
-
- pub(super) fn handshake_no_upgrades(&self, io: T) -> HandshakeNoUpgrades
- where
- T: AsyncRead + AsyncWrite + Send + 'static,
- B: Payload + 'static,
- {
- HandshakeNoUpgrades {
- inner: HandshakeInner {
- builder: self.clone(),
- io: Some(io),
- _marker: PhantomData,
- }
+ builder: self.clone(),
+ io: Some(io),
+ _marker: PhantomData,
}
}
}
@@ -489,64 +486,6 @@ where
type Item = (SendRequest, Connection);
type Error = ::Error;
- fn poll(&mut self) -> Poll {
- self.inner.poll()
- .map(|async| {
- async.map(|(tx, dispatch)| {
- (tx, Connection { inner: dispatch })
- })
- })
- }
-}
-
-impl fmt::Debug for Handshake {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- f.debug_struct("Handshake")
- .finish()
- }
-}
-
-impl Future for HandshakeNoUpgrades
-where
- T: AsyncRead + AsyncWrite + Send + 'static,
- B: Payload + 'static,
-{
- type Item = (SendRequest, Either<
- proto::h1::Dispatcher<
- proto::h1::dispatch::Client,
- B,
- T,
- proto::ClientTransaction,
- >,
- proto::h2::Client,
- >);
- type Error = ::Error;
-
- fn poll(&mut self) -> Poll {
- self.inner.poll()
- }
-}
-
-impl Future for HandshakeInner
-where
- T: AsyncRead + AsyncWrite + Send + 'static,
- B: Payload,
- R: proto::h1::Http1Transaction<
- Incoming=StatusCode,
- Outgoing=proto::RequestLine,
- >,
-{
- type Item = (SendRequest, Either<
- proto::h1::Dispatcher<
- proto::h1::dispatch::Client,
- B,
- T,
- R,
- >,
- proto::h2::Client,
- >);
- type Error = ::Error;
-
fn poll(&mut self) -> Poll {
let io = self.io.take().expect("polled more than once");
let (tx, rx) = dispatch::channel();
@@ -570,11 +509,20 @@ where
SendRequest {
dispatch: tx,
},
- either,
+ Connection {
+ inner: Some(either),
+ },
)))
}
}
+impl fmt::Debug for Handshake {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_struct("Handshake")
+ .finish()
+ }
+}
+
// ===== impl ResponseFuture
impl Future for ResponseFuture {
diff --git a/src/client/mod.rs b/src/client/mod.rs
index 7cb7c0da56..d822ec9f77 100644
--- a/src/client/mod.rs
+++ b/src/client/mod.rs
@@ -268,7 +268,7 @@ where C: Connect + Sync + 'static,
.h1_writev(h1_writev)
.h1_title_case_headers(h1_title_case_headers)
.http2_only(pool_key.1 == Ver::Http2)
- .handshake_no_upgrades(io)
+ .handshake(io)
.and_then(move |(tx, conn)| {
executor.execute(conn.map_err(|e| {
debug!("client connection error: {}", e)
diff --git a/src/common/io/mod.rs b/src/common/io/mod.rs
new file mode 100644
index 0000000000..2e6d506153
--- /dev/null
+++ b/src/common/io/mod.rs
@@ -0,0 +1,3 @@
+mod rewind;
+
+pub(crate) use self::rewind::Rewind;
diff --git a/src/server/rewind.rs b/src/common/io/rewind.rs
similarity index 92%
rename from src/server/rewind.rs
rename to src/common/io/rewind.rs
index 6d2bf90eda..797dad9a74 100644
--- a/src/server/rewind.rs
+++ b/src/common/io/rewind.rs
@@ -1,26 +1,40 @@
+use std::cmp;
+use std::io::{self, Read, Write};
+
use bytes::{Buf, BufMut, Bytes, IntoBuf};
use futures::{Async, Poll};
-use std::io::{self, Read, Write};
-use std::cmp;
use tokio_io::{AsyncRead, AsyncWrite};
+/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
-pub struct Rewind {
+pub(crate) struct Rewind {
pre: Option,
inner: T,
}
impl Rewind {
- pub(super) fn new(tcp: T) -> Rewind {
+ pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
- inner: tcp,
+ inner: io,
+ }
+ }
+
+ pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
+ Rewind {
+ pre: Some(buf),
+ inner: io,
}
}
- pub fn rewind(&mut self, bs: Bytes) {
+
+ pub(crate) fn rewind(&mut self, bs: Bytes) {
debug_assert!(self.pre.is_none());
self.pre = Some(bs);
}
+
+ pub(crate) fn into_inner(self) -> (T, Bytes) {
+ (self.inner, self.pre.unwrap_or_else(Bytes::new))
+ }
}
impl Read for Rewind
diff --git a/src/common/mod.rs b/src/common/mod.rs
index 124b2100e8..1bfa980fc9 100644
--- a/src/common/mod.rs
+++ b/src/common/mod.rs
@@ -1,5 +1,6 @@
mod buf;
mod exec;
+pub(crate) mod io;
mod never;
pub(crate) use self::buf::StaticBuf;
diff --git a/src/error.rs b/src/error.rs
index ab65e75523..337bea0c5b 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -61,6 +61,12 @@ pub(crate) enum Kind {
UnsupportedVersion,
/// User tried to create a CONNECT Request with the Client.
UnsupportedRequestMethod,
+
+ /// User tried polling for an upgrade that doesn't exist.
+ NoUpgrade,
+
+ /// User polled for an upgrade, but low-level API is not using upgrades.
+ ManualUpgrade,
}
#[derive(Debug, PartialEq)]
@@ -72,9 +78,6 @@ pub(crate) enum Parse {
Header,
TooLarge,
Status,
-
- /// A protocol upgrade was encountered, but not yet supported in hyper.
- UpgradeNotSupported,
}
/*
@@ -110,7 +113,8 @@ impl Error {
Kind::Service |
Kind::Closed |
Kind::UnsupportedVersion |
- Kind::UnsupportedRequestMethod => true,
+ Kind::UnsupportedRequestMethod |
+ Kind::NoUpgrade => true,
_ => false,
}
}
@@ -216,6 +220,14 @@ impl Error {
Error::new(Kind::UnsupportedRequestMethod, None)
}
+ pub(crate) fn new_user_no_upgrade() -> Error {
+ Error::new(Kind::NoUpgrade, None)
+ }
+
+ pub(crate) fn new_user_manual_upgrade() -> Error {
+ Error::new(Kind::ManualUpgrade, None)
+ }
+
pub(crate) fn new_user_new_service>(cause: E) -> Error {
Error::new(Kind::NewService, Some(cause.into()))
}
@@ -266,7 +278,6 @@ impl StdError for Error {
Kind::Parse(Parse::Header) => "invalid Header provided",
Kind::Parse(Parse::TooLarge) => "message head is too large",
Kind::Parse(Parse::Status) => "invalid Status provided",
- Kind::Parse(Parse::UpgradeNotSupported) => "unsupported protocol upgrade",
Kind::Incomplete => "message is incomplete",
Kind::MismatchedResponse => "response received without matching request",
Kind::Closed => "connection closed",
@@ -284,6 +295,8 @@ impl StdError for Error {
Kind::Http2 => "http2 general error",
Kind::UnsupportedVersion => "request has unsupported HTTP version",
Kind::UnsupportedRequestMethod => "request has unsupported HTTP method",
+ Kind::NoUpgrade => "no upgrade available",
+ Kind::ManualUpgrade => "upgrade expected but low level API in use",
Kind::Io => "an IO error occurred",
}
diff --git a/src/lib.rs b/src/lib.rs
index 968a1f772a..adddcbcb38 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -62,3 +62,4 @@ mod proto;
pub mod server;
pub mod service;
#[cfg(feature = "runtime")] pub mod rt;
+pub mod upgrade;
diff --git a/src/mock.rs b/src/mock.rs
index d8147c2cc9..b3c239c52d 100644
--- a/src/mock.rs
+++ b/src/mock.rs
@@ -79,6 +79,7 @@ pub struct AsyncIo {
inner: T,
max_read_vecs: usize,
num_writes: usize,
+ panic: bool,
park_tasks: bool,
task: Option,
}
@@ -93,6 +94,7 @@ impl AsyncIo {
inner: inner,
max_read_vecs: READ_VECS_CNT,
num_writes: 0,
+ panic: false,
park_tasks: false,
task: None,
}
@@ -110,6 +112,11 @@ impl AsyncIo {
self.error = Some(err);
}
+ #[cfg(feature = "nightly")]
+ pub fn panic(&mut self) {
+ self.panic = true;
+ }
+
pub fn max_read_vecs(&mut self, cnt: usize) {
assert!(cnt <= READ_VECS_CNT);
self.max_read_vecs = cnt;
@@ -185,6 +192,7 @@ impl, T: AsRef<[u8]>> PartialEq for AsyncIo {
impl Read for AsyncIo {
fn read(&mut self, buf: &mut [u8]) -> io::Result {
+ assert!(!self.panic, "AsyncIo::read panic");
self.blocked = false;
if let Some(err) = self.error.take() {
Err(err)
@@ -201,6 +209,7 @@ impl Read for AsyncIo {
impl Write for AsyncIo {
fn write(&mut self, data: &[u8]) -> io::Result {
+ assert!(!self.panic, "AsyncIo::write panic");
self.num_writes += 1;
if let Some(err) = self.error.take() {
trace!("AsyncIo::write error");
@@ -233,6 +242,7 @@ impl AsyncWrite for AsyncIo {
}
fn write_buf(&mut self, buf: &mut B) -> Poll {
+ assert!(!self.panic, "AsyncIo::write_buf panic");
if self.max_read_vecs == 0 {
return self.write_no_vecs(buf);
}
diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs
index 4c57e6e87e..925db1d8f8 100644
--- a/src/proto/h1/conn.rs
+++ b/src/proto/h1/conn.rs
@@ -8,9 +8,9 @@ use http::{HeaderMap, Method, Version};
use tokio_io::{AsyncRead, AsyncWrite};
use ::Chunk;
-use proto::{BodyLength, MessageHead};
+use proto::{BodyLength, DecodedLength, MessageHead};
use super::io::{Buffered};
-use super::{EncodedBuf, Encode, Encoder, Decode, Decoder, Http1Transaction, ParseContext};
+use super::{EncodedBuf, Encode, Encoder, /*Decode,*/ Decoder, Http1Transaction, ParseContext};
const H2_PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
@@ -44,6 +44,7 @@ where I: AsyncRead + AsyncWrite,
notify_read: false,
reading: Reading::Init,
writing: Writing::Init,
+ upgrade: None,
// We assume a modern world where the remote speaks HTTP/1.1.
// If they tell us otherwise, we'll downgrade in `read_head`.
version: Version::HTTP_11,
@@ -72,6 +73,10 @@ where I: AsyncRead + AsyncWrite,
self.io.into_inner()
}
+ pub fn pending_upgrade(&mut self) -> Option<::upgrade::Pending> {
+ self.state.upgrade.take()
+ }
+
pub fn is_read_closed(&self) -> bool {
self.state.is_read_closed()
}
@@ -114,80 +119,61 @@ where I: AsyncRead + AsyncWrite,
read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE
}
- pub fn read_head(&mut self) -> Poll, Option)>, ::Error> {
+ pub fn read_head(&mut self) -> Poll, DecodedLength, bool)>, ::Error> {
debug_assert!(self.can_read_head());
trace!("Conn::read_head");
- loop {
- let msg = match self.io.parse::(ParseContext {
- cached_headers: &mut self.state.cached_headers,
- req_method: &mut self.state.method,
- }) {
- Ok(Async::Ready(msg)) => msg,
- Ok(Async::NotReady) => return Ok(Async::NotReady),
- Err(e) => {
- // If we are currently waiting on a message, then an empty
- // message should be reported as an error. If not, it is just
- // the connection closing gracefully.
- let must_error = self.should_error_on_eof();
- self.state.close_read();
- self.io.consume_leading_lines();
- let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty();
- return if was_mid_parse || must_error {
- // We check if the buf contains the h2 Preface
- debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
- self.on_parse_error(e)
- .map(|()| Async::NotReady)
- } else {
- debug!("read eof");
- Ok(Async::Ready(None))
- };
- }
- };
+ let msg = match self.io.parse::(ParseContext {
+ cached_headers: &mut self.state.cached_headers,
+ req_method: &mut self.state.method,
+ }) {
+ Ok(Async::Ready(msg)) => msg,
+ Ok(Async::NotReady) => return Ok(Async::NotReady),
+ Err(e) => return self.on_read_head_error(e),
+ };
- self.state.version = msg.head.version;
- let head = msg.head;
- let decoder = match msg.decode {
- Decode::Normal(d) => {
- d
- },
- Decode::Final(d) => {
- trace!("final decoder, HTTP ending");
- debug_assert!(d.is_eof());
- self.state.close_read();
- d
- },
- Decode::Ignore => {
- // likely a 1xx message that we can ignore
- continue;
- }
- };
- debug!("incoming body is {}", decoder);
+ // Note: don't deconstruct `msg` into local variables, it appears
+ // the optimizer doesn't remove the extra copies.
- self.state.busy();
+ debug!("incoming body is {}", msg.decode);
+
+ self.state.busy();
+ self.state.keep_alive &= msg.keep_alive;
+ self.state.version = msg.head.version;
+
+ if msg.decode == DecodedLength::ZERO {
+ debug_assert!(!msg.expect_continue, "expect-continue needs a body");
+ self.state.reading = Reading::KeepAlive;
+ if !T::should_read_first() {
+ self.try_keep_alive();
+ }
+ } else {
if msg.expect_continue {
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
self.io.headers_buf().extend_from_slice(cont);
}
- let wants_keep_alive = msg.keep_alive;
- self.state.keep_alive &= wants_keep_alive;
-
- let content_length = decoder.content_length();
+ self.state.reading = Reading::Body(Decoder::new(msg.decode));
+ };
- if let Reading::Closed = self.state.reading {
- // actually want an `if not let ...`
- } else {
- self.state.reading = if content_length.is_none() {
- Reading::KeepAlive
- } else {
- Reading::Body(decoder)
- };
- }
- if content_length.is_none() {
- self.try_keep_alive();
- }
+ Ok(Async::Ready(Some((msg.head, msg.decode, msg.wants_upgrade))))
+ }
- return Ok(Async::Ready(Some((head, content_length))));
+ fn on_read_head_error(&mut self, e: ::Error) -> Poll, ::Error> {
+ // If we are currently waiting on a message, then an empty
+ // message should be reported as an error. If not, it is just
+ // the connection closing gracefully.
+ let must_error = self.should_error_on_eof();
+ self.state.close_read();
+ self.io.consume_leading_lines();
+ let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty();
+ if was_mid_parse || must_error {
+ // We check if the buf contains the h2 Preface
+ debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
+ self.on_parse_error(e)
+ .map(|()| Async::NotReady)
+ } else {
+ debug!("read eof");
+ Ok(Async::Ready(None))
}
}
@@ -612,6 +598,10 @@ where I: AsyncRead + AsyncWrite,
}
}
+ pub(super) fn on_upgrade(&mut self) -> ::upgrade::OnUpgrade {
+ self.state.prepare_upgrade()
+ }
+
// Used in h1::dispatch tests
#[cfg(test)]
pub(super) fn io_mut(&mut self) -> &mut I {
@@ -649,6 +639,8 @@ struct State {
reading: Reading,
/// State of allowed writes
writing: Writing,
+ /// An expected pending HTTP upgrade.
+ upgrade: Option<::upgrade::Pending>,
/// Either HTTP/1.0 or 1.1 connection
version: Version,
}
@@ -697,6 +689,7 @@ impl fmt::Debug for Writing {
impl ::std::ops::BitAndAssign for KA {
fn bitand_assign(&mut self, enabled: bool) {
if !enabled {
+ trace!("remote disabling keep-alive");
*self = KA::Disabled;
}
}
@@ -821,11 +814,53 @@ impl State {
_ => false
}
}
+
+ fn prepare_upgrade(&mut self) -> ::upgrade::OnUpgrade {
+ trace!("prepare possible HTTP upgrade");
+ debug_assert!(self.upgrade.is_none());
+ let (tx, rx) = ::upgrade::pending();
+ self.upgrade = Some(tx);
+ rx
+ }
}
#[cfg(test)]
//TODO: rewrite these using dispatch
mod tests {
+
+ #[cfg(feature = "nightly")]
+ #[bench]
+ fn bench_read_head_short(b: &mut ::test::Bencher) {
+ use super::*;
+ let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n";
+ let len = s.len();
+ b.bytes = len as u64;
+
+ let mut io = ::mock::AsyncIo::new_buf(Vec::new(), 0);
+ io.panic();
+ let mut conn = Conn::<_, ::Chunk, ::proto::h1::ServerTransaction>::new(io);
+ *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]);
+ conn.state.cached_headers = Some(HeaderMap::with_capacity(2));
+
+ b.iter(|| {
+ match conn.read_head().unwrap() {
+ Async::Ready(Some(x)) => {
+ ::test::black_box(&x);
+ let mut headers = x.0.headers;
+ headers.clear();
+ conn.state.cached_headers = Some(headers);
+ },
+ f => panic!("expected Ready(Some(..)): {:?}", f)
+ }
+
+
+ conn.io.read_buf_mut().reserve(1);
+ unsafe {
+ conn.io.read_buf_mut().set_len(len);
+ }
+ conn.state.reading = Reading::Init;
+ });
+ }
/*
use futures::{Async, Future, Stream, Sink};
use futures::future;
diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs
index 03296878af..b8e2cac7d5 100644
--- a/src/proto/h1/decode.rs
+++ b/src/proto/h1/decode.rs
@@ -7,7 +7,7 @@ use futures::{Async, Poll};
use bytes::Bytes;
use super::io::MemRead;
-use super::BodyLength;
+use super::{DecodedLength};
use self::Kind::{Length, Chunked, Eof};
@@ -74,6 +74,14 @@ impl Decoder {
Decoder { kind: Kind::Eof(false) }
}
+ pub(super) fn new(len: DecodedLength) -> Self {
+ match len {
+ DecodedLength::CHUNKED => Decoder::chunked(),
+ DecodedLength::CLOSE_DELIMITED => Decoder::eof(),
+ length => Decoder::length(length.danger_len()),
+ }
+ }
+
// methods
pub fn is_eof(&self) -> bool {
@@ -85,16 +93,6 @@ impl Decoder {
}
}
- pub fn content_length(&self) -> Option {
- match self.kind {
- Length(0) |
- Chunked(ChunkedState::End, _) |
- Eof(true) => None,
- Length(len) => Some(BodyLength::Known(len)),
- _ => Some(BodyLength::Unknown),
- }
- }
-
pub fn decode(&mut self, body: &mut R) -> Poll {
trace!("decode; state={:?}", self.kind);
match self.kind {
@@ -152,16 +150,6 @@ impl fmt::Debug for Decoder {
}
}
-impl fmt::Display for Decoder {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match self.kind {
- Kind::Length(n) => write!(f, "content-length ({} bytes)", n),
- Kind::Chunked(..) => f.write_str("chunked encoded"),
- Kind::Eof(..) => f.write_str("until end-of-file"),
- }
- }
-}
-
macro_rules! byte (
($rdr:ident) => ({
let buf = try_ready!($rdr.read_mem(1));
diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs
index ff411e81ee..290edfe20d 100644
--- a/src/proto/h1/dispatch.rs
+++ b/src/proto/h1/dispatch.rs
@@ -5,7 +5,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use body::{Body, Payload};
use body::internal::FullDataArg;
-use proto::{BodyLength, Conn, MessageHead, RequestHead, RequestLine, ResponseHead};
+use proto::{BodyLength, DecodedLength, Conn, Dispatched, MessageHead, RequestHead, RequestLine, ResponseHead};
use super::Http1Transaction;
use service::Service;
@@ -65,32 +65,34 @@ where
(io, buf, self.dispatch)
}
- /// The "Future" poll function. Runs this dispatcher until the
- /// connection is shutdown, or an error occurs.
- pub fn poll_until_shutdown(&mut self) -> Poll<(), ::Error> {
- self.poll_catch(true)
- }
-
/// Run this dispatcher until HTTP says this connection is done,
/// but don't call `AsyncWrite::shutdown` on the underlying IO.
///
- /// This is useful for HTTP upgrades.
+ /// This is useful for old-style HTTP upgrades, but ignores
+ /// newer-style upgrade API.
pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> {
self.poll_catch(false)
+ .map(|x| {
+ x.map(|ds| if let Dispatched::Upgrade(pending) = ds {
+ pending.manual();
+ })
+ })
}
- fn poll_catch(&mut self, should_shutdown: bool) -> Poll<(), ::Error> {
+ fn poll_catch(&mut self, should_shutdown: bool) -> Poll {
self.poll_inner(should_shutdown).or_else(|e| {
// An error means we're shutting down either way.
// We just try to give the error to the user,
// and close the connection with an Ok. If we
// cannot give it to the user, then return the Err.
- self.dispatch.recv_msg(Err(e)).map(Async::Ready)
+ self.dispatch.recv_msg(Err(e))?;
+ Ok(Async::Ready(Dispatched::Shutdown))
})
}
- fn poll_inner(&mut self, should_shutdown: bool) -> Poll<(), ::Error> {
+ fn poll_inner(&mut self, should_shutdown: bool) -> Poll {
T::update_date();
+
loop {
self.poll_read()?;
self.poll_write()?;
@@ -110,11 +112,14 @@ where
}
if self.is_done() {
- if should_shutdown {
+ if let Some(pending) = self.conn.pending_upgrade() {
+ self.conn.take_error()?;
+ return Ok(Async::Ready(Dispatched::Upgrade(pending)));
+ } else if should_shutdown {
try_ready!(self.conn.shutdown().map_err(::Error::new_shutdown));
}
self.conn.take_error()?;
- Ok(Async::Ready(()))
+ Ok(Async::Ready(Dispatched::Shutdown))
} else {
Ok(Async::NotReady)
}
@@ -190,20 +195,18 @@ where
}
// dispatch is ready for a message, try to read one
match self.conn.read_head() {
- Ok(Async::Ready(Some((head, body_len)))) => {
- let body = if let Some(body_len) = body_len {
- let (mut tx, rx) =
- Body::new_channel(if let BodyLength::Known(len) = body_len {
- Some(len)
- } else {
- None
- });
- let _ = tx.poll_ready(); // register this task if rx is dropped
- self.body_tx = Some(tx);
- rx
- } else {
- Body::empty()
+ Ok(Async::Ready(Some((head, body_len, wants_upgrade)))) => {
+ let mut body = match body_len {
+ DecodedLength::ZERO => Body::empty(),
+ other => {
+ let (tx, rx) = Body::new_channel(other.into_opt());
+ self.body_tx = Some(tx);
+ rx
+ },
};
+ if wants_upgrade {
+ body.set_on_upgrade(self.conn.on_upgrade());
+ }
self.dispatch.recv_msg(Ok((head, body)))?;
Ok(Async::Ready(()))
}
@@ -326,7 +329,6 @@ where
}
}
-
impl Future for Dispatcher
where
D: Dispatch, PollBody=Bs, RecvItem=MessageHead>,
@@ -334,12 +336,12 @@ where
T: Http1Transaction,
Bs: Payload,
{
- type Item = ();
+ type Item = Dispatched;
type Error = ::Error;
#[inline]
fn poll(&mut self) -> Poll {
- self.poll_until_shutdown()
+ self.poll_catch(true)
}
}
@@ -519,7 +521,7 @@ mod tests {
use super::*;
use mock::AsyncIo;
- use proto::ClientTransaction;
+ use proto::h1::ClientTransaction;
#[test]
fn client_read_bytes_before_writing_request() {
diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs
index d0a953c7c3..441dec74f4 100644
--- a/src/proto/h1/io.rs
+++ b/src/proto/h1/io.rs
@@ -93,6 +93,12 @@ where
self.read_buf.as_ref()
}
+ #[cfg(test)]
+ #[cfg(feature = "nightly")]
+ pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut {
+ &mut self.read_buf
+ }
+
pub fn headers_buf(&mut self) -> &mut Vec {
let buf = self.write_buf.headers_mut();
&mut buf.bytes
@@ -595,7 +601,7 @@ mod tests {
cached_headers: &mut None,
req_method: &mut None,
};
- assert!(buffered.parse::<::proto::ClientTransaction>(ctx).unwrap().is_not_ready());
+ assert!(buffered.parse::<::proto::h1::ClientTransaction>(ctx).unwrap().is_not_ready());
assert!(buffered.io.blocked());
}
diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs
index 46f0eeab00..3adf260d60 100644
--- a/src/proto/h1/mod.rs
+++ b/src/proto/h1/mod.rs
@@ -1,7 +1,7 @@
use bytes::BytesMut;
use http::{HeaderMap, Method};
-use proto::{MessageHead, BodyLength};
+use proto::{MessageHead, BodyLength, DecodedLength};
pub(crate) use self::conn::Conn;
pub(crate) use self::dispatch::Dispatcher;
@@ -19,12 +19,8 @@ mod io;
mod role;
-pub(crate) type ServerTransaction = self::role::Server;
-//pub type ServerTransaction = self::role::Server;
-//pub type ServerUpgradeTransaction = self::role::Server;
-
-pub(crate) type ClientTransaction = self::role::Client;
-pub(crate) type ClientUpgradeTransaction = self::role::Client;
+pub(crate) type ServerTransaction = role::Server;
+pub(crate) type ClientTransaction = role::Client;
pub(crate) trait Http1Transaction {
type Incoming;
@@ -40,14 +36,16 @@ pub(crate) trait Http1Transaction {
fn update_date() {}
}
+/// Result newtype for Http1Transaction::parse.
pub(crate) type ParseResult = Result>, ::error::Parse>;
#[derive(Debug)]
pub(crate) struct ParsedMessage {
head: MessageHead,
- decode: Decode,
+ decode: DecodedLength,
expect_continue: bool,
keep_alive: bool,
+ wants_upgrade: bool,
}
pub(crate) struct ParseContext<'a> {
@@ -64,12 +62,3 @@ pub(crate) struct Encode<'a, T: 'a> {
title_case_headers: bool,
}
-#[derive(Debug, PartialEq)]
-pub enum Decode {
- /// Decode normally.
- Normal(Decoder),
- /// After this decoder is done, HTTP is done.
- Final(Decoder),
- /// A header block that should be ignored, like unknown 1xx responses.
- Ignore,
-}
diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs
index dd21b623f9..3e30d16645 100644
--- a/src/proto/h1/role.rs
+++ b/src/proto/h1/role.rs
@@ -8,25 +8,19 @@ use httparse;
use error::Parse;
use headers;
-use proto::{BodyLength, MessageHead, RequestLine, RequestHead};
-use proto::h1::{Decode, Decoder, Encode, Encoder, Http1Transaction, ParseResult, ParseContext, ParsedMessage, date};
+use proto::{BodyLength, DecodedLength, MessageHead, RequestLine, RequestHead};
+use proto::h1::{Encode, Encoder, Http1Transaction, ParseResult, ParseContext, ParsedMessage, date};
const MAX_HEADERS: usize = 100;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
// There are 2 main roles, Client and Server.
-//
-// There is 1 modifier, OnUpgrade, which can wrap Client and Server,
-// to signal that HTTP upgrades are not supported.
-pub(crate) struct Client(T);
+pub(crate) enum Client {}
-pub(crate) struct Server(T);
+pub(crate) enum Server {}
-impl Http1Transaction for Server
-where
- T: OnUpgrade,
-{
+impl Http1Transaction for Server {
type Incoming = RequestLine;
type Outgoing = StatusCode;
@@ -34,31 +28,45 @@ where
if buf.len() == 0 {
return Ok(None);
}
+
+ let mut keep_alive;
+ let is_http_11;
+ let subject;
+ let version;
+ let len;
+ let headers_len;
+
// Unsafe: both headers_indices and headers are using unitialized memory,
// but we *never* read any of it until after httparse has assigned
// values into it. By not zeroing out the stack memory, this saves
// a good ~5% on pipeline benchmarks.
let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() };
- let (len, subject, version, headers_len) = {
+ {
let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() };
trace!("Request.parse([Header; {}], [u8; {}])", headers.len(), buf.len());
let mut req = httparse::Request::new(&mut headers);
let bytes = buf.as_ref();
match req.parse(bytes)? {
- httparse::Status::Complete(len) => {
- trace!("Request.parse Complete({})", len);
- let method = Method::from_bytes(req.method.unwrap().as_bytes())?;
- let path = req.path.unwrap().parse()?;
- let subject = RequestLine(method, path);
- let version = if req.version.unwrap() == 1 {
+ httparse::Status::Complete(parsed_len) => {
+ trace!("Request.parse Complete({})", parsed_len);
+ len = parsed_len;
+ subject = RequestLine(
+ Method::from_bytes(req.method.unwrap().as_bytes())?,
+ req.path.unwrap().parse()?
+ );
+ version = if req.version.unwrap() == 1 {
+ keep_alive = true;
+ is_http_11 = true;
Version::HTTP_11
} else {
+ keep_alive = false;
+ is_http_11 = false;
Version::HTTP_10
};
record_header_indices(bytes, &req.headers, &mut headers_indices);
- let headers_len = req.headers.len();
- (len, subject, version, headers_len)
+ headers_len = req.headers.len();
+ //(len, subject, version, headers_len)
}
httparse::Status::Partial => return Ok(None),
}
@@ -76,12 +84,12 @@ where
// 7. (irrelevant to Request)
- let mut decoder = None;
+ let mut decoder = DecodedLength::ZERO;
let mut expect_continue = false;
- let mut keep_alive = version == Version::HTTP_11;
let mut con_len = None;
let mut is_te = false;
let mut is_te_chunked = false;
+ let mut wants_upgrade = subject.0 == Method::CONNECT;
let mut headers = ctx.cached_headers
.take()
@@ -104,16 +112,14 @@ where
// If Transfer-Encoding header is present, and 'chunked' is
// not the final encoding, and this is a Request, then it is
// mal-formed. A server should respond with 400 Bad Request.
- if version == Version::HTTP_10 {
+ if !is_http_11 {
debug!("HTTP/1.0 cannot have Transfer-Encoding header");
return Err(Parse::Header);
}
is_te = true;
if headers::is_chunked_(&value) {
is_te_chunked = true;
- decoder = Some(Decoder::chunked());
- //debug!("request with transfer-encoding header, but not chunked, bad request");
- //return Err(Parse::Header);
+ decoder = DecodedLength::CHUNKED;
}
},
header::CONTENT_LENGTH => {
@@ -135,8 +141,8 @@ where
// we don't need to append this secondary length
continue;
}
+ decoder = DecodedLength::checked_new(len)?;
con_len = Some(len);
- decoder = Some(Decoder::length(len));
},
header::CONNECTION => {
// keep_alive was previously set to default for Version
@@ -152,6 +158,10 @@ where
header::EXPECT => {
expect_continue = value.as_bytes() == b"100-continue";
},
+ header::UPGRADE => {
+ // Upgrades are only allowed with HTTP/1.1
+ wants_upgrade = is_http_11;
+ },
_ => (),
}
@@ -159,15 +169,10 @@ where
headers.append(name, value);
}
- let decoder = if let Some(decoder) = decoder {
- decoder
- } else {
- if is_te && !is_te_chunked {
- debug!("request with transfer-encoding header, but not chunked, bad request");
- return Err(Parse::Header);
- }
- Decoder::length(0)
- };
+ if is_te && !is_te_chunked {
+ debug!("request with transfer-encoding header, but not chunked, bad request");
+ return Err(Parse::Header);
+ }
*ctx.req_method = Some(subject.0.clone());
@@ -177,9 +182,10 @@ where
subject,
headers,
},
- decode: Decode::Normal(decoder),
+ decode: decoder,
expect_continue,
keep_alive,
+ wants_upgrade,
}))
}
@@ -194,7 +200,7 @@ where
let is_upgrade = msg.head.subject == StatusCode::SWITCHING_PROTOCOLS
|| (msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success());
let (ret, mut is_last) = if is_upgrade {
- (T::on_encode_upgrade(&mut msg), true)
+ (Ok(()), true)
} else if msg.head.subject.is_informational() {
error!("response with 1xx status code not supported");
*msg.head = MessageHead::default();
@@ -485,7 +491,7 @@ where
}
}
-impl Server<()> {
+impl Server {
fn can_have_body(method: &Option, status: StatusCode) -> bool {
Server::can_chunked(method, status)
}
@@ -508,65 +514,69 @@ impl Server<()> {
}
}
-impl Http1Transaction for Client
-where
- T: OnUpgrade,
-{
+impl Http1Transaction for Client {
type Incoming = StatusCode;
type Outgoing = RequestLine;
fn parse(buf: &mut BytesMut, ctx: ParseContext) -> ParseResult {
- if buf.len() == 0 {
- return Ok(None);
- }
- // Unsafe: see comment in Server Http1Transaction, above.
- let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() };
- let (len, status, version, headers_len) = {
- let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() };
- trace!("Response.parse([Header; {}], [u8; {}])", headers.len(), buf.len());
- let mut res = httparse::Response::new(&mut headers);
- let bytes = buf.as_ref();
- match res.parse(bytes)? {
- httparse::Status::Complete(len) => {
- trace!("Response.parse Complete({})", len);
- let status = StatusCode::from_u16(res.code.unwrap())?;
- let version = if res.version.unwrap() == 1 {
- Version::HTTP_11
- } else {
- Version::HTTP_10
- };
- record_header_indices(bytes, &res.headers, &mut headers_indices);
- let headers_len = res.headers.len();
- (len, status, version, headers_len)
- },
- httparse::Status::Partial => return Ok(None),
+ // Loop to skip information status code headers (100 Continue, etc).
+ loop {
+ if buf.len() == 0 {
+ return Ok(None);
}
- };
+ // Unsafe: see comment in Server Http1Transaction, above.
+ let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() };
+ let (len, status, version, headers_len) = {
+ let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() };
+ trace!("Response.parse([Header; {}], [u8; {}])", headers.len(), buf.len());
+ let mut res = httparse::Response::new(&mut headers);
+ let bytes = buf.as_ref();
+ match res.parse(bytes)? {
+ httparse::Status::Complete(len) => {
+ trace!("Response.parse Complete({})", len);
+ let status = StatusCode::from_u16(res.code.unwrap())?;
+ let version = if res.version.unwrap() == 1 {
+ Version::HTTP_11
+ } else {
+ Version::HTTP_10
+ };
+ record_header_indices(bytes, &res.headers, &mut headers_indices);
+ let headers_len = res.headers.len();
+ (len, status, version, headers_len)
+ },
+ httparse::Status::Partial => return Ok(None),
+ }
+ };
- let slice = buf.split_to(len).freeze();
+ let slice = buf.split_to(len).freeze();
- let mut headers = ctx.cached_headers
- .take()
- .unwrap_or_else(HeaderMap::new);
+ let mut headers = ctx.cached_headers
+ .take()
+ .unwrap_or_else(HeaderMap::new);
- headers.reserve(headers_len);
- fill_headers(&mut headers, slice, &headers_indices[..headers_len]);
+ headers.reserve(headers_len);
+ fill_headers(&mut headers, slice, &headers_indices[..headers_len]);
- let keep_alive = version == Version::HTTP_11;
+ let keep_alive = version == Version::HTTP_11;
- let head = MessageHead {
- version,
- subject: status,
- headers,
- };
- let decode = Client::::decoder(&head, ctx.req_method)?;
+ let head = MessageHead {
+ version,
+ subject: status,
+ headers,
+ };
+ if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? {
+ return Ok(Some(ParsedMessage {
+ head,
+ decode,
+ expect_continue: false,
+ // a client upgrade means the connection can't be used
+ // again, as it is definitely upgrading.
+ keep_alive: keep_alive && !is_upgrade,
+ wants_upgrade: is_upgrade,
+ }));
+ }
- Ok(Some(ParsedMessage {
- head,
- decode,
- expect_continue: false,
- keep_alive,
- }))
+ }
}
fn encode(msg: Encode, dst: &mut Vec) -> ::Result {
@@ -617,8 +627,11 @@ where
}
}
-impl Client {
- fn decoder(inc: &MessageHead, method: &mut Option) -> Result {
+impl Client {
+ /// Returns Some(length, wants_upgrade) if successful.
+ ///
+ /// Returns None if this message head should be skipped (like a 100 status).
+ fn decoder(inc: &MessageHead, method: &mut Option) -> Result, Parse> {
// According to https://tools.ietf.org/html/rfc7230#section-3.3.3
// 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body.
// 2. Status 2xx to a CONNECT cannot have a body.
@@ -630,23 +643,23 @@ impl Client {
match inc.subject.as_u16() {
101 => {
- return T::on_decode_upgrade().map(Decode::Final);
+ return Ok(Some((DecodedLength::ZERO, true)));
},
100...199 => {
trace!("ignoring informational response: {}", inc.subject.as_u16());
- return Ok(Decode::Ignore);
+ return Ok(None);
},
204 |
- 304 => return Ok(Decode::Normal(Decoder::length(0))),
+ 304 => return Ok(Some((DecodedLength::ZERO, false))),
_ => (),
}
match *method {
Some(Method::HEAD) => {
- return Ok(Decode::Normal(Decoder::length(0)));
+ return Ok(Some((DecodedLength::ZERO, false)));
}
Some(Method::CONNECT) => match inc.subject.as_u16() {
200...299 => {
- return Ok(Decode::Final(Decoder::length(0)));
+ return Ok(Some((DecodedLength::ZERO, true)));
},
_ => {},
},
@@ -665,24 +678,24 @@ impl Client {
debug!("HTTP/1.0 cannot have Transfer-Encoding header");
Err(Parse::Header)
} else if headers::transfer_encoding_is_chunked(&inc.headers) {
- Ok(Decode::Normal(Decoder::chunked()))
+ Ok(Some((DecodedLength::CHUNKED, false)))
} else {
trace!("not chunked, read till eof");
- Ok(Decode::Normal(Decoder::eof()))
+ Ok(Some((DecodedLength::CHUNKED, false)))
}
} else if let Some(len) = headers::content_length_parse_all(&inc.headers) {
- Ok(Decode::Normal(Decoder::length(len)))
+ Ok(Some((DecodedLength::checked_new(len)?, false)))
} else if inc.headers.contains_key(header::CONTENT_LENGTH) {
debug!("illegal Content-Length header");
Err(Parse::Header)
} else {
trace!("neither Transfer-Encoding nor Content-Length");
- Ok(Decode::Normal(Decoder::eof()))
+ Ok(Some((DecodedLength::CLOSE_DELIMITED, false)))
}
}
}
-impl Client<()> {
+impl Client {
fn set_length(head: &mut RequestHead, body: Option) -> Encoder {
if let Some(body) = body {
let can_chunked = head.version == Version::HTTP_11
@@ -830,51 +843,6 @@ fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder {
}
}
-pub(crate) trait OnUpgrade {
- fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()>;
- fn on_decode_upgrade() -> Result;
-}
-
-pub(crate) enum YesUpgrades {}
-
-pub(crate) enum NoUpgrades {}
-
-impl OnUpgrade for YesUpgrades {
- fn on_encode_upgrade(_: &mut Encode) -> ::Result<()> {
- Ok(())
- }
-
- fn on_decode_upgrade() -> Result {
- debug!("101 response received, upgrading");
- // 101 upgrades always have no body
- Ok(Decoder::length(0))
- }
-}
-
-impl OnUpgrade for NoUpgrades {
- fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()> {
- *msg.head = MessageHead::default();
- msg.head.subject = ::StatusCode::INTERNAL_SERVER_ERROR;
- msg.body = None;
-
- if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS {
- error!("response with 101 status code not supported");
- Err(Parse::UpgradeNotSupported.into())
- } else if msg.req_method == &Some(Method::CONNECT) {
- error!("200 response to CONNECT request not supported");
- Err(::Error::new_user_unsupported_request_method())
- } else {
- debug_assert!(false, "upgrade incorrectly detected");
- Err(::Error::new_status())
- }
- }
-
- fn on_decode_upgrade() -> Result {
- debug!("received 101 upgrade response, not supported");
- Err(Parse::UpgradeNotSupported)
- }
-}
-
#[derive(Clone, Copy)]
struct HeaderIndices {
name: (usize, usize),
@@ -978,10 +946,6 @@ mod tests {
use bytes::BytesMut;
use super::*;
- use super::{Server as S, Client as C};
-
- type Server = S;
- type Client = C;
#[test]
fn test_parse_request() {
@@ -1033,8 +997,6 @@ mod tests {
#[test]
fn test_decoder_request() {
- use super::Decoder;
-
fn parse(s: &str) -> ParsedMessage {
let mut bytes = BytesMut::from(s);
Server::parse(&mut bytes, ParseContext {
@@ -1058,39 +1020,39 @@ mod tests {
assert_eq!(parse("\
GET / HTTP/1.1\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(0)));
+ ").decode, DecodedLength::ZERO);
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(0)));
+ ").decode, DecodedLength::ZERO);
// transfer-encoding: chunked
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
transfer-encoding: gzip, chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
transfer-encoding: gzip\r\n\
transfer-encoding: chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
// content-length
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
content-length: 10\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(10)));
+ ").decode, DecodedLength::new(10));
// transfer-encoding and content-length = chunked
assert_eq!(parse("\
@@ -1098,14 +1060,14 @@ mod tests {
content-length: 10\r\n\
transfer-encoding: chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\
content-length: 10\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
assert_eq!(parse("\
POST / HTTP/1.1\r\n\
@@ -1113,7 +1075,7 @@ mod tests {
content-length: 10\r\n\
transfer-encoding: chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
// multiple content-lengths of same value are fine
@@ -1122,7 +1084,7 @@ mod tests {
content-length: 10\r\n\
content-length: 10\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(10)));
+ ").decode, DecodedLength::new(10));
// multiple content-lengths with different values is an error
@@ -1153,7 +1115,7 @@ mod tests {
POST / HTTP/1.0\r\n\
content-length: 10\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(10)));
+ ").decode, DecodedLength::new(10));
// 1.0 doesn't understand chunked, so its an error
@@ -1171,6 +1133,16 @@ mod tests {
parse_with_method(s, Method::GET)
}
+ fn parse_ignores(s: &str) {
+ let mut bytes = BytesMut::from(s);
+ assert!(Client::parse(&mut bytes, ParseContext {
+ cached_headers: &mut None,
+ req_method: &mut Some(Method::GET),
+ })
+ .expect("parse ok")
+ .is_none())
+ }
+
fn parse_with_method(s: &str, m: Method) -> ParsedMessage {
let mut bytes = BytesMut::from(s);
Client::parse(&mut bytes, ParseContext {
@@ -1195,32 +1167,32 @@ mod tests {
assert_eq!(parse("\
HTTP/1.1 200 OK\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::eof()));
+ ").decode, DecodedLength::CLOSE_DELIMITED);
// 204 and 304 never have a body
assert_eq!(parse("\
HTTP/1.1 204 No Content\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(0)));
+ ").decode, DecodedLength::ZERO);
assert_eq!(parse("\
HTTP/1.1 304 Not Modified\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(0)));
+ ").decode, DecodedLength::ZERO);
// content-length
assert_eq!(parse("\
HTTP/1.1 200 OK\r\n\
content-length: 8\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(8)));
+ ").decode, DecodedLength::new(8));
assert_eq!(parse("\
HTTP/1.1 200 OK\r\n\
content-length: 8\r\n\
content-length: 8\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::length(8)));
+ ").decode, DecodedLength::new(8));
parse_err("\
HTTP/1.1 200 OK\r\n\
@@ -1235,7 +1207,7 @@ mod tests {
HTTP/1.1 200 OK\r\n\
transfer-encoding: chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
// transfer-encoding and content-length = chunked
assert_eq!(parse("\
@@ -1243,7 +1215,7 @@ mod tests {
content-length: 10\r\n\
transfer-encoding: chunked\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::chunked()));
+ ").decode, DecodedLength::CHUNKED);
// HEAD can have content-length, but not body
@@ -1251,44 +1223,54 @@ mod tests {
HTTP/1.1 200 OK\r\n\
content-length: 8\r\n\
\r\n\
- ", Method::HEAD).decode, Decode::Normal(Decoder::length(0)));
+ ", Method::HEAD).decode, DecodedLength::ZERO);
// CONNECT with 200 never has body
- assert_eq!(parse_with_method("\
- HTTP/1.1 200 OK\r\n\
- \r\n\
- ", Method::CONNECT).decode, Decode::Final(Decoder::length(0)));
+ {
+ let msg = parse_with_method("\
+ HTTP/1.1 200 OK\r\n\
+ \r\n\
+ ", Method::CONNECT);
+ assert_eq!(msg.decode, DecodedLength::ZERO);
+ assert!(!msg.keep_alive, "should be upgrade");
+ assert!(msg.wants_upgrade, "should be upgrade");
+ }
// CONNECT receiving non 200 can have a body
assert_eq!(parse_with_method("\
HTTP/1.1 400 Bad Request\r\n\
\r\n\
- ", Method::CONNECT).decode, Decode::Normal(Decoder::eof()));
+ ", Method::CONNECT).decode, DecodedLength::CLOSE_DELIMITED);
// 1xx status codes
- assert_eq!(parse("\
+ parse_ignores("\
HTTP/1.1 100 Continue\r\n\
\r\n\
- ").decode, Decode::Ignore);
+ ");
- assert_eq!(parse("\
+ parse_ignores("\
HTTP/1.1 103 Early Hints\r\n\
\r\n\
- ").decode, Decode::Ignore);
+ ");
// 101 upgrade not supported yet
- parse_err("\
- HTTP/1.1 101 Switching Protocols\r\n\
- \r\n\
- ");
+ {
+ let msg = parse("\
+ HTTP/1.1 101 Switching Protocols\r\n\
+ \r\n\
+ ");
+ assert_eq!(msg.decode, DecodedLength::ZERO);
+ assert!(!msg.keep_alive, "should be last");
+ assert!(msg.wants_upgrade, "should be upgrade");
+ }
// http/1.0
assert_eq!(parse("\
HTTP/1.0 200 OK\r\n\
\r\n\
- ").decode, Decode::Normal(Decoder::eof()));
+ ").decode, DecodedLength::CLOSE_DELIMITED);
// 1.0 doesn't understand chunked
parse_err("\
@@ -1320,28 +1302,11 @@ mod tests {
}
#[test]
- fn test_server_no_upgrades_connect_method() {
- let mut head = MessageHead::default();
-
- let mut vec = Vec::new();
- let err = Server::encode(Encode {
- head: &mut head,
- body: None,
- keep_alive: true,
- req_method: &mut Some(Method::CONNECT),
- title_case_headers: false,
- }, &mut vec).unwrap_err();
-
- assert!(err.is_user());
- assert_eq!(err.kind(), &::error::Kind::UnsupportedRequestMethod);
- }
-
- #[test]
- fn test_server_yes_upgrades_connect_method() {
+ fn test_server_encode_connect_method() {
let mut head = MessageHead::default();
let mut vec = Vec::new();
- let encoder = S::::encode(Encode {
+ let encoder = Server::encode(Encode {
head: &mut head,
body: None,
keep_alive: true,
@@ -1382,10 +1347,12 @@ mod tests {
b.bytes = len as u64;
b.iter(|| {
- let msg = Server::parse(&mut raw, ParseContext {
+ let mut msg = Server::parse(&mut raw, ParseContext {
cached_headers: &mut headers,
req_method: &mut None,
}).unwrap().unwrap();
+ ::test::black_box(&msg);
+ msg.head.headers.clear();
headers = Some(msg.head.headers);
restart(&mut raw, len);
});
@@ -1402,18 +1369,19 @@ mod tests {
#[cfg(feature = "nightly")]
#[bench]
fn bench_parse_short(b: &mut Bencher) {
- let mut raw = BytesMut::from(
- b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_vec()
- );
+ let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..];
+ let mut raw = BytesMut::from(s.to_vec());
let len = raw.len();
let mut headers = Some(HeaderMap::new());
b.bytes = len as u64;
b.iter(|| {
- let msg = Server::parse(&mut raw, ParseContext {
+ let mut msg = Server::parse(&mut raw, ParseContext {
cached_headers: &mut headers,
req_method: &mut None,
}).unwrap().unwrap();
+ ::test::black_box(&msg);
+ msg.head.headers.clear();
headers = Some(msg.head.headers);
restart(&mut raw, len);
});
@@ -1480,3 +1448,4 @@ mod tests {
})
}
}
+
diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs
index 6a1b65e84c..f8824518d2 100644
--- a/src/proto/h2/client.rs
+++ b/src/proto/h2/client.rs
@@ -8,6 +8,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use body::Payload;
use ::common::{Exec, Never};
use headers;
+use ::proto::Dispatched;
use super::{PipeToSendStream, SendBuf};
use ::{Body, Request, Response};
@@ -16,7 +17,7 @@ type ClientRx = ::client::dispatch::Receiver, Response>;
/// other handles to it have been dropped, so that it can shutdown.
type ConnDropRef = mpsc::Sender;
-pub struct Client
+pub(crate) struct Client
where
B: Payload,
{
@@ -54,7 +55,7 @@ where
T: AsyncRead + AsyncWrite + Send + 'static,
B: Payload + 'static,
{
- type Item = ();
+ type Item = Dispatched;
type Error = ::Error;
fn poll(&mut self) -> Poll {
@@ -153,7 +154,7 @@ where
Ok(Async::Ready(None)) |
Err(_) => {
trace!("client::dispatch::Sender dropped");
- return Ok(Async::Ready(()));
+ return Ok(Async::Ready(Dispatched::Shutdown));
}
}
},
diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs
index ad23c0c6ec..07400eadeb 100644
--- a/src/proto/h2/server.rs
+++ b/src/proto/h2/server.rs
@@ -7,6 +7,7 @@ use ::body::Payload;
use ::common::Exec;
use ::headers;
use ::service::Service;
+use ::proto::Dispatched;
use super::{PipeToSendStream, SendBuf};
use ::{Body, Response};
@@ -82,7 +83,7 @@ where
S::Future: Send + 'static,
B: Payload,
{
- type Item = ();
+ type Item = Dispatched;
type Error = ::Error;
fn poll(&mut self) -> Poll {
@@ -95,12 +96,13 @@ where
})
},
State::Serving(ref mut srv) => {
- return srv.poll_server(&mut self.service, &self.exec);
+ try_ready!(srv.poll_server(&mut self.service, &self.exec));
+ return Ok(Async::Ready(Dispatched::Shutdown));
}
State::Closed => {
// graceful_shutdown was called before handshaking finished,
// nothing to do here...
- return Ok(Async::Ready(()));
+ return Ok(Async::Ready(Dispatched::Shutdown));
}
};
self.state = next;
diff --git a/src/proto/mod.rs b/src/proto/mod.rs
index 131fb23209..7cd997f8b6 100644
--- a/src/proto/mod.rs
+++ b/src/proto/mod.rs
@@ -1,12 +1,12 @@
//! Pieces pertaining to the HTTP message protocol.
use http::{HeaderMap, Method, StatusCode, Uri, Version};
-pub(crate) use self::h1::{dispatch, Conn, ClientTransaction, ClientUpgradeTransaction, ServerTransaction};
+pub(crate) use self::h1::{dispatch, Conn, ServerTransaction};
+use self::body_length::DecodedLength;
pub(crate) mod h1;
pub(crate) mod h2;
-
/// An Incoming Message head. Includes request/status line, and headers.
#[derive(Clone, Debug, Default, PartialEq)]
pub struct MessageHead {
@@ -27,34 +27,6 @@ pub struct RequestLine(pub Method, pub Uri);
/// An incoming response message.
pub type ResponseHead = MessageHead;
-/*
-impl MessageHead {
- pub fn should_keep_alive(&self) -> bool {
- should_keep_alive(self.version, &self.headers)
- }
-
- pub fn expecting_continue(&self) -> bool {
- expecting_continue(self.version, &self.headers)
- }
-}
-
-/// Checks if a connection should be kept alive.
-#[inline]
-pub fn should_keep_alive(version: Version, headers: &HeaderMap) -> bool {
- if version == Version::HTTP_10 {
- headers::connection_keep_alive(headers)
- } else {
- !headers::connection_close(headers)
- }
-}
-
-/// Checks if a connection is expecting a `100 Continue` before sending its body.
-#[inline]
-pub fn expecting_continue(version: Version, headers: &HeaderMap) -> bool {
- version == Version::HTTP_11 && headers::expect_continue(headers)
-}
-*/
-
#[derive(Debug)]
pub enum BodyLength {
/// Content-Length
@@ -63,32 +35,72 @@ pub enum BodyLength {
Unknown,
}
-/*
-#[test]
-fn test_should_keep_alive() {
- let mut headers = HeaderMap::new();
-
- assert!(!should_keep_alive(Version::HTTP_10, &headers));
- assert!(should_keep_alive(Version::HTTP_11, &headers));
-
- headers.insert("connection", ::http::header::HeaderValue::from_static("close"));
- assert!(!should_keep_alive(Version::HTTP_10, &headers));
- assert!(!should_keep_alive(Version::HTTP_11, &headers));
-
- headers.insert("connection", ::http::header::HeaderValue::from_static("keep-alive"));
- assert!(should_keep_alive(Version::HTTP_10, &headers));
- assert!(should_keep_alive(Version::HTTP_11, &headers));
+/// Status of when an Disaptcher future completes.
+pub(crate) enum Dispatched {
+ /// Dispatcher completely shutdown connection.
+ Shutdown,
+ /// Dispatcher has pending upgrade, and so did not shutdown.
+ Upgrade(::upgrade::Pending),
}
-#[test]
-fn test_expecting_continue() {
- let mut headers = HeaderMap::new();
-
- assert!(!expecting_continue(Version::HTTP_10, &headers));
- assert!(!expecting_continue(Version::HTTP_11, &headers));
+/// A separate module to encapsulate the invariants of the DecodedLength type.
+mod body_length {
+ use std::fmt;
+
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
+ pub(crate) struct DecodedLength(u64);
+
+ const MAX_LEN: u64 = ::std::u64::MAX - 2;
+
+ impl DecodedLength {
+ pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX);
+ pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1);
+ pub(crate) const ZERO: DecodedLength = DecodedLength(0);
+
+ #[cfg(test)]
+ pub(crate) fn new(len: u64) -> Self {
+ debug_assert!(len <= MAX_LEN);
+ DecodedLength(len)
+ }
+
+ /// Takes the length as a content-length without other checks.
+ ///
+ /// Should only be called if previously confirmed this isn't
+ /// CLOSE_DELIMITED or CHUNKED.
+ #[inline]
+ pub(crate) fn danger_len(self) -> u64 {
+ debug_assert!(self.0 < Self::CHUNKED.0);
+ self.0
+ }
+
+ /// Converts to an Option representing a Known or Unknown length.
+ pub(crate) fn into_opt(self) -> Option {
+ match self {
+ DecodedLength::CHUNKED |
+ DecodedLength::CLOSE_DELIMITED => None,
+ DecodedLength(known) => Some(known)
+ }
+ }
+
+ /// Checks the `u64` is within the maximum allowed for content-length.
+ pub(crate) fn checked_new(len: u64) -> Result {
+ if len <= MAX_LEN {
+ Ok(DecodedLength(len))
+ } else {
+ warn!("content-length bigger than maximum: {} > {}", len, MAX_LEN);
+ Err(::error::Parse::TooLarge)
+ }
+ }
+ }
- headers.insert("expect", ::http::header::HeaderValue::from_static("100-continue"));
- assert!(!expecting_continue(Version::HTTP_10, &headers));
- assert!(expecting_continue(Version::HTTP_11, &headers));
+ impl fmt::Display for DecodedLength {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self {
+ DecodedLength::CLOSE_DELIMITED => f.write_str("close-delimited"),
+ DecodedLength::CHUNKED => f.write_str("chunked encoding"),
+ DecodedLength::ZERO => f.write_str("empty"),
+ DecodedLength(n) => write!(f, "content-length ({} bytes)", n),
+ }
+ }
+ }
}
-*/
diff --git a/src/server/conn.rs b/src/server/conn.rs
index 45e0d277d3..5a4d24bee7 100644
--- a/src/server/conn.rs
+++ b/src/server/conn.rs
@@ -9,22 +9,26 @@
//! higher-level [Server](super) API.
use std::fmt;
+use std::mem;
#[cfg(feature = "runtime")] use std::net::SocketAddr;
use std::sync::Arc;
#[cfg(feature = "runtime")] use std::time::Duration;
-use super::rewind::Rewind;
use bytes::Bytes;
use futures::{Async, Future, Poll, Stream};
use futures::future::{Either, Executor};
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "runtime")] use tokio_reactor::Handle;
+use body::{Body, Payload};
use common::Exec;
+use common::io::Rewind;
+use error::{Kind, Parse};
use proto;
-use body::{Body, Payload};
use service::{NewService, Service};
-use error::{Kind, Parse};
+use upgrade::Upgraded;
+
+use self::upgrades::UpgradeableConnection;
#[cfg(feature = "runtime")] pub use super::tcp::AddrIncoming;
@@ -109,6 +113,8 @@ where
fallback: bool,
}
+
+
/// Deconstructed parts of a `Connection`.
///
/// This allows taking apart a `Connection` at a later time, in order to
@@ -429,7 +435,7 @@ where
loop {
let polled = match *self.conn.as_mut().unwrap() {
Either::A(ref mut h1) => h1.poll_without_shutdown(),
- Either::B(ref mut h2) => h2.poll(),
+ Either::B(ref mut h2) => return h2.poll().map(|x| x.map(|_| ())),
};
match polled {
Ok(x) => return Ok(x),
@@ -466,6 +472,18 @@ where
debug_assert!(self.conn.is_none());
self.conn = Some(Either::B(h2));
}
+
+ /// Enable this connection to support higher-level HTTP upgrades.
+ ///
+ /// See [the `upgrade` module](::upgrade) for more.
+ pub fn with_upgrades(self) -> UpgradeableConnection
+ where
+ I: Send,
+ {
+ UpgradeableConnection {
+ inner: self,
+ }
+ }
}
impl Future for Connection
@@ -482,7 +500,15 @@ where
fn poll(&mut self) -> Poll {
loop {
match self.conn.poll() {
- Ok(x) => return Ok(x.map(|o| o.unwrap_or_else(|| ()))),
+ Ok(x) => return Ok(x.map(|opt| {
+ if let Some(proto::Dispatched::Upgrade(pending)) = opt {
+ // With no `Send` bound on `I`, we can't try to do
+ // upgrades here. In case a user was trying to use
+ // `Body::on_upgrade` with this API, send a special
+ // error letting them know about that.
+ pending.manual();
+ }
+ })),
Err(e) => {
debug!("error polling connection protocol: {}", e);
match *e.kind() {
@@ -507,7 +533,6 @@ where
.finish()
}
}
-
// ===== impl Serve =====
impl Serve {
@@ -614,7 +639,7 @@ where
let fut = connecting
.map_err(::Error::new_user_new_service)
// flatten basically
- .and_then(|conn| conn)
+ .and_then(|conn| conn.with_upgrades())
.map_err(|err| debug!("conn error: {}", err));
self.serve.protocol.exec.execute(fut);
} else {
@@ -623,3 +648,82 @@ where
}
}
}
+
+mod upgrades {
+ use super::*;
+
+ // A future binding a connection with a Service with Upgrade support.
+ //
+ // This type is unnameable outside the crate, and so basically just an
+ // `impl Future`, without requiring Rust 1.26.
+ #[must_use = "futures do nothing unless polled"]
+ #[allow(missing_debug_implementations)]
+ pub struct UpgradeableConnection
+ where
+ S: Service,
+ {
+ pub(super) inner: Connection,
+ }
+
+ impl UpgradeableConnection
+ where
+ S: Service + 'static,
+ S::Error: Into>,
+ S::Future: Send,
+ I: AsyncRead + AsyncWrite + Send + 'static,
+ B: Payload + 'static,
+ {
+ /// Start a graceful shutdown process for this connection.
+ ///
+ /// This `Connection` should continue to be polled until shutdown
+ /// can finish.
+ pub fn graceful_shutdown(&mut self) {
+ self.inner.graceful_shutdown()
+ }
+ }
+
+ impl Future for UpgradeableConnection
+ where
+ S: Service + 'static,
+ S::Error: Into>,
+ S::Future: Send,
+ I: AsyncRead + AsyncWrite + Send + 'static,
+ B: Payload + 'static,
+ {
+ type Item = ();
+ type Error = ::Error;
+
+ fn poll(&mut self) -> Poll {
+ loop {
+ match self.inner.conn.poll() {
+ Ok(Async::NotReady) => return Ok(Async::NotReady),
+ Ok(Async::Ready(Some(proto::Dispatched::Shutdown))) |
+ Ok(Async::Ready(None)) => {
+ return Ok(Async::Ready(()));
+ },
+ Ok(Async::Ready(Some(proto::Dispatched::Upgrade(pending)))) => {
+ let h1 = match mem::replace(&mut self.inner.conn, None) {
+ Some(Either::A(h1)) => h1,
+ _ => unreachable!("Upgrade expects h1"),
+ };
+
+ let (io, buf, _) = h1.into_inner();
+ pending.fulfill(Upgraded::new(Box::new(io), buf));
+ return Ok(Async::Ready(()));
+ },
+ Err(e) => {
+ debug!("error polling connection protocol: {}", e);
+ match *e.kind() {
+ Kind::Parse(Parse::VersionH2) if self.inner.fallback => {
+ self.inner.upgrade_h2();
+ continue;
+ }
+ _ => return Err(e),
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
diff --git a/src/server/mod.rs b/src/server/mod.rs
index 33898c587d..832912fc68 100644
--- a/src/server/mod.rs
+++ b/src/server/mod.rs
@@ -50,7 +50,6 @@
pub mod conn;
#[cfg(feature = "runtime")] mod tcp;
-mod rewind;
use std::fmt;
#[cfg(feature = "runtime")] use std::net::SocketAddr;
diff --git a/src/upgrade.rs b/src/upgrade.rs
new file mode 100644
index 0000000000..db592b1a42
--- /dev/null
+++ b/src/upgrade.rs
@@ -0,0 +1,254 @@
+//! HTTP Upgrades
+//!
+//! See [this example][example] showing how upgrades work with both
+//! Clients and Servers.
+//!
+//! [example]: https://github.com/hyperium/hyper/master/examples/upgrades.rs
+
+use std::any::TypeId;
+use std::error::Error as StdError;
+use std::fmt;
+use std::io::{self, Read, Write};
+
+use bytes::{Buf, BufMut, Bytes};
+use futures::{Async, Future, Poll};
+use futures::sync::oneshot;
+use tokio_io::{AsyncRead, AsyncWrite};
+
+use common::io::Rewind;
+
+/// An upgraded HTTP connection.
+///
+/// This type holds a trait object internally of the original IO that
+/// was used to speak HTTP before the upgrade. It can be used directly
+/// as a `Read` or `Write` for convenience.
+///
+/// Alternatively, if the exact type is known, this can be deconstructed
+/// into its parts.
+pub struct Upgraded {
+ io: Rewind>,
+}
+
+/// A future for a possible HTTP upgrade.
+///
+/// If no upgrade was available, or it doesn't succeed, yields an `Error`.
+pub struct OnUpgrade {
+ rx: Option>>,
+}
+
+/// The deconstructed parts of an [`Upgraded`](Upgraded) type.
+///
+/// Includes the original IO type, and a read buffer of bytes that the
+/// HTTP state machine may have already read before completing an upgrade.
+#[derive(Debug)]
+pub struct Parts {
+ /// The original IO object used before the upgrade.
+ pub io: T,
+ /// A buffer of bytes that have been read but not processed as HTTP.
+ ///
+ /// For instance, if the `Connection` is used for an HTTP upgrade request,
+ /// it is possible the server sent back the first bytes of the new protocol
+ /// along with the response upgrade.
+ ///
+ /// You will want to check for any existing bytes if you plan to continue
+ /// communicating on the IO object.
+ pub read_buf: Bytes,
+ _inner: (),
+}
+
+pub(crate) struct Pending {
+ tx: oneshot::Sender<::Result>
+}
+
+/// Error cause returned when an upgrade was expected but canceled
+/// for whatever reason.
+///
+/// This likely means the actual `Conn` future wasn't polled and upgraded.
+#[derive(Debug)]
+struct UpgradeExpected(());
+
+pub(crate) fn pending() -> (Pending, OnUpgrade) {
+ let (tx, rx) = oneshot::channel();
+ (
+ Pending {
+ tx,
+ },
+ OnUpgrade {
+ rx: Some(rx),
+ },
+ )
+}
+
+pub(crate) trait Io: AsyncRead + AsyncWrite + 'static {
+ fn __hyper_type_id(&self) -> TypeId {
+ TypeId::of::()
+ }
+}
+
+impl Io + Send {
+ fn __hyper_is(&self) -> bool {
+ let t = TypeId::of::();
+ self.__hyper_type_id() == t
+ }
+
+ fn __hyper_downcast(self: Box) -> Result, Box> {
+ if self.__hyper_is::() {
+ // Taken from `std::error::Error::downcast()`.
+ unsafe {
+ let raw: *mut Io = Box::into_raw(self);
+ Ok(Box::from_raw(raw as *mut T))
+ }
+ } else {
+ Err(self)
+ }
+ }
+}
+
+impl Io for T {}
+
+// ===== impl Upgraded =====
+
+impl Upgraded {
+ pub(crate) fn new(io: Box, read_buf: Bytes) -> Self {
+ Upgraded {
+ io: Rewind::new_buffered(io, read_buf),
+ }
+ }
+
+ /// Tries to downcast the internal trait object to the type passed.
+ ///
+ /// On success, returns the downcasted parts. On error, returns the
+ /// `Upgraded` back.
+ pub fn downcast(self) -> Result, Self> {
+ let (io, buf) = self.io.into_inner();
+ match io.__hyper_downcast() {
+ Ok(t) => Ok(Parts {
+ io: *t,
+ read_buf: buf,
+ _inner: (),
+ }),
+ Err(io) => Err(Upgraded {
+ io: Rewind::new_buffered(io, buf),
+ })
+ }
+ }
+}
+
+impl Read for Upgraded {
+ #[inline]
+ fn read(&mut self, buf: &mut [u8]) -> io::Result {
+ self.io.read(buf)
+ }
+}
+
+impl Write for Upgraded {
+ #[inline]
+ fn write(&mut self, buf: &[u8]) -> io::Result {
+ self.io.write(buf)
+ }
+
+ #[inline]
+ fn flush(&mut self) -> io::Result<()> {
+ self.io.flush()
+ }
+}
+
+impl AsyncRead for Upgraded {
+ #[inline]
+ unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
+ self.io.prepare_uninitialized_buffer(buf)
+ }
+
+ #[inline]
+ fn read_buf(&mut self, buf: &mut B) -> Poll {
+ self.io.read_buf(buf)
+ }
+}
+
+impl AsyncWrite for Upgraded {
+ #[inline]
+ fn shutdown(&mut self) -> Poll<(), io::Error> {
+ AsyncWrite::shutdown(&mut self.io)
+ }
+
+ #[inline]
+ fn write_buf(&mut self, buf: &mut B) -> Poll {
+ self.io.write_buf(buf)
+ }
+}
+
+impl fmt::Debug for Upgraded {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_struct("Upgraded")
+ .finish()
+ }
+}
+
+// ===== impl OnUpgrade =====
+
+impl OnUpgrade {
+ pub(crate) fn none() -> Self {
+ OnUpgrade {
+ rx: None,
+ }
+ }
+
+ pub(crate) fn is_none(&self) -> bool {
+ self.rx.is_none()
+ }
+}
+
+impl Future for OnUpgrade {
+ type Item = Upgraded;
+ type Error = ::Error;
+
+ fn poll(&mut self) -> Poll {
+ match self.rx {
+ Some(ref mut rx) => match rx.poll() {
+ Ok(Async::NotReady) => Ok(Async::NotReady),
+ Ok(Async::Ready(Ok(upgraded))) => Ok(Async::Ready(upgraded)),
+ Ok(Async::Ready(Err(err))) => Err(err),
+ Err(_oneshot_canceled) => Err(
+ ::Error::new_canceled(Some(UpgradeExpected(())))
+ ),
+ },
+ None => Err(::Error::new_user_no_upgrade()),
+ }
+ }
+}
+
+impl fmt::Debug for OnUpgrade {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_struct("OnUpgrade")
+ .finish()
+ }
+}
+
+// ===== impl Pending =====
+
+impl Pending {
+ pub(crate) fn fulfill(self, upgraded: Upgraded) {
+ let _ = self.tx.send(Ok(upgraded));
+ }
+
+ /// Don't fulfill the pending Upgrade, but instead signal that
+ /// upgrades are handled manually.
+ pub(crate) fn manual(self) {
+ let _ = self.tx.send(Err(::Error::new_user_manual_upgrade()));
+ }
+}
+
+// ===== impl UpgradeExpected =====
+
+impl fmt::Display for UpgradeExpected {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str(self.description())
+ }
+}
+
+impl StdError for UpgradeExpected {
+ fn description(&self) -> &str {
+ "upgrade expected but not completed"
+ }
+}
+
diff --git a/tests/client.rs b/tests/client.rs
index 5ea5f5a0a7..119533f4ea 100644
--- a/tests/client.rs
+++ b/tests/client.rs
@@ -596,33 +596,6 @@ test! {
body: None,
}
-
-test! {
- name: client_101_upgrade,
-
- server:
- expected: "\
- GET /upgrade HTTP/1.1\r\n\
- host: {addr}\r\n\
- \r\n\
- ",
- reply: "\
- HTTP/1.1 101 Switching Protocols\r\n\
- Upgrade: websocket\r\n\
- Connection: upgrade\r\n\
- \r\n\
- ",
-
- client:
- request:
- method: GET,
- url: "http://{addr}/upgrade",
- headers: {},
- body: None,
- error: |err| err.to_string() == "unsupported protocol upgrade",
-
-}
-
test! {
name: client_connect_method,
@@ -1277,6 +1250,68 @@ mod dispatch_impl {
res.join(rx).map(|r| r.0).wait().unwrap();
}
+ #[test]
+ fn client_upgrade() {
+ use tokio_io::io::{read_to_end, write_all};
+ let _ = pretty_env_logger::try_init();
+ let server = TcpListener::bind("127.0.0.1:0").unwrap();
+ let addr = server.local_addr().unwrap();
+ let runtime = Runtime::new().unwrap();
+ let handle = runtime.reactor();
+
+ let connector = DebugConnector::new(&handle);
+
+ let client = Client::builder()
+ .executor(runtime.executor())
+ .build(connector);
+
+ let (tx1, rx1) = oneshot::channel();
+ thread::spawn(move || {
+ let mut sock = server.accept().unwrap().0;
+ sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
+ sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
+ let mut buf = [0; 4096];
+ sock.read(&mut buf).expect("read 1");
+ sock.write_all(b"\
+ HTTP/1.1 101 Switching Protocols\r\n\
+ Upgrade: foobar\r\n\
+ \r\n\
+ foobar=ready\
+ ").unwrap();
+ let _ = tx1.send(());
+
+ let n = sock.read(&mut buf).expect("read 2");
+ assert_eq!(&buf[..n], b"foo=bar");
+ sock.write_all(b"bar=foo").expect("write 2");
+ });
+
+ let rx = rx1.expect("thread panicked");
+
+ let req = Request::builder()
+ .method("GET")
+ .uri(&*format!("http://{}/up", addr))
+ .body(Body::empty())
+ .unwrap();
+
+ let res = client.request(req);
+ let res = res.join(rx).map(|r| r.0).wait().unwrap();
+
+ assert_eq!(res.status(), 101);
+ let upgraded = res
+ .into_body()
+ .on_upgrade()
+ .wait()
+ .expect("on_upgrade");
+
+ let parts = upgraded.downcast::().unwrap();
+ assert_eq!(s(&parts.read_buf), "foobar=ready");
+
+ let io = parts.io;
+ let io = write_all(io, b"foo=bar").wait().unwrap().0;
+ let vec = read_to_end(io, vec![]).wait().unwrap().1;
+ assert_eq!(vec, b"bar=foo");
+ }
+
struct DebugConnector {
http: HttpConnector,
diff --git a/tests/server.rs b/tests/server.rs
index 9d4ece809d..7979b65e5f 100644
--- a/tests/server.rs
+++ b/tests/server.rs
@@ -24,7 +24,7 @@ use futures::future::{self, FutureResult, Either};
use futures::sync::oneshot;
use futures_timer::Delay;
use http::header::{HeaderName, HeaderValue};
-use tokio::net::TcpListener;
+use tokio::net::{TcpListener, TcpStream as TkTcpStream};
use tokio::runtime::Runtime;
use tokio::reactor::Handle;
use tokio_io::{AsyncRead, AsyncWrite};
@@ -33,7 +33,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use hyper::{Body, Request, Response, StatusCode};
use hyper::client::Client;
use hyper::server::conn::Http;
-use hyper::service::{service_fn, Service};
+use hyper::service::{service_fn, service_fn_ok, Service};
fn tcp_bind(addr: &SocketAddr, handle: &Handle) -> ::tokio::io::Result {
let std_listener = StdTcpListener::bind(addr).unwrap();
@@ -1270,6 +1270,142 @@ fn http_connect() {
assert_eq!(vec, b"bar=foo");
}
+#[test]
+fn upgrades_new() {
+ use tokio_io::io::{read_to_end, write_all};
+ let _ = pretty_env_logger::try_init();
+ let mut rt = Runtime::new().unwrap();
+ let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap(), &rt.reactor()).unwrap();
+ let addr = listener.local_addr().unwrap();
+ let (read_101_tx, read_101_rx) = oneshot::channel();
+
+ thread::spawn(move || {
+ let mut tcp = connect(&addr);
+ tcp.write_all(b"\
+ GET / HTTP/1.1\r\n\
+ Upgrade: foobar\r\n\
+ Connection: upgrade\r\n\
+ \r\n\
+ eagerly optimistic\
+ ").expect("write 1");
+ let mut buf = [0; 256];
+ tcp.read(&mut buf).expect("read 1");
+
+ let expected = "HTTP/1.1 101 Switching Protocols\r\n";
+ assert_eq!(s(&buf[..expected.len()]), expected);
+ let _ = read_101_tx.send(());
+
+ let n = tcp.read(&mut buf).expect("read 2");
+ assert_eq!(s(&buf[..n]), "foo=bar");
+ tcp.write_all(b"bar=foo").expect("write 2");
+ });
+
+ let (upgrades_tx, upgrades_rx) = mpsc::channel();
+ let svc = service_fn_ok(move |req: Request| {
+ let on_upgrade = req
+ .into_body()
+ .on_upgrade();
+ let _ = upgrades_tx.send(on_upgrade);
+ Response::builder()
+ .status(101)
+ .header("upgrade", "foobar")
+ .body(hyper::Body::empty())
+ .unwrap()
+ });
+
+ let fut = listener.incoming()
+ .into_future()
+ .map_err(|_| -> hyper::Error { unreachable!() })
+ .and_then(move |(item, _incoming)| {
+ let socket = item.unwrap();
+ Http::new()
+ .serve_connection(socket, svc)
+ .with_upgrades()
+ });
+
+ rt.block_on(fut).unwrap();
+ let on_upgrade = upgrades_rx.recv().unwrap();
+
+ // wait so that we don't write until other side saw 101 response
+ rt.block_on(read_101_rx).unwrap();
+
+ let upgraded = rt.block_on(on_upgrade).unwrap();
+ let parts = upgraded.downcast::().unwrap();
+ let io = parts.io;
+ assert_eq!(parts.read_buf, "eagerly optimistic");
+
+ let io = rt.block_on(write_all(io, b"foo=bar")).unwrap().0;
+ let vec = rt.block_on(read_to_end(io, vec![])).unwrap().1;
+ assert_eq!(s(&vec), "bar=foo");
+}
+
+#[test]
+fn http_connect_new() {
+ use tokio_io::io::{read_to_end, write_all};
+ let _ = pretty_env_logger::try_init();
+ let mut rt = Runtime::new().unwrap();
+ let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap(), &rt.reactor()).unwrap();
+ let addr = listener.local_addr().unwrap();
+ let (read_200_tx, read_200_rx) = oneshot::channel();
+
+ thread::spawn(move || {
+ let mut tcp = connect(&addr);
+ tcp.write_all(b"\
+ CONNECT localhost HTTP/1.1\r\n\
+ \r\n\
+ eagerly optimistic\
+ ").expect("write 1");
+ let mut buf = [0; 256];
+ tcp.read(&mut buf).expect("read 1");
+
+ let expected = "HTTP/1.1 200 OK\r\n";
+ assert_eq!(s(&buf[..expected.len()]), expected);
+ let _ = read_200_tx.send(());
+
+ let n = tcp.read(&mut buf).expect("read 2");
+ assert_eq!(s(&buf[..n]), "foo=bar");
+ tcp.write_all(b"bar=foo").expect("write 2");
+ });
+
+ let (upgrades_tx, upgrades_rx) = mpsc::channel();
+ let svc = service_fn_ok(move |req: Request| {
+ let on_upgrade = req
+ .into_body()
+ .on_upgrade();
+ let _ = upgrades_tx.send(on_upgrade);
+ Response::builder()
+ .status(200)
+ .body(hyper::Body::empty())
+ .unwrap()
+ });
+
+ let fut = listener.incoming()
+ .into_future()
+ .map_err(|_| -> hyper::Error { unreachable!() })
+ .and_then(move |(item, _incoming)| {
+ let socket = item.unwrap();
+ Http::new()
+ .serve_connection(socket, svc)
+ .with_upgrades()
+ });
+
+ rt.block_on(fut).unwrap();
+ let on_upgrade = upgrades_rx.recv().unwrap();
+
+ // wait so that we don't write until other side saw 200
+ rt.block_on(read_200_rx).unwrap();
+
+ let upgraded = rt.block_on(on_upgrade).unwrap();
+ let parts = upgraded.downcast::().unwrap();
+ let io = parts.io;
+ assert_eq!(parts.read_buf, "eagerly optimistic");
+
+ let io = rt.block_on(write_all(io, b"foo=bar")).unwrap().0;
+ let vec = rt.block_on(read_to_end(io, vec![])).unwrap().1;
+ assert_eq!(s(&vec), "bar=foo");
+}
+
+
#[test]
fn parse_errors_send_4xx_response() {
let runtime = Runtime::new().unwrap();