Skip to content

Commit

Permalink
Merge pull request #91 from palantir/tokio-1
Browse files Browse the repository at this point in the history
Upgrade to tokio 1.0
  • Loading branch information
sfackler authored Feb 12, 2021
2 parents 78b8737 + ffbdec3 commit dfd30a9
Show file tree
Hide file tree
Showing 55 changed files with 285 additions and 300 deletions.
5 changes: 5 additions & 0 deletions changelog/@unreleased/pr-91.v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: break
break:
description: Upgraded to tokio 1.0.
links:
- https://github.com/palantir/conjure-rust-runtime/pull/91
2 changes: 1 addition & 1 deletion conjure-runtime-config/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "conjure-runtime-config"
version = "0.3.1"
version = "0.4.0"
authors = ["Steven Fackler <sfackler@palantir.com>"]
edition = "2018"
license = "Apache-2.0"
Expand Down
31 changes: 18 additions & 13 deletions conjure-runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "conjure-runtime"
version = "0.3.1"
version = "0.4.0"
authors = ["Steven Fackler <sfackler@palantir.com>"]
edition = "2018"
license = "Apache-2.0"
Expand All @@ -10,20 +10,20 @@ readme = "../README.md"

[dependencies]
arc-swap = "1.0"
async-compression = { version = "0.3", default-features = false, features = ["zlib", "gzip", "stream"] }
async-compression = { version = "0.3", default-features = false, features = ["gzip", "tokio"] }
async-trait = "0.1"
base64 = "0.13"
bytes = "0.5"
bytes = "1.0"
conjure-error = "0.7"
conjure-http = "0.7"
conjure-object = "0.7"
conjure-serde = "0.7"
futures = "0.3"
http = "0.2"
http-body = "0.3"
http-body = "0.4"
http-zipkin = "0.3"
hyper = "0.13.4"
hyper-openssl = "0.8.1"
hyper = { version = "0.14", features = ["http1", "http2", "client", "tcp"] }
hyper-openssl = "0.9"
once_cell = "1.0"
openssl = "0.10"
parking_lot = "0.11"
Expand All @@ -37,21 +37,26 @@ serde = "1.0"
serde-value = "0.7"
witchcraft-log = "0.3"
witchcraft-metrics = "0.2"
tokio = { version = "0.2", features = ["io-util", "rt-threaded", "time"] }
tokio-io-timeout = "0.4"
tower = "0.3"
tokio = { version = "1.0", features = ["io-util", "rt-multi-thread", "time"] }
tokio-io-timeout = "1.0"
tokio-util = { version = "0.6", features = ["codec"] }
tower-layer = "0.3"
tower-service = "0.3"
url = "2.0"
zipkin = "0.4"

conjure-runtime-config = { version = "0.3.1", path = "../conjure-runtime-config" }
conjure-runtime-config = { version = "0.4.0", path = "../conjure-runtime-config" }

[dev-dependencies]
tokio = { version = "0.2", features = ["full"] }
flate2 = "1.0"
futures-test = "0.3"
tokio-openssl = "0.4"
tokio-test = "0.2"
hyper = { version = "0.14", features = ["server"] }
serde_yaml = "0.8"
tokio-openssl = "0.6"
tokio-test = "0.4"
tokio = { version = "1.0", features = ["full"] }
tower-util = "0.3"


# for doc examples
conjure-codegen = { version = "0.7", features = ["example-types"] }
3 changes: 1 addition & 2 deletions conjure-runtime/src/blocking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ mod shim;
fn runtime() -> io::Result<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
RUNTIME.get_or_try_init(|| {
runtime::Builder::new()
runtime::Builder::new_multi_thread()
.enable_all()
.threaded_scheduler()
.thread_name("conjure-runtime")
.build()
})
Expand Down
15 changes: 8 additions & 7 deletions conjure-runtime/src/blocking/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ where
/// Compared to the `Read` implementation, this method avoids some copies of the body data when working with an API
/// that already consumes `Bytes` objects.
pub fn read_bytes(&mut self) -> io::Result<Option<Bytes>> {
runtime()?.enter(|| executor::block_on(self.0.as_mut().read_bytes()))
let _guard = runtime()?.enter();
executor::block_on(self.0.as_mut().read_bytes())
}
}

Expand All @@ -70,7 +71,8 @@ where
B::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
runtime()?.enter(|| executor::block_on(self.0.read(buf)))
let _guard = runtime()?.enter();
executor::block_on(self.0.read(buf))
}
}

