diff --git a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt index 5183d4051a..3627b41584 100644 --- a/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt +++ b/aws/sdk-codegen/src/test/kotlin/software/amazon/smithy/rustsdk/SigV4AuthDecoratorTest.kt @@ -5,7 +5,12 @@ package software.amazon.smithy.rustsdk import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.integrationTest +import software.amazon.smithy.rust.codegen.core.testutil.tokioTest class SigV4AuthDecoratorTest { private val modelWithSigV4AuthScheme = @@ -35,6 +40,7 @@ class SigV4AuthDecoratorTest { structure SomeInput { @httpPayload + @required something: Bytestream } @@ -56,10 +62,115 @@ class SigV4AuthDecoratorTest { @unsignedPayload @http(uri: "/", method: "POST") operation SomeOperation { input: SomeInput, output: SomeOutput } - """.asSmithyModel() + """.asSmithyModel(smithyVersion = "2.0") @Test fun unsignedPayloadSetsCorrectHeader() { awsSdkIntegrationTest(modelWithSigV4AuthScheme) { _, _ -> } } + + private val modelWithSigV4aAuthScheme = + """ + namespace test + + use aws.auth#sigv4 + use aws.auth#sigv4a + use aws.api#service + use aws.protocols#restJson1 + use smithy.rules#endpointRuleSet + use aws.auth#unsignedPayload + use smithy.test#httpRequestTests + + @auth([sigv4a,sigv4]) + @sigv4(name: "dontcare") + @sigv4a(name: "dontcare") + @restJson1 + @endpointRuleSet({ + "version": "1.0", + "rules": [ + { + "type": "endpoint", + "conditions": [], + "endpoint": { + "url": "https://example.com", + "properties": { + "authSchemes": [ + { + "name": "sigv4a", + "signingRegionSet": ["*"], + "signingName": "dontcare" + } + ] + } + } + } + ], + "parameters": { + "endpoint": { "required": true, "type": "string", "builtIn": "SDK::Endpoint" }, + } + }) + @service(sdkId: "dontcare") + service TestService { version: "2023-01-01", operations: [SomeOperation] } + + @streaming + blob Bytestream + + structure SomeInput { + @httpPayload + @required + something: Bytestream + } + + structure SomeOutput { something: String } + + @http(uri: "/", method: "POST") + operation SomeOperation { input: SomeInput, output: SomeOutput } + """.asSmithyModel(smithyVersion = "2.0") + + @Test + fun unsignedPayloadSetsCorrectHeaderForSigV4a() { + awsSdkIntegrationTest(modelWithSigV4aAuthScheme) { clientCodegenContext, rustCrate -> + val moduleUseName = clientCodegenContext.moduleUseName() + val rc = clientCodegenContext.runtimeConfig + + rustCrate.integrationTest("sigv4a") { + Attribute.featureGate("test-util").render(this) + tokioTest("test_sigv4a_signing") { + rustTemplate( + """ + let http_client = #{StaticReplayClient}::new(vec![#{ReplayEvent}::new( + #{Request}::builder() + .header("authorization", "AWS4-ECDSA-P256-SHA256 Credential=ANOTREAL/20090213/dontcare/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-region-set;x-amz-user-agent, Signature=3045022100b95d1c054ff04b676d12f0c893348606844d67ccf595981f0ca4968fae2eddfd022073e66edc0ad1da05b08392fccefa3ad69f8ec9393461033412fa05c55b749e9d") + .uri("https://example.com") + .body(#{SdkBody}::from("Hello, world!")) + .unwrap(), + #{Response}::builder().status(200).body(#{SdkBody}::empty()).unwrap(), + )]); + let config = $moduleUseName::Config::builder() + .http_client(http_client.clone()) + .endpoint_url("https://example.com") + .behavior_version_latest() + .with_test_defaults() + .build(); + let client = $moduleUseName::Client::from_conf(config); + let _ = client.some_operation().something(#{ByteStream}::from_static(b"Hello, world!")).send().await; + + http_client.assert_requests_match(&["authorization"]); + let auth_header = http_client.actual_requests().next().unwrap().headers().get(http::header::AUTHORIZATION).unwrap(); + assert!(auth_header.contains("AWS4-ECDSA-P256-SHA256")); + """, + "ByteStream" to RuntimeType.byteStream(rc), + "Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(rc).resolve("Credentials"), + "Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"), + "ReplayEvent" to RuntimeType.smithyRuntimeTestUtil(rc).resolve("ReplayEvent"), + "Request" to RuntimeType.HttpRequest, + "Response" to RuntimeType.HttpResponse, + "SdkBody" to RuntimeType.sdkBody(rc), + "StaticReplayClient" to RuntimeType.smithyRuntimeTestUtil(rc).resolve("StaticReplayClient"), + "tracing_subscriber" to RuntimeType.TracingSubscriber, + ) + } + } + } + } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt index 61771ad61d..7ad456b02a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/RuntimeType.kt @@ -279,14 +279,12 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val Bytes = CargoDependency.Bytes.toType().resolve("Bytes") val Http = CargoDependency.Http.toType() val HttpBody = CargoDependency.HttpBody.toType() - val HttpHeaderMap = Http.resolve("HeaderMap") val HttpRequest = Http.resolve("Request") val HttpRequestBuilder = Http.resolve("request::Builder") val HttpResponse = Http.resolve("Response") val HttpResponseBuilder = Http.resolve("response::Builder") val Hyper = CargoDependency.Hyper.toType() val LazyStatic = CargoDependency.LazyStatic.toType() - val Md5 = CargoDependency.Md5.toType() val OnceCell = CargoDependency.OnceCell.toType() val PercentEncoding = CargoDependency.PercentEncoding.toType() val PrettyAssertions = CargoDependency.PrettyAssertions.toType() @@ -294,12 +292,12 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) val Serde = CargoDependency.Serde.toType() val SerdeDeserialize = Serde.resolve("Deserialize") val SerdeSerialize = Serde.resolve("Serialize") - val RegexLite = CargoDependency.RegexLite.toType() val Tokio = CargoDependency.Tokio.toType() val TokioStream = CargoDependency.TokioStream.toType() val Tower = CargoDependency.Tower.toType() val Tracing = CargoDependency.Tracing.toType() val TracingTest = CargoDependency.TracingTest.toType() + val TracingSubscriber = CargoDependency.TracingSubscriber.toType() // codegen types val ConstrainedTrait = RuntimeType("crate::constrained::Constrained", InlineDependency.constrained()) @@ -312,8 +310,6 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) fun smithyChecksums(runtimeConfig: RuntimeConfig) = CargoDependency.smithyChecksums(runtimeConfig).toType() - fun smithyCompression(runtimeConfig: RuntimeConfig) = CargoDependency.smithyCompression(runtimeConfig).toType() - fun smithyEventStream(runtimeConfig: RuntimeConfig) = CargoDependency.smithyEventStream(runtimeConfig).toType() fun smithyHttp(runtimeConfig: RuntimeConfig) = CargoDependency.smithyHttp(runtimeConfig).toType() @@ -444,7 +440,7 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) forInlineDependency(InlineDependency.awsQueryCompatibleErrors(runtimeConfig)) fun defaultAuthPlugin(runtimeConfig: RuntimeConfig) = - RuntimeType.forInlineDependency(InlineDependency.defaultAuthPlugin(runtimeConfig)) + forInlineDependency(InlineDependency.defaultAuthPlugin(runtimeConfig)) .resolve("DefaultAuthOptionsPlugin") fun labelFormat( @@ -502,9 +498,11 @@ data class RuntimeType(val path: String, val dependency: RustDependency? = null) return smithyTypes(runtimeConfig).resolve("date_time::Format::$timestampFormat") } + fun smithyRuntimeTestUtil(runtimeConfig: RuntimeConfig) = + CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType().resolve("client::http::test_util") + fun captureRequest(runtimeConfig: RuntimeConfig) = - CargoDependency.smithyRuntimeTestUtil(runtimeConfig).toType() - .resolve("client::http::test_util::capture_request") + smithyRuntimeTestUtil(runtimeConfig).resolve("capture_request") fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency)