Skip to content

Commit

Permalink
Refactor interceptor phases to improve optionality inside interceptors (
Browse files Browse the repository at this point in the history
#2670)

## Motivation and Context
This PR:
- Deletes `TraceProbe`
- Replaces the orchestrator's `Phase` helper with a couple of macros
- Introduces the concept of phases into `InterceptorContext` so that
input/output/error/request/response accessors don't need option wrappers
- Adds `TypeErasedError` so that `orchestrator::Error` can implement
`Error`
- Rewinds the interceptor context in the retry loop

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
jdisanti authored May 3, 2023
1 parent c75807c commit beedd2c
Show file tree
Hide file tree
Showing 18 changed files with 1,108 additions and 444 deletions.
23 changes: 13 additions & 10 deletions aws/rust-runtime/aws-runtime/src/invocation_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::error::BoxError;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
Expand Down Expand Up @@ -37,10 +38,10 @@ impl Default for InvocationIdInterceptor {
impl Interceptor for InvocationIdInterceptor {
fn modify_before_retry_loop(
&self,
context: &mut InterceptorContext,
context: &mut InterceptorContext<BeforeTransmit>,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let headers = context.request_mut()?.headers_mut();
let headers = context.request_mut().headers_mut();
let id = _cfg.get::<InvocationId>().unwrap_or(&self.id);
headers.append(AMZ_SDK_INVOCATION_ID, id.0.clone());
Ok(())
Expand Down Expand Up @@ -72,24 +73,26 @@ impl InvocationId {
mod tests {
use crate::invocation_id::InvocationIdInterceptor;
use aws_smithy_http::body::SdkBody;
use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
use aws_smithy_runtime_api::type_erasure::TypedBox;
use http::HeaderValue;

fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a HeaderValue {
context
.request()
.unwrap()
.headers()
.get(header_name)
.unwrap()
fn expect_header<'a>(
context: &'a InterceptorContext<BeforeTransmit>,
header_name: &str,
) -> &'a HeaderValue {
context.request().headers().get(header_name).unwrap()
}

#[test]
fn test_id_is_generated_and_set() {
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase())
.into_serialization_phase();
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let _ = context.take_input();
let mut context = context.into_before_transmit_phase();

let mut config = ConfigBag::base();
let interceptor = InvocationIdInterceptor::new();
Expand Down
12 changes: 8 additions & 4 deletions aws/rust-runtime/aws-runtime/src/recursion_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
use aws_types::os_shim_internal::Env;
Expand Down Expand Up @@ -39,10 +40,10 @@ impl RecursionDetectionInterceptor {
impl Interceptor for RecursionDetectionInterceptor {
fn modify_before_signing(
&self,
context: &mut InterceptorContext,
context: &mut InterceptorContext<BeforeTransmit>,
_cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let request = context.request_mut()?;
let request = context.request_mut();
if request.headers().contains_key(TRACE_ID_HEADER) {
return Ok(());
}
Expand Down Expand Up @@ -145,14 +146,17 @@ mod tests {
request = request.header(name, value);
}
let request = request.body(SdkBody::empty()).expect("must be valid");
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase())
.into_serialization_phase();
context.set_request(request);
let _ = context.take_input();
let mut context = context.into_before_transmit_phase();
let mut config = ConfigBag::base();

RecursionDetectionInterceptor { env }
.modify_before_signing(&mut context, &mut config)
.expect("interceptor must succeed");
let mutated_request = context.request().expect("request is still set");
let mutated_request = context.request();
for name in mutated_request.headers().keys() {
assert_eq!(
mutated_request.headers().get_all(name).iter().count(),
Expand Down
32 changes: 19 additions & 13 deletions aws/rust-runtime/aws-runtime/src/user_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

use aws_http::user_agent::{ApiMetadata, AwsUserAgent};
use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit;
use aws_smithy_runtime_api::client::interceptors::error::BoxError;
use aws_smithy_runtime_api::client::interceptors::{Interceptor, InterceptorContext};
use aws_smithy_runtime_api::config_bag::ConfigBag;
Expand Down Expand Up @@ -72,7 +73,7 @@ fn header_values(
impl Interceptor for UserAgentInterceptor {
fn modify_before_signing(
&self,
context: &mut InterceptorContext,
context: &mut InterceptorContext<BeforeTransmit>,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let api_metadata = cfg
Expand All @@ -95,7 +96,7 @@ impl Interceptor for UserAgentInterceptor {
Cow::Owned(ua)
});

let headers = context.request_mut()?.headers_mut();
let headers = context.request_mut().headers_mut();
let (user_agent, x_amz_user_agent) = header_values(&ua)?;
headers.append(USER_AGENT, user_agent);
headers.append(X_AMZ_USER_AGENT, x_amz_user_agent);
Expand All @@ -112,21 +113,30 @@ mod tests {
use aws_smithy_runtime_api::type_erasure::TypedBox;
use aws_smithy_types::error::display::DisplayErrorContext;

fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str {
fn expect_header<'a>(
context: &'a InterceptorContext<BeforeTransmit>,
header_name: &str,
) -> &'a str {
context
.request()
.unwrap()
.headers()
.get(header_name)
.unwrap()
.to_str()
.unwrap()
}

fn context() -> InterceptorContext<BeforeTransmit> {
let mut context = InterceptorContext::<()>::new(TypedBox::new("doesntmatter").erase())
.into_serialization_phase();
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let _ = context.take_input();
context.into_before_transmit_phase()
}

#[test]
fn test_overridden_ua() {
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let mut context = context();

let mut config = ConfigBag::base();
config.put(AwsUserAgent::for_tests());
Expand All @@ -149,8 +159,7 @@ mod tests {

#[test]
fn test_default_ua() {
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let mut context = context();

let api_metadata = ApiMetadata::new("some-service", "some-version");
let mut config = ConfigBag::base();
Expand Down Expand Up @@ -178,8 +187,7 @@ mod tests {

#[test]
fn test_app_name() {
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());
let mut context = context();

let api_metadata = ApiMetadata::new("some-service", "some-version");
let mut config = ConfigBag::base();
Expand Down Expand Up @@ -207,9 +215,7 @@ mod tests {

#[test]
fn test_api_metadata_missing() {
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
context.set_request(http::Request::builder().body(SdkBody::empty()).unwrap());

let mut context = context();
let mut config = ConfigBag::base();

let interceptor = UserAgentInterceptor::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class OperationRetryClassifiersFeature(
"RetryClassifiers" to smithyRuntimeApi.resolve("client::retries::RetryClassifiers"),
"OperationError" to codegenContext.symbolProvider.symbolForOperationError(operation),
"SdkError" to RuntimeType.smithyHttp(runtimeConfig).resolve("result::SdkError"),
"ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedBox"),
"ErasedError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("type_erasure::TypeErasedError"),
)

override fun section(section: OperationRuntimePluginSection) = when (section) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class EndpointParamsInterceptorGenerator(
"HttpResponse" to orchestrator.resolve("HttpResponse"),
"Interceptor" to interceptors.resolve("Interceptor"),
"InterceptorContext" to interceptors.resolve("InterceptorContext"),
"BeforeSerializationPhase" to interceptors.resolve("context::phase::BeforeSerialization"),
"InterceptorError" to interceptors.resolve("error::InterceptorError"),
"Params" to endpointTypesGenerator.paramsStruct(),
)
Expand All @@ -66,11 +67,10 @@ class EndpointParamsInterceptorGenerator(
impl #{Interceptor} for $interceptorName {
fn read_before_execution(
&self,
context: &#{InterceptorContext},
context: &#{InterceptorContext}<#{BeforeSerializationPhase}>,
cfg: &mut #{ConfigBag},
) -> Result<(), #{BoxError}> {
let _input = context.input()?;
let _input = _input
let _input = context.input()
.downcast_ref::<${operationInput.name}>()
.ok_or("failed to downcast to ${operationInput.name}")?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class ServiceRuntimePluginGenerator(
"Interceptors" to runtimeApi.resolve("client::interceptors::Interceptors"),
"SharedEndpointResolver" to http.resolve("endpoint::SharedEndpointResolver"),
"StaticAuthOptionResolver" to runtimeApi.resolve("client::auth::option_resolver::StaticAuthOptionResolver"),
"TraceProbe" to runtimeApi.resolve("client::orchestrator::TraceProbe"),
)
}

Expand Down Expand Up @@ -131,18 +130,6 @@ class ServiceRuntimePluginGenerator(
.expect("connection set");
cfg.set_connection(connection);
// TODO(RuntimePlugins): Add the TraceProbe to the config bag
cfg.set_trace_probe({
##[derive(Debug)]
struct StubTraceProbe;
impl #{TraceProbe} for StubTraceProbe {
fn dispatch_events(&self) {
// no-op
}
}
StubTraceProbe
});
#{additional_config}
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class FluentClientGenerator(
.await
.map_err(|err| {
err.map_service_error(|err| {
#{TypedBox}::<#{OperationError}>::assume_from(err)
#{TypedBox}::<#{OperationError}>::assume_from(err.into())
.expect("correct error type")
.unwrap()
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ class ResponseDeserializerGenerator(
pub(crate) fn $fnName<O, E>(result: Result<O, E>) -> Result<#{Output}, #{Error}>
where
O: std::fmt::Debug + Send + Sync + 'static,
E: std::fmt::Debug + Send + Sync + 'static,
E: std::error::Error + std::fmt::Debug + Send + Sync + 'static,
{
result.map(|output| #{TypedBox}::new(output).erase())
.map_err(|error| #{TypedBox}::new(error).erase())
.map_err(|error| #{TypedBox}::new(error).erase_error())
}
""",
*codegenScope,
Expand Down
Loading

0 comments on commit beedd2c

Please sign in to comment.