diff --git a/.circleci/config.yml b/.circleci/config.yml index ea5d4d09..f3f9d9d2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -22,7 +22,7 @@ version: 2 jobs: build: docker: - - image: rust:1.65.0 + - image: rust:1.75.0 environment: RUSTFLAGS: -D warnings steps: diff --git a/changelog/@unreleased/pr-176.v2.yml b/changelog/@unreleased/pr-176.v2.yml new file mode 100644 index 00000000..5292cfe8 --- /dev/null +++ b/changelog/@unreleased/pr-176.v2.yml @@ -0,0 +1,5 @@ +type: break +break: + description: The `Service` trait now supports `async fn call`. + links: + - https://github.com/palantir/conjure-rust-runtime/pull/176 diff --git a/conjure-runtime/src/blocking/client.rs b/conjure-runtime/src/blocking/client.rs index c1f2e8c8..184e8c3e 100644 --- a/conjure-runtime/src/blocking/client.rs +++ b/conjure-runtime/src/blocking/client.rs @@ -76,7 +76,6 @@ impl conjure_http::client::Client for Client where T: Service, Response = Response> + 'static + Sync + Send, T::Error: Into>, - T::Future: Send, B: http_body::Body + 'static + Send, B::Error: Into>, { diff --git a/conjure-runtime/src/client.rs b/conjure-runtime/src/client.rs index 113236f2..11681cff 100644 --- a/conjure-runtime/src/client.rs +++ b/conjure-runtime/src/client.rs @@ -147,7 +147,6 @@ impl AsyncClient for Client where T: Service, Response = http::Response> + 'static + Sync + Send, T::Error: Into>, - T::Future: Send, B: http_body::Body + 'static + Send, B::Error: Into>, { @@ -159,8 +158,6 @@ where &self, request: Request>, ) -> Result, Error> { - // split into 2 statements to avoid holding onto the state while awaiting the future - let future = self.state.load().service.call(request); - future.await + self.state.load().service.call(request).await } } diff --git a/conjure-runtime/src/raw/default.rs b/conjure-runtime/src/raw/default.rs index afe5e4e9..e96a22f5 100644 --- a/conjure-runtime/src/raw/default.rs +++ b/conjure-runtime/src/raw/default.rs @@ -18,10 +18,9 @@ use crate::service::tls_metrics::{TlsMetricsLayer, TlsMetricsService}; use crate::Builder; use bytes::Bytes; use conjure_error::Error; -use futures::ready; use http::{HeaderMap, Request, Response}; use http_body::{Body, SizeHint}; -use hyper::client::{HttpConnector, ResponseFuture}; +use hyper::client::HttpConnector; use hyper::Client; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use pin_project::pin_project; @@ -30,7 +29,6 @@ use rustls_pemfile::Item; use std::error; use std::fmt; use std::fs::File; -use std::future::Future; use std::io::BufReader; use std::marker::PhantomPinned; use std::path::Path; @@ -162,13 +160,18 @@ pub struct DefaultRawClient(Client); impl Service> for DefaultRawClient { type Response = Response; type Error = DefaultRawError; - type Future = DefaultRawFuture; - fn call(&self, req: Request) -> Self::Future { - DefaultRawFuture { - future: self.0.request(req), - _p: PhantomPinned, - } + async fn call(&self, req: Request) -> Result { + self.0 + .request(req) + .await + .map(|r| { + r.map(|inner| DefaultRawBody { + inner, + _p: PhantomPinned, + }) + }) + .map_err(DefaultRawError) } } @@ -214,29 +217,6 @@ impl Body for DefaultRawBody { } } -/// The future type used by `DefaultRawClient`. -#[pin_project] -pub struct DefaultRawFuture { - #[pin] - future: ResponseFuture, - #[pin] - _p: PhantomPinned, -} - -impl Future for DefaultRawFuture { - type Output = Result, DefaultRawError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response = ready!(self.project().future.poll(cx)).map_err(DefaultRawError)?; - let response = response.map(|inner| DefaultRawBody { - inner, - _p: PhantomPinned, - }); - - Poll::Ready(Ok(response)) - } -} - /// The error type used by `DefaultRawClient`. #[derive(Debug)] pub struct DefaultRawError(hyper::Error); diff --git a/conjure-runtime/src/raw/mod.rs b/conjure-runtime/src/raw/mod.rs index ac6b6b42..f106c603 100644 --- a/conjure-runtime/src/raw/mod.rs +++ b/conjure-runtime/src/raw/mod.rs @@ -53,11 +53,9 @@ pub trait Service { type Response; /// The error type returned by the service. type Error; - /// The future type returned by the service. - type Future: Future>; /// Asynchronously perform the request. - fn call(&self, req: R) -> Self::Future; + fn call(&self, req: R) -> impl Future> + Send; } impl Service for Arc @@ -66,9 +64,8 @@ where { type Response = T::Response; type Error = T::Error; - type Future = T::Future; - fn call(&self, req: R) -> Self::Future { + fn call(&self, req: R) -> impl Future> { (**self).call(req) } } diff --git a/conjure-runtime/src/service/gzip.rs b/conjure-runtime/src/service/gzip.rs index 2f0a0695..de24a9be 100644 --- a/conjure-runtime/src/service/gzip.rs +++ b/conjure-runtime/src/service/gzip.rs @@ -23,7 +23,6 @@ use http_body::{Body, SizeHint}; use once_cell::sync::Lazy; use pin_project::pin_project; use std::error::Error; -use std::future::Future; use std::io::Write; use std::pin::Pin; use std::task::{Context, Poll}; @@ -50,41 +49,20 @@ pub struct GzipService { impl Service> for GzipService where - S: Service, Response = Response>, + S: Service, Response = Response> + Sync + Send, + B1: Sync + Send, B2: Body, B2::Error: Into>, { type Response = Response>; type Error = S::Error; - type Future = GzipFuture; - fn call(&self, mut req: Request) -> Self::Future { + async fn call(&self, mut req: Request) -> Result { if let Entry::Vacant(e) = req.headers_mut().entry(ACCEPT_ENCODING) { e.insert(GZIP.clone()); } - GzipFuture { - future: self.inner.call(req), - } - } -} - -#[pin_project] -pub struct GzipFuture { - #[pin] - future: F, -} - -impl Future for GzipFuture -where - F: Future, E>>, - B: Body, - B::Error: Into>, -{ - type Output = Result>, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response = ready!(self.project().future.poll(cx))?; + let response = self.inner.call(req).await?; let (mut parts, body) = response.into_parts(); let decoder = match parts.headers.get(CONTENT_ENCODING) { @@ -102,7 +80,7 @@ where done: false, }; - Poll::Ready(Ok(Response::from_parts(parts, body))) + Ok(Response::from_parts(parts, body)) } } diff --git a/conjure-runtime/src/service/http_error.rs b/conjure-runtime/src/service/http_error.rs index 928d46b8..7aa1ea7a 100644 --- a/conjure-runtime/src/service/http_error.rs +++ b/conjure-runtime/src/service/http_error.rs @@ -15,18 +15,14 @@ use crate::errors::{RemoteError, ThrottledError, UnavailableError}; use crate::raw::Service; use crate::service::Layer; use crate::{Builder, ServerQos, ServiceError}; -use bytes::BufMut; +use bytes::Bytes; use conjure_error::Error; use conjure_serde::json; -use futures::ready; use http::header::RETRY_AFTER; use http::{Request, Response, StatusCode}; -use http_body::Body; -use pin_project::pin_project; +use http_body::{Body, Limited}; +use hyper::body; use std::error; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; use std::time::Duration; use witchcraft_log::info; @@ -74,150 +70,85 @@ pub struct HttpErrorService { impl Service> for HttpErrorService where - S: Service, Response = Response, Error = Error>, - B2: Body, + S: Service, Response = Response, Error = Error> + Sync + Send, + B1: Sync + Send, + B2: Body + Send, + B2::Data: Send, B2::Error: Into>, { type Response = Response; type Error = Error; - type Future = HttpErrorFuture; - fn call(&self, req: Request) -> Self::Future { - HttpErrorFuture::Call { - future: self.inner.call(req), - server_qos: self.server_qos, - service_error: self.service_error, - } - } -} - -#[pin_project(project = Projection)] -pub enum HttpErrorFuture { - Call { - #[pin] - future: F, - server_qos: ServerQos, - service_error: ServiceError, - }, - ReadingBody { - status: StatusCode, - #[pin] - body: B, - buf: Vec, - service_error: ServiceError, - }, -} - -impl Future for HttpErrorFuture -where - F: Future, Error>>, - B: Body, - B::Error: Into>, -{ - type Output = Result, Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let new_state = match self.as_mut().project() { - Projection::Call { - future, - server_qos, - service_error, - } => { - let response = ready!(future.poll(cx))?; - - if response.status().is_success() { - return Poll::Ready(Ok(response)); - } - - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let retry_after = response - .headers() - .get(RETRY_AFTER) - .and_then(|h| h.to_str().ok()) - .and_then(|s| s.parse().ok()) - .map(Duration::from_secs); - let error = ThrottledError { retry_after }; - - let e = match server_qos { - ServerQos::AutomaticRetry => Error::internal_safe(error), - ServerQos::Propagate429And503ToCaller => match retry_after { - Some(retry_after) => { - Error::throttle_for_safe(error, retry_after) - } - None => Error::throttle_safe(error), - }, - }; - - return Poll::Ready(Err(e)); - } - StatusCode::SERVICE_UNAVAILABLE => { - let error = UnavailableError(()); + async fn call(&self, req: Request) -> Result { + let response = self.inner.call(req).await?; - let e = match server_qos { - ServerQos::AutomaticRetry => Error::internal_safe(error), - ServerQos::Propagate429And503ToCaller => { - Error::unavailable_safe(error) - } - }; + if response.status().is_success() { + return Ok(response); + } - return Poll::Ready(Err(e)); - } - _ => HttpErrorFuture::ReadingBody { - status: response.status(), - body: response.into_body(), - buf: vec![], - service_error: *service_error, - }, + match response.status() { + StatusCode::TOO_MANY_REQUESTS => { + let retry_after = response + .headers() + .get(RETRY_AFTER) + .and_then(|h| h.to_str().ok()) + .and_then(|s| s.parse().ok()) + .map(Duration::from_secs); + let error = ThrottledError { retry_after }; + + let e = match self.server_qos { + ServerQos::AutomaticRetry => Error::internal_safe(error), + ServerQos::Propagate429And503ToCaller => match retry_after { + Some(retry_after) => Error::throttle_for_safe(error, retry_after), + None => Error::throttle_safe(error), + }, + }; + + Err(e) + } + StatusCode::SERVICE_UNAVAILABLE => { + let error = UnavailableError(()); + + let e = match self.server_qos { + ServerQos::AutomaticRetry => Error::internal_safe(error), + ServerQos::Propagate429And503ToCaller => Error::unavailable_safe(error), + }; + + Err(e) + } + _ => { + let (parts, body) = response.into_parts(); + + let body = match body::to_bytes(Limited::new(body, 10 * 1024)).await { + Ok(body) => body, + Err(e) => { + info!("error reading response body", error: Error::internal(e)); + Bytes::new() } - } - Projection::ReadingBody { - status, - mut body, - buf, - service_error, - } => { - loop { - let data = match ready!(body.as_mut().poll_data(cx)) { - Some(Ok(data)) => data, - Some(Err(e)) => { - info!("error reading response body", error: Error::internal(e)); - break; - } - None => break, - }; - - buf.put(data); - // limit how much we read in case something weird is going on - if buf.len() > 10 * 1024 { - break; - } + }; + + let error = RemoteError { + status: parts.status, + error: json::client_from_slice(&body).ok(), + }; + let log_body = error.error.is_none(); + + let mut error = match (&error.error, self.service_error) { + (Some(e), ServiceError::PropagateToCaller) => { + let e = e.clone(); + Error::propagated_service_safe(error, e) } - - let error = RemoteError { - status: *status, - error: json::client_from_slice(buf).ok(), - }; - let log_body = error.error.is_none(); - let mut error = match (&error.error, service_error) { - (Some(e), ServiceError::PropagateToCaller) => { - let e = e.clone(); - Error::propagated_service_safe(error, e) - } - (Some(_), ServiceError::WrapInNewError) | (None, _) => { - Error::internal_safe(error) - } - }; - if log_body { - error = error.with_unsafe_param("body", String::from_utf8_lossy(buf)); + (Some(_), ServiceError::WrapInNewError) | (None, _) => { + Error::internal_safe(error) } + }; - return Poll::Ready(Err(error)); + if log_body { + error = error.with_unsafe_param("body", String::from_utf8_lossy(&body)); } - }; - self.set(new_state); + Err(error) + } } } } diff --git a/conjure-runtime/src/service/map_error.rs b/conjure-runtime/src/service/map_error.rs index 17d8d4c3..b14acea8 100644 --- a/conjure-runtime/src/service/map_error.rs +++ b/conjure-runtime/src/service/map_error.rs @@ -14,12 +14,8 @@ use crate::raw::Service; use crate::service::Layer; use conjure_error::Error; -use pin_project::pin_project; use std::error; use std::fmt; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; #[derive(Debug)] pub struct RawClientError(pub Box); @@ -54,37 +50,17 @@ pub struct MapErrorService { impl Service for MapErrorService where - S: Service, + S: Service + Sync + Send, S::Error: Into>, + R: Sync + Send, { - type Error = Error; type Response = S::Response; - type Future = MapErrorFuture; - - fn call(&self, req: R) -> Self::Future { - MapErrorFuture { - future: self.inner.call(req), - } - } -} - -#[pin_project] -pub struct MapErrorFuture { - #[pin] - future: F, -} - -impl Future for MapErrorFuture -where - F: Future>, - E: Into>, -{ - type Output = Result; + type Error = Error; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project() - .future - .poll(cx) + async fn call(&self, req: R) -> Result { + self.inner + .call(req) + .await .map_err(|e| Error::internal_safe(RawClientError(e.into()))) } } diff --git a/conjure-runtime/src/service/metrics.rs b/conjure-runtime/src/service/metrics.rs index 1723d049..90a651c8 100644 --- a/conjure-runtime/src/service/metrics.rs +++ b/conjure-runtime/src/service/metrics.rs @@ -16,13 +16,8 @@ use crate::service::Layer; use crate::Builder; use conjure_error::Error; use conjure_http::client::Endpoint; -use futures::ready; use http::{Request, Response}; -use pin_project::pin_project; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::time::Instant; use witchcraft_metrics::{MetricId, MetricRegistry}; @@ -35,17 +30,15 @@ struct Metrics { /// /// Only errors with a cause of `RawClientError` will be treated as IO errors. pub struct MetricsLayer { - metrics: Option>, + metrics: Option, } impl MetricsLayer { pub fn new(service: &str, builder: &Builder) -> MetricsLayer { MetricsLayer { - metrics: builder.get_metrics().map(|m| { - Arc::new(Metrics { - metrics: m.clone(), - service_name: service.to_string(), - }) + metrics: builder.get_metrics().map(|m| Metrics { + metrics: m.clone(), + service_name: service.to_string(), }), } } @@ -64,52 +57,28 @@ impl Layer for MetricsLayer { pub struct MetricsService { inner: S, - metrics: Option>, + metrics: Option, } impl Service> for MetricsService where - S: Service, Response = Response, Error = Error>, + S: Service, Response = Response, Error = Error> + Sync + Send, + B1: Send, { type Response = S::Response; type Error = S::Error; - type Future = MetricsFuture; - fn call(&self, req: Request) -> Self::Future { - MetricsFuture { - endpoint: req - .extensions() - .get::() - .expect("Request extensions missing Endpoint") - .clone(), - future: self.inner.call(req), - start: Instant::now(), - metrics: self.metrics.clone(), - } - } -} - -#[pin_project] -pub struct MetricsFuture { - #[pin] - future: F, - start: Instant, - endpoint: Endpoint, - metrics: Option>, -} - -impl Future for MetricsFuture -where - F: Future, Error>>, -{ - type Output = F::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); + async fn call(&self, req: Request) -> Result { + let endpoint = req + .extensions() + .get::() + .expect("Request extensions missing Endpoint") + .clone(); - let result = ready!(this.future.poll(cx)); + let start = Instant::now(); + let result = self.inner.call(req).await; - if let Some(metrics) = this.metrics { + if let Some(metrics) = &self.metrics { let status = match &result { Ok(_) => "success", Err(_) => "failure", @@ -120,13 +89,13 @@ where .timer( MetricId::new("client.response") .with_tag("channel-name", metrics.service_name.clone()) - .with_tag("service-name", this.endpoint.service()) - .with_tag("endpoint", this.endpoint.name()) + .with_tag("service-name", endpoint.service()) + .with_tag("endpoint", endpoint.name()) .with_tag("status", status), ) - .update(this.start.elapsed()); + .update(start.elapsed()); } - Poll::Ready(result) + result } } diff --git a/conjure-runtime/src/service/mod.rs b/conjure-runtime/src/service/mod.rs index ecfafd77..ed756588 100644 --- a/conjure-runtime/src/service/mod.rs +++ b/conjure-runtime/src/service/mod.rs @@ -118,13 +118,12 @@ pub struct ServiceFn(T); impl Service for ServiceFn where T: Fn(R) -> F, - F: Future>, + F: Future> + Send, { type Response = S; type Error = E; - type Future = F; - fn call(&self, req: R) -> Self::Future { + fn call(&self, req: R) -> impl Future> { (self.0)(req) } } diff --git a/conjure-runtime/src/service/node/limiter/ciad.rs b/conjure-runtime/src/service/node/limiter/ciad.rs index 5d199bf4..4619021f 100644 --- a/conjure-runtime/src/service/node/limiter/ciad.rs +++ b/conjure-runtime/src/service/node/limiter/ciad.rs @@ -15,15 +15,10 @@ use crate::service::map_error::RawClientError; use crate::service::node::limiter::deficit_semaphore::{self, DeficitSemaphore}; use crate::util::atomic_f64::AtomicF64; use conjure_error::Error; -use futures::ready; use http::{Response, StatusCode}; -use pin_project::pin_project; -use std::future::Future; use std::marker::PhantomData; -use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use std::task::{Context, Poll}; const INITIAL_LIMIT: usize = 20; const BACKOFF_RATIO: f64 = 0.9; @@ -58,10 +53,14 @@ impl CiadConcurrencyLimiter { self.in_flight.load(Ordering::SeqCst) } - pub fn acquire(self: Arc) -> Acquire { - Acquire { - future: self.semaphore.clone().acquire(), + pub async fn acquire(self: Arc) -> Permit { + let permit = self.semaphore.clone().acquire().await; + + Permit { + in_flight_snapshot: self.in_flight.fetch_add(1, Ordering::SeqCst) + 1, + mode: Mode::Ignore, limiter: self, + _permit: permit, } } @@ -144,33 +143,6 @@ impl Behavior for EndpointLevel { } } -#[pin_project] -pub struct Acquire { - #[pin] - future: deficit_semaphore::Acquire, - limiter: Arc>, -} - -impl Future for Acquire -where - B: Behavior, -{ - type Output = Permit; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - let permit = ready!(this.future.poll(cx)); - - Poll::Ready(Permit { - limiter: this.limiter.clone(), - in_flight_snapshot: this.limiter.in_flight.fetch_add(1, Ordering::SeqCst) + 1, - mode: Mode::Ignore, - _permit: permit, - }) - } -} - #[derive(PartialEq, Eq, Debug)] pub enum Mode { Ignore, diff --git a/conjure-runtime/src/service/node/limiter/mod.rs b/conjure-runtime/src/service/node/limiter/mod.rs index 53319bfc..02f3a654 100644 --- a/conjure-runtime/src/service/node/limiter/mod.rs +++ b/conjure-runtime/src/service/node/limiter/mod.rs @@ -14,16 +14,10 @@ use crate::service::node::limiter::ciad::{CiadConcurrencyLimiter, EndpointLevel, HostLevel}; use crate::util::weak_reducing_gauge::Reduce; use conjure_error::Error; -use futures::future::{self, MaybeDone}; -use futures::ready; use http::{Method, Response}; use parking_lot::Mutex; -use pin_project::pin_project; use std::collections::HashMap; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; mod ciad; mod deficit_semaphore; @@ -51,47 +45,23 @@ impl Limiter { &self.host } - pub fn acquire(&self, method: &Method, pattern: &'static str) -> Acquire { - Acquire { - endpoint: future::maybe_done( - self.endpoints - .lock() - .entry(Endpoint { - method: method.clone(), - pattern, - }) - .or_insert_with(CiadConcurrencyLimiter::new) - .clone() - .acquire(), - ), - host: self.host.clone().acquire(), - } - } -} - -#[pin_project] -pub struct Acquire { - #[pin] - endpoint: MaybeDone>, - #[pin] - host: ciad::Acquire, -} - -impl Future for Acquire { - type Output = Permit; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - // acquire the endpoint permit first to avoid contention issues in the balanced limiter where requests to a - // thottled endpoint could "lock out" requests to other endpoints if we take the host permit first. - ready!(this.endpoint.as_mut().poll(cx)); - let host = ready!(this.host.poll(cx)); - - Poll::Ready(Permit { - endpoint: this.endpoint.take_output().unwrap(), - host, - }) + pub async fn acquire(&self, method: &Method, pattern: &'static str) -> Permit { + let endpoint = self + .endpoints + .lock() + .entry(Endpoint { + method: method.clone(), + pattern, + }) + .or_insert_with(CiadConcurrencyLimiter::new) + .clone(); + // acquire the endpoint permit first to avoid contention issues in the balanced limiter + // where requests to a throttled endpoint could "lock out" requests to other endpoints if we + // take the host permit first. + let endpoint = endpoint.acquire().await; + let host = self.host.clone().acquire().await; + + Permit { endpoint, host } } } diff --git a/conjure-runtime/src/service/node/metrics.rs b/conjure-runtime/src/service/node/metrics.rs index e0e5fa18..59b9c4a5 100644 --- a/conjure-runtime/src/service/node/metrics.rs +++ b/conjure-runtime/src/service/node/metrics.rs @@ -14,13 +14,8 @@ use crate::raw::Service; use crate::service::node::Node; use crate::service::Layer; -use futures::ready; use http::{Request, Response}; -use pin_project::pin_project; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::time::Instant; /// A layer which updates the host metrics for the node stored in the request's extensions map. @@ -40,53 +35,29 @@ pub struct NodeMetricsService { impl Service> for NodeMetricsService where - S: Service, Response = Response>, + S: Service, Response = Response> + Sync + Send, + B1: Sync + Send, { type Error = S::Error; type Response = S::Response; - type Future = NodeMetricsFuture; - fn call(&self, req: Request) -> Self::Future { + async fn call(&self, req: Request) -> Result { let node = req .extensions() .get::>() .expect("should have a Node extension") .clone(); - NodeMetricsFuture { - inner: self.inner.call(req), - start: Instant::now(), - node, - } - } -} - -#[pin_project] -pub struct NodeMetricsFuture { - #[pin] - inner: F, - start: Instant, - node: Arc, -} - -impl Future for NodeMetricsFuture -where - F: Future, E>>, -{ - type Output = Result, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - let result = ready!(this.inner.poll(cx)); + let start = Instant::now(); + let result = self.inner.call(req).await; - if let Some(host_metrics) = &this.node.host_metrics { + if let Some(host_metrics) = &node.host_metrics { match &result { - Ok(response) => host_metrics.update(response.status(), this.start.elapsed()), + Ok(response) => host_metrics.update(response.status(), start.elapsed()), Err(_) => host_metrics.update_io_error(), } } - Poll::Ready(result) + result } } diff --git a/conjure-runtime/src/service/node/mod.rs b/conjure-runtime/src/service/node/mod.rs index 6fd4086a..3caa08e3 100644 --- a/conjure-runtime/src/service/node/mod.rs +++ b/conjure-runtime/src/service/node/mod.rs @@ -20,13 +20,8 @@ use crate::util::weak_reducing_gauge::WeakReducingGauge; use crate::{Builder, ClientQos, HostMetrics}; use conjure_error::Error; use conjure_http::client::Endpoint; -use futures::ready; use http::{Request, Response}; -use pin_project::pin_project; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use url::Url; use witchcraft_metrics::MetricId; @@ -99,113 +94,74 @@ impl LimitedNode { node } - pub fn acquire(&self, request: &Request) -> Acquire { + pub async fn acquire(&self, request: &Request) -> AcquiredNode { let endpoint = request .extensions() .get::() .expect("Endpoint extension missing from request"); - Acquire { - acquire: self - .limiter - .as_ref() - .map(|l| l.acquire(request.method(), endpoint.path())), + let permit = match &self.limiter { + Some(limiter) => { + let permit = limiter.acquire(request.method(), endpoint.path()).await; + Some(permit) + } + None => None, + }; + + AcquiredNode { node: self.node.clone(), + permit, } } - pub fn wrap(&self, inner: &Arc, request: Request) -> Wrap + pub async fn wrap( + &self, + inner: &S, + request: Request, + ) -> Result where S: Service, Response = Response, Error = Error>, { - // don't create the span if client QoS is disabled + // Don't create the span if client QoS is disabled. if self.limiter.is_some() { let span = zipkin::next_span() .with_name("conjure-runtime: acquire-permit") .with_tag("node", &self.node.idx.to_string()); - - Wrap::Acquire { - future: span.detach().bind(self.acquire(&request)), - inner: inner.clone(), - request: Some(request), - } + let permit = span.detach().bind(self.acquire(&request)).await; + permit.wrap(inner, request).await } else { - Wrap::NodeFuture { - future: AcquiredNode { - node: self.node.clone(), - permit: None, - } - .wrap(&**inner, request), + AcquiredNode { + node: self.node.clone(), + permit: None, } + .wrap(inner, request) + .await } } } -#[pin_project] -pub struct Acquire { - node: Arc, - #[pin] - acquire: Option, -} - -impl Future for Acquire { - type Output = AcquiredNode; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let permit = match this.acquire.as_pin_mut().map(|a| a.poll(cx)) { - Some(Poll::Ready(permit)) => Some(permit), - Some(Poll::Pending) => return Poll::Pending, - None => None, - }; - - Poll::Ready(AcquiredNode { - node: this.node.clone(), - permit, - }) - } -} - pub struct AcquiredNode { node: Arc, permit: Option, } impl AcquiredNode { - pub fn wrap(self, inner: &S, mut req: Request) -> NodeFuture + pub async fn wrap( + self, + inner: &S, + mut req: Request, + ) -> Result where S: Service, Response = Response, Error = Error>, { - req.extensions_mut().insert(self.node.clone()); - - NodeFuture { - future: inner.call(req), - permit: self.permit, - } - } -} - -#[pin_project] -pub struct NodeFuture { - #[pin] - future: F, - permit: Option, -} - -impl Future for NodeFuture -where - F: Future, Error>>, -{ - type Output = F::Output; + req.extensions_mut().insert(self.node); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let response = ready!(this.future.poll(cx)); - if let Some(permit) = this.permit { + let response = inner.call(req).await; + if let Some(mut permit) = self.permit { permit.on_response(&response); } - Poll::Ready(response) + response } } @@ -225,48 +181,3 @@ impl Node { }) } } - -#[pin_project(project = WrapProject)] -pub enum Wrap -where - S: Service>, -{ - Acquire { - #[pin] - future: zipkin::Bind, - inner: Arc, - request: Option>, - }, - NodeFuture { - #[pin] - future: NodeFuture, - }, -} - -impl Future for Wrap -where - S: Service, Response = Response, Error = Error>, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let new_self = match self.as_mut().project() { - WrapProject::Acquire { - future, - inner, - request, - } => { - let acquired = ready!(future.poll(cx)); - let request = request.take().unwrap(); - - Wrap::NodeFuture { - future: acquired.wrap(&**inner, request), - } - } - WrapProject::NodeFuture { future } => return future.poll(cx), - }; - self.set(new_self); - } - } -} diff --git a/conjure-runtime/src/service/node/selector/balanced/mod.rs b/conjure-runtime/src/service/node/selector/balanced/mod.rs index b5006553..bca5d155 100644 --- a/conjure-runtime/src/service/node/selector/balanced/mod.rs +++ b/conjure-runtime/src/service/node/selector/balanced/mod.rs @@ -14,22 +14,18 @@ use crate::raw::Service; use crate::rng::ConjureRng; use crate::service::node::selector::balanced::reservoir::CoarseExponentialDecayReservoir; -use crate::service::node::{Acquire, LimitedNode, NodeFuture}; +use crate::service::node::{AcquiredNode, LimitedNode}; use crate::service::Layer; use crate::Builder; use conjure_error::Error; -use futures::ready; use http::{Request, Response}; -use pin_project::{pin_project, pinned_drop}; use rand::seq::SliceRandom; use std::future::Future; use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use witchcraft_log::debug; -use zipkin::{Detached, OpenSpan}; mod reservoir; @@ -45,18 +41,25 @@ pub struct TrackedNode { } impl TrackedNode { - fn new(node: LimitedNode) -> Arc { - Arc::new(TrackedNode { + fn new(node: LimitedNode) -> TrackedNode { + TrackedNode { node, in_flight: AtomicUsize::new(0), recent_failures: CoarseExponentialDecayReservoir::new(FAILURE_MEMORY), - }) + } } - fn acquire(self: Arc, request: &Request) -> AcquiringNode { + fn acquire<'a, 'b, 'c, B>( + &'a self, + request: &'b Request, + ) -> AcquiringNode<'a, impl Future + 'c> + where + 'a: 'c, + 'b: 'c, + { AcquiringNode { - acquire: Box::pin(self.node.acquire(request)), node: self, + acquire: Box::pin(self.node.acquire(request)), } } @@ -71,10 +74,10 @@ impl TrackedNode { } } -pub struct AcquiringNode { - node: Arc, +pub struct AcquiringNode<'a, F> { + node: &'a TrackedNode, // FIXME(#69) ideally we'd just pin the entire Vec - acquire: Pin>, + acquire: Pin>, } struct Score { @@ -100,12 +103,12 @@ impl Entropy for RandEntropy { } struct State { - nodes: Vec>, + nodes: Vec, entropy: T, } pub struct BalancedNodeSelectorLayer { - state: Arc>, + state: State, } impl BalancedNodeSelectorLayer { @@ -120,10 +123,10 @@ where { fn with_entropy(nodes: Vec, entropy: T) -> Self { BalancedNodeSelectorLayer { - state: Arc::new(State { + state: State { nodes: nodes.into_iter().map(TrackedNode::new).collect(), entropy, - }), + }, } } } @@ -134,34 +137,26 @@ impl Layer for BalancedNodeSelectorLayer { fn layer(self, inner: S) -> Self::Service { BalancedNodeSelectorService { state: self.state, - inner: Arc::new(inner), + inner, } } } pub struct BalancedNodeSelectorService { - state: Arc>, - inner: Arc, + state: State, + inner: S, } -impl Service> for BalancedNodeSelectorService +impl BalancedNodeSelectorService where T: Entropy, - S: Service, Response = Response, Error = Error>, { - type Response = S::Response; - type Error = S::Error; - type Future = BalancedNodeSelectorFuture; - - fn call(&self, req: Request) -> Self::Future { - let span = zipkin::next_span().with_name("conjure-runtime: balanced-node-selection"); - - // Dialogue skips nodes that have significantly worse scores than previous ones on each attempt, but to do that - // here we'd need a way to notify tasks on score changes. Rather than adding the complexity of implementing - // that, we just perform the filtering once when first constructing the future. This filtering is intended to - // bypass nodes that are e.g. entirely offline, so it should be fine to just do it once up front for a given - // request. - + async fn acquire<'a, B1>(&'a self, req: &Request) -> (&'a TrackedNode, AcquiredNode) { + // Dialogue skips nodes that have significantly worse scores than previous ones on each + // attempt, but to do that here we'd need a way to notify tasks on score changes. Rather + // than adding the complexity of implementing that, we just perform the filtering once. This + // filtering is intended to bypass nodes that are e.g. entirely offline, so it could be fine + // to just do it once up front for a given request. let mut snapshots = self .state .nodes @@ -174,7 +169,7 @@ where snapshots.sort_by_key(|s| s.score.score); let mut nodes = vec![]; - let mut give_up_threshold = usize::max_value(); + let mut give_up_threshold = usize::MAX; for snapshot in snapshots { if snapshot.score.score > give_up_threshold { debug!( @@ -183,7 +178,7 @@ where score: snapshot.score.score, giveUpScore: give_up_threshold, hostIndex: snapshot.node.node.node.idx, - }, + } ); continue; @@ -196,122 +191,92 @@ where .saturating_mul(UNHEALTHY_SCORE_MULTIPLIER); } - nodes.push(snapshot.node.clone().acquire(&req)); + nodes.push(snapshot.node.acquire(req)); } // shuffle so that we don't break ties the same way every request self.state.entropy.shuffle(&mut nodes); - BalancedNodeSelectorFuture::Acquire { - nodes, - service: self.inner.clone(), - request: Some(req), - span: span.detach(), - } + Acquire { nodes }.await } } -#[pin_project(project = Project, PinnedDrop)] -#[allow(clippy::large_enum_variant)] -pub enum BalancedNodeSelectorFuture -where - S: Service>, -{ - Acquire { - nodes: Vec, - service: Arc, - request: Option>, - span: OpenSpan, - }, - Wrap { - #[pin] - future: NodeFuture, - node: Arc, - }, +struct Acquire<'a, F> { + nodes: Vec>, } -#[pinned_drop] -impl PinnedDrop for BalancedNodeSelectorFuture +impl<'a, F> Future for Acquire<'a, F> where - S: Service>, + F: Future, { - fn drop(self: Pin<&mut Self>) { - if let Project::Wrap { node, .. } = self.project() { - node.in_flight.fetch_sub(1, Ordering::SeqCst); + type Output = (&'a TrackedNode, AcquiredNode); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // even though we've filtered above in acquire, we still want to poll the nodes in order of + // score to ensure we pick the best scoring node if multiple are available at the same time. + let mut snapshots = self + .nodes + .iter_mut() + .map(|node| Snapshot { + score: node.node.score(), + node, + }) + .collect::>(); + snapshots.sort_by_key(|n| n.score.score); + + for snapshot in snapshots { + if let Poll::Ready(acquired) = snapshot.node.acquire.as_mut().poll(cx) { + return Poll::Ready((snapshot.node.node, acquired)); + } } + + Poll::Pending } } -impl Future for BalancedNodeSelectorFuture +impl Service> for BalancedNodeSelectorService where - S: Service, Response = Response, Error = Error>, + T: Entropy + Sync + Send, + S: Service, Response = Response, Error = Error> + Sync + Send, + B1: Sync + Send, { - type Output = Result; + type Response = S::Response; + type Error = S::Error; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let new_self = match self.as_mut().project() { - Project::Acquire { - nodes, - service, - request, - span, - } => { - let mut _guard = Some(zipkin::set_current(span.context())); - - // even though we've filtered above in Service::call, we still want to poll the nodes in order of - // score to ensure we pick the best scoring node if multiple are available at the same time. - let mut snapshots = nodes - .iter_mut() - .map(|node| Snapshot { - score: node.node.score(), - node, - }) - .collect::>(); - snapshots.sort_by_key(|n| n.score.score); - - match snapshots - .into_iter() - .filter_map(|s| match s.node.acquire.as_mut().poll(cx) { - Poll::Ready(node) => { - // drop the context guard before we create the next future to avoid span nesting - _guard = None; - - s.node.node.in_flight.fetch_add(1, Ordering::SeqCst); - - Some(BalancedNodeSelectorFuture::Wrap { - future: node.wrap(&**service, request.take().unwrap()), - node: s.node.node.clone(), - }) - } - Poll::Pending => None, - }) - .next() - { - Some(f) => f, - None => return Poll::Pending, - } - } - Project::Wrap { future, node } => { - let result = ready!(future.poll(cx)); - - match &result { - // dialogue has a more complex set of conditionals, but this is what it ends up working out to - Ok(response) if response.status().is_server_error() => { - node.recent_failures.update(FAILURE_WEIGHT) - } - Ok(response) if response.status().is_client_error() => { - node.recent_failures.update(FAILURE_WEIGHT / 100.) - } - Ok(_) => {} - Err(_) => node.recent_failures.update(FAILURE_WEIGHT), - } + async fn call(&self, req: Request) -> Result { + let (node, tracked) = zipkin::next_span() + .with_name("conjure-runtime: balanced-node-selection") + .detach() + .bind(self.acquire(&req)) + .await; - return Poll::Ready(result); - } - }; - self.set(new_self); + node.in_flight.fetch_add(1, Ordering::SeqCst); + let _guard = InFlightGuard { node }; + + let result = tracked.wrap(&self.inner, req).await; + + match &result { + Ok(response) if response.status().is_server_error() => { + node.recent_failures.update(FAILURE_WEIGHT); + } + Ok(response) if response.status().is_client_error() => { + node.recent_failures.update(FAILURE_WEIGHT / 100.) + } + Ok(_) => {} + Err(_) => node.recent_failures.update(FAILURE_WEIGHT), } + + result + } +} + +struct InFlightGuard<'a> { + node: &'a TrackedNode, +} + +impl Drop for InFlightGuard<'_> { + fn drop(&mut self) { + self.node.in_flight.fetch_sub(1, Ordering::SeqCst); } } @@ -327,6 +292,7 @@ mod test { use http::StatusCode; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use tokio::time; fn request() -> Request<()> { @@ -375,9 +341,13 @@ mod test { } } })); + let service = Arc::new(service); // the first request will be to a, so wait until we know the request has hit the service. - tokio::spawn(service.call(request())); + tokio::spawn({ + let service = service.clone(); + async move { service.call(request()).await } + }); rx.next().await.unwrap(); for _ in 0..100 { diff --git a/conjure-runtime/src/service/node/selector/empty.rs b/conjure-runtime/src/service/node/selector/empty.rs index 6e4c5d7b..023376a4 100644 --- a/conjure-runtime/src/service/node/selector/empty.rs +++ b/conjure-runtime/src/service/node/selector/empty.rs @@ -15,11 +15,8 @@ use crate::raw::Service; use crate::service::Layer; use conjure_error::Error; use http::Request; -use std::future::Future; use std::marker::PhantomData; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; /// A node selector that always returns an error. pub struct EmptyNodeSelectorLayer { @@ -52,33 +49,14 @@ pub struct EmptyNodeSelectorService { impl Service> for EmptyNodeSelectorService where - S: Service, Error = Error>, + S: Service, Error = Error> + Sync, + B: Send, { type Response = S::Response; type Error = Error; - type Future = EmptyNodeSelectorFuture; - fn call(&self, _: Request) -> Self::Future { - EmptyNodeSelectorFuture { - service: self.service.clone(), - _p: PhantomData, - } - } -} - -pub struct EmptyNodeSelectorFuture { - service: Arc, - _p: PhantomData<(S, B)>, -} - -impl Future for EmptyNodeSelectorFuture -where - S: Service, Error = Error>, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(Err(Error::internal_safe("service configured with no URIs") - .with_safe_param("service", &*self.service))) + async fn call(&self, _: Request) -> Result { + Err(Error::internal_safe("service configured with no URIs") + .with_safe_param("service", &*self.service)) } } diff --git a/conjure-runtime/src/service/node/selector/mod.rs b/conjure-runtime/src/service/node/selector/mod.rs index 65f68354..ad55423f 100644 --- a/conjure-runtime/src/service/node/selector/mod.rs +++ b/conjure-runtime/src/service/node/selector/mod.rs @@ -13,27 +13,18 @@ // limitations under the License. use crate::raw::Service; use crate::service::node::selector::balanced::{ - BalancedNodeSelectorFuture, BalancedNodeSelectorLayer, BalancedNodeSelectorService, -}; -use crate::service::node::selector::empty::{ - EmptyNodeSelectorFuture, EmptyNodeSelectorLayer, EmptyNodeSelectorService, + BalancedNodeSelectorLayer, BalancedNodeSelectorService, }; +use crate::service::node::selector::empty::{EmptyNodeSelectorLayer, EmptyNodeSelectorService}; use crate::service::node::selector::pin_until_error::{ - FixedNodes, PinUntilErrorNodeSelectorFuture, PinUntilErrorNodeSelectorLayer, - PinUntilErrorNodeSelectorService, ReshufflingNodes, -}; -use crate::service::node::selector::single::{ - SingleNodeSelectorFuture, SingleNodeSelectorLayer, SingleNodeSelectorService, + FixedNodes, PinUntilErrorNodeSelectorLayer, PinUntilErrorNodeSelectorService, ReshufflingNodes, }; +use crate::service::node::selector::single::{SingleNodeSelectorLayer, SingleNodeSelectorService}; use crate::service::node::LimitedNode; use crate::service::Layer; use crate::{Builder, NodeSelectionStrategy}; use conjure_error::Error; use http::{Request, Response}; -use pin_project::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; mod balanced; mod empty; @@ -113,50 +104,19 @@ pub enum NodeSelectorService { impl Service> for NodeSelectorService where - S: Service, Response = Response, Error = Error>, + S: Service, Response = Response, Error = Error> + Sync + Send, + B1: Sync + Send, { type Response = S::Response; type Error = S::Error; - type Future = NodeSelectorFuture; - fn call(&self, req: Request) -> Self::Future { + async fn call(&self, req: Request) -> Result { match self { - NodeSelectorService::Empty(s) => NodeSelectorFuture::Empty(s.call(req)), - NodeSelectorService::Single(s) => NodeSelectorFuture::Single(s.call(req)), - NodeSelectorService::PinUntilError(s) => NodeSelectorFuture::PinUntilError(s.call(req)), - NodeSelectorService::PinUntilErrorWithoutReshuffle(s) => { - NodeSelectorFuture::PinUntilErrorWithoutReshuffle(s.call(req)) - } - NodeSelectorService::Balanced(s) => NodeSelectorFuture::Balanced(s.call(req)), - } - } -} - -#[pin_project(project = Projection)] -pub enum NodeSelectorFuture -where - S: Service>, -{ - Empty(#[pin] EmptyNodeSelectorFuture), - Single(#[pin] SingleNodeSelectorFuture), - PinUntilError(#[pin] PinUntilErrorNodeSelectorFuture), - PinUntilErrorWithoutReshuffle(#[pin] PinUntilErrorNodeSelectorFuture), - Balanced(#[pin] BalancedNodeSelectorFuture), -} - -impl Future for NodeSelectorFuture -where - S: Service, Response = Response, Error = Error>, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { - Projection::Empty(f) => f.poll(cx), - Projection::Single(f) => f.poll(cx), - Projection::PinUntilError(f) => f.poll(cx), - Projection::PinUntilErrorWithoutReshuffle(f) => f.poll(cx), - Projection::Balanced(f) => f.poll(cx), + NodeSelectorService::Empty(s) => s.call(req).await, + NodeSelectorService::Single(s) => s.call(req).await, + NodeSelectorService::PinUntilError(s) => s.call(req).await, + NodeSelectorService::PinUntilErrorWithoutReshuffle(s) => s.call(req).await, + NodeSelectorService::Balanced(s) => s.call(req).await, } } } diff --git a/conjure-runtime/src/service/node/selector/pin_until_error.rs b/conjure-runtime/src/service/node/selector/pin_until_error.rs index 865ca9f1..78be152a 100644 --- a/conjure-runtime/src/service/node/selector/pin_until_error.rs +++ b/conjure-runtime/src/service/node/selector/pin_until_error.rs @@ -13,22 +13,17 @@ // limitations under the License. use crate::raw::Service; use crate::rng::ConjureRng; -use crate::service::node::{LimitedNode, Wrap}; +use crate::service::node::LimitedNode; use crate::service::Layer; use crate::Builder; use arc_swap::ArcSwap; use conjure_error::Error; -use futures::ready; use http::{Request, Response}; -use pin_project::pin_project; use rand::distributions::uniform::SampleUniform; use rand::seq::SliceRandom; use rand::Rng; -use std::future::Future; -use std::pin::Pin; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::time::{Duration, Instant}; // we reshuffle nodes every 10 minutes on average, with 30 seconds of jitter to either side @@ -175,15 +170,10 @@ where } } -struct State { - current_pin: AtomicUsize, - nodes: T, -} - /// A node selector layer which pins to a host until a request either fails with a 5xx error or IO error, after which /// it rotates to the next. pub struct PinUntilErrorNodeSelectorLayer { - state: Arc>, + nodes: T, } impl PinUntilErrorNodeSelectorLayer @@ -191,12 +181,7 @@ where T: Nodes, { pub fn new(nodes: T) -> PinUntilErrorNodeSelectorLayer { - PinUntilErrorNodeSelectorLayer { - state: Arc::new(State { - current_pin: AtomicUsize::new(0), - nodes, - }), - } + PinUntilErrorNodeSelectorLayer { nodes } } } @@ -205,60 +190,33 @@ impl Layer for PinUntilErrorNodeSelectorLayer { fn layer(self, inner: S) -> Self::Service { PinUntilErrorNodeSelectorService { - state: self.state, - inner: Arc::new(inner), + nodes: self.nodes, + current_pin: AtomicUsize::new(0), + inner, } } } pub struct PinUntilErrorNodeSelectorService { - state: Arc>, - inner: Arc, + nodes: T, + current_pin: AtomicUsize, + inner: S, } impl Service> for PinUntilErrorNodeSelectorService where - T: Nodes, - S: Service, Response = Response, Error = Error>, + T: Nodes + Sync + Send, + S: Service, Response = Response, Error = Error> + Sync + Send, + B1: Sync + Send, { type Response = S::Response; type Error = S::Error; - type Future = PinUntilErrorNodeSelectorFuture; - fn call(&self, req: Request) -> Self::Future { - let pin = self.state.current_pin.load(Ordering::SeqCst); - let node = self.state.nodes.get(pin); + async fn call(&self, req: Request) -> Result { + let pin = self.current_pin.load(Ordering::SeqCst); + let node = self.nodes.get(pin); - PinUntilErrorNodeSelectorFuture { - future: node.wrap(&self.inner, req), - state: self.state.clone(), - pin, - } - } -} - -#[pin_project] -pub struct PinUntilErrorNodeSelectorFuture -where - S: Service>, -{ - #[pin] - future: Wrap, - state: Arc>, - pin: usize, -} - -impl Future for PinUntilErrorNodeSelectorFuture -where - T: Nodes, - S: Service, Response = Response, Error = Error>, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - let result = ready!(this.future.poll(cx)); + let result = node.wrap(&self.inner, req).await; let increment_host = match &result { Ok(response) => response.status().is_server_error(), @@ -266,16 +224,13 @@ where }; if increment_host { - let new_pin = (*this.pin + 1) % this.state.nodes.len(); - let _ = this.state.current_pin.compare_exchange( - *this.pin, - new_pin, - Ordering::SeqCst, - Ordering::SeqCst, - ); + let new_pin = (pin + 1) % self.nodes.len(); + let _ = + self.current_pin + .compare_exchange(pin, new_pin, Ordering::SeqCst, Ordering::SeqCst); } - Poll::Ready(result) + result } } diff --git a/conjure-runtime/src/service/node/selector/single.rs b/conjure-runtime/src/service/node/selector/single.rs index b16bc5bc..b334060a 100644 --- a/conjure-runtime/src/service/node/selector/single.rs +++ b/conjure-runtime/src/service/node/selector/single.rs @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. use crate::raw::Service; -use crate::service::node::{LimitedNode, Wrap}; +use crate::service::node::LimitedNode; use crate::service::Layer; use conjure_error::Error; use http::{Request, Response}; -use std::sync::Arc; /// A node selector layer which always selects a single node. /// @@ -37,28 +36,26 @@ impl Layer for SingleNodeSelectorLayer { fn layer(self, inner: S) -> SingleNodeSelectorService { SingleNodeSelectorService { - inner: Arc::new(inner), + inner, node: self.node, } } } pub struct SingleNodeSelectorService { - inner: Arc, + inner: S, node: LimitedNode, } impl Service> for SingleNodeSelectorService where - S: Service, Response = Response, Error = Error>, + S: Service, Response = Response, Error = Error> + Sync + Send, + B1: Sync + Send, { type Response = S::Response; type Error = S::Error; - type Future = SingleNodeSelectorFuture; - fn call(&self, req: Request) -> Self::Future { - self.node.wrap(&self.inner, req) + async fn call(&self, req: Request) -> Result { + self.node.wrap(&self.inner, req).await } } - -pub type SingleNodeSelectorFuture = Wrap; diff --git a/conjure-runtime/src/service/node/uri.rs b/conjure-runtime/src/service/node/uri.rs index 7d7dafd4..d71d8c76 100644 --- a/conjure-runtime/src/service/node/uri.rs +++ b/conjure-runtime/src/service/node/uri.rs @@ -15,6 +15,7 @@ use crate::raw::Service; use crate::service::node::Node; use crate::service::Layer; use http::Request; +use std::future::Future; use std::sync::Arc; /// A layer which converts an origin-form URI to an absolute-form by joining with a node's base URI stored in the @@ -39,9 +40,11 @@ where { type Error = S::Error; type Response = S::Response; - type Future = S::Future; - fn call(&self, mut req: Request) -> Self::Future { + fn call( + &self, + mut req: Request, + ) -> impl Future> { // we expect the request's URI to be in origin-form debug_assert!(req.uri().scheme().is_none()); debug_assert!(req.uri().authority().is_none()); diff --git a/conjure-runtime/src/service/proxy/mod.rs b/conjure-runtime/src/service/proxy/mod.rs index 3e52710b..c1482227 100644 --- a/conjure-runtime/src/service/proxy/mod.rs +++ b/conjure-runtime/src/service/proxy/mod.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::config; pub use crate::service::proxy::connector::{ProxyConnectorLayer, ProxyConnectorService}; -pub use crate::service::proxy::request::{ProxyLayer, ProxyService}; +pub use crate::service::proxy::request::ProxyLayer; use base64::display::Base64Display; use base64::engine::general_purpose::STANDARD; use conjure_error::Error; diff --git a/conjure-runtime/src/service/proxy/request.rs b/conjure-runtime/src/service/proxy/request.rs index f7572a6c..dcd4ab63 100644 --- a/conjure-runtime/src/service/proxy/request.rs +++ b/conjure-runtime/src/service/proxy/request.rs @@ -17,6 +17,7 @@ use crate::service::Layer; use http::header::PROXY_AUTHORIZATION; use http::uri::Scheme; use http::Request; +use std::future::Future; /// A layer which adjusts an HTTP request as necessary to respect proxy settings. /// @@ -58,9 +59,8 @@ where { type Response = S::Response; type Error = S::Error; - type Future = S::Future; - fn call(&self, mut req: Request) -> Self::Future { + fn call(&self, mut req: Request) -> impl Future> { match &self.config { ProxyConfig::Http(config) => { if req.uri().scheme() == Some(&Scheme::HTTP) { diff --git a/conjure-runtime/src/service/response_body.rs b/conjure-runtime/src/service/response_body.rs index 32eee0bf..624c1e47 100644 --- a/conjure-runtime/src/service/response_body.rs +++ b/conjure-runtime/src/service/response_body.rs @@ -17,11 +17,6 @@ use crate::{BaseBody, ResponseBody}; use bytes::Bytes; use http::Response; use http_body::Body; -use pin_project::pin_project; -use std::error; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; /// A layer which wraps the response body in the conjure-runtime public `ResponseBody` type. pub struct ResponseBodyLayer; @@ -40,41 +35,14 @@ pub struct ResponseBodyService { impl Service for ResponseBodyService where - S: Service>>, + S: Service>> + Sync + Send, + R: Send, B: Body, - B::Error: Into>, { type Response = Response>; - type Error = S::Error; - type Future = ResponseBodyFuture; - - fn call(&self, req: R) -> Self::Future { - ResponseBodyFuture { - future: self.inner.call(req), - } - } -} - -#[pin_project] -pub struct ResponseBodyFuture { - #[pin] - future: F, -} - -impl Future for ResponseBodyFuture -where - F: Future>, E>>, - B: Body, - B::Error: Into>, -{ - type Output = Result>, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project() - .future - .poll(cx) - .map_ok(|res| res.map(ResponseBody::new)) + async fn call(&self, req: R) -> Result { + self.inner.call(req).await.map(|r| r.map(ResponseBody::new)) } } diff --git a/conjure-runtime/src/service/retry.rs b/conjure-runtime/src/service/retry.rs index aff46bf5..c5a2c32e 100644 --- a/conjure-runtime/src/service/retry.rs +++ b/conjure-runtime/src/service/retry.rs @@ -22,13 +22,13 @@ use crate::{BodyWriter, Builder, Idempotency}; use async_trait::async_trait; use conjure_error::{Error, ErrorKind}; use conjure_http::client::{AsyncRequestBody, AsyncWriteBody, Endpoint}; -use futures::future::{self, BoxFuture}; +use futures::future; use http::request::Parts; use http::{Request, Response, StatusCode}; use rand::Rng; use std::error; +use std::future::Future; use std::pin::Pin; -use std::sync::Arc; use tokio::time::{self, Duration}; use witchcraft_log::info; @@ -47,7 +47,7 @@ pub struct RetryLayer { idempotency: Idempotency, max_num_retries: u32, backoff_slot_size: Duration, - rng: Arc, + rng: ConjureRng, } impl RetryLayer { @@ -60,7 +60,7 @@ impl RetryLayer { builder.get_max_num_retries() }, backoff_slot_size: builder.get_backoff_slot_size(), - rng: Arc::new(ConjureRng::new(builder)), + rng: ConjureRng::new(builder), } } } @@ -70,7 +70,7 @@ impl Layer for RetryLayer { fn layer(self, inner: S) -> Self::Service { RetryService { - inner: Arc::new(inner), + inner, idempotency: self.idempotency, max_num_retries: self.max_num_retries, backoff_slot_size: self.backoff_slot_size, @@ -80,25 +80,26 @@ impl Layer for RetryLayer { } pub struct RetryService { - inner: Arc, + inner: S, idempotency: Idempotency, max_num_retries: u32, backoff_slot_size: Duration, - rng: Arc, + rng: ConjureRng, } impl<'a, S, B> Service>> for RetryService where S: Service, Response = Response, Error = Error> + 'a + Sync + Send, S::Response: Send, - S::Future: Send, B: 'static, { type Response = S::Response; type Error = S::Error; - type Future = BoxFuture<'a, Result>; - fn call(&self, req: Request>) -> Self::Future { + fn call( + &self, + req: Request>, + ) -> impl Future> { let idempotent = match self.idempotency { Idempotency::Always => true, Idempotency::ByMethod => req.method().is_idempotent(), @@ -106,28 +107,22 @@ where }; let state = State { - inner: self.inner.clone(), + service: self, idempotent, - max_num_retries: self.max_num_retries, - backoff_slot_size: self.backoff_slot_size, - rng: self.rng.clone(), attempt: 0, }; - Box::pin(state.call(req)) + state.call(req) } } -struct State { - inner: Arc, +struct State<'a, S> { + service: &'a RetryService, idempotent: bool, - max_num_retries: u32, - backoff_slot_size: Duration, - rng: Arc, attempt: u32, } -impl State +impl State<'_, S> where S: Service, Response = Response, Error = Error>, { @@ -242,7 +237,7 @@ where let req = Request::from_parts(parts, body); let (body_result, response_result) = - future::join(writer.write(), self.inner.call(req)).await; + future::join(writer.write(), self.service.inner.call(req)).await; match (body_result, response_result) { (Ok(()), Ok(response)) => Ok(response), @@ -280,7 +275,7 @@ where retry_after: Option, ) -> Result<(), Error> { self.attempt += 1; - if self.attempt > self.max_num_retries { + if self.attempt > self.service.max_num_retries { info!("exceeded retry limits"); return Err(error); } @@ -297,12 +292,13 @@ where Some(backoff) => backoff, None => { let scale = 1 << (self.attempt - 1); - let max = self.backoff_slot_size * scale; + let max = self.service.backoff_slot_size * scale; // gen_range panics when min == max if max == Duration::from_secs(0) { Duration::from_secs(0) } else { - self.rng + self.service + .rng .with(|rng| rng.gen_range(Duration::from_secs(0)..max)) } } diff --git a/conjure-runtime/src/service/root_span.rs b/conjure-runtime/src/service/root_span.rs index c7fcc955..e122d963 100644 --- a/conjure-runtime/src/service/root_span.rs +++ b/conjure-runtime/src/service/root_span.rs @@ -15,6 +15,7 @@ use crate::service::{Layer, Service}; use crate::util::spans::{self, HttpSpanFuture}; use conjure_error::Error; use http::{Request, Response}; +use std::future::Future; /// A layer which manages the root level request span. pub struct RootSpanLayer; @@ -37,9 +38,8 @@ where { type Response = S::Response; type Error = S::Error; - type Future = HttpSpanFuture; - fn call(&self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> impl Future> { let mut span = zipkin::next_span() .with_name("conjure-runtime: request") .detach(); diff --git a/conjure-runtime/src/service/trace_propagation.rs b/conjure-runtime/src/service/trace_propagation.rs index cb3e19ba..42a32da8 100644 --- a/conjure-runtime/src/service/trace_propagation.rs +++ b/conjure-runtime/src/service/trace_propagation.rs @@ -14,6 +14,7 @@ use crate::raw::Service; use crate::service::Layer; use http::Request; +use std::future::Future; /// A request layer which injects Zipkin tracing information into an outgoing request's headers. /// @@ -38,9 +39,11 @@ where { type Response = S::Response; type Error = S::Error; - type Future = S::Future; - fn call(&self, mut req: Request) -> Self::Future { + fn call( + &self, + mut req: Request, + ) -> impl Future> { if let Some(context) = zipkin::current() { http_zipkin::set_trace_context(context, req.headers_mut()); } diff --git a/conjure-runtime/src/service/user_agent.rs b/conjure-runtime/src/service/user_agent.rs index 7e6c8daf..6ec94630 100644 --- a/conjure-runtime/src/service/user_agent.rs +++ b/conjure-runtime/src/service/user_agent.rs @@ -18,6 +18,7 @@ use conjure_http::client::Endpoint; use http::header::USER_AGENT; use http::{HeaderValue, Request}; use std::convert::TryFrom; +use std::future::Future; /// A layer which injects a `User-Agent` header into requests. /// @@ -59,9 +60,11 @@ where { type Response = S::Response; type Error = S::Error; - type Future = S::Future; - fn call(&self, mut req: Request) -> Self::Future { + fn call( + &self, + mut req: Request, + ) -> impl Future> { let endpoint = req .extensions() .get::() diff --git a/conjure-runtime/src/service/wait_for_spans.rs b/conjure-runtime/src/service/wait_for_spans.rs index cb7fde7d..930fecdb 100644 --- a/conjure-runtime/src/service/wait_for_spans.rs +++ b/conjure-runtime/src/service/wait_for_spans.rs @@ -13,14 +13,12 @@ // limitations under the License. use crate::raw::Service; use crate::service::Layer; -use futures::ready; use http::{HeaderMap, Response}; use http_body::{Body, SizeHint}; use pin_project::pin_project; -use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use zipkin::{Bind, Detached, Kind, OpenSpan}; +use zipkin::{Detached, Kind, OpenSpan}; /// A layer which wraps the request future in a `conjure-runtime: wait-for-headers` span, and the response's body in a /// `conjure-runtime: wait-for-body` span. @@ -40,44 +38,26 @@ pub struct WaitForSpansService { impl Service for WaitForSpansService where - S: Service>, + S: Service> + Sync + Send, + R: Send, { type Response = Response>; type Error = S::Error; - type Future = WaitForSpansFuture; - fn call(&self, req: R) -> Self::Future { - WaitForSpansFuture { - future: zipkin::next_span() - .with_name("conjure-runtime: wait-for-headers") - .with_kind(Kind::Client) - .detach() - .bind(self.inner.call(req)), - } - } -} - -#[pin_project] -pub struct WaitForSpansFuture { - #[pin] - future: Bind, -} - -impl Future for WaitForSpansFuture -where - F: Future, E>>, -{ - type Output = Result>, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response = ready!(self.project().future.poll(cx))?; + async fn call(&self, req: R) -> Result { + let response = zipkin::next_span() + .with_name("conjure-runtime: wait-for-headers") + .with_kind(Kind::Client) + .detach() + .bind(self.inner.call(req)) + .await?; - Poll::Ready(Ok(response.map(|body| WaitForSpansBody { + Ok(response.map(|body| WaitForSpansBody { body, _span: zipkin::next_span() .with_name("conjure-runtime: wait-for-body") .detach(), - }))) + })) } }