From eac52eb69c89d78c1844e9e2b0f0c3413031fc58 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Fri, 26 Apr 2024 16:37:50 -0700 Subject: [PATCH] Fix event stream `:content-type` for struct messages (#3603) Event stream operations with struct shaped messages were using the wrong `:content-type` message header value, which I think wasn't caught before since the supported AWS S3/Transcribe event stream operations don't serialize struct messages. This PR fixes the message content type serialization. ---- _By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice._ --- CHANGELOG.next.toml | 6 +++ .../codegen/core/smithy/protocols/AwsJson.kt | 3 ++ .../smithy/protocols/AwsQueryCompatible.kt | 4 ++ .../smithy/protocols/HttpBindingResolver.kt | 38 +++++++++++++++++-- .../HttpBoundProtocolPayloadGenerator.kt | 16 +++----- .../codegen/core/smithy/protocols/RestJson.kt | 10 ++++- .../codegen/core/smithy/protocols/RestXml.kt | 10 ++++- .../testutil/EventStreamMarshallTestCases.kt | 6 +-- .../core/testutil/EventStreamTestModels.kt | 4 ++ 9 files changed, 79 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index be839d59da..19ee75a757 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -58,3 +58,9 @@ message = "SDK crates now set the `rust-version` property in their Cargo.toml fi references = ["smithy-rs#3601"] meta = { "breaking" = false, "tada" = true, "bug" = false } author = "jdisanti" + +[[smithy-rs]] +message = "Fix event stream `:content-type` message headers for struct messages. Note: this was the `:content-type` header on individual event message frames that was incorrect, not the HTTP `content-type` header for the initial request." +references = ["smithy-rs#3603"] +meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all" } +author = "jdisanti" diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt index 1b54f4289f..486b443a6a 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsJson.kt @@ -83,6 +83,9 @@ class AwsJsonHttpBindingResolver( "application/x-amz-json-${awsJsonVersion.value}" override fun responseContentType(operationShape: OperationShape): String = requestContentType(operationShape) + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/json") } /** diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt index 7b15e81051..4cc4a2fa14 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/AwsQueryCompatible.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols +import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait @@ -38,6 +39,9 @@ class AwsQueryCompatibleHttpBindingResolver( override fun responseContentType(operationShape: OperationShape): String = awsJsonHttpBindingResolver.requestContentType(operationShape) + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + awsJsonHttpBindingResolver.eventStreamMessageContentType(memberShape) } class AwsQueryCompatible( diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt index 8eb693c41f..ad36e79190 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBindingResolver.kt @@ -8,8 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.ToShapeId import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.TimestampFormatTrait @@ -98,6 +100,11 @@ interface HttpBindingResolver { * Determines the response content type for given [operationShape]. */ fun responseContentType(operationShape: OperationShape): String? + + /** + * Determines the value of the event stream `:content-type` header based on union member + */ + fun eventStreamMessageContentType(memberShape: MemberShape): String? } /** @@ -108,12 +115,30 @@ data class ProtocolContentTypes( val requestDocument: String? = null, /** Response content type override for when the shape is a Document */ val responseDocument: String? = null, - /** EventStream content type */ + /** EventStream content type initial request/response content-type */ val eventStreamContentType: String? = null, + /** EventStream content type for struct message shapes (for `:content-type`) */ + val eventStreamMessageContentType: String? = null, ) { companion object { /** Create an instance of [ProtocolContentTypes] where all content types are the same */ - fun consistent(type: String) = ProtocolContentTypes(type, type, type) + fun consistent(type: String) = ProtocolContentTypes(type, type, type, type) + + /** + * Returns the event stream message `:content-type` for the given event stream union member shape. + * + * The `protocolContentType` is the content-type to use for non-string/non-blob shapes. + */ + fun eventStreamMemberContentType( + model: Model, + memberShape: MemberShape, + protocolContentType: String?, + ): String? = + when (model.expectShape(memberShape.target)) { + is StringShape -> "text/plain" + is BlobShape -> "application/octet-stream" + else -> protocolContentType + } } } @@ -121,7 +146,7 @@ data class ProtocolContentTypes( * An [HttpBindingResolver] that relies on the HttpTrait data in the Smithy models. */ open class HttpTraitHttpBindingResolver( - model: Model, + private val model: Model, private val contentTypes: ProtocolContentTypes, ) : HttpBindingResolver { private val httpIndex: HttpBindingIndex = HttpBindingIndex.of(model) @@ -158,6 +183,9 @@ open class HttpTraitHttpBindingResolver( contentTypes.eventStreamContentType, ).orNull() + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, contentTypes.eventStreamMessageContentType) + // Sort the members after extracting them from the map to have a consistent order private fun mappedBindings(bindings: Map): List = bindings.values.map(::HttpBindingDescriptor).sortedBy { it.memberName } @@ -172,6 +200,7 @@ open class StaticHttpBindingResolver( private val httpTrait: HttpTrait, private val requestContentType: String, private val responseContentType: String, + private val eventStreamMessageContentType: String? = null, ) : HttpBindingResolver { private fun bindings(shape: ToShapeId?) = shape?.let { model.expectShape(it.toShapeId()) }?.members() @@ -192,4 +221,7 @@ open class StaticHttpBindingResolver( override fun requestContentType(operationShape: OperationShape): String = requestContentType override fun responseContentType(operationShape: OperationShape): String = responseContentType + + override fun eventStreamMessageContentType(memberShape: MemberShape): String? = + ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, eventStreamMessageContentType) } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt index a8fa69a6c2..1c5c94d800 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/HttpBoundProtocolPayloadGenerator.kt @@ -197,7 +197,6 @@ class HttpBoundProtocolPayloadGenerator( if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) { val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName) writer.serializeViaEventStream( - operationShape, payloadMember, serializerGenerator, shapeName, @@ -206,7 +205,6 @@ class HttpBoundProtocolPayloadGenerator( } else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) { val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName) writer.serializeViaEventStream( - operationShape, payloadMember, serializerGenerator, "output", @@ -239,7 +237,6 @@ class HttpBoundProtocolPayloadGenerator( } private fun RustWriter.serializeViaEventStream( - operationShape: OperationShape, memberShape: MemberShape, serializerGenerator: StructuredDataSerializerGenerator, outerName: String, @@ -248,11 +245,10 @@ class HttpBoundProtocolPayloadGenerator( val memberName = symbolProvider.toMemberName(memberShape) val unionShape = model.expectShape(memberShape.target, UnionShape::class.java) - val contentType = - when (target) { - CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape) - CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape) - } + val payloadContentType = + httpBindingResolver.eventStreamMessageContentType(memberShape) + ?: throw CodegenException("event streams must set a content type") + val errorMarshallerConstructorFn = EventStreamErrorMarshallerGenerator( model, @@ -261,7 +257,7 @@ class HttpBoundProtocolPayloadGenerator( symbolProvider, unionShape, serializerGenerator, - contentType ?: throw CodegenException("event streams must set a content type"), + payloadContentType, ).render() val marshallerConstructorFn = EventStreamMarshallerGenerator( @@ -271,7 +267,7 @@ class HttpBoundProtocolPayloadGenerator( symbolProvider, unionShape, serializerGenerator, - contentType, + payloadContentType, ).render() // TODO(EventStream): [RPC] RPC protocols need to send an initial message with the diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt index c77c05b6c9..641548fc11 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestJson.kt @@ -74,7 +74,15 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol { ) override val httpBindingResolver: HttpBindingResolver = - RestJsonHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/json", "application/json", "application/vnd.amazon.eventstream")) + RestJsonHttpBindingResolver( + codegenContext.model, + ProtocolContentTypes( + requestDocument = "application/json", + responseDocument = "application/json", + eventStreamContentType = "application/vnd.amazon.eventstream", + eventStreamMessageContentType = "application/json", + ), + ) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt index 700fe60775..e8dab762a1 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/RestXml.kt @@ -36,7 +36,15 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol { } override val httpBindingResolver: HttpBindingResolver = - HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/xml", "application/xml", "application/vnd.amazon.eventstream")) + HttpTraitHttpBindingResolver( + codegenContext.model, + ProtocolContentTypes( + requestDocument = "application/xml", + responseDocument = "application/xml", + eventStreamContentType = "application/vnd.amazon.eventstream", + eventStreamMessageContentType = "application/xml", + ), + ) override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.DATE_TIME diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt index 65ae5019cb..ba37f3c421 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamMarshallTestCases.kt @@ -111,7 +111,7 @@ object EventStreamMarshallTestCases { let headers = headers_to_map(message.headers()); assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap()); validate_body( message.payload(), @@ -146,7 +146,7 @@ object EventStreamMarshallTestCases { let headers = headers_to_map(message.headers()); assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap()); validate_body( message.payload(), @@ -236,7 +236,7 @@ object EventStreamMarshallTestCases { let headers = headers_to_map(message.headers()); assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap()); assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap()); - assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap()); + assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap()); validate_body( message.payload(), diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt index c0f61e07db..dc37caf714 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/EventStreamTestModels.kt @@ -109,6 +109,7 @@ object EventStreamTestModels { val mediaType: String, val requestContentType: String, val responseContentType: String, + val eventStreamMessageContentType: String, val validTestStruct: String, val validMessageWithNoHeaderPayloadTraits: String, val validTestUnion: String, @@ -130,6 +131,7 @@ object EventStreamTestModels { mediaType = "application/json", requestContentType = "application/vnd.amazon.eventstream", responseContentType = "application/json", + eventStreamMessageContentType = "application/json", validTestStruct = """{"someString":"hello","someInt":5}""", validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", validTestUnion = """{"Foo":"hello"}""", @@ -145,6 +147,7 @@ object EventStreamTestModels { mediaType = "application/x-amz-json-1.1", requestContentType = "application/x-amz-json-1.1", responseContentType = "application/x-amz-json-1.1", + eventStreamMessageContentType = "application/json", validTestStruct = """{"someString":"hello","someInt":5}""", validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""", validTestUnion = """{"Foo":"hello"}""", @@ -160,6 +163,7 @@ object EventStreamTestModels { mediaType = "application/xml", requestContentType = "application/vnd.amazon.eventstream", responseContentType = "application/xml", + eventStreamMessageContentType = "application/xml", validTestStruct = """