diff --git a/Cargo.lock b/Cargo.lock index b09757a093..1d174fb477 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -703,6 +703,7 @@ dependencies = [ "linkerd-proxy-tcp", "linkerd-proxy-transport", "linkerd-reconnect", + "linkerd-retry", "linkerd-service-profiles", "linkerd-stack", "linkerd-stack-metrics", @@ -1019,13 +1020,11 @@ dependencies = [ "http-body", "hyper", "linkerd-error", - "linkerd-http-box", - "linkerd-stack", "linkerd-tracing", "parking_lot", "pin-project", + "thiserror", "tokio", - "tower", "tracing", ] @@ -1298,6 +1297,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "linkerd-retry" +version = "0.1.0" +dependencies = [ + "linkerd-error", + "linkerd-stack", + "pin-project", + "tower", + "tracing", +] + [[package]] name = "linkerd-service-profiles" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 448efad877..adae4608f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ members = [ "linkerd/proxy/tcp", "linkerd/proxy/transport", "linkerd/reconnect", + "linkerd/retry", "linkerd/service-profiles", "linkerd/signal", "linkerd/stack", diff --git a/linkerd/app/core/Cargo.toml b/linkerd/app/core/Cargo.toml index e15009b004..87733d3bfd 100644 --- a/linkerd/app/core/Cargo.toml +++ b/linkerd/app/core/Cargo.toml @@ -52,6 +52,7 @@ linkerd-proxy-tap = { path = "../../proxy/tap" } linkerd-proxy-tcp = { path = "../../proxy/tcp" } linkerd-proxy-transport = { path = "../../proxy/transport" } linkerd-reconnect = { path = "../../reconnect" } +linkerd-retry = { path = "../../retry" } linkerd-timeout = { path = "../../timeout" } linkerd-tracing = { path = "../../tracing" } linkerd-service-profiles = { path = "../../service-profiles" } diff --git a/linkerd/app/core/src/retry.rs b/linkerd/app/core/src/retry.rs index dfb8ecbe5e..2db752234b 100644 --- a/linkerd/app/core/src/retry.rs +++ b/linkerd/app/core/src/retry.rs @@ -4,48 +4,50 @@ use super::http_metrics::retries::Handle; use super::metrics::HttpRouteRetry; use crate::profiles; use futures::future; +use linkerd_error::Error; use linkerd_http_classify::{Classify, ClassifyEos, ClassifyResponse}; -use linkerd_stack::Param; +use linkerd_http_retry::ReplayBody; +use linkerd_retry as retry; +use linkerd_stack::{layer, Either, Param}; use std::sync::Arc; -use tower::retry::budget::Budget; -pub use linkerd_http_retry::*; - -pub fn layer(metrics: HttpRouteRetry) -> NewRetryLayer { - NewRetryLayer::new(NewRetry::new(metrics)) +pub fn layer( + metrics: HttpRouteRetry, +) -> impl layer::Layer> + Clone { + retry::NewRetry::<_, N>::layer(NewRetryPolicy::new(metrics)) } #[derive(Clone, Debug)] -pub struct NewRetry { +pub struct NewRetryPolicy { metrics: HttpRouteRetry, } #[derive(Clone, Debug)] -pub struct Retry { +pub struct RetryPolicy { metrics: Handle, - budget: Arc, + budget: Arc, response_classes: profiles::http::ResponseClasses, } /// Allow buffering requests up to 64 kb const MAX_BUFFERED_BYTES: usize = 64 * 1024; -// === impl NewRetry === +// === impl NewRetryPolicy === -impl NewRetry { +impl NewRetryPolicy { pub fn new(metrics: HttpRouteRetry) -> Self { Self { metrics } } } -impl NewPolicy for NewRetry { - type Policy = Retry; +impl retry::NewPolicy for NewRetryPolicy { + type Policy = RetryPolicy; fn new_policy(&self, route: &Route) -> Option { let retries = route.route.retries().cloned()?; let metrics = self.metrics.get_handle(route.param()); - Some(Retry { + Some(RetryPolicy { metrics, budget: retries.budget().clone(), response_classes: route.route.response_classes().clone(), @@ -55,7 +57,38 @@ impl NewPolicy for NewRetry { // === impl Retry === -impl Policy, http::Response, E> for Retry +impl RetryPolicy { + fn can_retry(&self, req: &http::Request) -> bool { + let content_length = |req: &http::Request<_>| { + req.headers() + .get(http::header::CONTENT_LENGTH) + .and_then(|value| value.to_str().ok()?.parse::().ok()) + }; + + // Requests without bodies can always be retried, as we will not need to + // buffer the body. If the request *does* have a body, retry it if and + // only if the request contains a `content-length` header and the + // content length is >= 64 kb. + let has_body = !req.body().is_end_stream(); + if has_body && content_length(&req).unwrap_or(usize::MAX) > MAX_BUFFERED_BYTES { + tracing::trace!( + req.has_body = has_body, + req.content_length = ?content_length(&req), + "not retryable", + ); + return false; + } + + tracing::trace!( + req.has_body = has_body, + req.content_length = ?content_length(&req), + "retryable", + ); + true + } +} + +impl retry::Policy, http::Response, E> for RetryPolicy where A: http_body::Body + Clone, { @@ -109,33 +142,20 @@ where } } -impl CanRetry for Retry { - fn can_retry(&self, req: &http::Request) -> bool { - let content_length = |req: &http::Request<_>| { - req.headers() - .get(http::header::CONTENT_LENGTH) - .and_then(|value| value.to_str().ok()?.parse::().ok()) - }; +impl retry::PrepareRequest, http::Response, E> for RetryPolicy +where + A: http_body::Body + Unpin, + A::Error: Into, +{ + type RetryRequest = http::Request>; - // Requests without bodies can always be retried, as we will not need to - // buffer the body. If the request *does* have a body, retry it if and - // only if the request contains a `content-length` header and the - // content length is >= 64 kb. - let has_body = !req.body().is_end_stream(); - if has_body && content_length(&req).unwrap_or(usize::MAX) > MAX_BUFFERED_BYTES { - tracing::trace!( - req.has_body = has_body, - req.content_length = ?content_length(&req), - "not retryable", - ); - return false; + fn prepare_request( + &self, + req: http::Request, + ) -> Either> { + if self.can_retry(&req) { + return Either::A(req.map(|body| ReplayBody::new(body, MAX_BUFFERED_BYTES))); } - - tracing::trace!( - req.has_body = has_body, - req.content_length = ?content_length(&req), - "retryable", - ); - true + Either::B(req) } } diff --git a/linkerd/app/inbound/fuzz/Cargo.lock b/linkerd/app/inbound/fuzz/Cargo.lock index 9e58c0c5ec..70006e5481 100644 --- a/linkerd/app/inbound/fuzz/Cargo.lock +++ b/linkerd/app/inbound/fuzz/Cargo.lock @@ -610,6 +610,7 @@ dependencies = [ "linkerd-proxy-tcp", "linkerd-proxy-transport", "linkerd-reconnect", + "linkerd-retry", "linkerd-service-profiles", "linkerd-stack", "linkerd-stack-metrics", @@ -863,11 +864,9 @@ dependencies = [ "http", "http-body", "linkerd-error", - "linkerd-http-box", - "linkerd-stack", "parking_lot", "pin-project", - "tower", + "thiserror", "tracing", ] @@ -1127,6 +1126,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "linkerd-retry" +version = "0.1.0" +dependencies = [ + "linkerd-error", + "linkerd-stack", + "pin-project", + "tower", + "tracing", +] + [[package]] name = "linkerd-service-profiles" version = "0.1.0" diff --git a/linkerd/app/outbound/src/http/logical.rs b/linkerd/app/outbound/src/http/logical.rs index 3381e2478f..2ec32225ad 100644 --- a/linkerd/app/outbound/src/http/logical.rs +++ b/linkerd/app/outbound/src/http/logical.rs @@ -120,6 +120,12 @@ impl Outbound { .http_route_actual .to_layer::(), ) + // Depending on whether or not the request can be retried, + // it may have one of two `Body` types. This layer unifies + // any `Body` type into `BoxBody` so that the rest of the + // stack doesn't have to implement `Service` for requests + // with both body types. + .push_on_response(http::BoxRequest::erased()) // Sets an optional retry policy. .push(retry::layer(rt.metrics.http_route_retry.clone())) // Sets an optional request timeout. @@ -134,7 +140,7 @@ impl Outbound { )) // Strips headers that may be set by this proxy and add an outbound // canonical-dst-header. The response body is boxed unify the profile - // stack's response type. withthat of to endpoint stack. + // stack's response type with that of to endpoint stack. .push(http::NewHeaderFromTarget::::layer()) .push_on_response(svc::layers().push(http::BoxResponse::layer())) .instrument(|l: &Logical| debug_span!("logical", dst = %l.logical_addr)) diff --git a/linkerd/http-box/src/erase_request.rs b/linkerd/http-box/src/erase_request.rs new file mode 100644 index 0000000000..1c4181ee44 --- /dev/null +++ b/linkerd/http-box/src/erase_request.rs @@ -0,0 +1,84 @@ +//! A middleware that boxes HTTP request bodies. + +use crate::BoxBody; +use linkerd_error::Error; +use linkerd_stack::{layer, Proxy}; +use std::task::{Context, Poll}; + +/// Boxes request bodies, erasing the original type. +/// +/// This is *very* similar to the [`BoxRequest`] middleware. However, that +/// middleware is generic over a specific body type that is erased. A given +/// instance of `EraseRequest` can only erase the type of one particular `Body` +/// type, while this middleware will erase bodies of *any* type. +/// +/// An astute reader may ask, why not simply replace `BoxRequest` with this +/// middleware, if it is a more flexible superset of the same behavior? The +/// answer is that in many cases, the use of this more flexible middleware +/// renders request body types uninferrable. If all `BoxRequest`s in the stack +/// are replaced with `EraseRequest`, suddenly a great deal of +/// `check_new_service` and `check_service` checks will require explicit +/// annotations for the pre-erasure body type. This is not great. +/// +/// Instead, this type is implemented separately and should be used only when a +/// stack must be able to implement `Service>` for *multiple +/// distinct values of `B`*. +#[derive(Debug)] +pub struct EraseRequest(S); + +impl EraseRequest { + pub fn new(inner: S) -> Self { + Self(inner) + } + + pub fn layer() -> impl layer::Layer + Clone + Copy { + layer::mk(Self::new) + } +} + +impl Clone for EraseRequest { + fn clone(&self) -> Self { + EraseRequest(self.0.clone()) + } +} + +impl tower::Service> for EraseRequest +where + B: http_body::Body + Send + 'static, + B::Data: Send + 'static, + B::Error: Into, + S: tower::Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: http::Request) -> Self::Future { + self.0.call(req.map(BoxBody::new)) + } +} + +impl Proxy, S> for EraseRequest

+where + B: http_body::Body + Send + 'static, + B::Data: Send + 'static, + B::Error: Into, + S: tower::Service, + P: Proxy, S>, +{ + type Request = P::Request; + type Response = P::Response; + type Error = P::Error; + type Future = P::Future; + + #[inline] + fn proxy(&self, inner: &mut S, req: http::Request) -> Self::Future { + self.0.proxy(inner, req.map(BoxBody::new)) + } +} diff --git a/linkerd/http-box/src/lib.rs b/linkerd/http-box/src/lib.rs index 49a8ca7078..062e37bf3f 100644 --- a/linkerd/http-box/src/lib.rs +++ b/linkerd/http-box/src/lib.rs @@ -3,11 +3,13 @@ #![allow(clippy::inconsistent_struct_constructor)] mod body; +mod erase_request; mod request; mod response; pub use self::{ body::{BoxBody, Data}, + erase_request::EraseRequest, request::BoxRequest, response::BoxResponse, }; diff --git a/linkerd/http-box/src/request.rs b/linkerd/http-box/src/request.rs index 30b4340301..714f2a206f 100644 --- a/linkerd/http-box/src/request.rs +++ b/linkerd/http-box/src/request.rs @@ -1,6 +1,6 @@ //! A middleware that boxes HTTP request bodies. -use crate::BoxBody; +use crate::{erase_request::EraseRequest, BoxBody}; use linkerd_error::Error; use linkerd_stack::layer; use std::{ @@ -17,6 +17,13 @@ impl BoxRequest { } } +impl BoxRequest { + /// Constructs a boxing layer that erases the inner request type with [`EraseRequest`]. + pub fn erased() -> impl layer::Layer> + Clone + Copy { + EraseRequest::layer() + } +} + impl Clone for BoxRequest { fn clone(&self) -> Self { BoxRequest(self.0.clone(), self.1) @@ -34,10 +41,12 @@ where type Error = S::Error; type Future = S::Future; + #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } + #[inline] fn call(&mut self, req: http::Request) -> Self::Future { self.0.call(req.map(BoxBody::new)) } diff --git a/linkerd/http-box/src/response.rs b/linkerd/http-box/src/response.rs index a3ad21f840..73e88e608a 100644 --- a/linkerd/http-box/src/response.rs +++ b/linkerd/http-box/src/response.rs @@ -26,10 +26,12 @@ where type Error = S::Error; type Future = future::MapOk Self::Response>; + #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } + #[inline] fn call(&mut self, req: Req) -> Self::Future { self.0.call(req).map_ok(|rsp| rsp.map(BoxBody::new)) } diff --git a/linkerd/http-retry/Cargo.toml b/linkerd/http-retry/Cargo.toml index 655fcf2fae..d635ab0d67 100644 --- a/linkerd/http-retry/Cargo.toml +++ b/linkerd/http-retry/Cargo.toml @@ -12,12 +12,10 @@ futures = { version = "0.3", default-features = false } http-body = "0.4" http = "0.2" linkerd-error = { path = "../error" } -linkerd-http-box = { path = "../http-box" } -linkerd-stack = { path = "../stack" } pin-project = "1" parking_lot = "0.11" -tower = { version = "0.4.7", default-features = false, features = ["retry", "util"] } tracing = "0.1.23" +thiserror = "1" [dev-dependencies] hyper = "0.14" diff --git a/linkerd/http-retry/src/lib.rs b/linkerd/http-retry/src/lib.rs index 74616ccceb..df16deead1 100644 --- a/linkerd/http-retry/src/lib.rs +++ b/linkerd/http-retry/src/lib.rs @@ -3,160 +3,939 @@ #![allow(clippy::inconsistent_struct_constructor)] #![allow(clippy::type_complexity)] +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http::HeaderMap; +use http_body::Body; use linkerd_error::Error; -use linkerd_http_box::BoxBody; -use linkerd_stack::{NewService, Proxy, ProxyService}; -use pin_project::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -pub use tower::retry::{budget::Budget, Policy}; -use tower::util::{Oneshot, ServiceExt}; -use tracing::trace; - -pub mod replay; -pub use self::replay::ReplayBody; - -/// A strategy for obtaining per-target retry polices. -pub trait NewPolicy { - type Policy; - - fn new_policy(&self, target: &T) -> Option; -} +use parking_lot::Mutex; +use std::{collections::VecDeque, io::IoSlice, pin::Pin, sync::Arc, task::Context, task::Poll}; +use thiserror::Error; -/// A layer that applies per-target retry polcies. +/// Wraps an HTTP body type and lazily buffers data as it is read from the inner +/// body. +/// +/// When this body is dropped, if a clone exists, any buffered data is shared +/// with its cloned. The first clone to be polled will take ownership over the +/// data until it is dropped. When *that* clone is dropped, the buffered data +/// --- including any new data read from the body by the clone, if the body has +/// not yet completed --- will be shared with any remaining clones. /// -/// Composes `NewService`s that produce a `Proxy`. -#[derive(Clone, Debug)] -pub struct NewRetryLayer

{ - new_policy: P, +/// The buffered data can then be used to retry the request if the original +/// request fails. +#[derive(Debug)] +pub struct ReplayBody { + /// Buffered state owned by this body if it is actively being polled. If + /// this body has been polled and no other body owned the state, this will + /// be `Some`. + state: Option>, + + /// Copy of the state shared across all clones. When the active clone is + /// dropped, it moves its state back into the shared state to be taken by the + /// next clone to be polled. + shared: Arc>, + + /// Should this clone replay the buffered body from the shared state before + /// polling the initial body? + replay_body: bool, + + /// Should this clone replay trailers from the shared state? + replay_trailers: bool, +} + +#[derive(Debug, Error)] +#[error("replay body discarded after reaching maximum buffered bytes limit")] +pub struct Capped; + +/// Data returned by `ReplayBody`'s `http_body::Body` implementation is either +/// `Bytes` returned by the initial body, or a list of all `Bytes` chunks +/// returned by the initial body (when replaying it). +#[derive(Debug)] +pub enum Data { + Initial(Bytes), + Replay(BufList), } -#[derive(Clone, Debug)] -pub struct NewRetry { - new_policy: P, - inner: N, +/// Body data composed of multiple `Bytes` chunks. +#[derive(Clone, Debug, Default)] +pub struct BufList { + bufs: VecDeque, } -#[derive(Clone, Debug)] -pub struct Retry { - policy: Option

, - inner: S, +#[derive(Debug)] +struct SharedState { + body: Mutex>>, + /// Did the initial body return `true` from `is_end_stream` before it was + /// ever polled? If so, always return `true`; the body is completely empty. + /// + /// We store this separately so that clones of a totally empty body can + /// always return `true` from `is_end_stream` even when they don't own the + /// shared state. + was_empty: bool, } -pub trait CanRetry { - /// Returns `true` if a request can be retried. - fn can_retry(&self, req: &http::Request) -> bool; +#[derive(Debug)] +struct BodyState { + buf: BufList, + trailers: Option, + rest: Option, + is_completed: bool, + + /// Maxiumum number of bytes to buffer. + max_bytes: usize, } -#[pin_project(project = ResponseFutureProj)] -pub enum ResponseFuture +// === impl ReplayBody === + +impl ReplayBody { + /// Wraps an initial `Body` in a `ReplayBody`. + /// + /// In order to prevent unbounded buffering, this takes a maximum number of + /// bytes to buffer as a second parameter. If more than than that number of + /// bytes would be buffered, the buffered data is discarded and any + /// subsequent clones of this body will fail. However, the *currently + /// active* clone of the body is allowed to continue without erroring. It + /// will simply stop buffering any additional data for retries. + pub fn new(body: B, max_bytes: usize) -> Self { + let was_empty = body.is_end_stream(); + Self { + state: Some(BodyState { + buf: Default::default(), + trailers: None, + rest: Some(body), + is_completed: false, + max_bytes: max_bytes + 1, + }), + shared: Arc::new(SharedState { + body: Mutex::new(None), + was_empty, + }), + // The initial `ReplayBody` has nothing to replay + replay_body: false, + replay_trailers: false, + } + } + + /// Mutably borrows the body state if this clone currently owns it, + /// or else tries to acquire it from the shared state. + /// + /// # Panics + /// + /// This panics if another clone has currently acquired the state, based on + /// the assumption that a retry body will not be polled until the previous + /// request has been dropped. + fn acquire_state<'a>( + state: &'a mut Option>, + shared: &Mutex>>, + ) -> &'a mut BodyState { + state.get_or_insert_with(|| shared.lock().take().expect("missing body state")) + } +} + +impl Body for ReplayBody where - R: tower::retry::Policy>, P::Response, Error> + Clone, - P: Proxy, S> + Clone, - S: tower::Service + Clone, - S::Error: Into, + B: Body + Unpin, + B::Error: Into, { - Disabled(#[pin] P::Future), - Retry( - #[pin] - Oneshot< - tower::retry::Retry< - R, - tower::util::MapRequest< - ProxyService, - fn(http::Request>) -> http::Request, - >, - >, - http::Request>, - >, - ), -} + type Data = Data; + type Error = Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.get_mut(); + let state = Self::acquire_state(&mut this.state, &this.shared.body); + // Move these out to avoid mutable borrow issues in the `map` closure + // when polling the inner body. + tracing::trace!( + replay_body = this.replay_body, + buf.has_remaining = state.buf.has_remaining(), + body.is_completed = state.is_completed, + body.max_bytes_remaining = state.max_bytes, + "Replay::poll_data" + ); + + // If we haven't replayed the buffer yet, and its not empty, return the + // buffered data first. + if this.replay_body { + if state.buf.has_remaining() { + tracing::trace!("replaying body"); + // Don't return the buffered data again on the next poll. + this.replay_body = false; + return Poll::Ready(Some(Ok(Data::Replay(state.buf.clone())))); + } + + if state.is_capped() { + tracing::trace!("cannot replay buffered body, maximum buffer length reached"); + return Poll::Ready(Some(Err(Capped.into()))); + } + } + + // If the inner body has previously ended, don't poll it again. + // + // NOTE(eliza): we would expect the inner body to just happily return + // `None` multiple times here, but `hyper::Body::channel` (which we use + // in the tests) will panic if it is polled after returning `None`, so + // we have to special-case this. :/ + if state.is_completed { + return Poll::Ready(None); + } -// === impl NewRetryLayer === + // If there's more data in the initial body, poll that... + if let Some(rest) = state.rest.as_mut() { + tracing::trace!("Polling initial body"); + let opt = futures::ready!(Pin::new(rest).poll_data(cx)); + + // If the body has ended, remember that so that future clones will + // not try polling it again --- some `Body` types will panic if they + // are polled after returning `None`. + if opt.is_none() { + tracing::trace!("Initial body completed"); + state.is_completed = true; + } + return Poll::Ready(opt.map(|ok| { + ok.map(|mut data| { + // If we have buffered the maximum number of bytes, allow + // *this* body to continue, but don't buffer any more. + let length = data.remaining(); + state.max_bytes = state.max_bytes.saturating_sub(length); + if state.is_capped() { + // If there's data in the buffer, discard it now, since + // we won't allow any clones to have a complete body. + if state.buf.has_remaining() { + tracing::debug!( + buf.size = state.buf.remaining(), + "buffered maximum capacity, discarding buffer" + ); + state.buf = Default::default(); + } + return Data::Initial(data.copy_to_bytes(length)); + } + + if state.is_capped() { + return Data::Initial(data.copy_to_bytes(length)); + } + + // Buffer and return the bytes + Data::Initial(state.buf.push_chunk(data)) + }) + .map_err(Into::into) + })); + } + + // Otherwise, guess we're done! + Poll::Ready(None) + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let this = self.get_mut(); + let state = Self::acquire_state(&mut this.state, &this.shared.body); + tracing::trace!( + replay_trailers = this.replay_trailers, + "Replay::poll_trailers" + ); + + if this.replay_trailers { + this.replay_trailers = false; + if let Some(ref trailers) = state.trailers { + tracing::trace!("Replaying trailers"); + return Poll::Ready(Ok(Some(trailers.clone()))); + } + } + + if let Some(rest) = state.rest.as_mut() { + // If the inner body has previously ended, don't poll it again. + if !rest.is_end_stream() { + let res = futures::ready!(Pin::new(rest).poll_trailers(cx)).map(|tlrs| { + if state.trailers.is_none() { + state.trailers = tlrs.clone(); + } + tlrs + }); + return Poll::Ready(res.map_err(Into::into)); + } + } + + Poll::Ready(Ok(None)) + } + + fn is_end_stream(&self) -> bool { + // if the initial body was EOS as soon as it was wrapped, then we are + // empty. + if self.shared.was_empty { + return true; + } + + let is_inner_eos = self + .state + .as_ref() + .and_then(|state| state.rest.as_ref().map(Body::is_end_stream)) + .unwrap_or(false); + + // if this body has data or trailers remaining to play back, it + // is not EOS + !self.replay_body && !self.replay_trailers + // if we have replayed everything, the initial body may + // still have data remaining, so ask it + && is_inner_eos + } + + fn size_hint(&self) -> http_body::SizeHint { + let mut hint = http_body::SizeHint::default(); + if let Some(ref state) = self.state { + let rem = state.buf.remaining() as u64; + + // Have we read the entire body? If so, the size is exactly the size + // of the buffer. + if state.is_completed { + return http_body::SizeHint::with_exact(rem); + } + + // Otherwise, the size is the size of the current buffer plus the + // size hint returned by the inner body. + let (rest_lower, rest_upper) = state + .rest + .as_ref() + .map(|rest| { + let hint = rest.size_hint(); + (hint.lower(), hint.upper().unwrap_or(0)) + }) + .unwrap_or_default(); + hint.set_lower(rem + rest_lower); + hint.set_upper(rem + rest_upper); + } -impl

NewRetryLayer

{ - pub fn new(new_policy: P) -> Self { - Self { new_policy } + hint } } -impl tower::layer::Layer for NewRetryLayer

{ - type Service = NewRetry; +impl Clone for ReplayBody { + fn clone(&self) -> Self { + Self { + state: None, + shared: self.shared.clone(), + // The clone should try to replay from the shared state before + // reading any additional data from the initial body. + replay_body: true, + replay_trailers: true, + } + } +} - fn layer(&self, inner: N) -> Self::Service { - Self::Service { - inner, - new_policy: self.new_policy.clone(), +impl Drop for ReplayBody { + fn drop(&mut self) { + // If this clone owned the shared state, put it back.`s + if let Some(state) = self.state.take() { + *self.shared.body.lock() = Some(state); } } } -// === impl NewRetry === +// === impl Data === -impl NewService for NewRetry -where - N: NewService, - P: NewPolicy, -{ - type Service = Retry; +impl Buf for Data { + #[inline] + fn remaining(&self) -> usize { + match self { + Data::Initial(buf) => buf.remaining(), + Data::Replay(bufs) => bufs.remaining(), + } + } - fn new_service(&mut self, target: T) -> Self::Service { - // Determine if there is a retry policy for the given target. - let policy = self.new_policy.new_policy(&target); + #[inline] + fn chunk(&self) -> &[u8] { + match self { + Data::Initial(buf) => buf.chunk(), + Data::Replay(bufs) => bufs.chunk(), + } + } - let inner = self.inner.new_service(target); - Retry { policy, inner } + #[inline] + fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize { + match self { + Data::Initial(buf) => buf.chunks_vectored(iovs), + Data::Replay(bufs) => bufs.chunks_vectored(iovs), + } + } + + #[inline] + fn advance(&mut self, amt: usize) { + match self { + Data::Initial(buf) => buf.advance(amt), + Data::Replay(bufs) => bufs.advance(amt), + } + } + + #[inline] + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + match self { + Data::Initial(buf) => buf.copy_to_bytes(len), + Data::Replay(bufs) => bufs.copy_to_bytes(len), + } } } -// === impl Retry === +// === impl BufList === -impl Proxy, S> for Retry -where - R: tower::retry::Policy>, P::Response, Error> + CanRetry + Clone, - P: Proxy, S> + Clone, - S: tower::Service + Clone, - S::Error: Into, - B: http_body::Body + Unpin + Send + 'static, - B::Data: Send, - B::Error: Into, -{ - type Request = P::Request; - type Response = P::Response; - type Error = Error; - type Future = ResponseFuture; - - fn proxy(&self, svc: &mut S, req: http::Request) -> Self::Future { - trace!(retryable = %self.policy.is_some()); - - if let Some(policy) = self.policy.as_ref() { - if policy.can_retry(&req) { - let inner = self.inner.clone().wrap_service(svc.clone()).map_request( - (|req: http::Request>| req.map(BoxBody::new)) as fn(_) -> _, - ); - let retry = tower::retry::Retry::new(policy.clone(), inner); - return ResponseFuture::Retry(retry.oneshot(req.map(ReplayBody::new))); +impl BufList { + fn push_chunk(&mut self, mut data: impl Buf) -> Bytes { + let len = data.remaining(); + // `data` is (almost) certainly a `Bytes`, so `copy_to_bytes` should + // internally be a cheap refcount bump almost all of the time. + // But, if it isn't, this will copy it to a `Bytes` that we can + // now clone. + let bytes = data.copy_to_bytes(len); + // Buffer a clone of the bytes read on this poll. + self.bufs.push_back(bytes.clone()); + // Return the bytes + bytes + } +} + +impl Buf for BufList { + fn remaining(&self) -> usize { + self.bufs.iter().map(Buf::remaining).sum() + } + + fn chunk(&self) -> &[u8] { + self.bufs.front().map(Buf::chunk).unwrap_or(&[]) + } + + fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize { + // Are there more than zero iovecs to write to? + if iovs.is_empty() { + return 0; + } + + // Loop over the buffers in the replay buffer list, and try to fill as + // many iovecs as we can from each buffer. + let mut filled = 0; + for buf in &self.bufs { + filled += buf.chunks_vectored(&mut iovs[filled..]); + if filled == iovs.len() { + return filled; + } + } + + filled + } + + fn advance(&mut self, mut amt: usize) { + while amt > 0 { + let rem = self.bufs[0].remaining(); + // If the amount to advance by is less than the first buffer in + // the buffer list, advance that buffer's cursor by `amt`, + // and we're done. + if rem > amt { + self.bufs[0].advance(amt); + return; + } + + // Otherwise, advance the first buffer to its end, and + // continue. + self.bufs[0].advance(rem); + amt -= rem; + + self.bufs.pop_front(); + } + } + + fn copy_to_bytes(&mut self, len: usize) -> Bytes { + // If the length of the requested `Bytes` is <= the length of the front + // buffer, we can just use its `copy_to_bytes` implementation (which is + // just a reference count bump). + match self.bufs.front_mut() { + Some(first) if len <= first.remaining() => { + let buf = first.copy_to_bytes(len); + // If we consumed the first buffer, also advance our "cursor" by + // popping it. + if first.remaining() == 0 { + self.bufs.pop_front(); + } + + buf + } + _ => { + assert!(len <= self.remaining(), "`len` greater than remaining"); + let mut buf = BytesMut::with_capacity(len); + buf.put(self.take(len)); + buf.freeze() } } + } +} - ResponseFuture::Disabled(self.inner.proxy(svc, req.map(BoxBody::new))) +// === impl BodyState === + +impl BodyState { + #[inline] + fn is_capped(&self) -> bool { + self.max_bytes == 0 } } -impl Future for ResponseFuture -where - R: tower::retry::Policy>, P::Response, Error> + Clone, - P: Proxy, S> + Clone, - S: tower::Service + Clone, - S::Error: Into, -{ - type Output = Result; +#[cfg(test)] +mod tests { + use super::*; + use http::{HeaderMap, HeaderValue}; + + #[tokio::test] + async fn replays_one_chunk() { + let Test { + mut tx, + initial, + replay, + _trace, + } = Test::new(); + tx.send_data("hello world").await; + drop(tx); + + let initial = body_to_string(initial).await; + assert_eq!(initial, "hello world"); + + let replay = body_to_string(replay).await; + assert_eq!(replay, "hello world"); + } + + #[tokio::test] + async fn replays_several_chunks() { + let Test { + mut tx, + initial, + replay, + _trace, + } = Test::new(); + + tokio::spawn(async move { + tx.send_data("hello").await; + tx.send_data(" world").await; + tx.send_data(", have lots").await; + tx.send_data(" of fun!").await; + }); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { - ResponseFutureProj::Disabled(f) => f.poll(cx).map_err(Into::into), - ResponseFutureProj::Retry(f) => f.poll(cx).map_err(Into::into), + let initial = body_to_string(initial).await; + assert_eq!(initial, "hello world, have lots of fun!"); + + let replay = body_to_string(replay).await; + assert_eq!(replay, "hello world, have lots of fun!"); + } + + #[tokio::test] + async fn replays_trailers() { + let Test { + mut tx, + mut initial, + mut replay, + _trace, + } = Test::new(); + + let mut tlrs = HeaderMap::new(); + tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); + tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); + + tx.send_data("hello world").await; + tx.send_trailers(tlrs.clone()).await; + drop(tx); + + while initial.data().await.is_some() { + // do nothing + } + let initial_tlrs = initial.trailers().await.expect("trailers should not error"); + assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); + + // drop the initial body to send the data to the replay + drop(initial); + + while replay.data().await.is_some() { + // do nothing } + let replay_tlrs = replay.trailers().await.expect("trailers should not error"); + assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); + } + + #[tokio::test] + async fn trailers_only() { + let Test { + mut tx, + mut initial, + mut replay, + _trace, + } = Test::new(); + + let mut tlrs = HeaderMap::new(); + tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); + tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); + + tx.send_trailers(tlrs.clone()).await; + + drop(tx); + + assert!(dbg!(initial.data().await).is_none(), "no data in body"); + let initial_tlrs = initial.trailers().await.expect("trailers should not error"); + assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); + + // drop the initial body to send the data to the replay + drop(initial); + + assert!(dbg!(replay.data().await).is_none(), "no data in body"); + let replay_tlrs = replay.trailers().await.expect("trailers should not error"); + assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); + } + + #[tokio::test(flavor = "current_thread")] + async fn switches_with_body_remaining() { + // This simulates a case where the server returns an error _before_ the + // entire body has been read. + let Test { + mut tx, + mut initial, + mut replay, + _trace, + } = Test::new(); + + tx.send_data("hello").await; + assert_eq!(chunk(&mut initial).await.unwrap(), "hello"); + + tx.send_data(" world").await; + assert_eq!(chunk(&mut initial).await.unwrap(), " world"); + + // drop the initial body to send the data to the replay + drop(initial); + tracing::info!("dropped initial body"); + + tokio::spawn(async move { + tx.send_data(", have lots of fun").await; + tx.send_trailers(HeaderMap::new()).await; + }); + + assert_eq!( + body_to_string(&mut replay).await, + "hello world, have lots of fun" + ); + } + + #[tokio::test(flavor = "current_thread")] + async fn multiple_replays() { + let Test { + mut tx, + mut initial, + mut replay, + _trace, + } = Test::new(); + + let mut tlrs = HeaderMap::new(); + tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); + tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); + + let tlrs2 = tlrs.clone(); + tokio::spawn(async move { + tx.send_data("hello").await; + tx.send_data(" world").await; + tx.send_trailers(tlrs2).await; + }); + + assert_eq!(body_to_string(&mut initial).await, "hello world"); + + let initial_tlrs = initial.trailers().await.expect("trailers should not error"); + assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); + + // drop the initial body to send the data to the replay + drop(initial); + + let mut replay2 = replay.clone(); + assert_eq!(body_to_string(&mut replay).await, "hello world"); + + let replay_tlrs = replay.trailers().await.expect("trailers should not error"); + assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); + + // drop the initial body to send the data to the replay + drop(replay); + + assert_eq!(body_to_string(&mut replay2).await, "hello world"); + + let replay2_tlrs = replay2.trailers().await.expect("trailers should not error"); + assert_eq!(replay2_tlrs.as_ref(), Some(&tlrs)); + } + + #[tokio::test(flavor = "current_thread")] + async fn multiple_incomplete_replays() { + let Test { + mut tx, + mut initial, + mut replay, + _trace, + } = Test::new(); + + let mut tlrs = HeaderMap::new(); + tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); + tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); + + tx.send_data("hello").await; + assert_eq!(chunk(&mut initial).await.unwrap(), "hello"); + + // drop the initial body to send the data to the replay + drop(initial); + tracing::info!("dropped initial body"); + + let mut replay2 = replay.clone(); + + tx.send_data(" world").await; + assert_eq!(chunk(&mut replay).await.unwrap(), "hello"); + assert_eq!(chunk(&mut replay).await.unwrap(), " world"); + + // drop the replay body to send the data to the second replay + drop(replay); + tracing::info!("dropped first replay body"); + + let tlrs2 = tlrs.clone(); + tokio::spawn(async move { + tx.send_data(", have lots").await; + tx.send_data(" of fun!").await; + tx.send_trailers(tlrs2).await; + }); + + assert_eq!( + body_to_string(&mut replay2).await, + "hello world, have lots of fun!" + ); + + let replay2_tlrs = replay2.trailers().await.expect("trailers should not error"); + assert_eq!(replay2_tlrs.as_ref(), Some(&tlrs)); + } + + #[tokio::test(flavor = "current_thread")] + async fn drop_clone_early() { + let Test { + mut tx, + mut initial, + mut replay, + _trace, + } = Test::new(); + + let mut tlrs = HeaderMap::new(); + tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); + tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); + + let tlrs2 = tlrs.clone(); + tokio::spawn(async move { + tx.send_data("hello").await; + tx.send_data(" world").await; + tx.send_trailers(tlrs2).await; + }); + + assert_eq!(body_to_string(&mut initial).await, "hello world"); + + let initial_tlrs = initial.trailers().await.expect("trailers should not error"); + assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); + + // drop the initial body to send the data to the replay + drop(initial); + + // clone the body again and then drop it + let replay2 = replay.clone(); + drop(replay2); + + assert_eq!(body_to_string(&mut replay).await, "hello world"); + let replay_tlrs = replay.trailers().await.expect("trailers should not error"); + assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); + } + + // This test is specifically for behavior across clones, so the clippy lint + // is wrong here. + #[allow(clippy::redundant_clone)] + #[test] + fn empty_body_is_always_eos() { + // If the initial body was empty, every clone should always return + // `true` from `is_end_stream`. + let initial = ReplayBody::new(hyper::Body::empty(), 64 * 1024); + assert!(initial.is_end_stream()); + + let replay = initial.clone(); + assert!(replay.is_end_stream()); + + let replay2 = replay.clone(); + assert!(replay2.is_end_stream()); + } + + #[tokio::test(flavor = "current_thread")] + async fn eos_only_when_fully_replayed() { + // Test that each clone of a body is not EOS until the data has been + // fully replayed. + let mut initial = ReplayBody::new(hyper::Body::from("hello world"), 64 * 1024); + let mut replay = initial.clone(); + + body_to_string(&mut initial).await; + assert!(!replay.is_end_stream()); + + initial.trailers().await.expect("trailers should not error"); + assert!(initial.is_end_stream()); + assert!(!replay.is_end_stream()); + + // drop the initial body to send the data to the replay + drop(initial); + + assert!(!replay.is_end_stream()); + + body_to_string(&mut replay).await; + assert!(!replay.is_end_stream()); + + replay.trailers().await.expect("trailers should not error"); + assert!(replay.is_end_stream()); + + // Even if we clone a body _after_ it has been driven to EOS, the clone + // must not be EOS. + let mut replay2 = replay.clone(); + assert!(!replay2.is_end_stream()); + + // drop the initial body to send the data to the replay + drop(replay); + + body_to_string(&mut replay2).await; + assert!(!replay2.is_end_stream()); + + replay2.trailers().await.expect("trailers should not error"); + assert!(replay2.is_end_stream()); + } + + #[tokio::test(flavor = "current_thread")] + async fn caps_buffer() { + // Test that, when the initial body is longer than the preconfigured + // cap, we allow the request to continue, but stop buffering. The + // initial body will complete, but the replay will immediately fail. + let _trace = linkerd_tracing::test::with_default_filter("linkerd_http_retry=trace"); + + let (mut tx, body) = hyper::Body::channel(); + let mut initial = ReplayBody::new(body, 8); + let mut replay = initial.clone(); + + // Send enough data to reach the cap + tx.send_data(Bytes::from("aaaaaaaa")).await.unwrap(); + assert_eq!(chunk(&mut initial).await, Some("aaaaaaaa".to_string())); + + // Further chunks are still forwarded on the initial body + tx.send_data(Bytes::from("bbbbbbbb")).await.unwrap(); + assert_eq!(chunk(&mut initial).await, Some("bbbbbbbb".to_string())); + + drop(initial); + + // The request's replay should error, since we discarded the buffer when + // we hit the cap. + let err = replay + .data() + .await + .expect("replay must yield Some(Err(..)) when capped") + .expect_err("replay must error when cappped"); + assert!(err.is::()) + } + + #[tokio::test(flavor = "current_thread")] + async fn caps_across_replays() { + // Test that, when the initial body is longer than the preconfigured + // cap, we allow the request to continue, but stop buffering. + let _trace = linkerd_tracing::test::with_default_filter("linkerd_http_retry=debug"); + + let (mut tx, body) = hyper::Body::channel(); + let mut initial = ReplayBody::new(body, 8); + let mut replay = initial.clone(); + + // Send enough data to reach the cap + tx.send_data(Bytes::from("aaaaaaaa")).await.unwrap(); + assert_eq!(chunk(&mut initial).await, Some("aaaaaaaa".to_string())); + drop(initial); + + let mut replay2 = replay.clone(); + + // The replay will reach the cap, but it should still return data from + // the original body. + tx.send_data(Bytes::from("bbbbbbbb")).await.unwrap(); + assert_eq!(chunk(&mut replay).await, Some("aaaaaaaa".to_string())); + assert_eq!(chunk(&mut replay).await, Some("bbbbbbbb".to_string())); + drop(replay); + + // The second replay will fail, though, because the buffer was discarded. + let err = replay2 + .data() + .await + .expect("replay must yield Some(Err(..)) when capped") + .expect_err("replay must error when cappped"); + assert!(err.is::()) + } + + struct Test { + tx: Tx, + initial: ReplayBody, + replay: ReplayBody, + _trace: tracing::subscriber::DefaultGuard, + } + + struct Tx(hyper::body::Sender); + + impl Test { + fn new() -> Self { + let (tx, body) = hyper::Body::channel(); + let initial = ReplayBody::new(body, 64 * 1024); + let replay = initial.clone(); + Self { + tx: Tx(tx), + initial, + replay, + _trace: linkerd_tracing::test::with_default_filter("linkerd_http_retry=debug"), + } + } + } + + impl Tx { + #[tracing::instrument(skip(self))] + async fn send_data(&mut self, data: impl Into + std::fmt::Debug) { + let data = data.into(); + tracing::trace!("sending data..."); + self.0.send_data(data).await.expect("rx is not dropped"); + tracing::info!("sent data"); + } + + #[tracing::instrument(skip(self))] + async fn send_trailers(&mut self, trailers: HeaderMap) { + tracing::trace!("sending trailers..."); + self.0 + .send_trailers(trailers) + .await + .expect("rx is not dropped"); + tracing::info!("sent trailers"); + } + } + + async fn chunk(body: &mut T) -> Option + where + T: http_body::Body + Unpin, + { + tracing::trace!("waiting for a body chunk..."); + let chunk = body + .data() + .await + .map(|res| res.map_err(|_| ()).unwrap()) + .map(string); + tracing::info!(?chunk); + chunk + } + + async fn body_to_string(mut body: T) -> String + where + T: http_body::Body + Unpin, + T::Error: std::fmt::Debug, + { + let mut s = String::new(); + while let Some(chunk) = chunk(&mut body).await { + s.push_str(&chunk[..]); + } + tracing::info!(body = ?s, "no more data"); + s + } + + fn string(mut data: impl Buf) -> String { + let bytes = data.copy_to_bytes(data.remaining()); + String::from_utf8(bytes.to_vec()).unwrap() } } diff --git a/linkerd/http-retry/src/replay.rs b/linkerd/http-retry/src/replay.rs deleted file mode 100644 index 6e3362b6a3..0000000000 --- a/linkerd/http-retry/src/replay.rs +++ /dev/null @@ -1,807 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use http::HeaderMap; -use http_body::Body; -use parking_lot::Mutex; -use std::{collections::VecDeque, io::IoSlice, pin::Pin, sync::Arc, task::Context, task::Poll}; - -/// Wraps an HTTP body type and lazily buffers data as it is read from the inner -/// body. -/// -/// When this body is dropped, if a clone exists, any buffered data is shared -/// with its cloned. The first clone to be polled will take ownership over the -/// data until it is dropped. When *that* clone is dropped, the buffered data -/// --- including any new data read from the body by the clone, if the body has -/// not yet completed --- will be shared with any remaining clones. -/// -/// The buffered data can then be used to retry the request if the original -/// request fails. -#[derive(Debug)] -pub struct ReplayBody { - /// Buffered state owned by this body if it is actively being polled. If - /// this body has been polled and no other body owned the state, this will - /// be `Some`. - state: Option>, - - /// Copy of the state shared across all clones. When the active clone is - /// dropped, it moves its state back into the shared state to be taken by the - /// next clone to be polled. - shared: Arc>, - - /// Should this clone replay the buffered body from the shared state before - /// polling the initial body? - replay_body: bool, - - /// Should this clone replay trailers from the shared state? - replay_trailers: bool, -} - -/// Data returned by `ReplayBody`'s `http_body::Body` implementation is either -/// `Bytes` returned by the initial body, or a list of all `Bytes` chunks -/// returned by the initial body (when replaying it). -#[derive(Debug)] -pub enum Data { - Initial(Bytes), - Replay(BufList), -} - -/// Body data composed of multiple `Bytes` chunks. -#[derive(Clone, Debug, Default)] -pub struct BufList { - bufs: VecDeque, -} - -#[derive(Debug)] -struct SharedState { - body: Mutex>>, - /// Did the initial body return `true` from `is_end_stream` before it was - /// ever polled? If so, always return `true`; the body is completely empty. - /// - /// We store this separately so that clones of a totally empty body can - /// always return `true` from `is_end_stream` even when they don't own the - /// shared state. - was_empty: bool, -} - -#[derive(Debug)] -struct BodyState { - buf: BufList, - trailers: Option, - rest: Option, - is_completed: bool, -} - -// === impl ReplayBody === - -impl ReplayBody { - /// Wraps an initial `Body` in a `ReplayBody`. - pub fn new(body: B) -> Self { - let was_empty = body.is_end_stream(); - Self { - state: Some(BodyState { - buf: Default::default(), - trailers: None, - rest: Some(body), - is_completed: false, - }), - shared: Arc::new(SharedState { - body: Mutex::new(None), - was_empty, - }), - // The initial `ReplayBody` has nothing to replay - replay_body: false, - replay_trailers: false, - } - } - - /// Mutably borrows the body state if this clone currently owns it, - /// or else tries to acquire it from the shared state. - /// - /// # Panics - /// - /// This panics if another clone has currently acquired the state, based on - /// the assumption that a retry body will not be polled until the previous - /// request has been dropped. - fn acquire_state<'a>( - state: &'a mut Option>, - shared: &Mutex>>, - ) -> &'a mut BodyState { - state.get_or_insert_with(|| shared.lock().take().expect("missing body state")) - } -} - -impl Body for ReplayBody { - type Data = Data; - type Error = B::Error; - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let this = self.get_mut(); - let state = Self::acquire_state(&mut this.state, &this.shared.body); - tracing::trace!( - replay_body = this.replay_body, - buf.has_remaining = state.buf.has_remaining(), - body.is_completed = state.is_completed, - "Replay::poll_data" - ); - - // If we haven't replayed the buffer yet, and its not empty, return the - // buffered data first. - if this.replay_body && state.buf.has_remaining() { - tracing::trace!("replaying body"); - // Don't return the buffered data again on the next poll. - this.replay_body = false; - return Poll::Ready(Some(Ok(Data::Replay(state.buf.clone())))); - } - - // If the inner body has previously ended, don't poll it again. - // - // NOTE(eliza): we would expect the inner body to just happily return - // `None` multiple times here, but `hyper::Body::channel` (which we use - // in the tests) will panic if it is polled after returning `None`, so - // we have to special-case this. :/ - if state.is_completed { - return Poll::Ready(None); - } - - // If there's more data in the initial body, poll that... - if let Some(rest) = state.rest.as_mut() { - tracing::trace!("Polling initial body"); - let opt = futures::ready!(Pin::new(rest).poll_data(cx)); - - // If the body has ended, remember that so that future clones will - // not try polling it again --- some `Body` types will panic if they - // are polled after returning `None`. - if opt.is_none() { - tracing::trace!("Initial body completed"); - state.is_completed = true; - } - return Poll::Ready( - opt.map(|ok| ok.map(|data| Data::Initial(state.buf.push_chunk(data)))), - ); - } - - // Otherwise, guess we're done! - Poll::Ready(None) - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let this = self.get_mut(); - let state = Self::acquire_state(&mut this.state, &this.shared.body); - tracing::trace!( - replay_trailers = this.replay_trailers, - "Replay::poll_trailers" - ); - - if this.replay_trailers { - this.replay_trailers = false; - if let Some(ref trailers) = state.trailers { - tracing::trace!("Replaying trailers"); - return Poll::Ready(Ok(Some(trailers.clone()))); - } - } - - if let Some(rest) = state.rest.as_mut() { - // If the inner body has previously ended, don't poll it again. - if !rest.is_end_stream() { - let res = futures::ready!(Pin::new(rest).poll_trailers(cx)).map(|tlrs| { - if state.trailers.is_none() { - state.trailers = tlrs.clone(); - } - tlrs - }); - return Poll::Ready(res); - } - } - - Poll::Ready(Ok(None)) - } - - fn is_end_stream(&self) -> bool { - // if the initial body was EOS as soon as it was wrapped, then we are - // empty. - if self.shared.was_empty { - return true; - } - - let is_inner_eos = self - .state - .as_ref() - .and_then(|state| state.rest.as_ref().map(Body::is_end_stream)) - .unwrap_or(false); - - // if this body has data or trailers remaining to play back, it - // is not EOS - !self.replay_body && !self.replay_trailers - // if we have replayed everything, the initial body may - // still have data remaining, so ask it - && is_inner_eos - } - - fn size_hint(&self) -> http_body::SizeHint { - let mut hint = http_body::SizeHint::default(); - if let Some(ref state) = self.state { - let rem = state.buf.remaining() as u64; - - // Have we read the entire body? If so, the size is exactly the size - // of the buffer. - if state.is_completed { - return http_body::SizeHint::with_exact(rem); - } - - // Otherwise, the size is the size of the current buffer plus the - // size hint returned by the inner body. - let (rest_lower, rest_upper) = state - .rest - .as_ref() - .map(|rest| { - let hint = rest.size_hint(); - (hint.lower(), hint.upper().unwrap_or(0)) - }) - .unwrap_or_default(); - hint.set_lower(rem + rest_lower); - hint.set_upper(rem + rest_upper); - } - - hint - } -} - -impl Clone for ReplayBody { - fn clone(&self) -> Self { - Self { - state: None, - shared: self.shared.clone(), - // The clone should try to replay from the shared state before - // reading any additional data from the initial body. - replay_body: true, - replay_trailers: true, - } - } -} - -impl Drop for ReplayBody { - fn drop(&mut self) { - // If this clone owned the shared state, put it back.`s - if let Some(state) = self.state.take() { - *self.shared.body.lock() = Some(state); - } - } -} - -// === impl Data === - -impl Buf for Data { - #[inline] - fn remaining(&self) -> usize { - match self { - Data::Initial(buf) => buf.remaining(), - Data::Replay(bufs) => bufs.remaining(), - } - } - - #[inline] - fn chunk(&self) -> &[u8] { - match self { - Data::Initial(buf) => buf.chunk(), - Data::Replay(bufs) => bufs.chunk(), - } - } - - #[inline] - fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize { - match self { - Data::Initial(buf) => buf.chunks_vectored(iovs), - Data::Replay(bufs) => bufs.chunks_vectored(iovs), - } - } - - #[inline] - fn advance(&mut self, amt: usize) { - match self { - Data::Initial(buf) => buf.advance(amt), - Data::Replay(bufs) => bufs.advance(amt), - } - } - - #[inline] - fn copy_to_bytes(&mut self, len: usize) -> Bytes { - match self { - Data::Initial(buf) => buf.copy_to_bytes(len), - Data::Replay(bufs) => bufs.copy_to_bytes(len), - } - } -} - -// === impl BufList === - -impl BufList { - fn push_chunk(&mut self, mut data: impl Buf) -> Bytes { - let len = data.remaining(); - // `data` is (almost) certainly a `Bytes`, so `copy_to_bytes` should - // internally be a cheap refcount bump almost all of the time. - // But, if it isn't, this will copy it to a `Bytes` that we can - // now clone. - let bytes = data.copy_to_bytes(len); - // Buffer a clone of the bytes read on this poll. - self.bufs.push_back(bytes.clone()); - // Return the bytes - bytes - } -} - -impl Buf for BufList { - fn remaining(&self) -> usize { - self.bufs.iter().map(Buf::remaining).sum() - } - - fn chunk(&self) -> &[u8] { - self.bufs.front().map(Buf::chunk).unwrap_or(&[]) - } - - fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize { - // Are there more than zero iovecs to write to? - if iovs.is_empty() { - return 0; - } - - // Loop over the buffers in the replay buffer list, and try to fill as - // many iovecs as we can from each buffer. - let mut filled = 0; - for buf in &self.bufs { - filled += buf.chunks_vectored(&mut iovs[filled..]); - if filled == iovs.len() { - return filled; - } - } - - filled - } - - fn advance(&mut self, mut amt: usize) { - while amt > 0 { - let rem = self.bufs[0].remaining(); - // If the amount to advance by is less than the first buffer in - // the buffer list, advance that buffer's cursor by `amt`, - // and we're done. - if rem > amt { - self.bufs[0].advance(amt); - return; - } - - // Otherwise, advance the first buffer to its end, and - // continue. - self.bufs[0].advance(rem); - amt -= rem; - - self.bufs.pop_front(); - } - } - - fn copy_to_bytes(&mut self, len: usize) -> Bytes { - // If the length of the requested `Bytes` is <= the length of the front - // buffer, we can just use its `copy_to_bytes` implementation (which is - // just a reference count bump). - match self.bufs.front_mut() { - Some(first) if len <= first.remaining() => { - let buf = first.copy_to_bytes(len); - // if we consumed the first buffer, also advance our "cursor" by - // popping it. - if first.remaining() == 0 { - self.bufs.pop_front(); - } - - buf - } - _ => { - assert!(len <= self.remaining(), "`len` greater than remaining"); - let mut buf = BytesMut::with_capacity(len); - buf.put(self.take(len)); - buf.freeze() - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use http::{HeaderMap, HeaderValue}; - - #[tokio::test] - async fn replays_one_chunk() { - let Test { - mut tx, - initial, - replay, - _trace, - } = Test::new(); - tx.send_data("hello world").await; - drop(tx); - - let initial = body_to_string(initial).await; - assert_eq!(initial, "hello world"); - - let replay = body_to_string(replay).await; - assert_eq!(replay, "hello world"); - } - - #[tokio::test] - async fn replays_several_chunks() { - let Test { - mut tx, - initial, - replay, - _trace, - } = Test::new(); - - tokio::spawn(async move { - tx.send_data("hello").await; - tx.send_data(" world").await; - tx.send_data(", have lots").await; - tx.send_data(" of fun!").await; - }); - - let initial = body_to_string(initial).await; - assert_eq!(initial, "hello world, have lots of fun!"); - - let replay = body_to_string(replay).await; - assert_eq!(replay, "hello world, have lots of fun!"); - } - - #[tokio::test] - async fn replays_trailers() { - let Test { - mut tx, - mut initial, - mut replay, - _trace, - } = Test::new(); - - let mut tlrs = HeaderMap::new(); - tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); - tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); - - tx.send_data("hello world").await; - tx.send_trailers(tlrs.clone()).await; - drop(tx); - - while initial.data().await.is_some() { - // do nothing - } - let initial_tlrs = initial.trailers().await.expect("trailers should not error"); - assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); - - // drop the initial body to send the data to the replay - drop(initial); - - while replay.data().await.is_some() { - // do nothing - } - let replay_tlrs = replay.trailers().await.expect("trailers should not error"); - assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); - } - - #[tokio::test] - async fn trailers_only() { - let Test { - mut tx, - mut initial, - mut replay, - _trace, - } = Test::new(); - - let mut tlrs = HeaderMap::new(); - tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); - tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); - - tx.send_trailers(tlrs.clone()).await; - - drop(tx); - - assert!(dbg!(initial.data().await).is_none(), "no data in body"); - let initial_tlrs = initial.trailers().await.expect("trailers should not error"); - assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); - - // drop the initial body to send the data to the replay - drop(initial); - - assert!(dbg!(replay.data().await).is_none(), "no data in body"); - let replay_tlrs = replay.trailers().await.expect("trailers should not error"); - assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); - } - - #[tokio::test(flavor = "current_thread")] - async fn switches_with_body_remaining() { - // This simulates a case where the server returns an error _before_ the - // entire body has been read. - let Test { - mut tx, - mut initial, - mut replay, - _trace, - } = Test::new(); - - tx.send_data("hello").await; - assert_eq!(chunk(&mut initial).await.unwrap(), "hello"); - - tx.send_data(" world").await; - assert_eq!(chunk(&mut initial).await.unwrap(), " world"); - - // drop the initial body to send the data to the replay - drop(initial); - tracing::info!("dropped initial body"); - - tokio::spawn(async move { - tx.send_data(", have lots of fun").await; - tx.send_trailers(HeaderMap::new()).await; - }); - - assert_eq!( - body_to_string(&mut replay).await, - "hello world, have lots of fun" - ); - } - - #[tokio::test(flavor = "current_thread")] - async fn multiple_replays() { - let Test { - mut tx, - mut initial, - mut replay, - _trace, - } = Test::new(); - - let mut tlrs = HeaderMap::new(); - tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); - tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); - - let tlrs2 = tlrs.clone(); - tokio::spawn(async move { - tx.send_data("hello").await; - tx.send_data(" world").await; - tx.send_trailers(tlrs2).await; - }); - - assert_eq!(body_to_string(&mut initial).await, "hello world"); - - let initial_tlrs = initial.trailers().await.expect("trailers should not error"); - assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); - - // drop the initial body to send the data to the replay - drop(initial); - - let mut replay2 = replay.clone(); - assert_eq!(body_to_string(&mut replay).await, "hello world"); - - let replay_tlrs = replay.trailers().await.expect("trailers should not error"); - assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); - - // drop the initial body to send the data to the replay - drop(replay); - - assert_eq!(body_to_string(&mut replay2).await, "hello world"); - - let replay2_tlrs = replay2.trailers().await.expect("trailers should not error"); - assert_eq!(replay2_tlrs.as_ref(), Some(&tlrs)); - } - - #[tokio::test(flavor = "current_thread")] - async fn multiple_incomplete_replays() { - let Test { - mut tx, - mut initial, - mut replay, - _trace, - } = Test::new(); - - let mut tlrs = HeaderMap::new(); - tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); - tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); - - tx.send_data("hello").await; - assert_eq!(chunk(&mut initial).await.unwrap(), "hello"); - - // drop the initial body to send the data to the replay - drop(initial); - tracing::info!("dropped initial body"); - - let mut replay2 = replay.clone(); - - tx.send_data(" world").await; - assert_eq!(chunk(&mut replay).await.unwrap(), "hello"); - assert_eq!(chunk(&mut replay).await.unwrap(), " world"); - - // drop the replay body to send the data to the second replay - drop(replay); - tracing::info!("dropped first replay body"); - - let tlrs2 = tlrs.clone(); - tokio::spawn(async move { - tx.send_data(", have lots").await; - tx.send_data(" of fun!").await; - tx.send_trailers(tlrs2).await; - }); - - assert_eq!( - body_to_string(&mut replay2).await, - "hello world, have lots of fun!" - ); - - let replay2_tlrs = replay2.trailers().await.expect("trailers should not error"); - assert_eq!(replay2_tlrs.as_ref(), Some(&tlrs)); - } - - #[tokio::test(flavor = "current_thread")] - async fn drop_clone_early() { - let Test { - mut tx, - mut initial, - mut replay, - _trace, - } = Test::new(); - - let mut tlrs = HeaderMap::new(); - tlrs.insert("x-hello", HeaderValue::from_str("world").unwrap()); - tlrs.insert("x-foo", HeaderValue::from_str("bar").unwrap()); - - let tlrs2 = tlrs.clone(); - tokio::spawn(async move { - tx.send_data("hello").await; - tx.send_data(" world").await; - tx.send_trailers(tlrs2).await; - }); - - assert_eq!(body_to_string(&mut initial).await, "hello world"); - - let initial_tlrs = initial.trailers().await.expect("trailers should not error"); - assert_eq!(initial_tlrs.as_ref(), Some(&tlrs)); - - // drop the initial body to send the data to the replay - drop(initial); - - // clone the body again and then drop it - let replay2 = replay.clone(); - drop(replay2); - - assert_eq!(body_to_string(&mut replay).await, "hello world"); - let replay_tlrs = replay.trailers().await.expect("trailers should not error"); - assert_eq!(replay_tlrs.as_ref(), Some(&tlrs)); - } - - // This test is specifically for behavior across clones, so the clippy lint - // is wrong here. - #[allow(clippy::redundant_clone)] - #[test] - fn empty_body_is_always_eos() { - // If the initial body was empty, every clone should always return - // `true` from `is_end_stream`. - let initial = ReplayBody::new(hyper::Body::empty()); - assert!(initial.is_end_stream()); - - let replay = initial.clone(); - assert!(replay.is_end_stream()); - - let replay2 = replay.clone(); - assert!(replay2.is_end_stream()); - } - - #[tokio::test(flavor = "current_thread")] - async fn eos_only_when_fully_replayed() { - // Test that each clone of a body is not EOS until the data has been - // fully replayed. - let mut initial = ReplayBody::new(hyper::Body::from("hello world")); - let mut replay = initial.clone(); - - body_to_string(&mut initial).await; - assert!(!replay.is_end_stream()); - - initial.trailers().await.expect("trailers should not error"); - assert!(initial.is_end_stream()); - assert!(!replay.is_end_stream()); - - // drop the initial body to send the data to the replay - drop(initial); - - assert!(!replay.is_end_stream()); - - body_to_string(&mut replay).await; - assert!(!replay.is_end_stream()); - - replay.trailers().await.expect("trailers should not error"); - assert!(replay.is_end_stream()); - - // Even if we clone a body _after_ it has been driven to EOS, the clone - // must not be EOS. - let mut replay2 = replay.clone(); - assert!(!replay2.is_end_stream()); - - // drop the initial body to send the data to the replay - drop(replay); - - body_to_string(&mut replay2).await; - assert!(!replay2.is_end_stream()); - - replay2.trailers().await.expect("trailers should not error"); - assert!(replay2.is_end_stream()); - } - - struct Test { - tx: Tx, - initial: ReplayBody, - replay: ReplayBody, - _trace: tracing::subscriber::DefaultGuard, - } - - struct Tx(hyper::body::Sender); - - impl Test { - fn new() -> Self { - let (tx, body) = hyper::Body::channel(); - let initial = ReplayBody::new(body); - let replay = initial.clone(); - Self { - tx: Tx(tx), - initial, - replay, - _trace: linkerd_tracing::test::with_default_filter("linkerd_http_retry=debug"), - } - } - } - - impl Tx { - #[tracing::instrument(skip(self))] - async fn send_data(&mut self, data: impl Into + std::fmt::Debug) { - let data = data.into(); - tracing::trace!("sending data..."); - self.0.send_data(data).await.expect("rx is not dropped"); - tracing::info!("sent data"); - } - - #[tracing::instrument(skip(self))] - async fn send_trailers(&mut self, trailers: HeaderMap) { - tracing::trace!("sending trailers..."); - self.0 - .send_trailers(trailers) - .await - .expect("rx is not dropped"); - tracing::info!("sent trailers"); - } - } - - async fn chunk(body: &mut T) -> Option - where - T: http_body::Body + Unpin, - { - tracing::trace!("waiting for a body chunk..."); - let chunk = body - .data() - .await - .map(|res| res.map_err(|_| ()).unwrap()) - .map(string); - tracing::info!(?chunk); - chunk - } - - async fn body_to_string(mut body: T) -> String - where - T: http_body::Body + Unpin, - T::Error: std::fmt::Debug, - { - let mut s = String::new(); - while let Some(chunk) = chunk(&mut body).await { - s.push_str(&chunk[..]); - } - tracing::info!(body = ?s, "no more data"); - s - } - - fn string(mut data: impl Buf) -> String { - let bytes = data.copy_to_bytes(data.remaining()); - String::from_utf8(bytes.to_vec()).unwrap() - } -} diff --git a/linkerd/retry/Cargo.toml b/linkerd/retry/Cargo.toml new file mode 100644 index 0000000000..faefaa840f --- /dev/null +++ b/linkerd/retry/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "linkerd-retry" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2018" +publish = false + +[dependencies] +linkerd-error = { path = "../error" } +linkerd-stack = { path = "../stack" } +pin-project = "1" +tower = { version = "0.4.7", default-features = false, features = ["retry"] } +tracing = "0.1.23" diff --git a/linkerd/retry/src/lib.rs b/linkerd/retry/src/lib.rs new file mode 100644 index 0000000000..54adb17098 --- /dev/null +++ b/linkerd/retry/src/lib.rs @@ -0,0 +1,146 @@ +#![deny(warnings, rust_2018_idioms)] +#![forbid(unsafe_code)] +#![allow(clippy::inconsistent_struct_constructor)] + +use linkerd_error::Error; +use linkerd_stack::{layer, Either, NewService, Oneshot, Proxy, ProxyService, ServiceExt}; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +pub use tower::retry::{budget::Budget, Policy}; +use tracing::trace; + +/// A strategy for obtaining per-target retry polices. +pub trait NewPolicy { + type Policy; + + fn new_policy(&self, target: &T) -> Option; +} + +/// An extension to [`tower::retry::Policy`] that adds a method to prepare a +/// request to be retried, possibly changing its type. +pub trait PrepareRequest: tower::retry::Policy { + /// A request type that can be retried. + /// + /// This *may* be the same as the `Req` type parameter, but it can also be a + /// different type, if retries can only be attempted for a specific request type. + type RetryRequest; + + /// Prepare an initial request for a potential retry. + /// + /// If the request is retryable, this should return `Either::A`. Otherwise, + /// if this returns `Either::B`, the request will not be retried if it + /// fails. + /// + /// If retrying requires a specific request type other than the input type + /// to this policy, this function may transform the request into a request + /// of that type. + fn prepare_request(&self, req: Req) -> Either; +} + +/// Applies per-target retry policies. +#[derive(Clone, Debug)] +pub struct NewRetry { + new_policy: P, + inner: N, +} + +#[derive(Clone, Debug)] +pub struct Retry { + policy: Option

, + inner: S, +} + +#[pin_project(project = ResponseFutureProj)] +pub enum ResponseFuture +where + R: tower::retry::Policy + Clone, + P: Proxy + Clone, + S: tower::Service + Clone, + S::Error: Into, +{ + Disabled(#[pin] P::Future), + Retry(#[pin] Oneshot>, Req>), +} + +// === impl NewRetry === + +impl NewRetry { + pub fn layer(new_policy: P) -> impl layer::Layer + Clone { + layer::mk(move |inner| Self { + inner, + new_policy: new_policy.clone(), + }) + } +} + +impl NewService for NewRetry +where + N: NewService, + P: NewPolicy, +{ + type Service = Retry; + + fn new_service(&mut self, target: T) -> Self::Service { + // Determine if there is a retry policy for the given target. + let policy = self.new_policy.new_policy(&target); + + let inner = self.inner.new_service(target); + Retry { policy, inner } + } +} + +// === impl Retry === + +impl Proxy for Retry +where + R: PrepareRequest + Clone, + P: Proxy + + Proxy + + Clone, + S: tower::Service + Clone, + S::Error: Into, +{ + type Request = PReq; + type Response = PRsp; + type Error = Error; + type Future = ResponseFuture; + + fn proxy(&self, svc: &mut S, req: Req) -> Self::Future { + trace!(retryable = %self.policy.is_some()); + + if let Some(policy) = self.policy.as_ref() { + return match policy.prepare_request(req) { + Either::A(retry_req) => { + let inner = + Proxy::::wrap_service(self.inner.clone(), svc.clone()); + let retry = tower::retry::Retry::new(policy.clone(), inner); + ResponseFuture::Retry(retry.oneshot(retry_req)) + } + Either::B(req) => ResponseFuture::Disabled(self.inner.proxy(svc, req)), + }; + } + + ResponseFuture::Disabled(self.inner.proxy(svc, req)) + } +} + +impl Future for ResponseFuture +where + R: tower::retry::Policy + Clone, + P: Proxy + Clone, + S: tower::Service + Clone, + S::Error: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + ResponseFutureProj::Disabled(f) => f.poll(cx).map_err(Into::into), + ResponseFutureProj::Retry(f) => f.poll(cx).map_err(Into::into), + } + } +} diff --git a/linkerd/stack/src/lib.rs b/linkerd/stack/src/lib.rs index 04c234d98c..3c3aa30322 100644 --- a/linkerd/stack/src/lib.rs +++ b/linkerd/stack/src/lib.rs @@ -41,7 +41,7 @@ pub use self::{ unwrap_or::UnwrapOr, }; pub use tower::{ - util::{future_service, FutureService, MapErr, MapErrLayer, ServiceExt}, + util::{future_service, FutureService, MapErr, MapErrLayer, Oneshot, ServiceExt}, Service, };