diff --git a/Cargo.toml b/Cargo.toml index dac533117a..81b49cbfdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ iovec = "0.1" log = "0.4" net2 = { version = "0.2.32", optional = true } time = "0.1" -tokio = { version = "0.1.5", optional = true } +tokio = { version = "0.1.7", optional = true } tokio-executor = { version = "0.1.0", optional = true } tokio-io = "0.1" tokio-reactor = { version = "0.1", optional = true } @@ -101,6 +101,12 @@ name = "send_file" path = "examples/send_file.rs" required-features = ["runtime"] +[[example]] +name = "upgrades" +path = "examples/upgrades.rs" +required-features = ["runtime"] + + [[example]] name = "web_api" path = "examples/web_api.rs" diff --git a/examples/README.md b/examples/README.md index 90a425235f..b877ef84a8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -16,4 +16,6 @@ Run examples with `cargo run --example example_name`. * [`send_file`](send_file.rs) - A server that sends back content of files using tokio_fs to read the files asynchronously. +* [`upgrades`](upgrades.rs) - A server and client demonstrating how to do HTTP upgrades (such as WebSockets or `CONNECT` tunneling). + * [`web_api`](web_api.rs) - A server consisting in a service that returns incoming POST request's content in the response in uppercase and a service that call that call the first service and includes the first service response in its own response. diff --git a/examples/upgrades.rs b/examples/upgrades.rs new file mode 100644 index 0000000000..3cee4940a9 --- /dev/null +++ b/examples/upgrades.rs @@ -0,0 +1,127 @@ +// Note: `hyper::upgrade` docs link to this upgrade. +extern crate futures; +extern crate hyper; +extern crate tokio; + +use std::str; + +use futures::sync::oneshot; + +use hyper::{Body, Client, Request, Response, Server, StatusCode}; +use hyper::header::{UPGRADE, HeaderValue}; +use hyper::rt::{self, Future}; +use hyper::service::service_fn_ok; + +/// Our server HTTP handler to initiate HTTP upgrades. +fn server_upgrade(req: Request) -> Response { + let mut res = Response::new(Body::empty()); + + // Send a 400 to any request that doesn't have + // an `Upgrade` header. + if !req.headers().contains_key(UPGRADE) { + *res.status_mut() = StatusCode::BAD_REQUEST; + return res; + } + + // Setup a future that will eventually receive the upgraded + // connection and talk a new protocol, and spawn the future + // into the runtime. + // + // Note: This can't possibly be fulfilled until the 101 response + // is returned below, so it's better to spawn this future instead + // waiting for it to complete to then return a response. + let on_upgrade = req + .into_body() + .on_upgrade() + .map_err(|err| eprintln!("upgrade error: {}", err)) + .and_then(|upgraded| { + // We have an upgraded connection that we can read and + // write on directly. + // + // Since we completely control this example, we know exactly + // how many bytes the client will write, so just read exact... + tokio::io::read_exact(upgraded, vec![0; 7]) + .and_then(|(upgraded, vec)| { + println!("server[foobar] recv: {:?}", str::from_utf8(&vec)); + + // And now write back the server 'foobar' protocol's + // response... + tokio::io::write_all(upgraded, b"bar=foo") + }) + .map(|_| println!("server[foobar] sent")) + .map_err(|e| eprintln!("server foobar io error: {}", e)) + }); + + rt::spawn(on_upgrade); + + + // Now return a 101 Response saying we agree to the upgrade to some + // made-up 'foobar' protocol. + *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + res.headers_mut().insert(UPGRADE, HeaderValue::from_static("foobar")); + res +} + +fn main() { + // For this example, we just make a server and our own client to talk to + // it, so the exact port isn't important. Instead, let the OS give us an + // unused port. + let addr = ([127, 0, 0, 1], 0).into(); + + let server = Server::bind(&addr) + .serve(|| service_fn_ok(server_upgrade)); + + // We need the assigned address for the client to send it messages. + let addr = server.local_addr(); + + + // For this example, a oneshot is used to signal that after 1 request, + // the server should be shutdown. + let (tx, rx) = oneshot::channel(); + + let server = server + .map_err(|e| eprintln!("server error: {}", e)) + .select2(rx) + .then(|_| Ok(())); + + rt::run(rt::lazy(move || { + rt::spawn(server); + + let req = Request::builder() + .uri(format!("http://{}/", addr)) + .header(UPGRADE, "foobar") + .body(Body::empty()) + .unwrap(); + + Client::new() + .request(req) + .and_then(|res| { + if res.status() != StatusCode::SWITCHING_PROTOCOLS { + panic!("Our server didn't upgrade: {}", res.status()); + } + + res + .into_body() + .on_upgrade() + }) + .map_err(|e| eprintln!("client error: {}", e)) + .and_then(|upgraded| { + // We've gotten an upgraded connection that we can read + // and write directly on. Let's start out 'foobar' protocol. + tokio::io::write_all(upgraded, b"foo=bar") + .and_then(|(upgraded, _)| { + println!("client[foobar] sent"); + tokio::io::read_to_end(upgraded, Vec::new()) + }) + .map(|(_upgraded, vec)| { + println!("client[foobar] recv: {:?}", str::from_utf8(&vec)); + + + // Complete the oneshot so that the server stops + // listening and the process can close down. + let _ = tx.send(()); + }) + .map_err(|e| eprintln!("client foobar io error: {}", e)) + }) + })); +} diff --git a/src/body/body.rs b/src/body/body.rs index 7be740bc9f..09f36b9b45 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -10,6 +10,7 @@ use http::HeaderMap; use common::Never; use super::{Chunk, Payload}; use super::internal::{FullDataArg, FullDataRet}; +use upgrade::OnUpgrade; type BodySender = mpsc::Sender>; @@ -21,15 +22,9 @@ type BodySender = mpsc::Sender>; #[must_use = "streams do nothing unless polled"] pub struct Body { kind: Kind, - /// Allow the client to pass a future to delay the `Body` from returning - /// EOF. This allows the `Client` to try to put the idle connection - /// back into the pool before the body is "finished". - /// - /// The reason for this is so that creating a new request after finishing - /// streaming the body of a response could sometimes result in creating - /// a brand new connection, since the pool didn't know about the idle - /// connection yet. - delayed_eof: Option, + /// Keep the extra bits in an `Option>`, so that + /// Body stays small in the common case (no extras needed). + extra: Option>, } enum Kind { @@ -43,6 +38,19 @@ enum Kind { Wrapped(Box> + Send>), } +struct Extra { + /// Allow the client to pass a future to delay the `Body` from returning + /// EOF. This allows the `Client` to try to put the idle connection + /// back into the pool before the body is "finished". + /// + /// The reason for this is so that creating a new request after finishing + /// streaming the body of a response could sometimes result in creating + /// a brand new connection, since the pool didn't know about the idle + /// connection yet. + delayed_eof: Option, + on_upgrade: OnUpgrade, +} + type DelayEofUntil = oneshot::Receiver; enum DelayEof { @@ -89,7 +97,6 @@ impl Body { Self::new_channel(None) } - #[inline] pub(crate) fn new_channel(content_length: Option) -> (Sender, Body) { let (tx, rx) = mpsc::channel(0); let (abort_tx, abort_rx) = oneshot::channel(); @@ -139,10 +146,20 @@ impl Body { Body::new(Kind::Wrapped(Box::new(mapped))) } + /// Converts this `Body` into a `Future` of a pending HTTP upgrade. + /// + /// See [the `upgrade` module](::upgrade) for more. + pub fn on_upgrade(self) -> OnUpgrade { + self + .extra + .map(|ex| ex.on_upgrade) + .unwrap_or_else(OnUpgrade::none) + } + fn new(kind: Kind) -> Body { Body { kind: kind, - delayed_eof: None, + extra: None, } } @@ -150,23 +167,46 @@ impl Body { Body::new(Kind::H2(recv)) } + pub(crate) fn set_on_upgrade(&mut self, upgrade: OnUpgrade) { + debug_assert!(!upgrade.is_none(), "set_on_upgrade with empty upgrade"); + let extra = self.extra_mut(); + debug_assert!(extra.on_upgrade.is_none(), "set_on_upgrade twice"); + extra.on_upgrade = upgrade; + } + pub(crate) fn delayed_eof(&mut self, fut: DelayEofUntil) { - self.delayed_eof = Some(DelayEof::NotEof(fut)); + self.extra_mut().delayed_eof = Some(DelayEof::NotEof(fut)); + } + + fn take_delayed_eof(&mut self) -> Option { + self + .extra + .as_mut() + .and_then(|extra| extra.delayed_eof.take()) + } + + fn extra_mut(&mut self) -> &mut Extra { + self + .extra + .get_or_insert_with(|| Box::new(Extra { + delayed_eof: None, + on_upgrade: OnUpgrade::none(), + })) } fn poll_eof(&mut self) -> Poll, ::Error> { - match self.delayed_eof.take() { + match self.take_delayed_eof() { Some(DelayEof::NotEof(mut delay)) => { match self.poll_inner() { ok @ Ok(Async::Ready(Some(..))) | ok @ Ok(Async::NotReady) => { - self.delayed_eof = Some(DelayEof::NotEof(delay)); + self.extra_mut().delayed_eof = Some(DelayEof::NotEof(delay)); ok }, Ok(Async::Ready(None)) => match delay.poll() { Ok(Async::Ready(never)) => match never {}, Ok(Async::NotReady) => { - self.delayed_eof = Some(DelayEof::Eof(delay)); + self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay)); Ok(Async::NotReady) }, Err(_done) => { @@ -180,7 +220,7 @@ impl Body { match delay.poll() { Ok(Async::Ready(never)) => match never {}, Ok(Async::NotReady) => { - self.delayed_eof = Some(DelayEof::Eof(delay)); + self.extra_mut().delayed_eof = Some(DelayEof::Eof(delay)); Ok(Async::NotReady) }, Err(_done) => { diff --git a/src/client/conn.rs b/src/client/conn.rs index e35d435c6e..e136fda350 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -9,6 +9,7 @@ //! higher-level [Client](super) API. use std::fmt; use std::marker::PhantomData; +use std::mem; use bytes::Bytes; use futures::{Async, Future, Poll}; @@ -17,9 +18,21 @@ use tokio_io::{AsyncRead, AsyncWrite}; use body::Payload; use common::Exec; +use upgrade::Upgraded; use proto; use super::dispatch; -use {Body, Request, Response, StatusCode}; +use {Body, Request, Response}; + +type Http1Dispatcher = proto::dispatch::Dispatcher< + proto::dispatch::Client, + B, + T, + R, +>; +type ConnEither = Either< + Http1Dispatcher, + proto::h2::Client, +>; /// Returns a `Handshake` future over some IO. /// @@ -48,15 +61,7 @@ where T: AsyncRead + AsyncWrite + Send + 'static, B: Payload + 'static, { - inner: Either< - proto::dispatch::Dispatcher< - proto::dispatch::Client, - B, - T, - proto::ClientUpgradeTransaction, - >, - proto::h2::Client, - >, + inner: Option>, } @@ -76,7 +81,9 @@ pub struct Builder { /// If successful, yields a `(SendRequest, Connection)` pair. #[must_use = "futures do nothing unless polled"] pub struct Handshake { - inner: HandshakeInner, + builder: Builder, + io: Option, + _marker: PhantomData, } /// A future returned by `SendRequest::send_request`. @@ -112,27 +119,18 @@ pub struct Parts { // ========== internal client api /// A `Future` for when `SendRequest::poll_ready()` is ready. +#[must_use = "futures do nothing unless polled"] pub(super) struct WhenReady { tx: Option>, } // A `SendRequest` that can be cloned to send HTTP2 requests. // private for now, probably not a great idea of a type... +#[must_use = "futures do nothing unless polled"] pub(super) struct Http2SendRequest { dispatch: dispatch::UnboundedSender, Response>, } -#[must_use = "futures do nothing unless polled"] -pub(super) struct HandshakeNoUpgrades { - inner: HandshakeInner, -} - -struct HandshakeInner { - builder: Builder, - io: Option, - _marker: PhantomData<(B, R)>, -} - // ===== impl SendRequest impl SendRequest @@ -354,7 +352,7 @@ where /// /// Only works for HTTP/1 connections. HTTP/2 connections will panic. pub fn into_parts(self) -> Parts { - let (io, read_buf, _) = match self.inner { + let (io, read_buf, _) = match self.inner.expect("already upgraded") { Either::A(h1) => h1.into_inner(), Either::B(_h2) => { panic!("http2 cannot into_inner"); @@ -376,12 +374,12 @@ where /// but it is not desired to actally shutdown the IO object. Instead you /// would take it back using `into_parts`. pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> { - match self.inner { - Either::A(ref mut h1) => { + match self.inner.as_mut().expect("already upgraded") { + &mut Either::A(ref mut h1) => { h1.poll_without_shutdown() }, - Either::B(ref mut h2) => { - h2.poll() + &mut Either::B(ref mut h2) => { + h2.poll().map(|x| x.map(|_| ())) } } } @@ -396,7 +394,22 @@ where type Error = ::Error; fn poll(&mut self) -> Poll { - self.inner.poll() + match try_ready!(self.inner.poll()) { + Some(proto::Dispatched::Shutdown) | + None => { + Ok(Async::Ready(())) + }, + Some(proto::Dispatched::Upgrade(pending)) => { + let h1 = match mem::replace(&mut self.inner, None) { + Some(Either::A(h1)) => h1, + _ => unreachable!("Upgrade expects h1"), + }; + + let (io, buf, _) = h1.into_inner(); + pending.fulfill(Upgraded::new(Box::new(io), buf)); + Ok(Async::Ready(())) + } + } } } @@ -456,25 +469,9 @@ impl Builder { B: Payload + 'static, { Handshake { - inner: HandshakeInner { - builder: self.clone(), - io: Some(io), - _marker: PhantomData, - } - } - } - - pub(super) fn handshake_no_upgrades(&self, io: T) -> HandshakeNoUpgrades - where - T: AsyncRead + AsyncWrite + Send + 'static, - B: Payload + 'static, - { - HandshakeNoUpgrades { - inner: HandshakeInner { - builder: self.clone(), - io: Some(io), - _marker: PhantomData, - } + builder: self.clone(), + io: Some(io), + _marker: PhantomData, } } } @@ -489,64 +486,6 @@ where type Item = (SendRequest, Connection); type Error = ::Error; - fn poll(&mut self) -> Poll { - self.inner.poll() - .map(|async| { - async.map(|(tx, dispatch)| { - (tx, Connection { inner: dispatch }) - }) - }) - } -} - -impl fmt::Debug for Handshake { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Handshake") - .finish() - } -} - -impl Future for HandshakeNoUpgrades -where - T: AsyncRead + AsyncWrite + Send + 'static, - B: Payload + 'static, -{ - type Item = (SendRequest, Either< - proto::h1::Dispatcher< - proto::h1::dispatch::Client, - B, - T, - proto::ClientTransaction, - >, - proto::h2::Client, - >); - type Error = ::Error; - - fn poll(&mut self) -> Poll { - self.inner.poll() - } -} - -impl Future for HandshakeInner -where - T: AsyncRead + AsyncWrite + Send + 'static, - B: Payload, - R: proto::h1::Http1Transaction< - Incoming=StatusCode, - Outgoing=proto::RequestLine, - >, -{ - type Item = (SendRequest, Either< - proto::h1::Dispatcher< - proto::h1::dispatch::Client, - B, - T, - R, - >, - proto::h2::Client, - >); - type Error = ::Error; - fn poll(&mut self) -> Poll { let io = self.io.take().expect("polled more than once"); let (tx, rx) = dispatch::channel(); @@ -570,11 +509,20 @@ where SendRequest { dispatch: tx, }, - either, + Connection { + inner: Some(either), + }, ))) } } +impl fmt::Debug for Handshake { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Handshake") + .finish() + } +} + // ===== impl ResponseFuture impl Future for ResponseFuture { diff --git a/src/client/mod.rs b/src/client/mod.rs index 7cb7c0da56..d822ec9f77 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -268,7 +268,7 @@ where C: Connect + Sync + 'static, .h1_writev(h1_writev) .h1_title_case_headers(h1_title_case_headers) .http2_only(pool_key.1 == Ver::Http2) - .handshake_no_upgrades(io) + .handshake(io) .and_then(move |(tx, conn)| { executor.execute(conn.map_err(|e| { debug!("client connection error: {}", e) diff --git a/src/common/io/mod.rs b/src/common/io/mod.rs new file mode 100644 index 0000000000..2e6d506153 --- /dev/null +++ b/src/common/io/mod.rs @@ -0,0 +1,3 @@ +mod rewind; + +pub(crate) use self::rewind::Rewind; diff --git a/src/server/rewind.rs b/src/common/io/rewind.rs similarity index 92% rename from src/server/rewind.rs rename to src/common/io/rewind.rs index 6d2bf90eda..797dad9a74 100644 --- a/src/server/rewind.rs +++ b/src/common/io/rewind.rs @@ -1,26 +1,40 @@ +use std::cmp; +use std::io::{self, Read, Write}; + use bytes::{Buf, BufMut, Bytes, IntoBuf}; use futures::{Async, Poll}; -use std::io::{self, Read, Write}; -use std::cmp; use tokio_io::{AsyncRead, AsyncWrite}; +/// Combine a buffer with an IO, rewinding reads to use the buffer. #[derive(Debug)] -pub struct Rewind { +pub(crate) struct Rewind { pre: Option, inner: T, } impl Rewind { - pub(super) fn new(tcp: T) -> Rewind { + pub(crate) fn new(io: T) -> Self { Rewind { pre: None, - inner: tcp, + inner: io, + } + } + + pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { + Rewind { + pre: Some(buf), + inner: io, } } - pub fn rewind(&mut self, bs: Bytes) { + + pub(crate) fn rewind(&mut self, bs: Bytes) { debug_assert!(self.pre.is_none()); self.pre = Some(bs); } + + pub(crate) fn into_inner(self) -> (T, Bytes) { + (self.inner, self.pre.unwrap_or_else(Bytes::new)) + } } impl Read for Rewind diff --git a/src/common/mod.rs b/src/common/mod.rs index 124b2100e8..1bfa980fc9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,5 +1,6 @@ mod buf; mod exec; +pub(crate) mod io; mod never; pub(crate) use self::buf::StaticBuf; diff --git a/src/error.rs b/src/error.rs index ab65e75523..337bea0c5b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -61,6 +61,12 @@ pub(crate) enum Kind { UnsupportedVersion, /// User tried to create a CONNECT Request with the Client. UnsupportedRequestMethod, + + /// User tried polling for an upgrade that doesn't exist. + NoUpgrade, + + /// User polled for an upgrade, but low-level API is not using upgrades. + ManualUpgrade, } #[derive(Debug, PartialEq)] @@ -72,9 +78,6 @@ pub(crate) enum Parse { Header, TooLarge, Status, - - /// A protocol upgrade was encountered, but not yet supported in hyper. - UpgradeNotSupported, } /* @@ -110,7 +113,8 @@ impl Error { Kind::Service | Kind::Closed | Kind::UnsupportedVersion | - Kind::UnsupportedRequestMethod => true, + Kind::UnsupportedRequestMethod | + Kind::NoUpgrade => true, _ => false, } } @@ -216,6 +220,14 @@ impl Error { Error::new(Kind::UnsupportedRequestMethod, None) } + pub(crate) fn new_user_no_upgrade() -> Error { + Error::new(Kind::NoUpgrade, None) + } + + pub(crate) fn new_user_manual_upgrade() -> Error { + Error::new(Kind::ManualUpgrade, None) + } + pub(crate) fn new_user_new_service>(cause: E) -> Error { Error::new(Kind::NewService, Some(cause.into())) } @@ -266,7 +278,6 @@ impl StdError for Error { Kind::Parse(Parse::Header) => "invalid Header provided", Kind::Parse(Parse::TooLarge) => "message head is too large", Kind::Parse(Parse::Status) => "invalid Status provided", - Kind::Parse(Parse::UpgradeNotSupported) => "unsupported protocol upgrade", Kind::Incomplete => "message is incomplete", Kind::MismatchedResponse => "response received without matching request", Kind::Closed => "connection closed", @@ -284,6 +295,8 @@ impl StdError for Error { Kind::Http2 => "http2 general error", Kind::UnsupportedVersion => "request has unsupported HTTP version", Kind::UnsupportedRequestMethod => "request has unsupported HTTP method", + Kind::NoUpgrade => "no upgrade available", + Kind::ManualUpgrade => "upgrade expected but low level API in use", Kind::Io => "an IO error occurred", } diff --git a/src/lib.rs b/src/lib.rs index 968a1f772a..adddcbcb38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,3 +62,4 @@ mod proto; pub mod server; pub mod service; #[cfg(feature = "runtime")] pub mod rt; +pub mod upgrade; diff --git a/src/mock.rs b/src/mock.rs index d8147c2cc9..b3c239c52d 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -79,6 +79,7 @@ pub struct AsyncIo { inner: T, max_read_vecs: usize, num_writes: usize, + panic: bool, park_tasks: bool, task: Option, } @@ -93,6 +94,7 @@ impl AsyncIo { inner: inner, max_read_vecs: READ_VECS_CNT, num_writes: 0, + panic: false, park_tasks: false, task: None, } @@ -110,6 +112,11 @@ impl AsyncIo { self.error = Some(err); } + #[cfg(feature = "nightly")] + pub fn panic(&mut self) { + self.panic = true; + } + pub fn max_read_vecs(&mut self, cnt: usize) { assert!(cnt <= READ_VECS_CNT); self.max_read_vecs = cnt; @@ -185,6 +192,7 @@ impl, T: AsRef<[u8]>> PartialEq for AsyncIo { impl Read for AsyncIo { fn read(&mut self, buf: &mut [u8]) -> io::Result { + assert!(!self.panic, "AsyncIo::read panic"); self.blocked = false; if let Some(err) = self.error.take() { Err(err) @@ -201,6 +209,7 @@ impl Read for AsyncIo { impl Write for AsyncIo { fn write(&mut self, data: &[u8]) -> io::Result { + assert!(!self.panic, "AsyncIo::write panic"); self.num_writes += 1; if let Some(err) = self.error.take() { trace!("AsyncIo::write error"); @@ -233,6 +242,7 @@ impl AsyncWrite for AsyncIo { } fn write_buf(&mut self, buf: &mut B) -> Poll { + assert!(!self.panic, "AsyncIo::write_buf panic"); if self.max_read_vecs == 0 { return self.write_no_vecs(buf); } diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 4c57e6e87e..925db1d8f8 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,9 +8,9 @@ use http::{HeaderMap, Method, Version}; use tokio_io::{AsyncRead, AsyncWrite}; use ::Chunk; -use proto::{BodyLength, MessageHead}; +use proto::{BodyLength, DecodedLength, MessageHead}; use super::io::{Buffered}; -use super::{EncodedBuf, Encode, Encoder, Decode, Decoder, Http1Transaction, ParseContext}; +use super::{EncodedBuf, Encode, Encoder, /*Decode,*/ Decoder, Http1Transaction, ParseContext}; const H2_PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; @@ -44,6 +44,7 @@ where I: AsyncRead + AsyncWrite, notify_read: false, reading: Reading::Init, writing: Writing::Init, + upgrade: None, // We assume a modern world where the remote speaks HTTP/1.1. // If they tell us otherwise, we'll downgrade in `read_head`. version: Version::HTTP_11, @@ -72,6 +73,10 @@ where I: AsyncRead + AsyncWrite, self.io.into_inner() } + pub fn pending_upgrade(&mut self) -> Option<::upgrade::Pending> { + self.state.upgrade.take() + } + pub fn is_read_closed(&self) -> bool { self.state.is_read_closed() } @@ -114,80 +119,61 @@ where I: AsyncRead + AsyncWrite, read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE } - pub fn read_head(&mut self) -> Poll, Option)>, ::Error> { + pub fn read_head(&mut self) -> Poll, DecodedLength, bool)>, ::Error> { debug_assert!(self.can_read_head()); trace!("Conn::read_head"); - loop { - let msg = match self.io.parse::(ParseContext { - cached_headers: &mut self.state.cached_headers, - req_method: &mut self.state.method, - }) { - Ok(Async::Ready(msg)) => msg, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => { - // If we are currently waiting on a message, then an empty - // message should be reported as an error. If not, it is just - // the connection closing gracefully. - let must_error = self.should_error_on_eof(); - self.state.close_read(); - self.io.consume_leading_lines(); - let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); - return if was_mid_parse || must_error { - // We check if the buf contains the h2 Preface - debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); - self.on_parse_error(e) - .map(|()| Async::NotReady) - } else { - debug!("read eof"); - Ok(Async::Ready(None)) - }; - } - }; + let msg = match self.io.parse::(ParseContext { + cached_headers: &mut self.state.cached_headers, + req_method: &mut self.state.method, + }) { + Ok(Async::Ready(msg)) => msg, + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(e) => return self.on_read_head_error(e), + }; - self.state.version = msg.head.version; - let head = msg.head; - let decoder = match msg.decode { - Decode::Normal(d) => { - d - }, - Decode::Final(d) => { - trace!("final decoder, HTTP ending"); - debug_assert!(d.is_eof()); - self.state.close_read(); - d - }, - Decode::Ignore => { - // likely a 1xx message that we can ignore - continue; - } - }; - debug!("incoming body is {}", decoder); + // Note: don't deconstruct `msg` into local variables, it appears + // the optimizer doesn't remove the extra copies. - self.state.busy(); + debug!("incoming body is {}", msg.decode); + + self.state.busy(); + self.state.keep_alive &= msg.keep_alive; + self.state.version = msg.head.version; + + if msg.decode == DecodedLength::ZERO { + debug_assert!(!msg.expect_continue, "expect-continue needs a body"); + self.state.reading = Reading::KeepAlive; + if !T::should_read_first() { + self.try_keep_alive(); + } + } else { if msg.expect_continue { let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; self.io.headers_buf().extend_from_slice(cont); } - let wants_keep_alive = msg.keep_alive; - self.state.keep_alive &= wants_keep_alive; - - let content_length = decoder.content_length(); + self.state.reading = Reading::Body(Decoder::new(msg.decode)); + }; - if let Reading::Closed = self.state.reading { - // actually want an `if not let ...` - } else { - self.state.reading = if content_length.is_none() { - Reading::KeepAlive - } else { - Reading::Body(decoder) - }; - } - if content_length.is_none() { - self.try_keep_alive(); - } + Ok(Async::Ready(Some((msg.head, msg.decode, msg.wants_upgrade)))) + } - return Ok(Async::Ready(Some((head, content_length)))); + fn on_read_head_error(&mut self, e: ::Error) -> Poll, ::Error> { + // If we are currently waiting on a message, then an empty + // message should be reported as an error. If not, it is just + // the connection closing gracefully. + let must_error = self.should_error_on_eof(); + self.state.close_read(); + self.io.consume_leading_lines(); + let was_mid_parse = e.is_parse() || !self.io.read_buf().is_empty(); + if was_mid_parse || must_error { + // We check if the buf contains the h2 Preface + debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); + self.on_parse_error(e) + .map(|()| Async::NotReady) + } else { + debug!("read eof"); + Ok(Async::Ready(None)) } } @@ -612,6 +598,10 @@ where I: AsyncRead + AsyncWrite, } } + pub(super) fn on_upgrade(&mut self) -> ::upgrade::OnUpgrade { + self.state.prepare_upgrade() + } + // Used in h1::dispatch tests #[cfg(test)] pub(super) fn io_mut(&mut self) -> &mut I { @@ -649,6 +639,8 @@ struct State { reading: Reading, /// State of allowed writes writing: Writing, + /// An expected pending HTTP upgrade. + upgrade: Option<::upgrade::Pending>, /// Either HTTP/1.0 or 1.1 connection version: Version, } @@ -697,6 +689,7 @@ impl fmt::Debug for Writing { impl ::std::ops::BitAndAssign for KA { fn bitand_assign(&mut self, enabled: bool) { if !enabled { + trace!("remote disabling keep-alive"); *self = KA::Disabled; } } @@ -821,11 +814,53 @@ impl State { _ => false } } + + fn prepare_upgrade(&mut self) -> ::upgrade::OnUpgrade { + trace!("prepare possible HTTP upgrade"); + debug_assert!(self.upgrade.is_none()); + let (tx, rx) = ::upgrade::pending(); + self.upgrade = Some(tx); + rx + } } #[cfg(test)] //TODO: rewrite these using dispatch mod tests { + + #[cfg(feature = "nightly")] + #[bench] + fn bench_read_head_short(b: &mut ::test::Bencher) { + use super::*; + let s = b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"; + let len = s.len(); + b.bytes = len as u64; + + let mut io = ::mock::AsyncIo::new_buf(Vec::new(), 0); + io.panic(); + let mut conn = Conn::<_, ::Chunk, ::proto::h1::ServerTransaction>::new(io); + *conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]); + conn.state.cached_headers = Some(HeaderMap::with_capacity(2)); + + b.iter(|| { + match conn.read_head().unwrap() { + Async::Ready(Some(x)) => { + ::test::black_box(&x); + let mut headers = x.0.headers; + headers.clear(); + conn.state.cached_headers = Some(headers); + }, + f => panic!("expected Ready(Some(..)): {:?}", f) + } + + + conn.io.read_buf_mut().reserve(1); + unsafe { + conn.io.read_buf_mut().set_len(len); + } + conn.state.reading = Reading::Init; + }); + } /* use futures::{Async, Future, Stream, Sink}; use futures::future; diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 03296878af..b8e2cac7d5 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -7,7 +7,7 @@ use futures::{Async, Poll}; use bytes::Bytes; use super::io::MemRead; -use super::BodyLength; +use super::{DecodedLength}; use self::Kind::{Length, Chunked, Eof}; @@ -74,6 +74,14 @@ impl Decoder { Decoder { kind: Kind::Eof(false) } } + pub(super) fn new(len: DecodedLength) -> Self { + match len { + DecodedLength::CHUNKED => Decoder::chunked(), + DecodedLength::CLOSE_DELIMITED => Decoder::eof(), + length => Decoder::length(length.danger_len()), + } + } + // methods pub fn is_eof(&self) -> bool { @@ -85,16 +93,6 @@ impl Decoder { } } - pub fn content_length(&self) -> Option { - match self.kind { - Length(0) | - Chunked(ChunkedState::End, _) | - Eof(true) => None, - Length(len) => Some(BodyLength::Known(len)), - _ => Some(BodyLength::Unknown), - } - } - pub fn decode(&mut self, body: &mut R) -> Poll { trace!("decode; state={:?}", self.kind); match self.kind { @@ -152,16 +150,6 @@ impl fmt::Debug for Decoder { } } -impl fmt::Display for Decoder { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self.kind { - Kind::Length(n) => write!(f, "content-length ({} bytes)", n), - Kind::Chunked(..) => f.write_str("chunked encoded"), - Kind::Eof(..) => f.write_str("until end-of-file"), - } - } -} - macro_rules! byte ( ($rdr:ident) => ({ let buf = try_ready!($rdr.read_mem(1)); diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index ff411e81ee..290edfe20d 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -5,7 +5,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use body::{Body, Payload}; use body::internal::FullDataArg; -use proto::{BodyLength, Conn, MessageHead, RequestHead, RequestLine, ResponseHead}; +use proto::{BodyLength, DecodedLength, Conn, Dispatched, MessageHead, RequestHead, RequestLine, ResponseHead}; use super::Http1Transaction; use service::Service; @@ -65,32 +65,34 @@ where (io, buf, self.dispatch) } - /// The "Future" poll function. Runs this dispatcher until the - /// connection is shutdown, or an error occurs. - pub fn poll_until_shutdown(&mut self) -> Poll<(), ::Error> { - self.poll_catch(true) - } - /// Run this dispatcher until HTTP says this connection is done, /// but don't call `AsyncWrite::shutdown` on the underlying IO. /// - /// This is useful for HTTP upgrades. + /// This is useful for old-style HTTP upgrades, but ignores + /// newer-style upgrade API. pub fn poll_without_shutdown(&mut self) -> Poll<(), ::Error> { self.poll_catch(false) + .map(|x| { + x.map(|ds| if let Dispatched::Upgrade(pending) = ds { + pending.manual(); + }) + }) } - fn poll_catch(&mut self, should_shutdown: bool) -> Poll<(), ::Error> { + fn poll_catch(&mut self, should_shutdown: bool) -> Poll { self.poll_inner(should_shutdown).or_else(|e| { // An error means we're shutting down either way. // We just try to give the error to the user, // and close the connection with an Ok. If we // cannot give it to the user, then return the Err. - self.dispatch.recv_msg(Err(e)).map(Async::Ready) + self.dispatch.recv_msg(Err(e))?; + Ok(Async::Ready(Dispatched::Shutdown)) }) } - fn poll_inner(&mut self, should_shutdown: bool) -> Poll<(), ::Error> { + fn poll_inner(&mut self, should_shutdown: bool) -> Poll { T::update_date(); + loop { self.poll_read()?; self.poll_write()?; @@ -110,11 +112,14 @@ where } if self.is_done() { - if should_shutdown { + if let Some(pending) = self.conn.pending_upgrade() { + self.conn.take_error()?; + return Ok(Async::Ready(Dispatched::Upgrade(pending))); + } else if should_shutdown { try_ready!(self.conn.shutdown().map_err(::Error::new_shutdown)); } self.conn.take_error()?; - Ok(Async::Ready(())) + Ok(Async::Ready(Dispatched::Shutdown)) } else { Ok(Async::NotReady) } @@ -190,20 +195,18 @@ where } // dispatch is ready for a message, try to read one match self.conn.read_head() { - Ok(Async::Ready(Some((head, body_len)))) => { - let body = if let Some(body_len) = body_len { - let (mut tx, rx) = - Body::new_channel(if let BodyLength::Known(len) = body_len { - Some(len) - } else { - None - }); - let _ = tx.poll_ready(); // register this task if rx is dropped - self.body_tx = Some(tx); - rx - } else { - Body::empty() + Ok(Async::Ready(Some((head, body_len, wants_upgrade)))) => { + let mut body = match body_len { + DecodedLength::ZERO => Body::empty(), + other => { + let (tx, rx) = Body::new_channel(other.into_opt()); + self.body_tx = Some(tx); + rx + }, }; + if wants_upgrade { + body.set_on_upgrade(self.conn.on_upgrade()); + } self.dispatch.recv_msg(Ok((head, body)))?; Ok(Async::Ready(())) } @@ -326,7 +329,6 @@ where } } - impl Future for Dispatcher where D: Dispatch, PollBody=Bs, RecvItem=MessageHead>, @@ -334,12 +336,12 @@ where T: Http1Transaction, Bs: Payload, { - type Item = (); + type Item = Dispatched; type Error = ::Error; #[inline] fn poll(&mut self) -> Poll { - self.poll_until_shutdown() + self.poll_catch(true) } } @@ -519,7 +521,7 @@ mod tests { use super::*; use mock::AsyncIo; - use proto::ClientTransaction; + use proto::h1::ClientTransaction; #[test] fn client_read_bytes_before_writing_request() { diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index d0a953c7c3..441dec74f4 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -93,6 +93,12 @@ where self.read_buf.as_ref() } + #[cfg(test)] + #[cfg(feature = "nightly")] + pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { + &mut self.read_buf + } + pub fn headers_buf(&mut self) -> &mut Vec { let buf = self.write_buf.headers_mut(); &mut buf.bytes @@ -595,7 +601,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, }; - assert!(buffered.parse::<::proto::ClientTransaction>(ctx).unwrap().is_not_ready()); + assert!(buffered.parse::<::proto::h1::ClientTransaction>(ctx).unwrap().is_not_ready()); assert!(buffered.io.blocked()); } diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 46f0eeab00..3adf260d60 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -1,7 +1,7 @@ use bytes::BytesMut; use http::{HeaderMap, Method}; -use proto::{MessageHead, BodyLength}; +use proto::{MessageHead, BodyLength, DecodedLength}; pub(crate) use self::conn::Conn; pub(crate) use self::dispatch::Dispatcher; @@ -19,12 +19,8 @@ mod io; mod role; -pub(crate) type ServerTransaction = self::role::Server; -//pub type ServerTransaction = self::role::Server; -//pub type ServerUpgradeTransaction = self::role::Server; - -pub(crate) type ClientTransaction = self::role::Client; -pub(crate) type ClientUpgradeTransaction = self::role::Client; +pub(crate) type ServerTransaction = role::Server; +pub(crate) type ClientTransaction = role::Client; pub(crate) trait Http1Transaction { type Incoming; @@ -40,14 +36,16 @@ pub(crate) trait Http1Transaction { fn update_date() {} } +/// Result newtype for Http1Transaction::parse. pub(crate) type ParseResult = Result>, ::error::Parse>; #[derive(Debug)] pub(crate) struct ParsedMessage { head: MessageHead, - decode: Decode, + decode: DecodedLength, expect_continue: bool, keep_alive: bool, + wants_upgrade: bool, } pub(crate) struct ParseContext<'a> { @@ -64,12 +62,3 @@ pub(crate) struct Encode<'a, T: 'a> { title_case_headers: bool, } -#[derive(Debug, PartialEq)] -pub enum Decode { - /// Decode normally. - Normal(Decoder), - /// After this decoder is done, HTTP is done. - Final(Decoder), - /// A header block that should be ignored, like unknown 1xx responses. - Ignore, -} diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index dd21b623f9..3e30d16645 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -8,25 +8,19 @@ use httparse; use error::Parse; use headers; -use proto::{BodyLength, MessageHead, RequestLine, RequestHead}; -use proto::h1::{Decode, Decoder, Encode, Encoder, Http1Transaction, ParseResult, ParseContext, ParsedMessage, date}; +use proto::{BodyLength, DecodedLength, MessageHead, RequestLine, RequestHead}; +use proto::h1::{Encode, Encoder, Http1Transaction, ParseResult, ParseContext, ParsedMessage, date}; const MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific // There are 2 main roles, Client and Server. -// -// There is 1 modifier, OnUpgrade, which can wrap Client and Server, -// to signal that HTTP upgrades are not supported. -pub(crate) struct Client(T); +pub(crate) enum Client {} -pub(crate) struct Server(T); +pub(crate) enum Server {} -impl Http1Transaction for Server -where - T: OnUpgrade, -{ +impl Http1Transaction for Server { type Incoming = RequestLine; type Outgoing = StatusCode; @@ -34,31 +28,45 @@ where if buf.len() == 0 { return Ok(None); } + + let mut keep_alive; + let is_http_11; + let subject; + let version; + let len; + let headers_len; + // Unsafe: both headers_indices and headers are using unitialized memory, // but we *never* read any of it until after httparse has assigned // values into it. By not zeroing out the stack memory, this saves // a good ~5% on pipeline benchmarks. let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, subject, version, headers_len) = { + { let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; trace!("Request.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); let mut req = httparse::Request::new(&mut headers); let bytes = buf.as_ref(); match req.parse(bytes)? { - httparse::Status::Complete(len) => { - trace!("Request.parse Complete({})", len); - let method = Method::from_bytes(req.method.unwrap().as_bytes())?; - let path = req.path.unwrap().parse()?; - let subject = RequestLine(method, path); - let version = if req.version.unwrap() == 1 { + httparse::Status::Complete(parsed_len) => { + trace!("Request.parse Complete({})", parsed_len); + len = parsed_len; + subject = RequestLine( + Method::from_bytes(req.method.unwrap().as_bytes())?, + req.path.unwrap().parse()? + ); + version = if req.version.unwrap() == 1 { + keep_alive = true; + is_http_11 = true; Version::HTTP_11 } else { + keep_alive = false; + is_http_11 = false; Version::HTTP_10 }; record_header_indices(bytes, &req.headers, &mut headers_indices); - let headers_len = req.headers.len(); - (len, subject, version, headers_len) + headers_len = req.headers.len(); + //(len, subject, version, headers_len) } httparse::Status::Partial => return Ok(None), } @@ -76,12 +84,12 @@ where // 7. (irrelevant to Request) - let mut decoder = None; + let mut decoder = DecodedLength::ZERO; let mut expect_continue = false; - let mut keep_alive = version == Version::HTTP_11; let mut con_len = None; let mut is_te = false; let mut is_te_chunked = false; + let mut wants_upgrade = subject.0 == Method::CONNECT; let mut headers = ctx.cached_headers .take() @@ -104,16 +112,14 @@ where // If Transfer-Encoding header is present, and 'chunked' is // not the final encoding, and this is a Request, then it is // mal-formed. A server should respond with 400 Bad Request. - if version == Version::HTTP_10 { + if !is_http_11 { debug!("HTTP/1.0 cannot have Transfer-Encoding header"); return Err(Parse::Header); } is_te = true; if headers::is_chunked_(&value) { is_te_chunked = true; - decoder = Some(Decoder::chunked()); - //debug!("request with transfer-encoding header, but not chunked, bad request"); - //return Err(Parse::Header); + decoder = DecodedLength::CHUNKED; } }, header::CONTENT_LENGTH => { @@ -135,8 +141,8 @@ where // we don't need to append this secondary length continue; } + decoder = DecodedLength::checked_new(len)?; con_len = Some(len); - decoder = Some(Decoder::length(len)); }, header::CONNECTION => { // keep_alive was previously set to default for Version @@ -152,6 +158,10 @@ where header::EXPECT => { expect_continue = value.as_bytes() == b"100-continue"; }, + header::UPGRADE => { + // Upgrades are only allowed with HTTP/1.1 + wants_upgrade = is_http_11; + }, _ => (), } @@ -159,15 +169,10 @@ where headers.append(name, value); } - let decoder = if let Some(decoder) = decoder { - decoder - } else { - if is_te && !is_te_chunked { - debug!("request with transfer-encoding header, but not chunked, bad request"); - return Err(Parse::Header); - } - Decoder::length(0) - }; + if is_te && !is_te_chunked { + debug!("request with transfer-encoding header, but not chunked, bad request"); + return Err(Parse::Header); + } *ctx.req_method = Some(subject.0.clone()); @@ -177,9 +182,10 @@ where subject, headers, }, - decode: Decode::Normal(decoder), + decode: decoder, expect_continue, keep_alive, + wants_upgrade, })) } @@ -194,7 +200,7 @@ where let is_upgrade = msg.head.subject == StatusCode::SWITCHING_PROTOCOLS || (msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success()); let (ret, mut is_last) = if is_upgrade { - (T::on_encode_upgrade(&mut msg), true) + (Ok(()), true) } else if msg.head.subject.is_informational() { error!("response with 1xx status code not supported"); *msg.head = MessageHead::default(); @@ -485,7 +491,7 @@ where } } -impl Server<()> { +impl Server { fn can_have_body(method: &Option, status: StatusCode) -> bool { Server::can_chunked(method, status) } @@ -508,65 +514,69 @@ impl Server<()> { } } -impl Http1Transaction for Client -where - T: OnUpgrade, -{ +impl Http1Transaction for Client { type Incoming = StatusCode; type Outgoing = RequestLine; fn parse(buf: &mut BytesMut, ctx: ParseContext) -> ParseResult { - if buf.len() == 0 { - return Ok(None); - } - // Unsafe: see comment in Server Http1Transaction, above. - let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; - let (len, status, version, headers_len) = { - let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; - trace!("Response.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); - let mut res = httparse::Response::new(&mut headers); - let bytes = buf.as_ref(); - match res.parse(bytes)? { - httparse::Status::Complete(len) => { - trace!("Response.parse Complete({})", len); - let status = StatusCode::from_u16(res.code.unwrap())?; - let version = if res.version.unwrap() == 1 { - Version::HTTP_11 - } else { - Version::HTTP_10 - }; - record_header_indices(bytes, &res.headers, &mut headers_indices); - let headers_len = res.headers.len(); - (len, status, version, headers_len) - }, - httparse::Status::Partial => return Ok(None), + // Loop to skip information status code headers (100 Continue, etc). + loop { + if buf.len() == 0 { + return Ok(None); } - }; + // Unsafe: see comment in Server Http1Transaction, above. + let mut headers_indices: [HeaderIndices; MAX_HEADERS] = unsafe { mem::uninitialized() }; + let (len, status, version, headers_len) = { + let mut headers: [httparse::Header; MAX_HEADERS] = unsafe { mem::uninitialized() }; + trace!("Response.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); + let mut res = httparse::Response::new(&mut headers); + let bytes = buf.as_ref(); + match res.parse(bytes)? { + httparse::Status::Complete(len) => { + trace!("Response.parse Complete({})", len); + let status = StatusCode::from_u16(res.code.unwrap())?; + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + record_header_indices(bytes, &res.headers, &mut headers_indices); + let headers_len = res.headers.len(); + (len, status, version, headers_len) + }, + httparse::Status::Partial => return Ok(None), + } + }; - let slice = buf.split_to(len).freeze(); + let slice = buf.split_to(len).freeze(); - let mut headers = ctx.cached_headers - .take() - .unwrap_or_else(HeaderMap::new); + let mut headers = ctx.cached_headers + .take() + .unwrap_or_else(HeaderMap::new); - headers.reserve(headers_len); - fill_headers(&mut headers, slice, &headers_indices[..headers_len]); + headers.reserve(headers_len); + fill_headers(&mut headers, slice, &headers_indices[..headers_len]); - let keep_alive = version == Version::HTTP_11; + let keep_alive = version == Version::HTTP_11; - let head = MessageHead { - version, - subject: status, - headers, - }; - let decode = Client::::decoder(&head, ctx.req_method)?; + let head = MessageHead { + version, + subject: status, + headers, + }; + if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { + return Ok(Some(ParsedMessage { + head, + decode, + expect_continue: false, + // a client upgrade means the connection can't be used + // again, as it is definitely upgrading. + keep_alive: keep_alive && !is_upgrade, + wants_upgrade: is_upgrade, + })); + } - Ok(Some(ParsedMessage { - head, - decode, - expect_continue: false, - keep_alive, - })) + } } fn encode(msg: Encode, dst: &mut Vec) -> ::Result { @@ -617,8 +627,11 @@ where } } -impl Client { - fn decoder(inc: &MessageHead, method: &mut Option) -> Result { +impl Client { + /// Returns Some(length, wants_upgrade) if successful. + /// + /// Returns None if this message head should be skipped (like a 100 status). + fn decoder(inc: &MessageHead, method: &mut Option) -> Result, Parse> { // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. // 2. Status 2xx to a CONNECT cannot have a body. @@ -630,23 +643,23 @@ impl Client { match inc.subject.as_u16() { 101 => { - return T::on_decode_upgrade().map(Decode::Final); + return Ok(Some((DecodedLength::ZERO, true))); }, 100...199 => { trace!("ignoring informational response: {}", inc.subject.as_u16()); - return Ok(Decode::Ignore); + return Ok(None); }, 204 | - 304 => return Ok(Decode::Normal(Decoder::length(0))), + 304 => return Ok(Some((DecodedLength::ZERO, false))), _ => (), } match *method { Some(Method::HEAD) => { - return Ok(Decode::Normal(Decoder::length(0))); + return Ok(Some((DecodedLength::ZERO, false))); } Some(Method::CONNECT) => match inc.subject.as_u16() { 200...299 => { - return Ok(Decode::Final(Decoder::length(0))); + return Ok(Some((DecodedLength::ZERO, true))); }, _ => {}, }, @@ -665,24 +678,24 @@ impl Client { debug!("HTTP/1.0 cannot have Transfer-Encoding header"); Err(Parse::Header) } else if headers::transfer_encoding_is_chunked(&inc.headers) { - Ok(Decode::Normal(Decoder::chunked())) + Ok(Some((DecodedLength::CHUNKED, false))) } else { trace!("not chunked, read till eof"); - Ok(Decode::Normal(Decoder::eof())) + Ok(Some((DecodedLength::CHUNKED, false))) } } else if let Some(len) = headers::content_length_parse_all(&inc.headers) { - Ok(Decode::Normal(Decoder::length(len))) + Ok(Some((DecodedLength::checked_new(len)?, false))) } else if inc.headers.contains_key(header::CONTENT_LENGTH) { debug!("illegal Content-Length header"); Err(Parse::Header) } else { trace!("neither Transfer-Encoding nor Content-Length"); - Ok(Decode::Normal(Decoder::eof())) + Ok(Some((DecodedLength::CLOSE_DELIMITED, false))) } } } -impl Client<()> { +impl Client { fn set_length(head: &mut RequestHead, body: Option) -> Encoder { if let Some(body) = body { let can_chunked = head.version == Version::HTTP_11 @@ -830,51 +843,6 @@ fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { } } -pub(crate) trait OnUpgrade { - fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()>; - fn on_decode_upgrade() -> Result; -} - -pub(crate) enum YesUpgrades {} - -pub(crate) enum NoUpgrades {} - -impl OnUpgrade for YesUpgrades { - fn on_encode_upgrade(_: &mut Encode) -> ::Result<()> { - Ok(()) - } - - fn on_decode_upgrade() -> Result { - debug!("101 response received, upgrading"); - // 101 upgrades always have no body - Ok(Decoder::length(0)) - } -} - -impl OnUpgrade for NoUpgrades { - fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()> { - *msg.head = MessageHead::default(); - msg.head.subject = ::StatusCode::INTERNAL_SERVER_ERROR; - msg.body = None; - - if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { - error!("response with 101 status code not supported"); - Err(Parse::UpgradeNotSupported.into()) - } else if msg.req_method == &Some(Method::CONNECT) { - error!("200 response to CONNECT request not supported"); - Err(::Error::new_user_unsupported_request_method()) - } else { - debug_assert!(false, "upgrade incorrectly detected"); - Err(::Error::new_status()) - } - } - - fn on_decode_upgrade() -> Result { - debug!("received 101 upgrade response, not supported"); - Err(Parse::UpgradeNotSupported) - } -} - #[derive(Clone, Copy)] struct HeaderIndices { name: (usize, usize), @@ -978,10 +946,6 @@ mod tests { use bytes::BytesMut; use super::*; - use super::{Server as S, Client as C}; - - type Server = S; - type Client = C; #[test] fn test_parse_request() { @@ -1033,8 +997,6 @@ mod tests { #[test] fn test_decoder_request() { - use super::Decoder; - fn parse(s: &str) -> ParsedMessage { let mut bytes = BytesMut::from(s); Server::parse(&mut bytes, ParseContext { @@ -1058,39 +1020,39 @@ mod tests { assert_eq!(parse("\ GET / HTTP/1.1\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(0))); + ").decode, DecodedLength::ZERO); assert_eq!(parse("\ POST / HTTP/1.1\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(0))); + ").decode, DecodedLength::ZERO); // transfer-encoding: chunked assert_eq!(parse("\ POST / HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); assert_eq!(parse("\ POST / HTTP/1.1\r\n\ transfer-encoding: gzip, chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); assert_eq!(parse("\ POST / HTTP/1.1\r\n\ transfer-encoding: gzip\r\n\ transfer-encoding: chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); // content-length assert_eq!(parse("\ POST / HTTP/1.1\r\n\ content-length: 10\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(10))); + ").decode, DecodedLength::new(10)); // transfer-encoding and content-length = chunked assert_eq!(parse("\ @@ -1098,14 +1060,14 @@ mod tests { content-length: 10\r\n\ transfer-encoding: chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); assert_eq!(parse("\ POST / HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\ content-length: 10\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); assert_eq!(parse("\ POST / HTTP/1.1\r\n\ @@ -1113,7 +1075,7 @@ mod tests { content-length: 10\r\n\ transfer-encoding: chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); // multiple content-lengths of same value are fine @@ -1122,7 +1084,7 @@ mod tests { content-length: 10\r\n\ content-length: 10\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(10))); + ").decode, DecodedLength::new(10)); // multiple content-lengths with different values is an error @@ -1153,7 +1115,7 @@ mod tests { POST / HTTP/1.0\r\n\ content-length: 10\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(10))); + ").decode, DecodedLength::new(10)); // 1.0 doesn't understand chunked, so its an error @@ -1171,6 +1133,16 @@ mod tests { parse_with_method(s, Method::GET) } + fn parse_ignores(s: &str) { + let mut bytes = BytesMut::from(s); + assert!(Client::parse(&mut bytes, ParseContext { + cached_headers: &mut None, + req_method: &mut Some(Method::GET), + }) + .expect("parse ok") + .is_none()) + } + fn parse_with_method(s: &str, m: Method) -> ParsedMessage { let mut bytes = BytesMut::from(s); Client::parse(&mut bytes, ParseContext { @@ -1195,32 +1167,32 @@ mod tests { assert_eq!(parse("\ HTTP/1.1 200 OK\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::eof())); + ").decode, DecodedLength::CLOSE_DELIMITED); // 204 and 304 never have a body assert_eq!(parse("\ HTTP/1.1 204 No Content\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(0))); + ").decode, DecodedLength::ZERO); assert_eq!(parse("\ HTTP/1.1 304 Not Modified\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(0))); + ").decode, DecodedLength::ZERO); // content-length assert_eq!(parse("\ HTTP/1.1 200 OK\r\n\ content-length: 8\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(8))); + ").decode, DecodedLength::new(8)); assert_eq!(parse("\ HTTP/1.1 200 OK\r\n\ content-length: 8\r\n\ content-length: 8\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::length(8))); + ").decode, DecodedLength::new(8)); parse_err("\ HTTP/1.1 200 OK\r\n\ @@ -1235,7 +1207,7 @@ mod tests { HTTP/1.1 200 OK\r\n\ transfer-encoding: chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); // transfer-encoding and content-length = chunked assert_eq!(parse("\ @@ -1243,7 +1215,7 @@ mod tests { content-length: 10\r\n\ transfer-encoding: chunked\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::chunked())); + ").decode, DecodedLength::CHUNKED); // HEAD can have content-length, but not body @@ -1251,44 +1223,54 @@ mod tests { HTTP/1.1 200 OK\r\n\ content-length: 8\r\n\ \r\n\ - ", Method::HEAD).decode, Decode::Normal(Decoder::length(0))); + ", Method::HEAD).decode, DecodedLength::ZERO); // CONNECT with 200 never has body - assert_eq!(parse_with_method("\ - HTTP/1.1 200 OK\r\n\ - \r\n\ - ", Method::CONNECT).decode, Decode::Final(Decoder::length(0))); + { + let msg = parse_with_method("\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ", Method::CONNECT); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be upgrade"); + assert!(msg.wants_upgrade, "should be upgrade"); + } // CONNECT receiving non 200 can have a body assert_eq!(parse_with_method("\ HTTP/1.1 400 Bad Request\r\n\ \r\n\ - ", Method::CONNECT).decode, Decode::Normal(Decoder::eof())); + ", Method::CONNECT).decode, DecodedLength::CLOSE_DELIMITED); // 1xx status codes - assert_eq!(parse("\ + parse_ignores("\ HTTP/1.1 100 Continue\r\n\ \r\n\ - ").decode, Decode::Ignore); + "); - assert_eq!(parse("\ + parse_ignores("\ HTTP/1.1 103 Early Hints\r\n\ \r\n\ - ").decode, Decode::Ignore); + "); // 101 upgrade not supported yet - parse_err("\ - HTTP/1.1 101 Switching Protocols\r\n\ - \r\n\ - "); + { + let msg = parse("\ + HTTP/1.1 101 Switching Protocols\r\n\ + \r\n\ + "); + assert_eq!(msg.decode, DecodedLength::ZERO); + assert!(!msg.keep_alive, "should be last"); + assert!(msg.wants_upgrade, "should be upgrade"); + } // http/1.0 assert_eq!(parse("\ HTTP/1.0 200 OK\r\n\ \r\n\ - ").decode, Decode::Normal(Decoder::eof())); + ").decode, DecodedLength::CLOSE_DELIMITED); // 1.0 doesn't understand chunked parse_err("\ @@ -1320,28 +1302,11 @@ mod tests { } #[test] - fn test_server_no_upgrades_connect_method() { - let mut head = MessageHead::default(); - - let mut vec = Vec::new(); - let err = Server::encode(Encode { - head: &mut head, - body: None, - keep_alive: true, - req_method: &mut Some(Method::CONNECT), - title_case_headers: false, - }, &mut vec).unwrap_err(); - - assert!(err.is_user()); - assert_eq!(err.kind(), &::error::Kind::UnsupportedRequestMethod); - } - - #[test] - fn test_server_yes_upgrades_connect_method() { + fn test_server_encode_connect_method() { let mut head = MessageHead::default(); let mut vec = Vec::new(); - let encoder = S::::encode(Encode { + let encoder = Server::encode(Encode { head: &mut head, body: None, keep_alive: true, @@ -1382,10 +1347,12 @@ mod tests { b.bytes = len as u64; b.iter(|| { - let msg = Server::parse(&mut raw, ParseContext { + let mut msg = Server::parse(&mut raw, ParseContext { cached_headers: &mut headers, req_method: &mut None, }).unwrap().unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); headers = Some(msg.head.headers); restart(&mut raw, len); }); @@ -1402,18 +1369,19 @@ mod tests { #[cfg(feature = "nightly")] #[bench] fn bench_parse_short(b: &mut Bencher) { - let mut raw = BytesMut::from( - b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n".to_vec() - ); + let s = &b"GET / HTTP/1.1\r\nHost: localhost:8080\r\n\r\n"[..]; + let mut raw = BytesMut::from(s.to_vec()); let len = raw.len(); let mut headers = Some(HeaderMap::new()); b.bytes = len as u64; b.iter(|| { - let msg = Server::parse(&mut raw, ParseContext { + let mut msg = Server::parse(&mut raw, ParseContext { cached_headers: &mut headers, req_method: &mut None, }).unwrap().unwrap(); + ::test::black_box(&msg); + msg.head.headers.clear(); headers = Some(msg.head.headers); restart(&mut raw, len); }); @@ -1480,3 +1448,4 @@ mod tests { }) } } + diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 6a1b65e84c..f8824518d2 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -8,6 +8,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use body::Payload; use ::common::{Exec, Never}; use headers; +use ::proto::Dispatched; use super::{PipeToSendStream, SendBuf}; use ::{Body, Request, Response}; @@ -16,7 +17,7 @@ type ClientRx = ::client::dispatch::Receiver, Response>; /// other handles to it have been dropped, so that it can shutdown. type ConnDropRef = mpsc::Sender; -pub struct Client +pub(crate) struct Client where B: Payload, { @@ -54,7 +55,7 @@ where T: AsyncRead + AsyncWrite + Send + 'static, B: Payload + 'static, { - type Item = (); + type Item = Dispatched; type Error = ::Error; fn poll(&mut self) -> Poll { @@ -153,7 +154,7 @@ where Ok(Async::Ready(None)) | Err(_) => { trace!("client::dispatch::Sender dropped"); - return Ok(Async::Ready(())); + return Ok(Async::Ready(Dispatched::Shutdown)); } } }, diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index ad23c0c6ec..07400eadeb 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -7,6 +7,7 @@ use ::body::Payload; use ::common::Exec; use ::headers; use ::service::Service; +use ::proto::Dispatched; use super::{PipeToSendStream, SendBuf}; use ::{Body, Response}; @@ -82,7 +83,7 @@ where S::Future: Send + 'static, B: Payload, { - type Item = (); + type Item = Dispatched; type Error = ::Error; fn poll(&mut self) -> Poll { @@ -95,12 +96,13 @@ where }) }, State::Serving(ref mut srv) => { - return srv.poll_server(&mut self.service, &self.exec); + try_ready!(srv.poll_server(&mut self.service, &self.exec)); + return Ok(Async::Ready(Dispatched::Shutdown)); } State::Closed => { // graceful_shutdown was called before handshaking finished, // nothing to do here... - return Ok(Async::Ready(())); + return Ok(Async::Ready(Dispatched::Shutdown)); } }; self.state = next; diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 131fb23209..7cd997f8b6 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -1,12 +1,12 @@ //! Pieces pertaining to the HTTP message protocol. use http::{HeaderMap, Method, StatusCode, Uri, Version}; -pub(crate) use self::h1::{dispatch, Conn, ClientTransaction, ClientUpgradeTransaction, ServerTransaction}; +pub(crate) use self::h1::{dispatch, Conn, ServerTransaction}; +use self::body_length::DecodedLength; pub(crate) mod h1; pub(crate) mod h2; - /// An Incoming Message head. Includes request/status line, and headers. #[derive(Clone, Debug, Default, PartialEq)] pub struct MessageHead { @@ -27,34 +27,6 @@ pub struct RequestLine(pub Method, pub Uri); /// An incoming response message. pub type ResponseHead = MessageHead; -/* -impl MessageHead { - pub fn should_keep_alive(&self) -> bool { - should_keep_alive(self.version, &self.headers) - } - - pub fn expecting_continue(&self) -> bool { - expecting_continue(self.version, &self.headers) - } -} - -/// Checks if a connection should be kept alive. -#[inline] -pub fn should_keep_alive(version: Version, headers: &HeaderMap) -> bool { - if version == Version::HTTP_10 { - headers::connection_keep_alive(headers) - } else { - !headers::connection_close(headers) - } -} - -/// Checks if a connection is expecting a `100 Continue` before sending its body. -#[inline] -pub fn expecting_continue(version: Version, headers: &HeaderMap) -> bool { - version == Version::HTTP_11 && headers::expect_continue(headers) -} -*/ - #[derive(Debug)] pub enum BodyLength { /// Content-Length @@ -63,32 +35,72 @@ pub enum BodyLength { Unknown, } -/* -#[test] -fn test_should_keep_alive() { - let mut headers = HeaderMap::new(); - - assert!(!should_keep_alive(Version::HTTP_10, &headers)); - assert!(should_keep_alive(Version::HTTP_11, &headers)); - - headers.insert("connection", ::http::header::HeaderValue::from_static("close")); - assert!(!should_keep_alive(Version::HTTP_10, &headers)); - assert!(!should_keep_alive(Version::HTTP_11, &headers)); - - headers.insert("connection", ::http::header::HeaderValue::from_static("keep-alive")); - assert!(should_keep_alive(Version::HTTP_10, &headers)); - assert!(should_keep_alive(Version::HTTP_11, &headers)); +/// Status of when an Disaptcher future completes. +pub(crate) enum Dispatched { + /// Dispatcher completely shutdown connection. + Shutdown, + /// Dispatcher has pending upgrade, and so did not shutdown. + Upgrade(::upgrade::Pending), } -#[test] -fn test_expecting_continue() { - let mut headers = HeaderMap::new(); - - assert!(!expecting_continue(Version::HTTP_10, &headers)); - assert!(!expecting_continue(Version::HTTP_11, &headers)); +/// A separate module to encapsulate the invariants of the DecodedLength type. +mod body_length { + use std::fmt; + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub(crate) struct DecodedLength(u64); + + const MAX_LEN: u64 = ::std::u64::MAX - 2; + + impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + #[cfg(test)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + + /// Takes the length as a content-length without other checks. + /// + /// Should only be called if previously confirmed this isn't + /// CLOSE_DELIMITED or CHUNKED. + #[inline] + pub(crate) fn danger_len(self) -> u64 { + debug_assert!(self.0 < Self::CHUNKED.0); + self.0 + } + + /// Converts to an Option representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option { + match self { + DecodedLength::CHUNKED | + DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known) + } + } + + /// Checks the `u64` is within the maximum allowed for content-length. + pub(crate) fn checked_new(len: u64) -> Result { + if len <= MAX_LEN { + Ok(DecodedLength(len)) + } else { + warn!("content-length bigger than maximum: {} > {}", len, MAX_LEN); + Err(::error::Parse::TooLarge) + } + } + } - headers.insert("expect", ::http::header::HeaderValue::from_static("100-continue")); - assert!(!expecting_continue(Version::HTTP_10, &headers)); - assert!(expecting_continue(Version::HTTP_11, &headers)); + impl fmt::Display for DecodedLength { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + DecodedLength::CLOSE_DELIMITED => f.write_str("close-delimited"), + DecodedLength::CHUNKED => f.write_str("chunked encoding"), + DecodedLength::ZERO => f.write_str("empty"), + DecodedLength(n) => write!(f, "content-length ({} bytes)", n), + } + } + } } -*/ diff --git a/src/server/conn.rs b/src/server/conn.rs index 45e0d277d3..5a4d24bee7 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -9,22 +9,26 @@ //! higher-level [Server](super) API. use std::fmt; +use std::mem; #[cfg(feature = "runtime")] use std::net::SocketAddr; use std::sync::Arc; #[cfg(feature = "runtime")] use std::time::Duration; -use super::rewind::Rewind; use bytes::Bytes; use futures::{Async, Future, Poll, Stream}; use futures::future::{Either, Executor}; use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] use tokio_reactor::Handle; +use body::{Body, Payload}; use common::Exec; +use common::io::Rewind; +use error::{Kind, Parse}; use proto; -use body::{Body, Payload}; use service::{NewService, Service}; -use error::{Kind, Parse}; +use upgrade::Upgraded; + +use self::upgrades::UpgradeableConnection; #[cfg(feature = "runtime")] pub use super::tcp::AddrIncoming; @@ -109,6 +113,8 @@ where fallback: bool, } + + /// Deconstructed parts of a `Connection`. /// /// This allows taking apart a `Connection` at a later time, in order to @@ -429,7 +435,7 @@ where loop { let polled = match *self.conn.as_mut().unwrap() { Either::A(ref mut h1) => h1.poll_without_shutdown(), - Either::B(ref mut h2) => h2.poll(), + Either::B(ref mut h2) => return h2.poll().map(|x| x.map(|_| ())), }; match polled { Ok(x) => return Ok(x), @@ -466,6 +472,18 @@ where debug_assert!(self.conn.is_none()); self.conn = Some(Either::B(h2)); } + + /// Enable this connection to support higher-level HTTP upgrades. + /// + /// See [the `upgrade` module](::upgrade) for more. + pub fn with_upgrades(self) -> UpgradeableConnection + where + I: Send, + { + UpgradeableConnection { + inner: self, + } + } } impl Future for Connection @@ -482,7 +500,15 @@ where fn poll(&mut self) -> Poll { loop { match self.conn.poll() { - Ok(x) => return Ok(x.map(|o| o.unwrap_or_else(|| ()))), + Ok(x) => return Ok(x.map(|opt| { + if let Some(proto::Dispatched::Upgrade(pending)) = opt { + // With no `Send` bound on `I`, we can't try to do + // upgrades here. In case a user was trying to use + // `Body::on_upgrade` with this API, send a special + // error letting them know about that. + pending.manual(); + } + })), Err(e) => { debug!("error polling connection protocol: {}", e); match *e.kind() { @@ -507,7 +533,6 @@ where .finish() } } - // ===== impl Serve ===== impl Serve { @@ -614,7 +639,7 @@ where let fut = connecting .map_err(::Error::new_user_new_service) // flatten basically - .and_then(|conn| conn) + .and_then(|conn| conn.with_upgrades()) .map_err(|err| debug!("conn error: {}", err)); self.serve.protocol.exec.execute(fut); } else { @@ -623,3 +648,82 @@ where } } } + +mod upgrades { + use super::*; + + // A future binding a connection with a Service with Upgrade support. + // + // This type is unnameable outside the crate, and so basically just an + // `impl Future`, without requiring Rust 1.26. + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct UpgradeableConnection + where + S: Service, + { + pub(super) inner: Connection, + } + + impl UpgradeableConnection + where + S: Service + 'static, + S::Error: Into>, + S::Future: Send, + I: AsyncRead + AsyncWrite + Send + 'static, + B: Payload + 'static, + { + /// Start a graceful shutdown process for this connection. + /// + /// This `Connection` should continue to be polled until shutdown + /// can finish. + pub fn graceful_shutdown(&mut self) { + self.inner.graceful_shutdown() + } + } + + impl Future for UpgradeableConnection + where + S: Service + 'static, + S::Error: Into>, + S::Future: Send, + I: AsyncRead + AsyncWrite + Send + 'static, + B: Payload + 'static, + { + type Item = (); + type Error = ::Error; + + fn poll(&mut self) -> Poll { + loop { + match self.inner.conn.poll() { + Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::Ready(Some(proto::Dispatched::Shutdown))) | + Ok(Async::Ready(None)) => { + return Ok(Async::Ready(())); + }, + Ok(Async::Ready(Some(proto::Dispatched::Upgrade(pending)))) => { + let h1 = match mem::replace(&mut self.inner.conn, None) { + Some(Either::A(h1)) => h1, + _ => unreachable!("Upgrade expects h1"), + }; + + let (io, buf, _) = h1.into_inner(); + pending.fulfill(Upgraded::new(Box::new(io), buf)); + return Ok(Async::Ready(())); + }, + Err(e) => { + debug!("error polling connection protocol: {}", e); + match *e.kind() { + Kind::Parse(Parse::VersionH2) if self.inner.fallback => { + self.inner.upgrade_h2(); + continue; + } + _ => return Err(e), + } + } + } + } + } + } +} + diff --git a/src/server/mod.rs b/src/server/mod.rs index 33898c587d..832912fc68 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -50,7 +50,6 @@ pub mod conn; #[cfg(feature = "runtime")] mod tcp; -mod rewind; use std::fmt; #[cfg(feature = "runtime")] use std::net::SocketAddr; diff --git a/src/upgrade.rs b/src/upgrade.rs new file mode 100644 index 0000000000..db592b1a42 --- /dev/null +++ b/src/upgrade.rs @@ -0,0 +1,254 @@ +//! HTTP Upgrades +//! +//! See [this example][example] showing how upgrades work with both +//! Clients and Servers. +//! +//! [example]: https://github.com/hyperium/hyper/master/examples/upgrades.rs + +use std::any::TypeId; +use std::error::Error as StdError; +use std::fmt; +use std::io::{self, Read, Write}; + +use bytes::{Buf, BufMut, Bytes}; +use futures::{Async, Future, Poll}; +use futures::sync::oneshot; +use tokio_io::{AsyncRead, AsyncWrite}; + +use common::io::Rewind; + +/// An upgraded HTTP connection. +/// +/// This type holds a trait object internally of the original IO that +/// was used to speak HTTP before the upgrade. It can be used directly +/// as a `Read` or `Write` for convenience. +/// +/// Alternatively, if the exact type is known, this can be deconstructed +/// into its parts. +pub struct Upgraded { + io: Rewind>, +} + +/// A future for a possible HTTP upgrade. +/// +/// If no upgrade was available, or it doesn't succeed, yields an `Error`. +pub struct OnUpgrade { + rx: Option>>, +} + +/// The deconstructed parts of an [`Upgraded`](Upgraded) type. +/// +/// Includes the original IO type, and a read buffer of bytes that the +/// HTTP state machine may have already read before completing an upgrade. +#[derive(Debug)] +pub struct Parts { + /// The original IO object used before the upgrade. + pub io: T, + /// A buffer of bytes that have been read but not processed as HTTP. + /// + /// For instance, if the `Connection` is used for an HTTP upgrade request, + /// it is possible the server sent back the first bytes of the new protocol + /// along with the response upgrade. + /// + /// You will want to check for any existing bytes if you plan to continue + /// communicating on the IO object. + pub read_buf: Bytes, + _inner: (), +} + +pub(crate) struct Pending { + tx: oneshot::Sender<::Result> +} + +/// Error cause returned when an upgrade was expected but canceled +/// for whatever reason. +/// +/// This likely means the actual `Conn` future wasn't polled and upgraded. +#[derive(Debug)] +struct UpgradeExpected(()); + +pub(crate) fn pending() -> (Pending, OnUpgrade) { + let (tx, rx) = oneshot::channel(); + ( + Pending { + tx, + }, + OnUpgrade { + rx: Some(rx), + }, + ) +} + +pub(crate) trait Io: AsyncRead + AsyncWrite + 'static { + fn __hyper_type_id(&self) -> TypeId { + TypeId::of::() + } +} + +impl Io + Send { + fn __hyper_is(&self) -> bool { + let t = TypeId::of::(); + self.__hyper_type_id() == t + } + + fn __hyper_downcast(self: Box) -> Result, Box> { + if self.__hyper_is::() { + // Taken from `std::error::Error::downcast()`. + unsafe { + let raw: *mut Io = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } +} + +impl Io for T {} + +// ===== impl Upgraded ===== + +impl Upgraded { + pub(crate) fn new(io: Box, read_buf: Bytes) -> Self { + Upgraded { + io: Rewind::new_buffered(io, read_buf), + } + } + + /// Tries to downcast the internal trait object to the type passed. + /// + /// On success, returns the downcasted parts. On error, returns the + /// `Upgraded` back. + pub fn downcast(self) -> Result, Self> { + let (io, buf) = self.io.into_inner(); + match io.__hyper_downcast() { + Ok(t) => Ok(Parts { + io: *t, + read_buf: buf, + _inner: (), + }), + Err(io) => Err(Upgraded { + io: Rewind::new_buffered(io, buf), + }) + } + } +} + +impl Read for Upgraded { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.io.read(buf) + } +} + +impl Write for Upgraded { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.io.write(buf) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.io.flush() + } +} + +impl AsyncRead for Upgraded { + #[inline] + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.prepare_uninitialized_buffer(buf) + } + + #[inline] + fn read_buf(&mut self, buf: &mut B) -> Poll { + self.io.read_buf(buf) + } +} + +impl AsyncWrite for Upgraded { + #[inline] + fn shutdown(&mut self) -> Poll<(), io::Error> { + AsyncWrite::shutdown(&mut self.io) + } + + #[inline] + fn write_buf(&mut self, buf: &mut B) -> Poll { + self.io.write_buf(buf) + } +} + +impl fmt::Debug for Upgraded { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Upgraded") + .finish() + } +} + +// ===== impl OnUpgrade ===== + +impl OnUpgrade { + pub(crate) fn none() -> Self { + OnUpgrade { + rx: None, + } + } + + pub(crate) fn is_none(&self) -> bool { + self.rx.is_none() + } +} + +impl Future for OnUpgrade { + type Item = Upgraded; + type Error = ::Error; + + fn poll(&mut self) -> Poll { + match self.rx { + Some(ref mut rx) => match rx.poll() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(Ok(upgraded))) => Ok(Async::Ready(upgraded)), + Ok(Async::Ready(Err(err))) => Err(err), + Err(_oneshot_canceled) => Err( + ::Error::new_canceled(Some(UpgradeExpected(()))) + ), + }, + None => Err(::Error::new_user_no_upgrade()), + } + } +} + +impl fmt::Debug for OnUpgrade { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("OnUpgrade") + .finish() + } +} + +// ===== impl Pending ===== + +impl Pending { + pub(crate) fn fulfill(self, upgraded: Upgraded) { + let _ = self.tx.send(Ok(upgraded)); + } + + /// Don't fulfill the pending Upgrade, but instead signal that + /// upgrades are handled manually. + pub(crate) fn manual(self) { + let _ = self.tx.send(Err(::Error::new_user_manual_upgrade())); + } +} + +// ===== impl UpgradeExpected ===== + +impl fmt::Display for UpgradeExpected { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.description()) + } +} + +impl StdError for UpgradeExpected { + fn description(&self) -> &str { + "upgrade expected but not completed" + } +} + diff --git a/tests/client.rs b/tests/client.rs index 5ea5f5a0a7..119533f4ea 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -596,33 +596,6 @@ test! { body: None, } - -test! { - name: client_101_upgrade, - - server: - expected: "\ - GET /upgrade HTTP/1.1\r\n\ - host: {addr}\r\n\ - \r\n\ - ", - reply: "\ - HTTP/1.1 101 Switching Protocols\r\n\ - Upgrade: websocket\r\n\ - Connection: upgrade\r\n\ - \r\n\ - ", - - client: - request: - method: GET, - url: "http://{addr}/upgrade", - headers: {}, - body: None, - error: |err| err.to_string() == "unsupported protocol upgrade", - -} - test! { name: client_connect_method, @@ -1277,6 +1250,68 @@ mod dispatch_impl { res.join(rx).map(|r| r.0).wait().unwrap(); } + #[test] + fn client_upgrade() { + use tokio_io::io::{read_to_end, write_all}; + let _ = pretty_env_logger::try_init(); + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let runtime = Runtime::new().unwrap(); + let handle = runtime.reactor(); + + let connector = DebugConnector::new(&handle); + + let client = Client::builder() + .executor(runtime.executor()) + .build(connector); + + let (tx1, rx1) = oneshot::channel(); + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"\ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: foobar\r\n\ + \r\n\ + foobar=ready\ + ").unwrap(); + let _ = tx1.send(()); + + let n = sock.read(&mut buf).expect("read 2"); + assert_eq!(&buf[..n], b"foo=bar"); + sock.write_all(b"bar=foo").expect("write 2"); + }); + + let rx = rx1.expect("thread panicked"); + + let req = Request::builder() + .method("GET") + .uri(&*format!("http://{}/up", addr)) + .body(Body::empty()) + .unwrap(); + + let res = client.request(req); + let res = res.join(rx).map(|r| r.0).wait().unwrap(); + + assert_eq!(res.status(), 101); + let upgraded = res + .into_body() + .on_upgrade() + .wait() + .expect("on_upgrade"); + + let parts = upgraded.downcast::().unwrap(); + assert_eq!(s(&parts.read_buf), "foobar=ready"); + + let io = parts.io; + let io = write_all(io, b"foo=bar").wait().unwrap().0; + let vec = read_to_end(io, vec![]).wait().unwrap().1; + assert_eq!(vec, b"bar=foo"); + } + struct DebugConnector { http: HttpConnector, diff --git a/tests/server.rs b/tests/server.rs index 9d4ece809d..7979b65e5f 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -24,7 +24,7 @@ use futures::future::{self, FutureResult, Either}; use futures::sync::oneshot; use futures_timer::Delay; use http::header::{HeaderName, HeaderValue}; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, TcpStream as TkTcpStream}; use tokio::runtime::Runtime; use tokio::reactor::Handle; use tokio_io::{AsyncRead, AsyncWrite}; @@ -33,7 +33,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use hyper::{Body, Request, Response, StatusCode}; use hyper::client::Client; use hyper::server::conn::Http; -use hyper::service::{service_fn, Service}; +use hyper::service::{service_fn, service_fn_ok, Service}; fn tcp_bind(addr: &SocketAddr, handle: &Handle) -> ::tokio::io::Result { let std_listener = StdTcpListener::bind(addr).unwrap(); @@ -1270,6 +1270,142 @@ fn http_connect() { assert_eq!(vec, b"bar=foo"); } +#[test] +fn upgrades_new() { + use tokio_io::io::{read_to_end, write_all}; + let _ = pretty_env_logger::try_init(); + let mut rt = Runtime::new().unwrap(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap(), &rt.reactor()).unwrap(); + let addr = listener.local_addr().unwrap(); + let (read_101_tx, read_101_rx) = oneshot::channel(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"\ + GET / HTTP/1.1\r\n\ + Upgrade: foobar\r\n\ + Connection: upgrade\r\n\ + \r\n\ + eagerly optimistic\ + ").expect("write 1"); + let mut buf = [0; 256]; + tcp.read(&mut buf).expect("read 1"); + + let expected = "HTTP/1.1 101 Switching Protocols\r\n"; + assert_eq!(s(&buf[..expected.len()]), expected); + let _ = read_101_tx.send(()); + + let n = tcp.read(&mut buf).expect("read 2"); + assert_eq!(s(&buf[..n]), "foo=bar"); + tcp.write_all(b"bar=foo").expect("write 2"); + }); + + let (upgrades_tx, upgrades_rx) = mpsc::channel(); + let svc = service_fn_ok(move |req: Request| { + let on_upgrade = req + .into_body() + .on_upgrade(); + let _ = upgrades_tx.send(on_upgrade); + Response::builder() + .status(101) + .header("upgrade", "foobar") + .body(hyper::Body::empty()) + .unwrap() + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| -> hyper::Error { unreachable!() }) + .and_then(move |(item, _incoming)| { + let socket = item.unwrap(); + Http::new() + .serve_connection(socket, svc) + .with_upgrades() + }); + + rt.block_on(fut).unwrap(); + let on_upgrade = upgrades_rx.recv().unwrap(); + + // wait so that we don't write until other side saw 101 response + rt.block_on(read_101_rx).unwrap(); + + let upgraded = rt.block_on(on_upgrade).unwrap(); + let parts = upgraded.downcast::().unwrap(); + let io = parts.io; + assert_eq!(parts.read_buf, "eagerly optimistic"); + + let io = rt.block_on(write_all(io, b"foo=bar")).unwrap().0; + let vec = rt.block_on(read_to_end(io, vec![])).unwrap().1; + assert_eq!(s(&vec), "bar=foo"); +} + +#[test] +fn http_connect_new() { + use tokio_io::io::{read_to_end, write_all}; + let _ = pretty_env_logger::try_init(); + let mut rt = Runtime::new().unwrap(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap(), &rt.reactor()).unwrap(); + let addr = listener.local_addr().unwrap(); + let (read_200_tx, read_200_rx) = oneshot::channel(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"\ + CONNECT localhost HTTP/1.1\r\n\ + \r\n\ + eagerly optimistic\ + ").expect("write 1"); + let mut buf = [0; 256]; + tcp.read(&mut buf).expect("read 1"); + + let expected = "HTTP/1.1 200 OK\r\n"; + assert_eq!(s(&buf[..expected.len()]), expected); + let _ = read_200_tx.send(()); + + let n = tcp.read(&mut buf).expect("read 2"); + assert_eq!(s(&buf[..n]), "foo=bar"); + tcp.write_all(b"bar=foo").expect("write 2"); + }); + + let (upgrades_tx, upgrades_rx) = mpsc::channel(); + let svc = service_fn_ok(move |req: Request| { + let on_upgrade = req + .into_body() + .on_upgrade(); + let _ = upgrades_tx.send(on_upgrade); + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap() + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| -> hyper::Error { unreachable!() }) + .and_then(move |(item, _incoming)| { + let socket = item.unwrap(); + Http::new() + .serve_connection(socket, svc) + .with_upgrades() + }); + + rt.block_on(fut).unwrap(); + let on_upgrade = upgrades_rx.recv().unwrap(); + + // wait so that we don't write until other side saw 200 + rt.block_on(read_200_rx).unwrap(); + + let upgraded = rt.block_on(on_upgrade).unwrap(); + let parts = upgraded.downcast::().unwrap(); + let io = parts.io; + assert_eq!(parts.read_buf, "eagerly optimistic"); + + let io = rt.block_on(write_all(io, b"foo=bar")).unwrap().0; + let vec = rt.block_on(read_to_end(io, vec![])).unwrap().1; + assert_eq!(s(&vec), "bar=foo"); +} + + #[test] fn parse_errors_send_4xx_response() { let runtime = Runtime::new().unwrap();