From e6416ae52973c39be345ae8c004e7d64f8900c7a Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Wed, 15 Feb 2023 14:39:48 -0500 Subject: [PATCH] Add Connection Poisoning to aws-smithy-client --- .../rustsdk/AwsFluentClientDecorator.kt | 1 + aws/sdk/build.gradle.kts | 3 + aws/sdk/integration-tests/Cargo.toml | 3 + .../integration-tests/s3/tests/reconnects.rs | 62 ++++ rust-runtime/Cargo.toml | 4 + rust-runtime/aws-smithy-client/Cargo.toml | 8 +- rust-runtime/aws-smithy-client/src/builder.rs | 45 +++ rust-runtime/aws-smithy-client/src/erase.rs | 2 + .../aws-smithy-client/src/hyper_ext.rs | 63 +++- rust-runtime/aws-smithy-client/src/lib.rs | 10 +- rust-runtime/aws-smithy-client/src/poison.rs | 143 ++++++++ rust-runtime/aws-smithy-client/src/retry.rs | 23 +- .../aws-smithy-client/src/test_connection.rs | 321 ++++++++++++++++++ rust-runtime/aws-smithy-client/src/timeout.rs | 2 +- .../aws-smithy-client/tests/e2e_test.rs | 77 +---- .../tests/reconnect_on_transient_error.rs | 230 +++++++++++++ .../tests/test_operation/mod.rs | 84 +++++ .../aws-smithy-http-tower/src/dispatch.rs | 9 +- .../aws-smithy-http/src/connection.rs | 96 ++++++ rust-runtime/aws-smithy-http/src/lib.rs | 1 + rust-runtime/aws-smithy-http/src/result.rs | 50 +++ rust-runtime/aws-smithy-types/src/retry.rs | 37 ++ 22 files changed, 1177 insertions(+), 97 deletions(-) create mode 100644 aws/sdk/integration-tests/s3/tests/reconnects.rs create mode 100644 rust-runtime/aws-smithy-client/src/poison.rs create mode 100644 rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs create mode 100644 rust-runtime/aws-smithy-client/tests/test_operation/mod.rs create mode 100644 rust-runtime/aws-smithy-http/src/connection.rs diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index e240d642700..180d71e0134 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -210,6 +210,7 @@ private class AwsFluentClientExtensions(types: Types) { }; let mut builder = builder .middleware(#{DynMiddleware}::new(#{Middleware}::new())) + .reconnect_mode(retry_config.reconnect_mode()) .retry_config(retry_config.into()) .operation_timeout_config(timeout_config.into()); builder.set_sleep_impl(sleep_impl); diff --git a/aws/sdk/build.gradle.kts b/aws/sdk/build.gradle.kts index 31478777048..199b65ba874 100644 --- a/aws/sdk/build.gradle.kts +++ b/aws/sdk/build.gradle.kts @@ -302,6 +302,9 @@ fun generateCargoWorkspace(services: AwsServices): String { |] |members = [${"\n"}${services.allModules.joinToString(",\n") { "| \"$it\"" }} |] + | + |[patch.crates-io] + |hyper = { git = 'https://github.com/hyperium/hyper', branch = "0.14.x" } """.trimMargin() } diff --git a/aws/sdk/integration-tests/Cargo.toml b/aws/sdk/integration-tests/Cargo.toml index a36345cda0f..bc197aeaf46 100644 --- a/aws/sdk/integration-tests/Cargo.toml +++ b/aws/sdk/integration-tests/Cargo.toml @@ -17,3 +17,6 @@ members = [ "transcribestreaming", "using-native-tls-instead-of-rustls", ] + +[patch.crates-io] +hyper = { git = 'https://github.com/hyperium/hyper', branch = "0.14.x" } diff --git a/aws/sdk/integration-tests/s3/tests/reconnects.rs b/aws/sdk/integration-tests/s3/tests/reconnects.rs new file mode 100644 index 00000000000..6a998c50c9e --- /dev/null +++ b/aws/sdk/integration-tests/s3/tests/reconnects.rs @@ -0,0 +1,62 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_credential_types::provider::SharedCredentialsProvider; +use aws_credential_types::Credentials; +use aws_smithy_async::rt::sleep::TokioSleep; +use aws_smithy_client::test_connection::wire_mock::{ + check_matches, ReplayedEvent, WireLevelTestConnection, +}; +use aws_smithy_client::{ev, match_events}; +use aws_smithy_types::retry::RetryConfig; +use aws_types::region::Region; +use aws_types::SdkConfig; +use std::sync::Arc; +use tracing_subscriber::EnvFilter; + +#[tokio::test] +async fn reconnect_on_503() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::new("trace")) + .init(); + let mock = WireLevelTestConnection::spinup(vec![ + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::with_body("here-is-your-object"), + ]) + .await; + + let sdk_config = SdkConfig::builder() + .region(Region::from_static("us-east-2")) + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) + .sleep_impl(Arc::new(TokioSleep::new())) + .endpoint_url(mock.endpoint_url()) + .http_connector(mock.http_connector()) + .retry_config(RetryConfig::standard()) + .build(); + let client = aws_sdk_s3::Client::new(&sdk_config); + let resp = client + .get_object() + .bucket("bucket") + .key("key") + .send() + .await + .expect("succeeds after retries"); + assert_eq!( + resp.body.collect().await.unwrap().to_vec(), + b"here-is-your-object" + ); + match_events!( + ev!(dns), + ev!(connect), + ev!(http(503)), + ev!(dns), + ev!(connect), + ev!(http(503)), + ev!(dns), + ev!(connect), + ev!(http(200)) + )(&mock.events()); +} diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 6a53b080e99..8d0b2f81570 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -1,5 +1,6 @@ [workspace] + members = [ "inlineable", "aws-smithy-async", @@ -18,3 +19,6 @@ members = [ "aws-smithy-http-server", "aws-smithy-http-server-python", ] + +[patch.crates-io] +hyper = { git = 'https://github.com/hyperium/hyper', branch = "0.14.x" } diff --git a/rust-runtime/aws-smithy-client/Cargo.toml b/rust-runtime/aws-smithy-client/Cargo.toml index d8f8041bbdb..662f272b0e9 100644 --- a/rust-runtime/aws-smithy-client/Cargo.toml +++ b/rust-runtime/aws-smithy-client/Cargo.toml @@ -9,12 +9,13 @@ repository = "https://github.com/awslabs/smithy-rs" [features] rt-tokio = ["aws-smithy-async/rt-tokio"] -test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls"] +test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls", "hyper/server", "hyper/h2"] native-tls = ["client-hyper", "hyper-tls", "rt-tokio"] rustls = ["client-hyper", "hyper-rustls", "rt-tokio", "lazy_static"] client-hyper = ["hyper"] hyper-webpki-doctest-only = ["hyper-rustls/webpki-roots"] + [dependencies] aws-smithy-async = { path = "../aws-smithy-async" } aws-smithy-http = { path = "../aws-smithy-http" } @@ -25,7 +26,7 @@ bytes = "1" fastrand = "1.4.0" http = "0.2.3" http-body = "0.4.4" -hyper = { version = "0.14.12", features = ["client", "http2", "http1", "tcp"], optional = true } +hyper = { version = "0.14.24", features = ["client", "http2", "http1", "tcp"], optional = true } # cargo does not support optional test dependencies, so to completely disable rustls when # the native-tls feature is enabled, we need to add the webpki-roots feature here. # https://github.com/rust-lang/cargo/issues/1596 @@ -44,6 +45,9 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1.8.4", features = ["full", "test-util"] } tower-test = "0.4.0" +tracing-subscriber = "0.3.16" +tracing-test = "0.2.4" + [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-client/src/builder.rs b/rust-runtime/aws-smithy-client/src/builder.rs index 1fe4ba12eb1..7fb86169bc5 100644 --- a/rust-runtime/aws-smithy-client/src/builder.rs +++ b/rust-runtime/aws-smithy-client/src/builder.rs @@ -7,6 +7,7 @@ use crate::{bounds, erase, retry, Client}; use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep}; use aws_smithy_http::body::SdkBody; use aws_smithy_http::result::ConnectorError; +use aws_smithy_types::retry::ReconnectMode; use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig}; use std::sync::Arc; @@ -37,6 +38,12 @@ pub struct Builder { retry_policy: MaybeRequiresSleep, operation_timeout_config: Option, sleep_impl: Option>, + reconnect_mode: Option, +} + +/// transitional default: disable this behavior by default +fn default_reconnect_mode() -> ReconnectMode { + ReconnectMode::NoReconnect } impl Default for Builder @@ -55,6 +62,7 @@ where ), operation_timeout_config: None, sleep_impl: default_async_sleep(), + reconnect_mode: Some(default_reconnect_mode()), } } } @@ -173,6 +181,7 @@ impl Builder<(), M, R> { retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } @@ -229,6 +238,7 @@ impl Builder { operation_timeout_config: self.operation_timeout_config, middleware, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } @@ -280,6 +290,7 @@ impl Builder { operation_timeout_config: self.operation_timeout_config, middleware: self.middleware, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } } @@ -347,6 +358,7 @@ impl Builder { retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } @@ -361,9 +373,41 @@ impl Builder { retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, no reconnection occurs. + /// + /// When enabled and a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host. + pub fn reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self { + self.set_reconnect_mode(Some(reconnect_mode)); + self + } + + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, no reconnection occurs. + /// + /// When enabled and a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host. + pub fn set_reconnect_mode(&mut self, reconnect_mode: Option) -> &mut Self { + self.reconnect_mode = reconnect_mode; + self + } + + /// Enable reconnection on transient errors + /// + /// By default, when a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host but may increase the load on + /// the server. + pub fn reconnect_on_transient_errors(self) -> Self { + self.reconnect_mode(ReconnectMode::ReconnectOnTransientError) + } + /// Build a Smithy service [`Client`]. pub fn build(self) -> Client { let operation_timeout_config = self @@ -392,6 +436,7 @@ impl Builder { middleware: self.middleware, operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode.unwrap_or(ReconnectMode::NoReconnect), } } } diff --git a/rust-runtime/aws-smithy-client/src/erase.rs b/rust-runtime/aws-smithy-client/src/erase.rs index 2cac5afeaa5..648562192ce 100644 --- a/rust-runtime/aws-smithy-client/src/erase.rs +++ b/rust-runtime/aws-smithy-client/src/erase.rs @@ -61,6 +61,7 @@ where retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } } @@ -101,6 +102,7 @@ where retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } diff --git a/rust-runtime/aws-smithy-client/src/hyper_ext.rs b/rust-runtime/aws-smithy-client/src/hyper_ext.rs index 11a27c1d533..ac0eaf77211 100644 --- a/rust-runtime/aws-smithy-client/src/hyper_ext.rs +++ b/rust-runtime/aws-smithy-client/src/hyper_ext.rs @@ -92,13 +92,22 @@ use crate::never::stream::EmptyStream; use aws_smithy_async::future::timeout::TimedOutError; use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep}; use aws_smithy_http::body::SdkBody; + use aws_smithy_http::result::ConnectorError; use aws_smithy_types::error::display::DisplayErrorContext; use aws_smithy_types::retry::ErrorKind; -use http::Uri; -use hyper::client::connect::{Connected, Connection}; +use http::{Extensions, Uri}; +use hyper::client::connect::{ + capture_connection, CaptureConnection, Connected, Connection, HttpInfo, +}; + use std::error::Error; +use std::fmt::Debug; + use std::sync::Arc; + +use crate::erase::boxclone::BoxFuture; +use aws_smithy_http::connection::{CaptureSmithyConnection, ConnectionMetadata}; use tokio::io::{AsyncRead, AsyncWrite}; use tower::{BoxError, Service}; @@ -108,7 +117,30 @@ use tower::{BoxError, Service}; /// see [the module documentation](crate::hyper_ext). #[derive(Clone, Debug)] #[non_exhaustive] -pub struct Adapter(HttpReadTimeout, SdkBody>>); +pub struct Adapter { + client: HttpReadTimeout, SdkBody>>, +} + +/// Extract a smithy connection from a hyper CaptureConnection +fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option { + let capture_conn = capture_conn.clone(); + if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() { + let mut extensions = Extensions::new(); + conn.get_extras(&mut extensions); + let http_info = extensions.get::(); + let smithy_connection = ConnectionMetadata::new( + conn.is_proxied(), + http_info.map(|info| info.remote_addr()), + move || match capture_conn.connection_metadata().as_ref() { + Some(conn) => conn.poison(), + None => tracing::trace!("no connection existed to poison"), + }, + ); + Some(smithy_connection) + } else { + None + } +} impl Service> for Adapter where @@ -121,20 +153,22 @@ where type Response = http::Response; type Error = ConnectorError; - #[allow(clippy::type_complexity)] - type Future = std::pin::Pin< - Box> + Send + 'static>, - >; + type Future = BoxFuture; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.0.poll_ready(cx).map_err(downcast_error) + self.client.poll_ready(cx).map_err(downcast_error) } - fn call(&mut self, req: http::Request) -> Self::Future { - let fut = self.0.call(req); + fn call(&mut self, mut req: http::Request) -> Self::Future { + let capture_connection = capture_connection(&mut req); + if let Some(capture_smithy_connection) = req.extensions().get::() { + capture_smithy_connection + .set_connection_retriever(move || extract_smithy_connection(&capture_connection)); + } + let fut = self.client.call(req); Box::pin(async move { Ok(fut.await.map_err(downcast_error)?.map(SdkBody::from)) }) } } @@ -271,7 +305,9 @@ impl Builder { ), None => HttpReadTimeout::no_timeout(base), }; - Adapter(read_timeout) + Adapter { + client: read_timeout, + } } /// Set the async sleep implementation used for timeouts @@ -343,7 +379,6 @@ mod timeout_middleware { use pin_project_lite::pin_project; use tower::BoxError; - use aws_smithy_async::future; use aws_smithy_async::future::timeout::{TimedOutError, Timeout}; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_async::rt::sleep::Sleep; @@ -493,7 +528,7 @@ mod timeout_middleware { Some((sleep, duration)) => { let sleep = sleep.sleep(*duration); MaybeTimeoutFuture::Timeout { - timeout: future::timeout::Timeout::new(self.inner.call(req), sleep), + timeout: Timeout::new(self.inner.call(req), sleep), error_type: "HTTP connect", duration: *duration, } @@ -522,7 +557,7 @@ mod timeout_middleware { Some((sleep, duration)) => { let sleep = sleep.sleep(*duration); MaybeTimeoutFuture::Timeout { - timeout: future::timeout::Timeout::new(self.inner.call(req), sleep), + timeout: Timeout::new(self.inner.call(req), sleep), error_type: "HTTP read", duration: *duration, } diff --git a/rust-runtime/aws-smithy-client/src/lib.rs b/rust-runtime/aws-smithy-client/src/lib.rs index 6e5e5ba9ee7..479641704da 100644 --- a/rust-runtime/aws-smithy-client/src/lib.rs +++ b/rust-runtime/aws-smithy-client/src/lib.rs @@ -26,6 +26,7 @@ pub mod bounds; pub mod erase; pub mod http_connector; pub mod never; +mod poison; pub mod retry; pub mod timeout; @@ -50,14 +51,17 @@ pub mod hyper_ext; #[doc(hidden)] pub mod static_tests; +use crate::poison::PoisonLayer; use aws_smithy_async::rt::sleep::AsyncSleep; + use aws_smithy_http::operation::Operation; use aws_smithy_http::response::ParseHttpResponse; pub use aws_smithy_http::result::{SdkError, SdkSuccess}; +use aws_smithy_http::retry::ClassifyRetry; use aws_smithy_http_tower::dispatch::DispatchLayer; use aws_smithy_http_tower::parse_response::ParseResponseLayer; use aws_smithy_types::error::display::DisplayErrorContext; -use aws_smithy_types::retry::ProvideErrorKind; +use aws_smithy_types::retry::{ProvideErrorKind, ReconnectMode}; use aws_smithy_types::timeout::OperationTimeoutConfig; use std::sync::Arc; use timeout::ClientTimeoutParams; @@ -93,6 +97,7 @@ pub struct Client< connector: Connector, middleware: Middleware, retry_policy: RetryPolicy, + reconnect_mode: ReconnectMode, operation_timeout_config: OperationTimeoutConfig, sleep_impl: Option>, } @@ -140,6 +145,7 @@ where E: std::error::Error + Send + Sync + 'static, Retry: Send + Sync, R::Policy: bounds::SmithyRetryPolicy, + Retry: ClassifyRetry, SdkError>, bounds::Parsed<>::Service, O, Retry>: Service, Response = SdkSuccess, Error = SdkError> + Clone, { @@ -159,6 +165,7 @@ where E: std::error::Error + Send + Sync + 'static, Retry: Send + Sync, R::Policy: bounds::SmithyRetryPolicy, + Retry: ClassifyRetry, SdkError>, // This bound is not _technically_ inferred by all the previous bounds, but in practice it // is because _we_ know that there is only implementation of Service for Parsed // (ParsedResponseService), and it will apply as long as the bounds on C, M, and R hold, @@ -179,6 +186,7 @@ where self.retry_policy .new_request_policy(self.sleep_impl.clone()), ) + .layer(PoisonLayer::new(self.reconnect_mode)) .layer(TimeoutLayer::new(timeout_params.operation_attempt_timeout)) .layer(ParseResponseLayer::::new()) // These layers can be considered as occurring in order. That is, first invoke the diff --git a/rust-runtime/aws-smithy-client/src/poison.rs b/rust-runtime/aws-smithy-client/src/poison.rs new file mode 100644 index 00000000000..ffbaaf8abc1 --- /dev/null +++ b/rust-runtime/aws-smithy-client/src/poison.rs @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Connection Poisoning +//! +//! The client supports behavior where on transient errors (e.g. timeouts, 503, etc.) it will ensure +//! that the offending connection is not reused. This happens to ensure that in the case where the +//! connection itself is broken (e.g. connected to a bad host) we don't reuse it for other requests. +//! +//! This relies on a series of mechanisms: +//! 1. [`CaptureSmithyConnection`] is a container which exists in the operation property bag. It is +//! inserted by this layer before the request is sent. +//! 2. The [`DispatchLayer`](aws_smithy_http_tower::dispatch::DispatchLayer) copies the field from operation extensions HTTP request extensions. +//! 3. The HTTP layer (e.g. Hyper) sets [`ConnectionMetadata`](aws_smithy_http::connection::ConnectionMetadata) +//! when it is available. +//! 4. When the response comes back, if indicated, this layer invokes +//! [`ConnectionMetadata::poison`](aws_smithy_http::connection::ConnectionMetadata::poison). +//! +//! ### Why isn't this integrated into `retry.rs`? +//! If the request has a streaming body, we won't attempt to retry because [`Operation::try_clone()`] will +//! return `None`. Therefore, we need to handle this inside of the retry loop. + +use std::future::Future; + +use aws_smithy_http::operation::Operation; +use aws_smithy_http::result::{SdkError, SdkSuccess}; +use aws_smithy_http::retry::ClassifyRetry; + +use aws_smithy_http::connection::CaptureSmithyConnection; +use aws_smithy_types::retry::{ErrorKind, ReconnectMode, RetryKind}; +use pin_project_lite::pin_project; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// PoisonLayer that poisons connections depending on the error kind +pub(crate) struct PoisonLayer { + inner: PhantomData, + mode: ReconnectMode, +} + +impl PoisonLayer { + pub(crate) fn new(mode: ReconnectMode) -> Self { + Self { + inner: Default::default(), + mode, + } + } +} + +impl Clone for PoisonLayer { + fn clone(&self) -> Self { + Self { + inner: Default::default(), + mode: self.mode, + } + } +} + +impl tower::Layer for PoisonLayer { + type Service = PoisonService; + + fn layer(&self, inner: S) -> Self::Service { + PoisonService { + inner, + mode: self.mode, + } + } +} + +#[derive(Clone)] +pub(crate) struct PoisonService { + inner: S, + mode: ReconnectMode, +} + +impl tower::Service> for PoisonService +where + R: ClassifyRetry, SdkError>, + S: tower::Service, Response = SdkSuccess, Error = SdkError>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = PoisonServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Operation) -> Self::Future { + let classifier = req.retry_classifier().clone(); + let capture_smithy_connection = CaptureSmithyConnection::new(); + req.properties_mut() + .insert(capture_smithy_connection.clone()); + PoisonServiceFuture { + inner: self.inner.call(req), + conn: capture_smithy_connection, + mode: self.mode, + classifier, + } + } +} + +pin_project! { + pub struct PoisonServiceFuture { + #[pin] + inner: F, + classifier: R, + conn: CaptureSmithyConnection, + mode: ReconnectMode + } +} + +impl Future for PoisonServiceFuture +where + F: Future, SdkError>>, + R: ClassifyRetry, SdkError>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.inner.poll(cx) { + Poll::Ready(resp) => { + let retry_kind = this.classifier.classify_retry(resp.as_ref()); + if this.mode == &ReconnectMode::ReconnectOnTransientError + && retry_kind == RetryKind::Error(ErrorKind::TransientError) + { + if let Some(smithy_conn) = this.conn.get() { + tracing::info!("poisoning connection: {:?}", smithy_conn); + smithy_conn.poison(); + } else { + tracing::trace!("No smithy connection found! The underlying HTTP connection never set a connection."); + } + } + Poll::Ready(resp) + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/rust-runtime/aws-smithy-client/src/retry.rs b/rust-runtime/aws-smithy-client/src/retry.rs index 7e6ceff10cb..10df4ae5fc7 100644 --- a/rust-runtime/aws-smithy-client/src/retry.rs +++ b/rust-runtime/aws-smithy-client/src/retry.rs @@ -17,14 +17,15 @@ use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::time::Duration; -use crate::{SdkError, SdkSuccess}; +use tracing::Instrument; use aws_smithy_async::rt::sleep::AsyncSleep; + use aws_smithy_http::operation::Operation; use aws_smithy_http::retry::ClassifyRetry; use aws_smithy_types::retry::{ErrorKind, RetryKind}; -use tracing::Instrument; +use crate::{SdkError, SdkSuccess}; /// A policy instantiator. /// @@ -292,9 +293,20 @@ impl RetryHandler { fn should_retry_error(&self, error_kind: &ErrorKind) -> Option<(Self, Duration)> { let quota_used = { if self.local.attempts == self.config.max_attempts { + tracing::trace!( + attempts = self.local.attempts, + max_attempts = self.config.max_attempts, + "not retrying becuase we are out of attempts" + ); return None; } - self.shared.quota_acquire(error_kind, &self.config)? + match self.shared.quota_acquire(error_kind, &self.config) { + Some(quota) => quota, + None => { + tracing::trace!(state = ?self.shared, "not retrying because no quota is available"); + return None; + } + } }; let backoff = calculate_exponential_backoff( // Generate a random base multiplier to create jitter @@ -334,7 +346,9 @@ impl RetryHandler { } fn retry_for(&self, retry_kind: RetryKind) -> Option> { - let (next, dur) = self.should_retry(&retry_kind)?; + let retry = self.should_retry(&retry_kind); + tracing::trace!(retry=?retry, retry_kind = ?retry_kind, "retry action"); + let (next, dur) = retry?; let sleep = match &self.sleep_impl { Some(sleep) => sleep, @@ -377,6 +391,7 @@ where ) -> Option { let classifier = req.retry_classifier(); let retry_kind = classifier.classify_retry(result); + tracing::trace!(retry_kind = ?retry_kind, "retry classification"); self.retry_for(retry_kind) } diff --git a/rust-runtime/aws-smithy-client/src/test_connection.rs b/rust-runtime/aws-smithy-client/src/test_connection.rs index d7b0b15ecef..08acd52b90d 100644 --- a/rust-runtime/aws-smithy-client/src/test_connection.rs +++ b/rust-runtime/aws-smithy-client/src/test_connection.rs @@ -271,6 +271,327 @@ where } } +/// [`wire_mock`] contains utilities for mocking at the socket level +/// +/// Other tools in this module actually operate at the `http::Request` / `http::Response` level. This +/// is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`wire_mock::WireLevelTestConnection`] binds +/// to an actual socket on the host +/// +/// # Examples +/// ``` +/// use tower::layer::util::Identity; +/// use aws_smithy_client::http_connector::ConnectorSettings; +/// use aws_smithy_client::{match_events, ev}; +/// use aws_smithy_client::test_connection::wire_mock::check_matches; +/// # async fn example() { +/// use aws_smithy_client::test_connection::wire_mock::{ReplayedEvent, WireLevelTestConnection}; +/// // This connection binds to a local address +/// let mock = WireLevelTestConnection::spinup(vec![ +/// ReplayedEvent::status(503), +/// ReplayedEvent::status(200) +/// ]).await; +/// let client = aws_smithy_client::Client::builder() +/// .connector(mock.http_connector().connector(&ConnectorSettings::default(), None).unwrap()) +/// .middleware(Identity::new()) +/// .build(); +/// /* do something with */ +/// // assert that you got the events you expected +/// match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events()); +/// # } +/// ``` +pub mod wire_mock { + use bytes::Bytes; + use http::{Request, Response}; + use hyper::client::connect::dns::Name; + use hyper::server::conn::AddrStream; + use hyper::service::{make_service_fn, service_fn}; + use hyper::{Body, Server}; + use std::collections::HashSet; + use std::convert::Infallible; + use std::error::Error; + + use hyper::client::HttpConnector as HyperHttpConnector; + use std::iter; + use std::iter::Once; + use std::net::{SocketAddr, TcpListener}; + use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll}; + + use tokio::spawn; + use tower::Service; + + /// An event recorded by [`WireLevelTestConnection`] + #[derive(Debug, Clone)] + pub enum RecordedEvent { + DnsLookup(Name), + NewConnection, + Response(ReplayedEvent), + } + + type Matcher = ( + Box Result<(), Box>>, + &'static str, + ); + + /// This method should only be used by the macro + #[doc(hidden)] + pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) { + let mut events_iter = events.iter(); + let mut matcher_iter = matchers.iter(); + let mut idx = -1; + loop { + idx += 1; + let bail = |err: Box| panic!("failed on event {}:\n {}", idx, err); + match (events_iter.next(), matcher_iter.next()) { + (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail), + (None, None) => return, + (Some(event), None) => { + bail(format!("got {:?} but no more events were expected", event).into()) + } + (None, Some((_expect, msg))) => { + bail(format!("expected {:?} but no more events were expected", msg).into()) + } + } + } + } + + #[macro_export] + macro_rules! matcher { + ($expect:tt) => { + ( + Box::new( + |event: &::aws_smithy_client::test_connection::wire_mock::RecordedEvent| { + if !matches!(event, $expect) { + return Err(format!( + "expected `{}` but got {:?}", + stringify!($expect), + event + ) + .into()); + } + Ok(()) + }, + ), + stringify!($expect), + ) + }; + } + + /// Helper macro to generate a series of test expectations + #[macro_export] + macro_rules! match_events { + ($( $expect:pat),*) => { + |events| { + check_matches(events, &[$( ::aws_smithy_client::matcher!($expect) ),*]); + } + }; + } + + /// Helper to generate match expressions for events + #[macro_export] + macro_rules! ev { + (http($status:expr)) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::Response( + ReplayedEvent::HttpResponse { + status: $status, + .. + }, + ) + }; + (dns) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::DnsLookup(_) + }; + (connect) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::NewConnection + }; + (timeout) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::Response( + ReplayedEvent::Timeout, + ) + }; + } + + pub use {ev, match_events, matcher}; + + #[derive(Clone, Debug, PartialEq, Eq)] + pub enum ReplayedEvent { + Timeout, + HttpResponse { status: u16, body: Bytes }, + } + + impl ReplayedEvent { + pub fn ok() -> Self { + Self::HttpResponse { + status: 200, + body: Bytes::new(), + } + } + + pub fn with_body(body: &str) -> Self { + Self::HttpResponse { + status: 200, + body: Bytes::copy_from_slice(body.as_ref()), + } + } + + pub fn status(status: u16) -> Self { + Self::HttpResponse { + status, + body: Bytes::new(), + } + } + } + + use crate::erase::boxclone::BoxFuture; + use crate::http_connector::HttpConnector; + use crate::hyper_ext; + use aws_smithy_async::future::never::Never; + use tokio::sync::oneshot; + + #[derive(Debug)] + pub struct WireLevelTestConnection { + event_log: Arc>>, + bind_addr: SocketAddr, + _shutdown_hook: oneshot::Sender<()>, + } + + impl WireLevelTestConnection { + pub async fn spinup(mut response_events: Vec) -> Self { + let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + let (tx, rx) = oneshot::channel(); + let listener_addr = listener.local_addr().unwrap(); + response_events.reverse(); + let response_events = Arc::new(Mutex::new(response_events)); + let handler_events = response_events; + let wire_events = Arc::new(Mutex::new(vec![])); + let wire_log_for_service = wire_events.clone(); + let poisoned_conns: Arc>> = Default::default(); + let make_service = make_service_fn(move |_conn: &AddrStream| { + let poisoned_conns = poisoned_conns.clone(); + let events = handler_events.clone(); + let wire_log = wire_log_for_service.clone(); + let local_addr = _conn.remote_addr(); + tracing::info!("established connection: {:?}", _conn); + wire_log.lock().unwrap().push(RecordedEvent::NewConnection); + async move { + Ok::<_, Infallible>(service_fn(move |_: Request| { + if poisoned_conns.lock().unwrap().contains(&local_addr) { + tracing::error!("poisoned connection {:?} was reused!", &local_addr); + panic!("poisoned connection was reused!"); + } + let next_event = events.clone().lock().unwrap().pop(); + let wire_log = wire_log.clone(); + let poisoned_conns = poisoned_conns.clone(); + if poisoned_conns.lock().unwrap().contains(&local_addr) { + //panic!("poisoned connection was reused! {}", local_addr); + } + async move { + let next_event = next_event + .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log)); + wire_log + .lock() + .unwrap() + .push(RecordedEvent::Response(next_event.clone())); + if next_event == ReplayedEvent::Timeout { + println!("{} is poisoned", local_addr); + poisoned_conns.lock().unwrap().insert(local_addr); + } + tracing::debug!("replying with {:?}", next_event); + let event = generate_response_event(next_event).await; + dbg!(event) + } + })) + } + }); + let server = Server::from_tcp(listener) + .unwrap() + .serve(make_service) + .with_graceful_shutdown(async { + rx.await.ok(); + println!("server shutdown!"); + }); + spawn(async move { server.await }); + Self { + event_log: wire_events, + bind_addr: listener_addr, + _shutdown_hook: tx, + } + } + + pub fn events(&self) -> Vec { + self.event_log.lock().unwrap().clone() + } + + pub fn bind_addr(&self) -> SocketAddr { + self.bind_addr + } + + pub fn dns_resolver(&self) -> LoggingDnsResolver { + let event_log = self.event_log.clone(); + let bind_addr = self.bind_addr; + LoggingDnsResolver { + log: event_log, + socket_addr: bind_addr, + } + } + + pub fn http_connector(&self) -> HttpConnector { + let http_connector = HyperHttpConnector::new_with_resolver(self.dns_resolver()); + hyper_ext::Adapter::builder().build(http_connector).into() + } + + pub fn endpoint_url(&self) -> String { + format!( + "http://this-url-is-converted-to-localhost.com:{}", + self.bind_addr().port() + ) + } + } + + async fn generate_response_event(event: ReplayedEvent) -> Result, Infallible> { + let resp = match event { + ReplayedEvent::HttpResponse { status, body } => http::Response::builder() + .status(status) + .body(hyper::Body::from(body)) + .unwrap(), + ReplayedEvent::Timeout => { + Never::new().await; + unreachable!() + } + }; + Ok::<_, Infallible>(resp) + } + + /// DNS resolver that keeps a log of all lookups + /// + /// Regardless of what hostname is requested, it will always return the same socket address. + #[derive(Clone, Debug)] + pub struct LoggingDnsResolver { + log: Arc>>, + socket_addr: SocketAddr, + } + + impl Service for LoggingDnsResolver { + type Response = Once; + type Error = Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Name) -> Self::Future { + let sock_addr = self.socket_addr; + let log = self.log.clone(); + Box::pin(async move { + println!("looking up {:?}, replying with {:?}", req, sock_addr); + log.lock().unwrap().push(RecordedEvent::DnsLookup(req)); + Ok(iter::once(sock_addr)) + }) + } + } +} + #[cfg(test)] mod tests { use hyper::service::Service; diff --git a/rust-runtime/aws-smithy-client/src/timeout.rs b/rust-runtime/aws-smithy-client/src/timeout.rs index 4cfc1938f63..85957eb11e1 100644 --- a/rust-runtime/aws-smithy-client/src/timeout.rs +++ b/rust-runtime/aws-smithy-client/src/timeout.rs @@ -208,7 +208,7 @@ where InnerService: tower::Service, Error = SdkError>, { type Response = InnerService::Response; - type Error = aws_smithy_http::result::SdkError; + type Error = SdkError; type Future = TimeoutServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/rust-runtime/aws-smithy-client/tests/e2e_test.rs b/rust-runtime/aws-smithy-client/tests/e2e_test.rs index 99b689d430d..0a8594d6b3d 100644 --- a/rust-runtime/aws-smithy-client/tests/e2e_test.rs +++ b/rust-runtime/aws-smithy-client/tests/e2e_test.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +mod test_operation; use crate::test_operation::{TestOperationParser, TestRetryClassifier}; use aws_smithy_async::rt::sleep::TokioSleep; use aws_smithy_client::test_connection::TestConnection; @@ -15,78 +16,6 @@ use std::sync::Arc; use std::time::Duration; use tower::layer::util::Identity; -mod test_operation { - use aws_smithy_http::operation; - use aws_smithy_http::response::ParseHttpResponse; - use aws_smithy_http::result::SdkError; - use aws_smithy_http::retry::ClassifyRetry; - use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind}; - use bytes::Bytes; - use std::error::Error; - use std::fmt::{self, Debug, Display, Formatter}; - - #[derive(Clone)] - pub(super) struct TestOperationParser; - - #[derive(Debug)] - pub(super) struct OperationError; - - impl Display for OperationError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self) - } - } - - impl Error for OperationError {} - - impl ProvideErrorKind for OperationError { - fn retryable_error_kind(&self) -> Option { - Some(ErrorKind::ThrottlingError) - } - - fn code(&self) -> Option<&str> { - None - } - } - - impl ParseHttpResponse for TestOperationParser { - type Output = Result; - - fn parse_unloaded(&self, response: &mut operation::Response) -> Option { - if response.http().status().is_success() { - Some(Ok("Hello!".to_string())) - } else { - Some(Err(OperationError)) - } - } - - fn parse_loaded(&self, _response: &http::Response) -> Self::Output { - Ok("Hello!".to_string()) - } - } - - #[derive(Clone)] - pub(super) struct TestRetryClassifier; - - impl ClassifyRetry> for TestRetryClassifier - where - E: ProvideErrorKind + Debug, - T: Debug, - { - fn classify_retry(&self, err: Result<&T, &SdkError>) -> RetryKind { - let kind = match err { - Err(SdkError::ServiceError(context)) => context.err().retryable_error_kind(), - Ok(_) => return RetryKind::Unnecessary, - _ => panic!("test handler only handles modeled errors got: {:?}", err), - }; - match kind { - Some(kind) => RetryKind::Error(kind), - None => RetryKind::UnretryableFailure, - } - } - } -} - fn test_operation() -> Operation { let req = operation::Request::new( http::Request::builder() @@ -108,14 +37,14 @@ async fn end_to_end_retry_test() { fn ok() -> http::Response<&'static str> { http::Response::builder() .status(200) - .body("response body") + .body("Hello!") .unwrap() } fn err() -> http::Response<&'static str> { http::Response::builder() .status(500) - .body("response body") + .body("This was an error") .unwrap() } // 1 failing response followed by 1 successful response diff --git a/rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs b/rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs new file mode 100644 index 00000000000..71dcd527ca4 --- /dev/null +++ b/rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs @@ -0,0 +1,230 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#![cfg(feature = "test-util")] + +mod test_operation; + +use aws_smithy_async::rt::sleep::TokioSleep; +use aws_smithy_client::test_connection::wire_mock; +use aws_smithy_client::test_connection::wire_mock::{check_matches, RecordedEvent, ReplayedEvent}; +use aws_smithy_client::{hyper_ext, Builder}; +use aws_smithy_client::{match_events, Client}; +use aws_smithy_http::body::SdkBody; +use aws_smithy_http::operation; +use aws_smithy_http::operation::Operation; +use aws_smithy_types::retry::ReconnectMode; +use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig}; +use http::Uri; +use http_body::combinators::BoxBody; +use hyper::client::{Builder as HyperBuilder, HttpConnector}; +use std::convert::Infallible; +use std::sync::Arc; +use std::time::Duration; +use test_operation::{TestOperationParser, TestRetryClassifier}; +use tower::layer::util::Identity; +use wire_mock::ev; + +fn end_of_test() -> &'static str { + "end_of_test" +} + +fn test_operation( + uri: Uri, + retryable: bool, +) -> Operation { + let mut req = operation::Request::new( + http::Request::builder() + .uri(uri) + .body(SdkBody::from("request body")) + .unwrap(), + ); + if !retryable { + req = req + .augment(|req, _conf| { + Ok::<_, Infallible>( + req.map(|_| SdkBody::from_dyn(BoxBody::new(SdkBody::from("body")))), + ) + }) + .unwrap(); + } + Operation::new(req, TestOperationParser).with_retry_classifier(TestRetryClassifier) +} + +async fn h1_and_h2(events: Vec, match_clause: impl Fn(&[RecordedEvent])) { + wire_level_test(events.clone(), |_b| {}, |b| b, &match_clause).await; + wire_level_test( + events, + |b| { + b.http2_only(true); + }, + |b| b, + match_clause, + ) + .await; + println!("h2 ok!"); +} + +/// Repeatedly send test operation until `end_of_test` is received +/// +/// When the test is over, match_clause is evaluated +async fn wire_level_test( + events: Vec, + hyper_builder_settings: impl Fn(&mut HyperBuilder), + client_builder_settings: impl Fn(Builder) -> Builder, + match_clause: impl Fn(&[RecordedEvent]), +) { + let connection = wire_mock::WireLevelTestConnection::spinup(events).await; + + let http_connector = HttpConnector::new_with_resolver(connection.dns_resolver()); + let mut hyper_builder = hyper::Client::builder(); + hyper_builder_settings(&mut hyper_builder); + let hyper_adapter = hyper_ext::Adapter::builder() + .hyper_builder(hyper_builder) + .build(http_connector); + let client = client_builder_settings( + Client::builder().reconnect_mode(ReconnectMode::ReconnectOnTransientError), + ) + .connector(hyper_adapter) + .middleware(Identity::new()) + .operation_timeout_config(OperationTimeoutConfig::from( + &TimeoutConfig::builder() + .operation_attempt_timeout(Duration::from_millis(100)) + .build(), + )) + .sleep_impl(Arc::new(TokioSleep::new())) + .build(); + loop { + match client + .call(test_operation( + connection.endpoint_url().parse().unwrap(), + false, + )) + .await + { + Ok(resp) => { + tracing::info!("response: {:?}", resp); + if resp == end_of_test() { + break; + } + } + Err(e) => tracing::info!("error: {:?}", e), + } + } + let events = connection.events(); + match_clause(&events); +} + +#[tokio::test] +async fn non_transient_errors_no_reconect() { + h1_and_h2( + vec![ + ReplayedEvent::status(400), + ReplayedEvent::with_body(end_of_test()), + ], + match_events!(ev!(dns), ev!(connect), ev!(http(400)), ev!(http(200))), + ) + .await +} + +#[tokio::test] +async fn reestablish_dns_on_503() { + h1_and_h2( + vec![ + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::with_body(end_of_test()), + ], + match_events!( + // first request + ev!(dns), + ev!(connect), + ev!(http(503)), + // second request + ev!(dns), + ev!(connect), + ev!(http(503)), + // third request + ev!(dns), + ev!(connect), + ev!(http(503)), + // all good + ev!(dns), + ev!(connect), + ev!(http(200)) + ), + ) + .await; +} + +#[tokio::test] +async fn connection_shared_on_success() { + h1_and_h2( + vec![ + ReplayedEvent::ok(), + ReplayedEvent::ok(), + ReplayedEvent::status(503), + ReplayedEvent::with_body(end_of_test()), + ], + match_events!( + ev!(dns), + ev!(connect), + ev!(http(200)), + ev!(http(200)), + ev!(http(503)), + ev!(dns), + ev!(connect), + ev!(http(200)) + ), + ) + .await; +} + +#[tokio::test] +async fn no_reconnect_when_disabled() { + use wire_mock::ev; + wire_level_test( + vec![ + ReplayedEvent::status(503), + ReplayedEvent::with_body(end_of_test()), + ], + |_b| {}, + |b| b.reconnect_mode(ReconnectMode::NoReconnect), + match_events!(ev!(dns), ev!(connect), ev!(http(503)), ev!(http(200))), + ) + .await; +} + +#[tokio::test] +async fn connection_reestablished_after_timeout() { + use wire_mock::ev; + h1_and_h2( + vec![ + ReplayedEvent::ok(), + ReplayedEvent::Timeout, + ReplayedEvent::ok(), + ReplayedEvent::Timeout, + ReplayedEvent::with_body(end_of_test()), + ], + match_events!( + // first connection + ev!(dns), + ev!(connect), + ev!(http(200)), + // reuse but got a timeout + ev!(timeout), + // so we reconnect + ev!(dns), + ev!(connect), + ev!(http(200)), + ev!(timeout), + ev!(dns), + ev!(connect), + ev!(http(200)) + ), + ) + .await; +} diff --git a/rust-runtime/aws-smithy-client/tests/test_operation/mod.rs b/rust-runtime/aws-smithy-client/tests/test_operation/mod.rs new file mode 100644 index 00000000000..db193e4bd9b --- /dev/null +++ b/rust-runtime/aws-smithy-client/tests/test_operation/mod.rs @@ -0,0 +1,84 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http::operation; +use aws_smithy_http::response::ParseHttpResponse; +use aws_smithy_http::result::SdkError; +use aws_smithy_http::retry::ClassifyRetry; +use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind}; +use bytes::Bytes; +use std::error::Error; +use std::fmt::{self, Debug, Display, Formatter}; +use std::str; + +#[derive(Clone)] +pub(super) struct TestOperationParser; + +#[derive(Debug)] +pub(super) struct OperationError(ErrorKind); + +impl Display for OperationError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Error for OperationError {} + +impl ProvideErrorKind for OperationError { + fn retryable_error_kind(&self) -> Option { + Some(self.0) + } + + fn code(&self) -> Option<&str> { + None + } +} + +impl ParseHttpResponse for TestOperationParser { + type Output = Result; + + fn parse_unloaded(&self, response: &mut operation::Response) -> Option { + tracing::debug!("got response: {:?}", response); + match response.http().status() { + s if s.is_success() => None, + s if s.is_client_error() => Some(Err(OperationError(ErrorKind::ServerError))), + s if s.is_server_error() => Some(Err(OperationError(ErrorKind::TransientError))), + _ => panic!("unexpected status: {}", response.http().status()), + } + } + + fn parse_loaded(&self, response: &http::Response) -> Self::Output { + Ok(str::from_utf8(response.body().as_ref()) + .unwrap() + .to_string()) + } +} + +#[derive(Clone)] +pub(super) struct TestRetryClassifier; + +impl ClassifyRetry> for TestRetryClassifier +where + E: ProvideErrorKind + Debug, + T: Debug, +{ + fn classify_retry(&self, err: Result<&T, &SdkError>) -> RetryKind { + tracing::info!("got response: {:?}", err); + let kind = match err { + Err(SdkError::ServiceError(context)) => context.err().retryable_error_kind(), + Err(SdkError::DispatchFailure(err)) if err.is_timeout() => { + Some(ErrorKind::TransientError) + } + Err(SdkError::TimeoutError(_)) => Some(ErrorKind::TransientError), + Ok(_) => return RetryKind::Unnecessary, + _ => panic!("test handler only handles modeled errors got: {:?}", err), + }; + match kind { + Some(kind) => RetryKind::Error(kind), + None => RetryKind::UnretryableFailure, + } + } +} diff --git a/rust-runtime/aws-smithy-http-tower/src/dispatch.rs b/rust-runtime/aws-smithy-http-tower/src/dispatch.rs index 8a1119d61bd..a10693a62b5 100644 --- a/rust-runtime/aws-smithy-http-tower/src/dispatch.rs +++ b/rust-runtime/aws-smithy-http-tower/src/dispatch.rs @@ -5,6 +5,7 @@ use crate::SendOperationError; use aws_smithy_http::body::SdkBody; +use aws_smithy_http::connection::CaptureSmithyConnection; use aws_smithy_http::operation; use aws_smithy_http::result::ConnectorError; use std::future::Future; @@ -41,7 +42,13 @@ where } fn call(&mut self, req: operation::Request) -> Self::Future { - let (req, property_bag) = req.into_parts(); + let (mut req, property_bag) = req.into_parts(); + // copy the smithy connection + if let Some(smithy_conn) = property_bag.acquire().get::() { + req.extensions_mut().insert(smithy_conn.clone()); + } else { + println!("nothing to copy!"); + } let mut inner = self.inner.clone(); let future = async move { trace!(request = ?req, "dispatching request"); diff --git a/rust-runtime/aws-smithy-http/src/connection.rs b/rust-runtime/aws-smithy-http/src/connection.rs new file mode 100644 index 00000000000..b4e64256bf5 --- /dev/null +++ b/rust-runtime/aws-smithy-http/src/connection.rs @@ -0,0 +1,96 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::fmt::{Debug, Formatter}; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +#[derive(Clone)] +pub struct ConnectionMetadata { + is_proxied: bool, + remote_addr: Option, + poison_fn: Arc, +} + +impl Debug for ConnectionMetadata { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SmithyConnection") + .field("is_proxied", &self.is_proxied) + .field("remote_addr", &self.remote_addr) + .finish() + } +} + +type LoaderFn = dyn Fn() -> Option + Send + Sync; + +#[derive(Clone, Default)] +pub struct CaptureSmithyConnection { + loader: Arc>>>, +} + +impl CaptureSmithyConnection { + pub fn new() -> Self { + Self { + loader: Default::default(), + } + } + pub fn set_connection_retriever(&self, f: F) + where + F: Fn() -> Option + Send + Sync + 'static, + { + *self.loader.lock().unwrap() = Some(Box::new(f)); + } + + pub fn get(&self) -> Option { + match self.loader.lock().unwrap().as_ref() { + Some(loader) => loader(), + None => { + println!("no loader was set :-/"); + None + } + } + } +} + +impl ConnectionMetadata { + pub fn poison(&self) { + tracing::info!("smithy connection was poisoned"); + (self.poison_fn)() + } +} + +impl ConnectionMetadata { + pub fn new( + is_proxied: bool, + remote_addr: Option, + poison: impl Fn() + Send + Sync + 'static, + ) -> Self { + Self { + is_proxied, + remote_addr, + poison_fn: Arc::new(poison), + } + } + + pub fn remote_addr(&self) -> Option { + self.remote_addr + } +} + +#[cfg(test)] +mod test { + use crate::connection::{CaptureSmithyConnection, ConnectionMetadata}; + + #[test] + fn retrieve_connection_metadata() { + let mut retriever = CaptureSmithyConnection::new(); + let retriever_clone = retriever.clone(); + assert!(retriever.get().is_none()); + retriever.set_connection_retriever(|| Some(ConnectionMetadata::new(true, None, || {}))); + + assert!(retriever.get().is_some()); + assert!(retriever_clone.get().is_some()); + } +} diff --git a/rust-runtime/aws-smithy-http/src/lib.rs b/rust-runtime/aws-smithy-http/src/lib.rs index f777e15c82a..77156efc2f3 100644 --- a/rust-runtime/aws-smithy-http/src/lib.rs +++ b/rust-runtime/aws-smithy-http/src/lib.rs @@ -39,4 +39,5 @@ pub mod event_stream; pub mod byte_stream; +pub mod connection; mod urlencode; diff --git a/rust-runtime/aws-smithy-http/src/result.rs b/rust-runtime/aws-smithy-http/src/result.rs index f11dcc2d36e..a3fdfcc761e 100644 --- a/rust-runtime/aws-smithy-http/src/result.rs +++ b/rust-runtime/aws-smithy-http/src/result.rs @@ -12,6 +12,7 @@ //! `Result` wrapper types for [success](SdkSuccess) and [failure](SdkError) responses. +use crate::connection::ConnectionMetadata; use crate::operation; use aws_smithy_types::error::metadata::{ProvideErrorMetadata, EMPTY_ERROR_METADATA}; use aws_smithy_types::error::ErrorMetadata; @@ -240,6 +241,11 @@ impl DispatchFailure { pub fn is_other(&self) -> Option { self.source.is_other() } + + /// Returns the inner error if it is a connector error + pub fn as_connector_error(&self) -> Option<&ConnectorError> { + Some(&self.source) + } } /// Error context for [`SdkError::ResponseError`] @@ -505,6 +511,22 @@ enum ConnectorErrorKind { pub struct ConnectorError { kind: ConnectorErrorKind, source: BoxError, + connection: ConnectionStatus, +} + +#[non_exhaustive] +#[derive(Debug)] +pub(crate) enum ConnectionStatus { + /// This request was never connected to the remote + /// + /// This indicates the failure was during connection establishment + NeverConnected, + + /// It is unknown whether a connection was established + Unknown, + + /// The request connected to the remote prior to failure + Connected(ConnectionMetadata), } impl Display for ConnectorError { @@ -532,14 +554,28 @@ impl ConnectorError { Self { kind: ConnectorErrorKind::Timeout, source, + connection: ConnectionStatus::Unknown, } } + /// Include connection information along with this error + pub fn with_connection(mut self, info: ConnectionMetadata) -> Self { + self.connection = ConnectionStatus::Connected(info); + self + } + + /// Set the connection status on this error to report that a connection was never established + pub fn never_connected(mut self) -> Self { + self.connection = ConnectionStatus::NeverConnected; + self + } + /// Construct a [`ConnectorError`] from an error caused by the user (e.g. invalid HTTP request) pub fn user(source: BoxError) -> Self { Self { kind: ConnectorErrorKind::User, source, + connection: ConnectionStatus::Unknown, } } @@ -548,6 +584,7 @@ impl ConnectorError { Self { kind: ConnectorErrorKind::Io, source, + connection: ConnectionStatus::Unknown, } } @@ -558,6 +595,7 @@ impl ConnectorError { Self { source, kind: ConnectorErrorKind::Other(kind), + connection: ConnectionStatus::Unknown, } } @@ -583,4 +621,16 @@ impl ConnectorError { _ => None, } } + + /// Returns metadata about the connection + /// + /// If a connection was established and provided by the internal connector, a connection will + /// be returned. + pub fn connection_metadata(&self) -> Option<&ConnectionMetadata> { + match &self.connection { + ConnectionStatus::NeverConnected => None, + ConnectionStatus::Unknown => None, + ConnectionStatus::Connected(conn) => Some(conn), + } + } } diff --git a/rust-runtime/aws-smithy-types/src/retry.rs b/rust-runtime/aws-smithy-types/src/retry.rs index 43be79cae43..276cdbba939 100644 --- a/rust-runtime/aws-smithy-types/src/retry.rs +++ b/rust-runtime/aws-smithy-types/src/retry.rs @@ -143,6 +143,7 @@ pub struct RetryConfigBuilder { mode: Option, max_attempts: Option, initial_backoff: Option, + reconnect_mode: Option, } impl RetryConfigBuilder { @@ -208,6 +209,7 @@ impl RetryConfigBuilder { mode: self.mode.or(other.mode), max_attempts: self.max_attempts.or(other.max_attempts), initial_backoff: self.initial_backoff.or(other.initial_backoff), + reconnect_mode: self.reconnect_mode.or(other.reconnect_mode), } } @@ -219,6 +221,9 @@ impl RetryConfigBuilder { initial_backoff: self .initial_backoff .unwrap_or_else(|| Duration::from_secs(1)), + reconnect_mode: self + .reconnect_mode + .unwrap_or(ReconnectMode::ReconnectOnTransientError), } } } @@ -230,6 +235,20 @@ pub struct RetryConfig { mode: RetryMode, max_attempts: u32, initial_backoff: Duration, + reconnect_mode: ReconnectMode, +} + +/// Mode for connection re-establishment +/// +/// By default, when a transient error is encountered, the connection in use will be poisoned. This +/// behavior can be disabled by setting [`NoReconnect`] instead. +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ReconnectMode { + /// Reconnect on [`ErrorKind::TransientError`] + ReconnectOnTransientError, + + /// Disable reconnects + NoReconnect, } impl RetryConfig { @@ -239,6 +258,7 @@ impl RetryConfig { mode: RetryMode::Standard, max_attempts: 3, initial_backoff: Duration::from_secs(1), + reconnect_mode: ReconnectMode::ReconnectOnTransientError, } } @@ -260,6 +280,18 @@ impl RetryConfig { self } + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, when a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host but may increase the load on + /// the server. + /// + /// This behavior can be disabled by setting [`NoReconnect`] instead. + pub fn with_reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self { + self.reconnect_mode = reconnect_mode; + self + } + /// Set the multiplier used when calculating backoff times as part of an /// [exponential backoff with jitter](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/) /// strategy. Most services should work fine with the default duration of 1 second, but if you @@ -287,6 +319,11 @@ impl RetryConfig { self.mode } + /// Returns the [`ReconnectMode`] + pub fn reconnect_mode(&self) -> ReconnectMode { + self.reconnect_mode + } + /// Returns the max attempts. pub fn max_attempts(&self) -> u32 { self.max_attempts