Expand All @@ -81,11 +83,10 @@ where
{
fn fill_buf(&mut self) -> io::Result<&[u8]> {
// lifetime shenanigans mean we can't return the value of poll_fill_buf directly
runtime()?.enter(|| {
executor::block_on(future::poll_fn(|cx| {
self.0.as_mut().poll_fill_buf(cx).map_ok(|_| ())
}))
})?;
let _guard = runtime()?.enter();
executor::block_on(future::poll_fn(|cx| {
self.0.as_mut().poll_fill_buf(cx).map_ok(|_| ())
}))?;
Ok(self.0.buffer())
}

Expand Down
14 changes: 7 additions & 7 deletions conjure-runtime/src/raw/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::ServiceBuilder;
use tower_layer::Layer;

// This is pretty arbitrary - I just grabbed it from some Cloudflare blog post.
const TCP_KEEPALIVE: Duration = Duration::from_secs(3 * 60);
Expand Down Expand Up @@ -69,12 +69,12 @@ impl BuildRawClient for DefaultRawClientBuilder {

let proxy = ProxyConfig::from_config(&builder.get_proxy())?;

let connector = ServiceBuilder::new()
.layer(TlsMetricsLayer::new(&service, builder))
.layer(HttpsLayer::with_connector(ssl).map_err(Error::internal_safe)?)
.layer(ProxyConnectorLayer::new(&proxy))
.layer(TimeoutLayer::new(builder))
.service(connector);
let connector = TimeoutLayer::new(builder).layer(connector);
let connector = ProxyConnectorLayer::new(&proxy).layer(connector);
let connector = HttpsLayer::with_connector(ssl)
.map_err(Error::internal_safe)?
.layer(connector);
let connector = TlsMetricsLayer::new(&service, builder).layer(connector);

let client = hyper::Client::builder()
.pool_idle_timeout(HTTP_KEEPALIVE)
Expand Down
23 changes: 10 additions & 13 deletions conjure-runtime/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use pin_project::pin_project;
use std::error;
use std::io;
use std::marker::PhantomPinned;
use std::mem::{self, MaybeUninit};
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead};
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};

/// An asynchronous HTTP response.
pub struct Response<B = DefaultRawBody> {
Expand Down Expand Up @@ -111,20 +111,17 @@ where
B: Body<Data = Bytes>,
B::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
false
}

fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let read_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
let nread = usize::min(buf.len(), read_buf.len());
buf[..nread].copy_from_slice(&read_buf[..nread]);
self.consume(nread);
Poll::Ready(Ok(nread))
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = usize::min(in_buf.len(), buf.remaining());
buf.put_slice(&in_buf[..len]);
self.consume(len);

Poll::Ready(Ok(()))
}
}

Expand Down
70 changes: 55 additions & 15 deletions conjure-runtime/src/service/gzip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.
use crate::raw::Service;
use crate::service::Layer;
use async_compression::stream::GzipDecoder;
use bytes::Bytes;
use async_compression::tokio::bufread::GzipDecoder;
use bytes::{Buf, Bytes, BytesMut};
use futures::{ready, Stream};
use http::header::{Entry, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH};
use http::{HeaderMap, HeaderValue, Request, Response};
Expand All @@ -26,6 +26,8 @@ use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
use tokio_util::codec::{BytesCodec, FramedRead};

static GZIP: Lazy<HeaderValue> = Lazy::new(|| HeaderValue::from_static("gzip"));

Expand Down Expand Up @@ -91,7 +93,13 @@ where
parts.headers.remove(CONTENT_ENCODING);
parts.headers.remove(CONTENT_LENGTH);
DecodedBody::Gzip {
body: GzipDecoder::new(ShimStream { body }),
body: FramedRead::new(
GzipDecoder::new(ShimReader {
body,
buf: Bytes::new(),
}),
BytesCodec::new(),
),
}
}
_ => DecodedBody::Identity { body },
Expand All @@ -109,7 +117,7 @@ pub enum DecodedBody<B> {
},
Gzip {
#[pin]
body: GzipDecoder<ShimStream<B>>,
body: FramedRead<GzipDecoder<ShimReader<B>>, BytesCodec>,
},
}

Expand All @@ -129,7 +137,9 @@ where
Projection::Identity { body } => body
.poll_data(cx)
.map(|o| o.map(|r| r.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))),
Projection::Gzip { body } => body.poll_next(cx),
Projection::Gzip { body } => body
.poll_next(cx)
.map(|o| o.map(|r| r.map(BytesMut::freeze))),
}
}

Expand All @@ -141,7 +151,8 @@ where
Projection::Identity { body } => body
.poll_trailers(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e)),
Projection::Gzip { body } => body
Projection::Gzip { body, .. } => body
.get_pin_mut()
.get_pin_mut()
.project()
.body
Expand All @@ -167,23 +178,52 @@ where
}

#[pin_project]
pub struct ShimStream<T> {
pub struct ShimReader<T> {
#[pin]
body: T,
buf: Bytes,
}

impl<T> AsyncRead for ShimReader<T>
where
T: Body<Data = Bytes>,
T::Error: Into<Box<dyn Error + Sync + Send>>,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let in_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
let len = usize::min(in_buf.len(), buf.remaining());
buf.put_slice(&in_buf[..len]);
self.consume(len);

Poll::Ready(Ok(()))
}
}

impl<T> Stream for ShimStream<T>
impl<T> AsyncBufRead for ShimReader<T>
where
T: Body,
T: Body<Data = Bytes>,
T::Error: Into<Box<dyn Error + Sync + Send>>,
{
type Item = io::Result<T::Data>;
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let mut this = self.project();

while !this.buf.has_remaining() {
*this.buf = match ready!(this.body.as_mut().poll_data(cx)) {
Some(Ok(buf)) => buf,
Some(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
None => break,
}
}

Poll::Ready(Ok(&*this.buf))
}

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project()
.body
.poll_data(cx)
.map(|o| o.map(|r| r.map_err(|e| io::Error::new(io::ErrorKind::Other, e))))
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().buf.advance(amt);
}
}

Expand Down
2 changes: 1 addition & 1 deletion conjure-runtime/src/service/node/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::time::Instant;

/// A layer which updates the host metrics for the node stored in the request's extensions map.
pub struct NodeMetricsLayer;
Expand Down
Loading

0 comments on commit dfd30a9

Please sign in to comment.