diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/Writable.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/Writable.kt index 1e092b09f3..c453fe0653 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/Writable.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/rustlang/Writable.kt @@ -26,6 +26,13 @@ fun Writable.isEmpty(): Boolean { return writer.toString() == RustWriter.root().toString() } +operator fun Writable.plus(other: Writable): Writable { + val first = this + return writable { + rustTemplate("#{First:W}#{Second:W}", "First" to first, "Second" to other) + } +} + /** * Helper allowing a `Iterable` to be joined together using a `String` separator. */ diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt index 59d5dec2c0..44a0cfc176 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt @@ -129,7 +129,7 @@ class AwsJsonSerializerGenerator( } open class AwsJson( - private val coreCodegenContext: CoreCodegenContext, + val coreCodegenContext: CoreCodegenContext, private val awsJsonVersion: AwsJsonVersion, ) : Protocol { private val runtimeConfig = coreCodegenContext.runtimeConfig @@ -143,6 +143,8 @@ open class AwsJson( ) private val jsonDeserModule = RustModule.private("json_deser") + val version: AwsJsonVersion get() = awsJsonVersion + override val httpBindingResolver: HttpBindingResolver = AwsJsonHttpBindingResolver(coreCodegenContext.model, awsJsonVersion) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt index 3cddbf7640..629ee1b204 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt @@ -87,7 +87,7 @@ class RestJsonHttpBindingResolver( } } -class RestJson(private val coreCodegenContext: CoreCodegenContext) : Protocol { +open class RestJson(val coreCodegenContext: CoreCodegenContext) : Protocol { private val runtimeConfig = coreCodegenContext.runtimeConfig private val errorScope = arrayOf( "Bytes" to RuntimeType.Bytes, diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt index 94c76c2c90..2b99fe09ab 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt @@ -53,7 +53,7 @@ class RestXmlFactory( } } -open class RestXml(private val coreCodegenContext: CoreCodegenContext) : Protocol { +open class RestXml(val coreCodegenContext: CoreCodegenContext) : Protocol { private val restXml = coreCodegenContext.serviceShape.expectTrait() private val runtimeConfig = coreCodegenContext.runtimeConfig private val errorScope = arrayOf( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt index cfb09d376f..fa386b8887 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerRuntimeType.kt @@ -39,6 +39,8 @@ object ServerRuntimeType { fun ResponseRejection(runtimeConfig: RuntimeConfig) = RuntimeType("ResponseRejection", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::rejection") - fun Protocol(runtimeConfig: RuntimeConfig) = - RuntimeType("Protocol", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols") + fun Protocol(name: String, runtimeConfig: RuntimeConfig) = + RuntimeType(name, ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::protocols") + + fun Protocol(runtimeConfig: RuntimeConfig) = Protocol("Protocol", runtimeConfig) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt new file mode 100644 index 0000000000..786f611798 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerOperationGenerator.kt @@ -0,0 +1,66 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.client.rustlang.Writable +import software.amazon.smithy.rust.codegen.client.rustlang.asType +import software.amazon.smithy.rust.codegen.client.rustlang.documentShape +import software.amazon.smithy.rust.codegen.client.rustlang.rust +import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.client.rustlang.writable +import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext +import software.amazon.smithy.rust.codegen.client.util.toPascalCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency + +class ServerOperationGenerator( + coreCodegenContext: CoreCodegenContext, + private val operation: OperationShape, +) { + private val runtimeConfig = coreCodegenContext.runtimeConfig + private val codegenScope = + arrayOf( + "SmithyHttpServer" to + ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), + ) + private val symbolProvider = coreCodegenContext.symbolProvider + private val model = coreCodegenContext.model + + private val operationName = symbolProvider.toSymbol(operation).name.toPascalCase() + private val operationId = operation.id + + /** Returns `std::convert::Infallible` if the model provides no errors. */ + private fun operationError(): Writable = writable { + if (operation.errors.isEmpty()) { + rust("std::convert::Infallible") + } else { + rust("crate::error::${operationName}Error") + } + } + + fun render(writer: RustWriter) { + writer.documentShape(operation, model) + + writer.rustTemplate( + """ + pub struct $operationName; + + impl #{SmithyHttpServer}::operation::OperationShape for $operationName { + const NAME: &'static str = "${operationId.toString().replace("#", "##")}"; + + type Input = crate::input::${operationName}Input; + type Output = crate::output::${operationName}Output; + type Error = #{Error:W}; + } + """, + "Error" to operationError(), + *codegenScope, + ) + // Adds newline to end of render + writer.rust("") + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index f8eb568ecf..2a28b8fca3 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -7,14 +7,18 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.knowledge.TopDownIndex import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.rustlang.Attribute +import software.amazon.smithy.rust.codegen.client.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.client.rustlang.RustModule import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.client.rustlang.Visibility import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.DefaultPublicModules import software.amazon.smithy.rust.codegen.client.smithy.RustCrate import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator /** @@ -63,6 +67,36 @@ open class ServerServiceGenerator( ) { writer -> renderOperationRegistry(writer, operations) } + + // TODO(https://github.com/awslabs/smithy-rs/issues/1707): Remove, this is temporary. + rustCrate.withModule( + RustModule( + "operation_shape", + RustMetadata( + visibility = Visibility.PUBLIC, + additionalAttributes = listOf( + Attribute.DocHidden, + ), + ), + null, + ), + ) { writer -> + for (operation in operations) { + ServerOperationGenerator(coreCodegenContext, operation).render(writer) + } + } + + // TODO(https://github.com/awslabs/smithy-rs/issues/1707): Remove, this is temporary. + rustCrate.withModule( + RustModule("service", RustMetadata(visibility = Visibility.PUBLIC, additionalAttributes = listOf(Attribute.DocHidden)), null), + ) { writer -> + val serverProtocol = ServerProtocol.fromCoreProtocol(protocol) + ServerServiceGeneratorV2( + coreCodegenContext, + serverProtocol, + ).render(writer) + } + renderExtras(operations) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt new file mode 100644 index 0000000000..7e93e79caa --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGeneratorV2.kt @@ -0,0 +1,341 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import software.amazon.smithy.model.knowledge.TopDownIndex +import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.client.rustlang.RustReservedWords +import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.client.rustlang.Writable +import software.amazon.smithy.rust.codegen.client.rustlang.asType +import software.amazon.smithy.rust.codegen.client.rustlang.documentShape +import software.amazon.smithy.rust.codegen.client.rustlang.join +import software.amazon.smithy.rust.codegen.client.rustlang.rust +import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.client.rustlang.writable +import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext +import software.amazon.smithy.rust.codegen.client.util.toPascalCase +import software.amazon.smithy.rust.codegen.client.util.toSnakeCase +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol + +class ServerServiceGeneratorV2( + coreCodegenContext: CoreCodegenContext, + private val protocol: ServerProtocol, +) { + private val runtimeConfig = coreCodegenContext.runtimeConfig + private val codegenScope = + arrayOf( + "Bytes" to CargoDependency.Bytes.asType(), + "Http" to CargoDependency.Http.asType(), + "HttpBody" to CargoDependency.HttpBody.asType(), + "SmithyHttpServer" to + ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), + "Tower" to CargoDependency.Tower.asType(), + ) + private val model = coreCodegenContext.model + private val symbolProvider = coreCodegenContext.symbolProvider + + private val service = coreCodegenContext.serviceShape + private val serviceName = service.id.name.toPascalCase() + private val builderName = "${serviceName}Builder" + + /** Calculate all `operationShape`s contained within the `ServiceShape`. */ + private val index = TopDownIndex.of(coreCodegenContext.model) + private val operations = index.getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } + + /** The sequence of builder generics: `Op1`, ..., `OpN`. */ + private val builderOps = (1..operations.size).map { "Op$it" } + + /** The sequence of extension types: `Ext1`, ..., `ExtN`. */ + private val extensionTypes = (1..operations.size).map { "Exts$it" } + + /** The sequence of field names for the builder. */ + private val builderFieldNames = operations.map { RustReservedWords.escapeIfNeeded(symbolProvider.toSymbol(it).name.toSnakeCase()) } + + /** The sequence of operation struct names. */ + private val operationStructNames = operations.map { symbolProvider.toSymbol(it).name.toPascalCase() } + + /** A `Writable` block of "field: Type" for the builder. */ + private val builderFields = builderFieldNames.zip(builderOps).map { (name, type) -> "$name: $type" } + + /** A `Writable` block containing all the `Handler` and `Operation` setters for the builder. */ + private fun builderSetters(): Writable = writable { + for ((index, pair) in builderFieldNames.zip(operationStructNames).withIndex()) { + val (fieldName, structName) = pair + + // The new generics for the operation setter, using `NewOp` where appropriate. + val replacedOpGenerics = builderOps.withIndex().map { (innerIndex, item) -> + if (innerIndex == index) { + "NewOp" + } else { + item + } + } + + // The new generics for the operation setter, using `NewOp` where appropriate. + val replacedExtGenerics = extensionTypes.withIndex().map { (innerIndex, item) -> + if (innerIndex == index) { + "NewExts" + } else { + item + } + } + + // The new generics for the handler setter, using `NewOp` where appropriate. + val replacedOpServiceGenerics = builderOps.withIndex().map { (innerIndex, item) -> + if (innerIndex == index) writable { + rustTemplate( + """ + #{SmithyHttpServer}::operation::Operation<#{SmithyHttpServer}::operation::IntoService> + """, + *codegenScope, + ) + } else { + writable(item) + } + } + + // The assignment of fields, using value where appropriate. + val switchedFields = builderFieldNames.withIndex().map { (innerIndex, item) -> + if (index == innerIndex) { + "$item: value" + } else { + "$item: self.$item" + } + } + + rustTemplate( + """ + /// Sets the [`$structName`](crate::operation_shape::$structName) operation. + /// + /// This should be a closure satisfying the [`Handler`](#{SmithyHttpServer}::operation::Handler) trait. + /// See the [operation module documentation](#{SmithyHttpServer}::operation) for more information. + pub fn $fieldName(self, value: H) -> $builderName<#{HandlerSetterGenerics:W}> + where + H: #{SmithyHttpServer}::operation::Handler + { + use #{SmithyHttpServer}::operation::OperationShapeExt; + self.${fieldName}_operation(crate::operation_shape::$structName::from_handler(value)) + } + + /// Sets the [`$structName`](crate::operation_shape::$structName) operation. + /// + /// This should be an [`Operation`](#{SmithyHttpServer}::operation::Operation) created from + /// [`$structName`](crate::operation_shape::$structName) using either + /// [`OperationShape::from_handler`](#{SmithyHttpServer}::operation::OperationShapeExt::from_handler) or + /// [`OperationShape::from_service`](#{SmithyHttpServer}::operation::OperationShapeExt::from_service). + pub fn ${fieldName}_operation(self, value: NewOp) -> $builderName<${(replacedOpGenerics + replacedExtGenerics).joinToString(", ")}> + { + $builderName { + ${switchedFields.joinToString(", ")}, + _exts: std::marker::PhantomData + } + } + """, + "Protocol" to protocol.markerStruct(), + "HandlerSetterGenerics" to (replacedOpServiceGenerics + (replacedExtGenerics.map { writable(it) })).join(", "), + *codegenScope, + ) + + // Adds newline between setters. + rust("") + } + } + + /** Returns the constraints required for the `build` method. */ + private val buildConstraints = operations.zip(builderOps).zip(extensionTypes).map { (first, exts) -> + val (operation, type) = first + // TODO(https://github.com/awslabs/smithy-rs/issues/1713#issue-1365169734): The `Error = Infallible` is an + // excess requirement to stay at parity with existing builder. + writable { + rustTemplate( + """ + $type: #{SmithyHttpServer}::operation::Upgradable< + #{Marker}, + crate::operation_shape::${symbolProvider.toSymbol(operation).name.toPascalCase()}, + $exts, + B, + >, + $type::Service: Clone + Send + 'static, + <$type::Service as #{Tower}::Service<#{Http}::Request>>::Future: Send + 'static, + + $type::Service: #{Tower}::Service<#{Http}::Request, Error = std::convert::Infallible> + """, + "Marker" to protocol.markerStruct(), + *codegenScope, + ) + } + } + + /** Returns a `Writable` containing the builder struct definition and its implementations. */ + private fun builder(): Writable = writable { + val extensionTypesDefault = extensionTypes.map { "$it = ()" } + val structGenerics = (builderOps + extensionTypesDefault).joinToString(", ") + val builderGenerics = (builderOps + extensionTypes).joinToString(", ") + + // Generate router construction block. + val router = protocol + .routerConstruction( + builderFieldNames + .map { + writable { rustTemplate("self.$it.upgrade()") } + } + .asIterable(), + ) + rustTemplate( + """ + /// The service builder for [`$serviceName`]. + /// + /// Constructed via [`$serviceName::builder`]. + pub struct $builderName<$structGenerics> { + ${builderFields.joinToString(", ")}, + ##[allow(unused_parens)] + _exts: std::marker::PhantomData<(${extensionTypes.joinToString(", ")})> + } + + impl<$builderGenerics> $builderName<$builderGenerics> { + #{Setters:W} + } + + impl<$builderGenerics> $builderName<$builderGenerics> { + /// Constructs a [`$serviceName`] from the arguments provided to the builder. + pub fn build(self) -> $serviceName<#{SmithyHttpServer}::routing::Route> + where + #{BuildConstraints:W} + { + let router = #{Router:W}; + $serviceName { + router: #{SmithyHttpServer}::routing::routers::RoutingService::new(router), + } + } + } + """, + "Setters" to builderSetters(), + "BuildConstraints" to buildConstraints.join(", "), + "Router" to router, + *codegenScope, + ) + } + + /** A `Writable` comma delimited sequence of `MissingOperation`. */ + private val notSetGenerics = (1..operations.size).map { + writable { rustTemplate("#{SmithyHttpServer}::operation::MissingOperation", *codegenScope) } + } + + /** Returns a `Writable` comma delimited sequence of `builder_field: MissingOperation`. */ + private val notSetFields = builderFieldNames.map { + writable { + rustTemplate( + "$it: #{SmithyHttpServer}::operation::MissingOperation", + *codegenScope, + ) + } + } + + /** A `Writable` comma delimited sequence of `DummyOperation`. */ + private val internalFailureGenerics = (1..operations.size).map { writable { rustTemplate("#{SmithyHttpServer}::operation::FailOnMissingOperation", *codegenScope) } } + + /** A `Writable` comma delimited sequence of `builder_field: DummyOperation`. */ + private val internalFailureFields = builderFieldNames.map { + writable { + rustTemplate( + "$it: #{SmithyHttpServer}::operation::FailOnMissingOperation", + *codegenScope, + ) + } + } + + /** Returns a `Writable` containing the service struct definition and its implementations. */ + private fun struct(): Writable = writable { + documentShape(service, model) + + rustTemplate( + """ + ##[derive(Clone)] + pub struct $serviceName { + router: #{SmithyHttpServer}::routing::routers::RoutingService<#{Router}, #{Protocol}>, + } + + impl $serviceName<()> { + /// Constructs a builder for [`$serviceName`]. + pub fn builder() -> $builderName<#{NotSetGenerics:W}> { + $builderName { + #{NotSetFields:W}, + _exts: std::marker::PhantomData + } + } + + /// Constructs an unchecked builder for [`$serviceName`]. + /// + /// This will not enforce that all operations are set, however if an unset operation is used at runtime + /// it will return status code 500 and log an error. + pub fn unchecked_builder() -> $builderName<#{InternalFailureGenerics:W}> { + $builderName { + #{InternalFailureFields:W}, + _exts: std::marker::PhantomData + } + } + } + + impl $serviceName { + /// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService). + pub fn into_make_service(self) -> #{SmithyHttpServer}::routing::IntoMakeService { + #{SmithyHttpServer}::routing::IntoMakeService::new(self) + } + + /// Applies a layer uniformly to all routes. + pub fn layer(self, layer: &L) -> $serviceName + where + L: #{Tower}::Layer + { + $serviceName { + router: self.router.map(|s| s.layer(layer)) + } + } + } + + impl #{Tower}::Service<#{Http}::Request> for $serviceName + where + S: #{Tower}::Service<#{Http}::Request, Response = #{Http}::Response> + Clone, + RespB: #{HttpBody}::Body + Send + 'static, + RespB::Error: Into> + { + type Response = #{Http}::Response<#{SmithyHttpServer}::body::BoxBody>; + type Error = S::Error; + type Future = #{SmithyHttpServer}::routing::routers::RoutingFuture; + + fn poll_ready(&mut self, cx: &mut std::task::Context) -> std::task::Poll> { + self.router.poll_ready(cx) + } + + fn call(&mut self, request: #{Http}::Request) -> Self::Future { + self.router.call(request) + } + } + """, + "InternalFailureGenerics" to internalFailureGenerics.join(", "), + "InternalFailureFields" to internalFailureFields.join(", "), + "NotSetGenerics" to notSetGenerics.join(", "), + "NotSetFields" to notSetFields.join(", "), + "Router" to protocol.routerType(), + "Protocol" to protocol.markerStruct(), + *codegenScope, + ) + } + + fun render(writer: RustWriter) { + writer.rustTemplate( + """ + #{Builder:W} + + #{Struct:W} + """, + "Builder" to builder(), + "Struct" to struct(), + ) + } +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt new file mode 100644 index 0000000000..bf879d8f34 --- /dev/null +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -0,0 +1,191 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol + +import software.amazon.smithy.model.knowledge.TopDownIndex +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.rust.codegen.client.rustlang.Writable +import software.amazon.smithy.rust.codegen.client.rustlang.asType +import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.client.rustlang.writable +import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext +import software.amazon.smithy.rust.codegen.client.smithy.RuntimeConfig +import software.amazon.smithy.rust.codegen.client.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJson +import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJsonVersion +import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol +import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestJson +import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType + +private fun allOperations(coreCodegenContext: CoreCodegenContext): List { + val index = TopDownIndex.of(coreCodegenContext.model) + return index.getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } +} + +interface ServerProtocol : Protocol { + /** Returns the Rust marker struct enjoying `OperationShape`. */ + fun markerStruct(): RuntimeType + + /** Returns the Rust router type. */ + fun routerType(): RuntimeType + + /** + * Returns the construction of the `routerType` given a `ServiceShape`, a collection of operation values + * (`self.operation_name`, ...), and the `Model`. + */ + fun routerConstruction(operationValues: Iterable): Writable + + companion object { + /** Upgrades the core protocol to a `ServerProtocol`. */ + fun fromCoreProtocol(protocol: Protocol): ServerProtocol = when (protocol) { + is AwsJson -> ServerAwsJsonProtocol.fromCoreProtocol(protocol) + is RestJson -> ServerRestJsonProtocol.fromCoreProtocol(protocol) + is RestXml -> ServerRestXmlProtocol.fromCoreProtocol(protocol) + else -> throw IllegalStateException("unsupported protocol") + } + } +} + +class ServerAwsJsonProtocol( + coreCodegenContext: CoreCodegenContext, + awsJsonVersion: AwsJsonVersion, +) : AwsJson(coreCodegenContext, awsJsonVersion), ServerProtocol { + private val runtimeConfig = coreCodegenContext.runtimeConfig + private val codegenScope = arrayOf( + "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), + ) + private val symbolProvider = coreCodegenContext.symbolProvider + private val service = coreCodegenContext.serviceShape + + companion object { + fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = ServerAwsJsonProtocol(awsJson.coreCodegenContext, awsJson.version) + } + + override fun markerStruct(): RuntimeType { + val name = when (version) { + is AwsJsonVersion.Json10 -> { + "AwsJson10" + } + is AwsJsonVersion.Json11 -> { + "AwsJson11" + } + } + return ServerRuntimeType.Protocol(name, runtimeConfig) + } + + override fun routerType() = RuntimeType("AwsJsonRouter", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing::routers::aws_json") + + override fun routerConstruction(operationValues: Iterable): Writable = writable { + val allOperationShapes = allOperations(coreCodegenContext) + + // TODO(https://github.com/awslabs/smithy-rs/issues/1724#issue-1367509999): This causes a panic: "symbol + // visitor should not be invoked in service shapes" + // val serviceName = symbolProvider.toSymbol(service).name + val serviceName = service.id.name + val pairs = writable { + for ((operation, operationValue) in allOperationShapes.zip(operationValues)) { + val operationName = symbolProvider.toSymbol(operation).name + rustTemplate( + """ + ( + String::from("$serviceName.$operationName"), + #{SmithyHttpServer}::routing::Route::new(#{OperationValue:W}) + ), + """, + "OperationValue" to operationValue, + *codegenScope, + ) + } + } + rustTemplate( + """ + #{Router}::from_iter([#{Pairs:W}]) + """, + "Router" to routerType(), + "Pairs" to pairs, + ) + } +} + +private fun restRouterType(runtimeConfig: RuntimeConfig) = RuntimeType("RestRouter", ServerCargoDependency.SmithyHttpServer(runtimeConfig), "${runtimeConfig.crateSrcPrefix}_http_server::routing::routers::rest") + +private fun restRouterConstruction( + protocol: ServerProtocol, + operationValues: Iterable, + coreCodegenContext: CoreCodegenContext, +): Writable = writable { + val operations = allOperations(coreCodegenContext) + + // TODO(https://github.com/awslabs/smithy-rs/issues/1724#issue-1367509999): This causes a panic: "symbol visitor + // should not be invoked in service shapes" + // val serviceName = symbolProvider.toSymbol(service).name + val serviceName = coreCodegenContext.serviceShape.id.name + val pairs = writable { + for ((operationShape, operationValue) in operations.zip(operationValues)) { + val operationName = coreCodegenContext.symbolProvider.toSymbol(operationShape).name + val key = protocol.serverRouterRequestSpec( + operationShape, + operationName, + serviceName, + ServerCargoDependency.SmithyHttpServer(coreCodegenContext.runtimeConfig).asType().member("routing::request_spec"), + ) + rustTemplate( + """ + ( + #{Key:W}, + #{SmithyHttpServer}::routing::Route::new(#{OperationValue:W}) + ), + """, + "Key" to key, + "OperationValue" to operationValue, + "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(coreCodegenContext.runtimeConfig).asType(), + ) + } + } + rustTemplate( + """ + #{Router}::from_iter([#{Pairs:W}]) + """, + "Router" to protocol.routerType(), + "Pairs" to pairs, + ) +} + +class ServerRestJsonProtocol( + coreCodegenContext: CoreCodegenContext, +) : RestJson(coreCodegenContext), ServerProtocol { + val runtimeConfig = coreCodegenContext.runtimeConfig + + companion object { + fun fromCoreProtocol(restJson: RestJson): ServerRestJsonProtocol = ServerRestJsonProtocol(restJson.coreCodegenContext) + } + + override fun markerStruct() = ServerRuntimeType.Protocol("AwsRestJson1", runtimeConfig) + + override fun routerType() = restRouterType(runtimeConfig) + + override fun routerConstruction(operationValues: Iterable): Writable = restRouterConstruction(this, operationValues, coreCodegenContext) +} + +class ServerRestXmlProtocol( + coreCodegenContext: CoreCodegenContext, +) : RestXml(coreCodegenContext), ServerProtocol { + val runtimeConfig = coreCodegenContext.runtimeConfig + + companion object { + fun fromCoreProtocol(restXml: RestXml): ServerRestXmlProtocol { + return ServerRestXmlProtocol(restXml.coreCodegenContext) + } + } + + override fun markerStruct() = ServerRuntimeType.Protocol("AwsRestXml", runtimeConfig) + + override fun routerType() = restRouterType(runtimeConfig) + + override fun routerConstruction(operationValues: Iterable): Writable = restRouterConstruction(this, operationValues, coreCodegenContext) +} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index ebf52a91e0..7bb0e3d89c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -52,6 +52,7 @@ import software.amazon.smithy.rust.codegen.client.util.inputShape import software.amazon.smithy.rust.codegen.client.util.isStreaming import software.amazon.smithy.rust.codegen.client.util.orNull import software.amazon.smithy.rust.codegen.client.util.outputShape +import software.amazon.smithy.rust.codegen.client.util.toPascalCase import software.amazon.smithy.rust.codegen.client.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType @@ -75,6 +76,7 @@ class ServerProtocolTestGenerator( private val symbolProvider = coreCodegenContext.symbolProvider private val operationIndex = OperationIndex.of(coreCodegenContext.model) + private val serviceName = coreCodegenContext.serviceShape.id.name.toPascalCase() private val operations = TopDownIndex.of(coreCodegenContext.model).getContainedOperations(coreCodegenContext.serviceShape).sortedBy { it.id } private val operationInputOutputTypes = operations.associateWith { @@ -104,6 +106,7 @@ class ServerProtocolTestGenerator( "SmithyHttp" to CargoDependency.SmithyHttp(coreCodegenContext.runtimeConfig).asType(), "Http" to CargoDependency.Http.asType(), "Hyper" to CargoDependency.Hyper.asType(), + "Tokio" to ServerCargoDependency.TokioDev.asType(), "Tower" to CargoDependency.Tower.asType(), "SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(coreCodegenContext.runtimeConfig).asType(), "AssertEq" to CargoDependency.PrettyAssertions.asType().member("assert_eq!"), @@ -357,6 +360,8 @@ class ServerProtocolTestGenerator( rust("/* test case disabled for this protocol (not yet supported) */") return } + + // Test against original `OperationRegistryBuilder`. with(httpRequestTestCase) { renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) } @@ -364,6 +369,14 @@ class ServerProtocolTestGenerator( checkRequest(operationShape, operationSymbol, httpRequestTestCase, this) } + // Test against new service builder. + with(httpRequestTestCase) { + renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull()) + } + if (protocolSupport.requestBodyDeserialization) { + checkRequest2(operationShape, operationSymbol, httpRequestTestCase, this) + } + // Explicitly warn if the test case defined parameters that we aren't doing anything with with(httpRequestTestCase) { if (authScheme.isPresent) { @@ -495,11 +508,34 @@ class ServerProtocolTestGenerator( } } - private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + /** Returns the body of the request test. */ + private fun checkRequestHandler(operationShape: OperationShape, httpRequestTestCase: HttpRequestTestCase) = writable { val inputShape = operationShape.inputShape(coreCodegenContext.model) val outputShape = operationShape.outputShape(coreCodegenContext.model) + // Construct expected request. + withBlock("let expected = ", ";") { + instantiator.render(this, inputShape, httpRequestTestCase.params) + } + + checkRequestParams(inputShape, this) + + // Construct a dummy response. + withBlock("let response = ", ";") { + instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) + } + + if (operationShape.errors.isEmpty()) { + write("response") + } else { + write("Ok(response)") + } + } + + /** Checks the request using the `OperationRegistryBuilder`. */ + private fun checkRequest(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { val (inputT, outputT) = operationInputOutputTypes[operationShape]!! + rustWriter.withBlock( """ super::$PROTOCOL_TEST_HELPER_MODULE_NAME::build_router_and_make_request( @@ -509,29 +545,40 @@ class ServerProtocolTestGenerator( builder.${operationShape.toName()}((|input| Box::pin(async move { """, - "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await", + "})) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME::Fun<$inputT, $outputT>)}).await;", ) { - // Construct expected request. - rustWriter.withBlock("let expected = ", ";") { - instantiator.render(this, inputShape, httpRequestTestCase.params) - } - - checkRequestParams(inputShape, rustWriter) - - // Construct a dummy response. - rustWriter.withBlock("let response = ", ";") { - instantiator.render(this, outputShape, Node.objectNode(), Instantiator.defaultContext().copy(defaultsForRequiredFields = true)) - } - - if (operationShape.errors.isEmpty()) { - rustWriter.write("response") - } else { - rustWriter.write("Ok(response)") - } + checkRequestHandler(operationShape, httpRequestTestCase)() } } + /** Checks the request using the new service builder. */ + private fun checkRequest2(operationShape: OperationShape, operationSymbol: Symbol, httpRequestTestCase: HttpRequestTestCase, rustWriter: RustWriter) { + val (inputT, _) = operationInputOutputTypes[operationShape]!! + val operationName = RustReservedWords.escapeIfNeeded(operationSymbol.name.toSnakeCase()) + rustWriter.rustTemplate( + """ + let (sender, mut receiver) = #{Tokio}::sync::mpsc::channel(1); + let service = crate::service::$serviceName::unchecked_builder() + .$operationName(move |input: $inputT| { + let sender = sender.clone(); + async move { + let result = { #{Body:W} }; + sender.send(()).await.expect("receiver dropped early"); + result + } + }) + .build::<#{Hyper}::body::Body>(); + let http_response = #{Tower}::ServiceExt::oneshot(service, http_request) + .await + .expect("unable to make an HTTP request"); + assert!(receiver.recv().await.is_some()) + """, + "Body" to checkRequestHandler(operationShape, httpRequestTestCase), + *codegenScope, + ) + } + private fun checkRequestParams(inputShape: StructureShape, rustWriter: RustWriter) { if (inputShape.hasStreamingMember(model)) { // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index 151e826685..a272ae7828 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -95,7 +95,7 @@ class ServerAwsJsonSerializerGenerator( ) : StructuredDataSerializerGenerator by jsonSerializerGenerator class ServerAwsJson( - private val coreCodegenContext: CoreCodegenContext, + coreCodegenContext: CoreCodegenContext, private val awsJsonVersion: AwsJsonVersion, ) : AwsJson(coreCodegenContext, awsJsonVersion) { override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 4a885b5d0f..c232e6d647 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -74,6 +74,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerRequestBindingGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.http.ServerResponseBindingGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import java.util.logging.Logger /* @@ -119,6 +120,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( private val operationDeserModule = RustModule.private("operation_deser") private val operationSerModule = RustModule.private("operation_ser") private val typeConversionGenerator = TypeConversionGenerator(model, symbolProvider, runtimeConfig) + private val serverProtocol = ServerProtocol.fromCoreProtocol(protocol) private val codegenScope = arrayOf( "AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(), @@ -207,9 +209,31 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } + + impl #{SmithyHttpServer}::request::FromRequest<#{Marker}, B> for #{I} + where + B: #{SmithyHttpServer}::body::HttpBody + Send, + B: 'static, + ${streamingBodyTraitBounds(operationShape)} + B::Data: Send, + #{RequestRejection} : From<::Error> + { + type Rejection = #{RuntimeError}; + type Future = std::pin::Pin> + Send>>; + + fn from_request(request: #{http}::Request) -> Self::Future { + let fut = async move { + let mut request_parts = #{SmithyHttpServer}::request::RequestParts::new(request); + $inputName::from_request(&mut request_parts).await.map(|x| x.0) + }; + Box::pin(fut) + } + } + """.trimIndent(), *codegenScope, "I" to inputSymbol, + "Marker" to serverProtocol.markerStruct(), "parse_request" to serverParseRequest(operationShape), "verify_response_content_type" to verifyResponseContentType, ) @@ -265,10 +289,23 @@ private class ServerHttpBoundProtocolTraitImplGenerator( $intoResponseImpl } } + + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $outputName::Output(self).into_response() + } + } + + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{E} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $outputName::Error(self).into_response() + } + } """.trimIndent(), *codegenScope, "O" to outputSymbol, "E" to errorSymbol, + "Marker" to serverProtocol.markerStruct(), "serialize_response" to serverSerializeResponse(operationShape), "serialize_error" to serverSerializeError(operationShape), ) @@ -297,9 +334,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator( $intoResponseImpl } } + + impl #{SmithyHttpServer}::response::IntoResponse<#{Marker}> for #{O} { + fn into_response(self) -> #{SmithyHttpServer}::response::Response { + $outputName(self).into_response() + } + } """.trimIndent(), *codegenScope, "O" to outputSymbol, + "Marker" to serverProtocol.markerStruct(), "serialize_response" to serverSerializeResponse(operationShape), ) } diff --git a/design/src/rfcs/rfc0020_service_builder.md b/design/src/rfcs/rfc0020_service_builder.md index 9dad6cfe97..05b6e43bc6 100644 --- a/design/src/rfcs/rfc0020_service_builder.md +++ b/design/src/rfcs/rfc0020_service_builder.md @@ -860,5 +860,6 @@ A toy implementation of the combined proposal is presented in [this PR](https:// - - [x] Add middleware primitives and error types to `rust-runtime/aws-smithy-http-server`. - -- [ ] Add code generation which outputs new service builder. +- [x] Add code generation which outputs new service builder. + - - [ ] Deprecate `OperationRegistryBuilder`, `OperationRegistry` and `Router`. diff --git a/rust-runtime/aws-smithy-http-server/src/operation/mod.rs b/rust-runtime/aws-smithy-http-server/src/operation/mod.rs index 70199714c2..5fec3a2f15 100644 --- a/rust-runtime/aws-smithy-http-server/src/operation/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/operation/mod.rs @@ -168,11 +168,12 @@ //! - The intention of `PollError` is to signal that the underlying service is no longer able to take requests, so //! should be discarded. See [`Service::poll_ready`](tower::Service::poll_ready). //! -//! The [`UpgradeLayer`] and it's [`Layer::Service`] [`Upgrade`] are both parameterized by a protocol. This allows -//! for upgrading to `Service` to be protocol dependent. +//! The [`UpgradeLayer`] and it's [`Layer::Service`](tower::Layer::Service) [`Upgrade`] are both parameterized by a +//! protocol. This allows for upgrading to `Service` to be +//! protocol dependent. //! -//! The [`Operation::upgrade`] will apply [`UpgradeLayer`] to `S` then apply the [`Layer`] `L`. The service builder -//! provided to the user will perform this composition on `build`. +//! The [`Operation::upgrade`] will apply [`UpgradeLayer`] to `S` then apply the [`Layer`](tower::Layer) `L`. The +//! service builder provided to the user will perform this composition on `build`. //! //! [Smithy operation]: https://awslabs.github.io/smithy/2.0/spec/service-types.html#operation @@ -181,17 +182,14 @@ mod operation_service; mod shape; mod upgrade; -use tower::{ - layer::util::{Identity, Stack}, - Layer, -}; +use tower::layer::util::{Identity, Stack}; pub use handler::*; pub use operation_service::*; pub use shape::*; pub use upgrade::*; -/// A Smithy operation, represented by a [`Service`](tower::Service) `S` and a [`Layer`] `L`. +/// A Smithy operation, represented by a [`Service`](tower::Service) `S` and a [`Layer`](tower::Layer) `L`. /// /// The `L` is held and applied lazily during [`Operation::upgrade`]. pub struct Operation { @@ -199,8 +197,6 @@ pub struct Operation { layer: L, } -type StackedUpgradeService = , L> as Layer>::Service; - impl Operation { /// Applies a [`Layer`] to the operation _after_ it has been upgraded via [`Operation::upgrade`]. pub fn layer(self, layer: NewL) -> Operation> { @@ -209,20 +205,6 @@ impl Operation { layer: Stack::new(self.layer, layer), } } - - /// Takes the [`Operation`], containing the inner [`Service`](tower::Service) `S`, the HTTP [`Layer`] `L` and - /// composes them together using [`UpgradeLayer`] for a specific protocol and [`OperationShape`]. - /// - /// The composition is made explicit in the method constraints and return type. - pub fn upgrade(self) -> StackedUpgradeService - where - UpgradeLayer: Layer, - L: Layer< as Layer>::Service>, - { - let Self { inner, layer } = self; - let layer = Stack::new(UpgradeLayer::new(), layer); - layer.layer(inner) - } } impl Operation> { @@ -253,9 +235,6 @@ impl Operation> { } } -/// A marker struct indicating an [`Operation`] has not been set in a builder. -pub struct OperationNotSet; - /// The operation [`Service`](tower::Service) has two classes of failure modes - those specified by the Smithy model /// and those associated with [`Service::poll_ready`](tower::Service::poll_ready). pub enum OperationError { diff --git a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs index 17dd7b29ae..3d4a617ccf 100644 --- a/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs +++ b/rust-runtime/aws-smithy-http-server/src/operation/upgrade.rs @@ -4,7 +4,8 @@ */ use std::{ - future::Future, + convert::Infallible, + future::{Future, Ready}, marker::PhantomData, pin::Pin, task::{Context, Poll}, @@ -12,14 +13,17 @@ use std::{ use futures_util::ready; use pin_project_lite::pin_project; -use tower::{Layer, Service}; +use tower::{layer::util::Stack, Layer, Service}; +use tracing::error; use crate::{ + body::BoxBody, request::{FromParts, FromRequest}, response::IntoResponse, + runtime_error::InternalFailureException, }; -use super::{OperationError, OperationShape}; +use super::{Operation, OperationError, OperationShape}; /// A [`Layer`] responsible for taking an operation [`Service`], accepting and returning Smithy /// types and converting it into a [`Service`] taking and returning [`http`] types. @@ -203,11 +207,108 @@ where } fn call(&mut self, req: http::Request) -> Self::Future { + let clone = self.inner.clone(); + let service = std::mem::replace(&mut self.inner, clone); UpgradeFuture { - service: self.inner.clone(), + service, inner: Inner::FromRequest { inner: <(Op::Input, Exts) as FromRequest>::from_request(req), }, } } } + +/// Provides an interface to convert a representation of an operation to a HTTP [`Service`](tower::Service) with +/// canonical associated types. +pub trait Upgradable { + type Service: Service, Response = http::Response>; + + /// Performs an upgrade from a representation of an operation to a HTTP [`Service`](tower::Service). + fn upgrade(self) -> Self::Service; +} + +impl Upgradable for Operation +where + // `Op` is used to specify the operation shape + Op: OperationShape, + + // Smithy input must convert from a HTTP request + Op::Input: FromRequest, + // Smithy output must convert into a HTTP response + Op::Output: IntoResponse

, + // Smithy error must convert into a HTTP response + Op::Error: IntoResponse

, + + // Must be able to convert extensions + Exts: FromParts

, + + // The signature of the inner service is correct + S: Service<(Op::Input, Exts), Response = Op::Output, Error = OperationError> + Clone, + + // Layer applies correctly to `Upgrade` + L: Layer>, + + // The signature of the output is correct + L::Service: Service, Response = http::Response>, +{ + type Service = L::Service; + + /// Takes the [`Operation`](Operation), applies [`UpgradeLayer`] to + /// the modified `S`, then finally applies the modified `L`. + /// + /// The composition is made explicit in the method constraints and return type. + fn upgrade(self) -> Self::Service { + let layer = Stack::new(UpgradeLayer::new(), self.layer); + layer.layer(self.inner) + } +} + +/// A marker struct indicating an [`Operation`] has not been set in a builder. +/// +/// This does _not_ implement [`Upgradable`] purposely. +pub struct MissingOperation; + +/// A marker struct indicating an [`Operation`] has not been set in a builder. +/// +/// This _does_ implement [`Upgradable`] but produces a [`Service`] which always returns an internal failure message. +pub struct FailOnMissingOperation; + +impl Upgradable for FailOnMissingOperation +where + InternalFailureException: IntoResponse

, +{ + type Service = MissingFailure

; + + fn upgrade(self) -> Self::Service { + MissingFailure { _protocol: PhantomData } + } +} + +/// A [`Service`] which always returns an internal failure message and logs an error. +pub struct MissingFailure

{ + _protocol: PhantomData

, +} + +impl

Clone for MissingFailure

{ + fn clone(&self) -> Self { + MissingFailure { _protocol: PhantomData } + } +} + +impl Service for MissingFailure

+where + InternalFailureException: IntoResponse

, +{ + type Response = http::Response; + type Error = Infallible; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _request: R) -> Self::Future { + error!("the operation has not been set"); + std::future::ready(Ok(InternalFailureException.into_response())) + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/request.rs b/rust-runtime/aws-smithy-http-server/src/request.rs index 2775a68cf5..36b880e228 100644 --- a/rust-runtime/aws-smithy-http-server/src/request.rs +++ b/rust-runtime/aws-smithy-http-server/src/request.rs @@ -32,7 +32,10 @@ * DEALINGS IN THE SOFTWARE. */ -use std::future::{ready, Future, Ready}; +use std::{ + convert::Infallible, + future::{ready, Future, Ready}, +}; use futures_util::{ future::{try_join, MapErr, MapOk, TryJoin}, @@ -117,6 +120,25 @@ pub trait FromParts: Sized { fn from_parts(parts: &mut Parts) -> Result; } +impl

FromParts

for () { + type Rejection = Infallible; + + fn from_parts(_parts: &mut Parts) -> Result { + Ok(()) + } +} + +impl FromParts

for (T,) +where + T: FromParts

, +{ + type Rejection = T::Rejection; + + fn from_parts(parts: &mut Parts) -> Result { + Ok((T::from_parts(parts)?,)) + } +} + impl FromParts

for (T1, T2) where T1: FromParts

, diff --git a/rust-runtime/aws-smithy-http-server/src/response.rs b/rust-runtime/aws-smithy-http-server/src/response.rs index 4c7bab848d..f35c10ed3b 100644 --- a/rust-runtime/aws-smithy-http-server/src/response.rs +++ b/rust-runtime/aws-smithy-http-server/src/response.rs @@ -42,3 +42,9 @@ pub trait IntoResponse { /// Performs a conversion into a [`http::Response`]. fn into_response(self) -> http::Response; } + +impl

IntoResponse

for std::convert::Infallible { + fn into_response(self) -> http::Response { + match self {} + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs index 0c65e4aff9..3e8a67c9c1 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/mod.rs @@ -7,6 +7,11 @@ //! //! [Smithy specification]: https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html +use std::{ + convert::Infallible, + task::{Context, Poll}, +}; + use self::request_spec::RequestSpec; use self::routers::{aws_json::AwsJsonRouter, rest::RestRouter, RoutingService}; use crate::body::{boxed, Body, BoxBody, HttpBody}; @@ -14,10 +19,6 @@ use crate::error::BoxError; use crate::protocols::{AwsJson10, AwsJson11, AwsRestJson1, AwsRestXml}; use http::{Request, Response}; -use std::{ - convert::Infallible, - task::{Context, Poll}, -}; use tower::layer::Layer; use tower::{Service, ServiceBuilder}; use tower_http::map_response_body::MapResponseBodyLayer; @@ -30,7 +31,8 @@ mod lambda_handler; pub mod request_spec; mod route; -mod routers; +#[doc(hidden)] +pub mod routers; mod tiny_map; pub use self::lambda_handler::LambdaHandler; diff --git a/rust-runtime/aws-smithy-http-server/src/routing/route.rs b/rust-runtime/aws-smithy-http-server/src/routing/route.rs index 29862c1a6c..8964c9629f 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/route.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/route.rs @@ -52,7 +52,7 @@ pub struct Route { } impl Route { - pub(super) fn new(svc: T) -> Self + pub fn new(svc: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, diff --git a/rust-runtime/aws-smithy-http-server/src/routing/routers/mod.rs b/rust-runtime/aws-smithy-http-server/src/routing/routers/mod.rs index a6d0d2fc07..e70fe1abdf 100644 --- a/rust-runtime/aws-smithy-http-server/src/routing/routers/mod.rs +++ b/rust-runtime/aws-smithy-http-server/src/routing/routers/mod.rs @@ -12,12 +12,21 @@ use std::{ task::{Context, Poll}, }; -use futures_util::future::Either; +use bytes::Bytes; +use futures_util::{ + future::{Either, MapOk}, + TryFutureExt, +}; use http::Response; +use http_body::Body as HttpBody; use tower::{util::Oneshot, Service, ServiceExt}; use tracing::debug; -use crate::{body::BoxBody, response::IntoResponse}; +use crate::{ + body::{boxed, BoxBody}, + error::BoxError, + response::IntoResponse, +}; pub mod aws_json; pub mod rest; @@ -94,8 +103,8 @@ impl RoutingService { } type EitherOneshotReady = Either< - Oneshot>, - Ready>>::Response, >>::Error>>, + MapOk>, fn(>>::Response) -> http::Response>, + Ready, >>::Error>>, >; pin_project_lite::pin_project! { @@ -110,14 +119,19 @@ where S: Service>, { /// Creates a [`RoutingFuture`] from [`ServiceExt::oneshot`]. - pub(super) fn from_oneshot(future: Oneshot>) -> Self { + pub(super) fn from_oneshot(future: Oneshot>) -> Self + where + S: Service, Response = http::Response>, + RespB: HttpBody + Send + 'static, + RespB::Error: Into, + { Self { - inner: Either::Left(future), + inner: Either::Left(future.map_ok(|x| x.map(boxed))), } } /// Creates a [`RoutingFuture`] from [`Service::Response`]. - pub(super) fn from_response(response: S::Response) -> Self { + pub(super) fn from_response(response: http::Response) -> Self { Self { inner: Either::Right(ready(Ok(response))), } @@ -128,18 +142,20 @@ impl Future for RoutingFuture where S: Service>, { - type Output = ::Output; + type Output = Result, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().inner.poll(cx) } } -impl Service> for RoutingService +impl Service> for RoutingService where R: Router, - R::Service: Service, Response = http::Response> + Clone, + R::Service: Service, Response = http::Response> + Clone, R::Error: IntoResponse

+ Error, + RespB: HttpBody + Send + 'static, + RespB::Error: Into, { type Response = Response; type Error = >>::Error; diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index 78e94bc9fb..3151a25b36 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -21,7 +21,10 @@ //! and converts into the corresponding `RuntimeError`, and then it uses the its //! [`RuntimeError::into_response`] method to render and send a response. -use crate::{protocols::Protocol, response::Response}; +use crate::{ + protocols::{AwsJson10, AwsJson11, AwsRestJson1, AwsRestXml, Protocol}, + response::{IntoResponse, Response}, +}; #[derive(Debug)] pub enum RuntimeErrorKind { @@ -48,13 +51,52 @@ impl RuntimeErrorKind { } } +pub struct InternalFailureException; + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + RuntimeError::internal_failure_from_protocol(Protocol::AwsJson10).into_response() + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + RuntimeError::internal_failure_from_protocol(Protocol::AwsJson11).into_response() + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + RuntimeError::internal_failure_from_protocol(Protocol::RestJson1).into_response() + } +} + +impl IntoResponse for InternalFailureException { + fn into_response(self) -> http::Response { + RuntimeError::internal_failure_from_protocol(Protocol::RestXml).into_response() + } +} + #[derive(Debug)] pub struct RuntimeError { pub protocol: Protocol, pub kind: RuntimeErrorKind, } +impl

IntoResponse

for RuntimeError { + fn into_response(self) -> http::Response { + self.into_response() + } +} + impl RuntimeError { + pub fn internal_failure_from_protocol(protocol: Protocol) -> Self { + RuntimeError { + protocol, + kind: RuntimeErrorKind::InternalFailure(crate::Error::new(String::new())), + } + } + pub fn into_response(self) -> Response { let status_code = match self.kind { RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST,