Skip to content

Commit

Permalink
Fix event stream :content-type for struct messages (smithy-lang#3603)
Browse files Browse the repository at this point in the history
Event stream operations with struct shaped messages were using the wrong
`:content-type` message header value, which I think wasn't caught before
since the supported AWS S3/Transcribe event stream operations don't
serialize struct messages. This PR fixes the message content type
serialization.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
jdisanti authored and Darwin Boersma committed May 5, 2024
1 parent 7736a4a commit dff0e6c
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 20 deletions.
30 changes: 28 additions & 2 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ let result = ec2_client.wait_until_instance_status_ok()
.await;
```
"""
references = ["aws-sdk-rust#400", "smithy-rs#3595", "smithy-rs#3593", "smithy-rs#3585", "smithy-rs#3571", "smithy-rs#3569"]
references = [
"aws-sdk-rust#400",
"smithy-rs#3595",
"smithy-rs#3593",
"smithy-rs#3585",
"smithy-rs#3571",
"smithy-rs#3569",
]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"

Expand All @@ -49,7 +56,14 @@ let result = client.wait_until_thing()
.await;
```
"""
references = ["smithy-rs#119", "smithy-rs#3595", "smithy-rs#3593", "smithy-rs#3585", "smithy-rs#3571", "smithy-rs#3569"]
references = [
"smithy-rs#119",
"smithy-rs#3595",
"smithy-rs#3593",
"smithy-rs#3585",
"smithy-rs#3571",
"smithy-rs#3569",
]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"

Expand All @@ -59,6 +73,18 @@ references = ["smithy-rs#3601"]
meta = { "breaking" = false, "tada" = true, "bug" = false }
author = "jdisanti"

[[smithy-rs]]
message = "Fix event stream `:content-type` message headers for struct messages. Note: this was the `:content-type` header on individual event message frames that was incorrect, not the HTTP `content-type` header for the initial request."
references = ["smithy-rs#3603"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all" }
author = "jdisanti"

[[smithy-rs]]
message = "Fix event stream `:content-type` message headers for struct messages. Note: this was the `:content-type` header on individual event message frames that was incorrect, not the HTTP `content-type` header for the initial request."
references = ["smithy-rs#3603"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all" }
author = "jdisanti"

[[smithy-rs]]
message = "Added support for BigInt and BigDecimal to Smithy model generation."
references = ["smithy-rs#312"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class AwsJsonHttpBindingResolver(
"application/x-amz-json-${awsJsonVersion.value}"

override fun responseContentType(operationShape: OperationShape): String = requestContentType(operationShape)

override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, "application/json")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.core.smithy.protocols

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
Expand Down Expand Up @@ -38,6 +39,9 @@ class AwsQueryCompatibleHttpBindingResolver(

override fun responseContentType(operationShape: OperationShape): String =
awsJsonHttpBindingResolver.requestContentType(operationShape)

override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
awsJsonHttpBindingResolver.eventStreamMessageContentType(memberShape)
}

class AwsQueryCompatible(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ package software.amazon.smithy.rust.codegen.core.smithy.protocols
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.ToShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
Expand Down Expand Up @@ -98,6 +100,11 @@ interface HttpBindingResolver {
* Determines the response content type for given [operationShape].
*/
fun responseContentType(operationShape: OperationShape): String?

/**
* Determines the value of the event stream `:content-type` header based on union member
*/
fun eventStreamMessageContentType(memberShape: MemberShape): String?
}

/**
Expand All @@ -108,20 +115,38 @@ data class ProtocolContentTypes(
val requestDocument: String? = null,
/** Response content type override for when the shape is a Document */
val responseDocument: String? = null,
/** EventStream content type */
/** EventStream content type initial request/response content-type */
val eventStreamContentType: String? = null,
/** EventStream content type for struct message shapes (for `:content-type`) */
val eventStreamMessageContentType: String? = null,
) {
companion object {
/** Create an instance of [ProtocolContentTypes] where all content types are the same */
fun consistent(type: String) = ProtocolContentTypes(type, type, type)
fun consistent(type: String) = ProtocolContentTypes(type, type, type, type)

/**
* Returns the event stream message `:content-type` for the given event stream union member shape.
*
* The `protocolContentType` is the content-type to use for non-string/non-blob shapes.
*/
fun eventStreamMemberContentType(
model: Model,
memberShape: MemberShape,
protocolContentType: String?,
): String? =
when (model.expectShape(memberShape.target)) {
is StringShape -> "text/plain"
is BlobShape -> "application/octet-stream"
else -> protocolContentType
}
}
}

/**
* An [HttpBindingResolver] that relies on the HttpTrait data in the Smithy models.
*/
open class HttpTraitHttpBindingResolver(
model: Model,
private val model: Model,
private val contentTypes: ProtocolContentTypes,
) : HttpBindingResolver {
private val httpIndex: HttpBindingIndex = HttpBindingIndex.of(model)
Expand Down Expand Up @@ -158,6 +183,9 @@ open class HttpTraitHttpBindingResolver(
contentTypes.eventStreamContentType,
).orNull()

override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, contentTypes.eventStreamMessageContentType)

// Sort the members after extracting them from the map to have a consistent order
private fun mappedBindings(bindings: Map<String, HttpBinding>): List<HttpBindingDescriptor> =
bindings.values.map(::HttpBindingDescriptor).sortedBy { it.memberName }
Expand All @@ -172,6 +200,7 @@ open class StaticHttpBindingResolver(
private val httpTrait: HttpTrait,
private val requestContentType: String,
private val responseContentType: String,
private val eventStreamMessageContentType: String? = null,
) : HttpBindingResolver {
private fun bindings(shape: ToShapeId?) =
shape?.let { model.expectShape(it.toShapeId()) }?.members()
Expand All @@ -192,4 +221,7 @@ open class StaticHttpBindingResolver(
override fun requestContentType(operationShape: OperationShape): String = requestContentType

override fun responseContentType(operationShape: OperationShape): String = responseContentType

override fun eventStreamMessageContentType(memberShape: MemberShape): String? =
ProtocolContentTypes.eventStreamMemberContentType(model, memberShape, eventStreamMessageContentType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ class HttpBoundProtocolPayloadGenerator(
if (operationShape.isInputEventStream(model) && target == CodegenTarget.CLIENT) {
val payloadMember = operationShape.inputShape(model).expectMember(payloadMemberName)
writer.serializeViaEventStream(
operationShape,
payloadMember,
serializerGenerator,
shapeName,
Expand All @@ -206,7 +205,6 @@ class HttpBoundProtocolPayloadGenerator(
} else if (operationShape.isOutputEventStream(model) && target == CodegenTarget.SERVER) {
val payloadMember = operationShape.outputShape(model).expectMember(payloadMemberName)
writer.serializeViaEventStream(
operationShape,
payloadMember,
serializerGenerator,
"output",
Expand Down Expand Up @@ -239,7 +237,6 @@ class HttpBoundProtocolPayloadGenerator(
}

private fun RustWriter.serializeViaEventStream(
operationShape: OperationShape,
memberShape: MemberShape,
serializerGenerator: StructuredDataSerializerGenerator,
outerName: String,
Expand All @@ -248,11 +245,10 @@ class HttpBoundProtocolPayloadGenerator(
val memberName = symbolProvider.toMemberName(memberShape)
val unionShape = model.expectShape(memberShape.target, UnionShape::class.java)

val contentType =
when (target) {
CodegenTarget.CLIENT -> httpBindingResolver.requestContentType(operationShape)
CodegenTarget.SERVER -> httpBindingResolver.responseContentType(operationShape)
}
val payloadContentType =
httpBindingResolver.eventStreamMessageContentType(memberShape)
?: throw CodegenException("event streams must set a content type")

val errorMarshallerConstructorFn =
EventStreamErrorMarshallerGenerator(
model,
Expand All @@ -261,7 +257,7 @@ class HttpBoundProtocolPayloadGenerator(
symbolProvider,
unionShape,
serializerGenerator,
contentType ?: throw CodegenException("event streams must set a content type"),
payloadContentType,
).render()
val marshallerConstructorFn =
EventStreamMarshallerGenerator(
Expand All @@ -271,7 +267,7 @@ class HttpBoundProtocolPayloadGenerator(
symbolProvider,
unionShape,
serializerGenerator,
contentType,
payloadContentType,
).render()

// TODO(EventStream): [RPC] RPC protocols need to send an initial message with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,15 @@ open class RestJson(val codegenContext: CodegenContext) : Protocol {
)

override val httpBindingResolver: HttpBindingResolver =
RestJsonHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/json", "application/json", "application/vnd.amazon.eventstream"))
RestJsonHttpBindingResolver(
codegenContext.model,
ProtocolContentTypes(
requestDocument = "application/json",
responseDocument = "application/json",
eventStreamContentType = "application/vnd.amazon.eventstream",
eventStreamMessageContentType = "application/json",
),
)

override val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ open class RestXml(val codegenContext: CodegenContext) : Protocol {
}

override val httpBindingResolver: HttpBindingResolver =
HttpTraitHttpBindingResolver(codegenContext.model, ProtocolContentTypes("application/xml", "application/xml", "application/vnd.amazon.eventstream"))
HttpTraitHttpBindingResolver(
codegenContext.model,
ProtocolContentTypes(
requestDocument = "application/xml",
responseDocument = "application/xml",
eventStreamContentType = "application/vnd.amazon.eventstream",
eventStreamMessageContentType = "application/xml",
),
)

override val defaultTimestampFormat: TimestampFormatTrait.Format =
TimestampFormatTrait.Format.DATE_TIME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ object EventStreamMarshallTestCases {
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithStruct"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
Expand Down Expand Up @@ -146,7 +146,7 @@ object EventStreamMarshallTestCases {
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithUnion"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
Expand Down Expand Up @@ -236,7 +236,7 @@ object EventStreamMarshallTestCases {
let headers = headers_to_map(message.headers());
assert_eq!(&str_header("event"), *headers.get(":message-type").unwrap());
assert_eq!(&str_header("MessageWithNoHeaderPayloadTraits"), *headers.get(":event-type").unwrap());
assert_eq!(&str_header(${testCase.requestContentType.dq()}), *headers.get(":content-type").unwrap());
assert_eq!(&str_header(${testCase.eventStreamMessageContentType.dq()}), *headers.get(":content-type").unwrap());
validate_body(
message.payload(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object EventStreamTestModels {
val mediaType: String,
val requestContentType: String,
val responseContentType: String,
val eventStreamMessageContentType: String,
val validTestStruct: String,
val validMessageWithNoHeaderPayloadTraits: String,
val validTestUnion: String,
Expand All @@ -130,6 +131,7 @@ object EventStreamTestModels {
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
Expand All @@ -145,6 +147,7 @@ object EventStreamTestModels {
mediaType = "application/x-amz-json-1.1",
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
Expand All @@ -160,6 +163,7 @@ object EventStreamTestModels {
mediaType = "application/xml",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/xml",
eventStreamMessageContentType = "application/xml",
validTestStruct =
"""
<TestStruct>
Expand Down

0 comments on commit dff0e6c

Please sign in to comment.