diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index f97ce21668..72a3521b06 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -339,3 +339,59 @@ Server SDKs now correctly reject operation inputs that don't set values for `req references = ["smithy-rs#1714", "smithy-rs#1342", "smithy-rs#1860"] meta = { "breaking" = true, "tada" = false, "bug" = true, "target" = "server" } author = "david-perez" + +[[smithy-rs]] +message = """ +Generate enums that guide the users to write match expressions in a forward-compatible way. +Before this change, users could write a match expression against an enum in a non-forward-compatible way: +```rust +match some_enum { + SomeEnum::Variant1 => { /* ... */ }, + SomeEnum::Variant2 => { /* ... */ }, + Unknown(value) if value == "NewVariant" => { /* ... */ }, + _ => { /* ... */ }, +} +``` +This code can handle a case for "NewVariant" with a version of SDK where the enum does not yet include `SomeEnum::NewVariant`, but breaks with another version of SDK where the enum defines `SomeEnum::NewVariant` because the execution will hit a different match arm, i.e. the last one. +After this change, users are guided to write the above match expression as follows: +```rust +match some_enum { + SomeEnum::Variant1 => { /* ... */ }, + SomeEnum::Variant2 => { /* ... */ }, + other @ _ if other.as_str() == "NewVariant" => { /* ... */ }, + _ => { /* ... */ }, +} +``` +This is forward-compatible because the execution will hit the second last match arm regardless of whether the enum defines `SomeEnum::NewVariant` or not. +""" +references = ["smithy-rs#1945"] +meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client"} +author = "ysaito1001" + +[[aws-sdk-rust]] +message = """ +Generate enums that guide the users to write match expressions in a forward-compatible way. +Before this change, users could write a match expression against an enum in a non-forward-compatible way: +```rust +match some_enum { + SomeEnum::Variant1 => { /* ... */ }, + SomeEnum::Variant2 => { /* ... */ }, + Unknown(value) if value == "NewVariant" => { /* ... */ }, + _ => { /* ... */ }, +} +``` +This code can handle a case for "NewVariant" with a version of SDK where the enum does not yet include `SomeEnum::NewVariant`, but breaks with another version of SDK where the enum defines `SomeEnum::NewVariant` because the execution will hit a different match arm, i.e. the last one. +After this change, users are guided to write the above match expression as follows: +```rust +match some_enum { + SomeEnum::Variant1 => { /* ... */ }, + SomeEnum::Variant2 => { /* ... */ }, + other @ _ if other.as_str() == "NewVariant" => { /* ... */ }, + _ => { /* ... */ }, +} +``` +This is forward-compatible because the execution will hit the second last match arm regardless of whether the enum defines `SomeEnum::NewVariant` or not. +""" +references = ["smithy-rs#1945"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "ysaito1001" diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt index 95657353d9..6d335929a7 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/CodegenVisitor.kt @@ -88,6 +88,7 @@ class CodegenVisitor( RustModule.Input, RustModule.Output, RustModule.Config, + RustModule.Types, RustModule.operation(Visibility.PUBLIC), ).associateBy { it.name } rustCrate = RustCrate( diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt index d9b3a2b45a..1bd51e2222 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customizations/SmithyTypesPubUseGenerator.kt @@ -8,15 +8,12 @@ package software.amazon.smithy.rust.codegen.client.smithy.customizations import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.asType -import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization -import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection +import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.util.hasEventStreamMember import software.amazon.smithy.rust.codegen.core.util.hasStreamingMember @@ -71,22 +68,12 @@ internal fun pubUseTypes(runtimeConfig: RuntimeConfig, model: Model): List pubUseType.shouldExport(model) }.map { it.type } } -class SmithyTypesPubUseGenerator(private val runtimeConfig: RuntimeConfig) : LibRsCustomization() { - override fun section(section: LibRsSection) = - writable { - when (section) { - is LibRsSection.Body -> { - val types = pubUseTypes(runtimeConfig, section.model) - if (types.isNotEmpty()) { - docs("Re-exported types from supporting crates.") - rustBlock("pub mod types") { - types.forEach { type -> rust("pub use #T;", type) } - } - } - } - - else -> { - } - } +/** Adds re-export statements in a separate file for the types module */ +fun pubUseSmithyTypes(runtimeConfig: RuntimeConfig, model: Model, rustCrate: RustCrate) { + rustCrate.withModule(RustModule.Types) { + val types = pubUseTypes(runtimeConfig, model) + if (types.isNotEmpty()) { + types.forEach { type -> rust("pub use #T;", type) } } + } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt index 3ae662bfcd..0567f3a13a 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/customize/RequiredCustomizations.kt @@ -15,10 +15,11 @@ import software.amazon.smithy.rust.codegen.client.smithy.customizations.HttpVers import software.amazon.smithy.rust.codegen.client.smithy.customizations.IdempotencyTokenGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.customizations.ResiliencyReExportCustomization -import software.amazon.smithy.rust.codegen.client.smithy.customizations.SmithyTypesPubUseGenerator +import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseSmithyTypes import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.core.rustlang.Feature +import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.customize.OperationCustomization @@ -55,7 +56,6 @@ class RequiredCustomizations : RustCodegenDecorator, ): List = baseCustomizations + CrateVersionGenerator() + - SmithyTypesPubUseGenerator(codegenContext.runtimeConfig) + AllowLintsGenerator() override fun extras(codegenContext: ClientCodegenContext, rustCrate: RustCrate) { @@ -64,6 +64,8 @@ class RequiredCustomizations : RustCodegenDecorator): Boolean = diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt index 4860b627d2..aa7a0fb8d3 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/rustlang/RustModule.kt @@ -29,6 +29,7 @@ data class RustModule(val name: String, val rustMetadata: RustMetadata, val docu val Model = public("model", documentation = "Data structures used by operation inputs/outputs.") val Input = public("input", documentation = "Input structures for operations.") val Output = public("output", documentation = "Output structures for operations.") + val Types = public("types", documentation = "Data primitives referenced by other data types.") /** * Helper method to generate the `operation` Rust module. diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt index 27ea5dcc68..b1dba3e33f 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/CodegenTarget.kt @@ -9,5 +9,23 @@ package software.amazon.smithy.rust.codegen.core.smithy * Code generation mode: In some situations, codegen has different behavior for client vs. server (eg. required fields) */ enum class CodegenTarget { - CLIENT, SERVER + CLIENT, SERVER; + + /** + * Convenience method to execute thunk if the target is for CLIENT + */ + fun ifClient(thunk: () -> B): B? = if (this == CLIENT) { + thunk() + } else { + null + } + + /** + * Convenience method to execute thunk if the target is for SERVER + */ + fun ifServer(thunk: () -> B): B? = if (this == SERVER) { + thunk() + } else { + null + } } diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt index 8c5a56b577..e8fa8b06e0 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/SymbolMetadataProvider.kt @@ -57,6 +57,7 @@ abstract class SymbolMetadataProvider(private val base: RustSymbolProvider) : Wr is StringShape -> if (shape.hasTrait()) { enumMeta(shape) } else null + else -> null } return baseSymbol.toBuilder().meta(meta).build() @@ -100,11 +101,13 @@ class BaseSymbolMetadataProvider( ) } } + container.isUnionShape || container.isListShape || container.isSetShape || container.isMapShape -> RustMetadata(visibility = Visibility.PUBLIC) + else -> TODO("Unrecognized container type: $container") } } @@ -120,9 +123,10 @@ class BaseSymbolMetadataProvider( override fun enumMeta(stringShape: StringShape): RustMetadata { return containerDefault.withDerives( RuntimeType.std.member("hash::Hash"), - ).withDerives( // enums can be eq because they can only contain strings + ).withDerives( + // enums can be eq because they can only contain ints and strings RuntimeType.std.member("cmp::Eq"), - // enums can be Ord because they can only contain strings + // enums can be Ord because they can only contain ints and strings RuntimeType.std.member("cmp::PartialOrd"), RuntimeType.std.member("cmp::Ord"), ) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index eed5b5462e..0acfe2641d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.model.traits.DocumentationTrait import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape import software.amazon.smithy.rust.codegen.core.rustlang.docs @@ -99,6 +100,9 @@ open class EnumGenerator( /** Name of the generated unknown enum member name for enums with named members. */ const val UnknownVariant = "Unknown" + /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */ + const val UnknownVariantValue = "UnknownVariantValue" + /** Name of the function on the enum impl to get a vec of value names */ const val Values = "values" } @@ -153,6 +157,10 @@ open class EnumGenerator( } private fun renderEnum() { + target.ifClient { + writer.renderForwardCompatibilityNote(enumName, sortedMembers, UnknownVariant, UnknownVariantValue) + } + val renamedWarning = sortedMembers.mapNotNull { it.name() }.filter { it.renamedFrom != null }.joinToString("\n") { val previousName = it.renamedFrom!! @@ -167,9 +175,9 @@ open class EnumGenerator( meta.render(writer) writer.rustBlock("enum $enumName") { sortedMembers.forEach { member -> member.render(writer) } - if (target == CodegenTarget.CLIENT) { - docs("$UnknownVariant contains new variants that have been added since this code was generated.") - rust("$UnknownVariant(String)") + target.ifClient { + docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") + rust("$UnknownVariant(#T)", unknownVariantValue()) } } } @@ -182,8 +190,9 @@ open class EnumGenerator( sortedMembers.forEach { member -> rust("""$enumName::${member.derivedName()} => ${member.value.dq()},""") } - if (target == CodegenTarget.CLIENT) { - rust("$enumName::$UnknownVariant(s) => s.as_ref()") + + target.ifClient { + rust("$enumName::$UnknownVariant(value) => value.as_str()") } } } @@ -198,6 +207,28 @@ open class EnumGenerator( } } + private fun unknownVariantValue(): RuntimeType { + return RuntimeType.forInlineFun(UnknownVariantValue, RustModule.Types) { + docs( + """ + Opaque struct used as inner data for the `Unknown` variant defined in enums in + the crate + + While this is not intended to be used directly, it is marked as `pub` because it is + part of the enums that are public interface. + """.trimIndent(), + ) + meta.render(this) + rust("struct $UnknownVariantValue(pub(crate) String);") + rustBlock("impl $UnknownVariantValue") { + // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct. + rustBlock("pub(crate) fn as_str(&self) -> &str") { + rust("&self.0") + } + } + } + } + protected open fun renderFromForStr() { writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.From) { rustBlock("fn from(s: &str) -> Self") { @@ -205,7 +236,7 @@ open class EnumGenerator( sortedMembers.forEach { member -> rust("""${member.value.dq()} => $enumName::${member.derivedName()},""") } - rust("other => $enumName::$UnknownVariant(other.to_owned())") + rust("other => $enumName::$UnknownVariant(#T(other.to_owned()))", unknownVariantValue()) } } } @@ -225,3 +256,61 @@ open class EnumGenerator( ) } } + +/** + * Generate the rustdoc describing how to write a match expression against a generated enum in a + * forward-compatible way. + */ +private fun RustWriter.renderForwardCompatibilityNote( + enumName: String, sortedMembers: List, + unknownVariant: String, unknownVariantValue: String, +) { + docs( + """ + When writing a match expression against `$enumName`, it is important to ensure + your code is forward-compatible. That is, if a match arm handles a case for a + feature that is supported by the service but has not been represented as an enum + variant in a current version of SDK, your code should continue to work when you + upgrade SDK to a future version in which the enum does include a variant for that + feature. + """.trimIndent(), + ) + docs("") + docs("Here is an example of how you can make a match expression forward-compatible:") + docs("") + docs("```text") + rust("/// ## let ${enumName.lowercase()} = unimplemented!();") + rust("/// match ${enumName.lowercase()} {") + sortedMembers.mapNotNull { it.name() }.forEach { member -> + rust("/// $enumName::${member.name} => { /* ... */ },") + } + rust("""/// other @ _ if other.as_str() == "NewFeature" => { /* handles a case for `NewFeature` */ },""") + rust("/// _ => { /* ... */ },") + rust("/// }") + docs("```") + docs( + """ + The above code demonstrates that when `${enumName.lowercase()}` represents + `NewFeature`, the execution path will lead to the second last match arm, + even though the enum does not contain a variant `$enumName::NewFeature` + in the current version of SDK. The reason is that the variable `other`, + created by the `@` operator, is bound to + `$enumName::$unknownVariant($unknownVariantValue("NewFeature".to_owned()))` + and calling `as_str` on it yields `"NewFeature"`. + This match expression is forward-compatible when executed with a newer + version of SDK where the variant `$enumName::NewFeature` is defined. + Specifically, when `${enumName.lowercase()}` represents `NewFeature`, + the execution path will hit the second last match arm as before by virtue of + calling `as_str` on `$enumName::NewFeature` also yielding `"NewFeature"`. + """.trimIndent(), + ) + docs("") + docs( + """ + Explicitly matching on the `$unknownVariant` variant should + be avoided for two reasons: + - The inner data `$unknownVariantValue` is opaque, and no further information can be extracted. + - It might inadvertently shadow other intended match arms. + """.trimIndent(), + ) +} diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index 90679f9773..963a436d08 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -9,13 +9,17 @@ import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.orNull @@ -106,28 +110,34 @@ class EnumGeneratorTest { @deprecated(since: "1.2.3") string InstanceType """.asSmithyModel() - val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") + val shape = model.lookup("test#InstanceType") - val generator = EnumGenerator(model, provider, writer, shape, shape.expectTrait()) - generator.render() - writer.compileAndTest( - """ - let instance = InstanceType::T2Micro; - assert_eq!(instance.as_str(), "t2.micro"); - assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); - assert_eq!(InstanceType::from("other"), InstanceType::Unknown("other".to_owned())); - // round trip unknown variants: - assert_eq!(InstanceType::from("other").as_str(), "other"); - """, - ) - val output = writer.toString() - output shouldContain "#[non_exhaustive]" - // on enum variant `T2Micro` - output shouldContain "#[deprecated]" - // on enum itself - output shouldContain "#[deprecated(since = \"1.2.3\")]" + val trait = shape.expectTrait() + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_generates_named_enums", + """ + let instance = InstanceType::T2Micro; + assert_eq!(instance.as_str(), "t2.micro"); + assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); + assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned()))); + // round trip unknown variants: + assert_eq!(InstanceType::from("other").as_str(), "other"); + """, + ) + val output = toString() + output.shouldContain("#[non_exhaustive]") + // on enum variant `T2Micro` + output.shouldContain("#[deprecated]") + // on enum itself + output.shouldContain("#[deprecated(since = \"1.2.3\")]") + } + project.compileAndTest() } @Test @@ -145,19 +155,25 @@ class EnumGeneratorTest { }]) string FooEnum """.asSmithyModel() + val shape = model.lookup("test#FooEnum") val trait = shape.expectTrait() - val writer = RustWriter.forModule("model") - val generator = EnumGenerator(model, testSymbolProvider(model), writer, shape, trait) - generator.render() - writer.compileAndTest( - """ - assert_eq!(FooEnum::Foo, FooEnum::Foo); - assert_ne!(FooEnum::Bar, FooEnum::Foo); - let mut hash_of_enums = std::collections::HashSet::new(); - hash_of_enums.insert(FooEnum::Foo); - """.trimIndent(), - ) + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "named_enums_implement_eq_and_hash", + """ + assert_eq!(FooEnum::Foo, FooEnum::Foo); + assert_ne!(FooEnum::Bar, FooEnum::Foo); + let mut hash_of_enums = std::collections::HashSet::new(); + hash_of_enums.insert(FooEnum::Foo); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -174,20 +190,26 @@ class EnumGeneratorTest { @deprecated string FooEnum """.asSmithyModel() + val shape = model.lookup("test#FooEnum") val trait = shape.expectTrait() - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, testSymbolProvider(model), writer, shape, trait) - generator.render() - writer.compileAndTest( - """ - assert_eq!(FooEnum::from("Foo"), FooEnum::from("Foo")); - assert_ne!(FooEnum::from("Bar"), FooEnum::from("Foo")); - let mut hash_of_enums = std::collections::HashSet::new(); - hash_of_enums.insert(FooEnum::from("Foo")); - """, - ) + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "unnamed_enums_implement_eq_and_hash", + """ + assert_eq!(FooEnum::from("Foo"), FooEnum::from("Foo")); + assert_ne!(FooEnum::from("Bar"), FooEnum::from("Foo")); + let mut hash_of_enums = std::collections::HashSet::new(); + hash_of_enums.insert(FooEnum::from("Foo")); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -213,19 +235,24 @@ class EnumGeneratorTest { ]) string FooEnum """.asSmithyModel() + val shape = model.lookup("test#FooEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - writer.rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, writer, shape, trait) - generator.render() - writer.compileAndTest( - """ - // Values should be sorted - assert_eq!(FooEnum::${EnumGenerator.Values}(), ["0", "1", "Bar", "Baz", "Foo"]); - """, - ) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_generates_unnamed_enums", + """ + // Values should be sorted + assert_eq!(FooEnum::${EnumGenerator.Values}(), ["0", "1", "Bar", "Baz", "Foo"]); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -240,19 +267,23 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - EnumGenerator(model, provider, writer, shape, trait).render() - - writer.compileAndTest( - """ - assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); - assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into())); - """, - ) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model", + """ + assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); + assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); + """.trimIndent(), + ) + } + project.compileAndTest() } @Test @@ -268,17 +299,22 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString() - - rendered shouldContain - """ - /// Some top-level documentation. - /// - /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ - """.trimIndent() + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + val rendered = toString() + rendered shouldContain + """ + /// Some top-level documentation. + /// + /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ + """.trimIndent() + } + project.compileAndTest() } @Test @@ -294,15 +330,20 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val rendered = RustWriter.forModule("model").also { EnumGenerator(model, provider, it, shape, trait).render() }.toString() - - rendered shouldContain - """ - /// Some top-level documentation. - """.trimIndent() + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + val rendered = toString() + rendered shouldContain + """ + /// Some top-level documentation. + """.trimIndent() + } + project.compileAndTest() } } @@ -317,17 +358,72 @@ class EnumGeneratorTest { string SomeEnum """.asSmithyModel() - val shape: StringShape = model.lookup("test#SomeEnum") + val shape = model.lookup("test#SomeEnum") val trait = shape.expectTrait() val provider = testSymbolProvider(model) - val writer = RustWriter.forModule("model") - EnumGenerator(model, provider, writer, shape, trait).render() + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "it_handles_variants_that_clash_with_rust_reserved_words", + """ + assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); + """.trimIndent(), + ) + } + project.compileAndTest() + } + + @Test + fun `matching on enum should be forward-compatible`() { + fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { + val shape = model.lookup(shapeId) + val trait = shape.expectTrait() + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + val generator = EnumGenerator(model, provider, this, shape, trait) + generator.render() + unitTest( + "matching_on_enum_should_be_forward_compatible", + """ + match $enumToMatchOn { + SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), + SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), + other @ _ if other.as_str() == "Variant3" => assert!(true), + _ => assert!(false, "expected `Variant3` but got `_`"), + } + """.trimIndent(), + ) + } + project.compileAndTest() + } - writer.compileAndTest( - """ - assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown("SomethingNew".into())); - """, - ) + val modelV1 = """ + namespace test + + @enum([ + { name: "Variant1", value: "Variant1" }, + { name: "Variant2", value: "Variant2" }, + ]) + string SomeEnum + """.asSmithyModel() + val variant3AsUnknown = """SomeEnum::from("Variant3")""" + expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown) + + val modelV2 = """ + namespace test + + @enum([ + { name: "Variant1", value: "Variant1" }, + { name: "Variant2", value: "Variant2" }, + { name: "Variant3", value: "Variant3" }, + ]) + string SomeEnum + """.asSmithyModel() + val variant3AsVariant3 = "SomeEnum::Variant3" + expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt index a8063cf898..9b2e94c87d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/ServerRequiredCustomizations.kt @@ -7,7 +7,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.customizations import software.amazon.smithy.rust.codegen.client.smithy.customizations.AllowLintsGenerator import software.amazon.smithy.rust.codegen.client.smithy.customizations.CrateVersionGenerator -import software.amazon.smithy.rust.codegen.client.smithy.customizations.SmithyTypesPubUseGenerator +import software.amazon.smithy.rust.codegen.client.smithy.customizations.pubUseSmithyTypes import software.amazon.smithy.rust.codegen.client.smithy.customize.RustCodegenDecorator import software.amazon.smithy.rust.codegen.core.rustlang.Feature import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext @@ -31,11 +31,13 @@ class ServerRequiredCustomizations : RustCodegenDecorator, ): List = - baseCustomizations + CrateVersionGenerator() + SmithyTypesPubUseGenerator(codegenContext.runtimeConfig) + AllowLintsGenerator() + baseCustomizations + CrateVersionGenerator() + AllowLintsGenerator() override fun extras(codegenContext: ServerCodegenContext, rustCrate: RustCrate) { // Add rt-tokio feature for `ByteStream::from_path` rustCrate.mergeFeature(Feature("rt-tokio", true, listOf("aws-smithy-http/rt-tokio"))) + + pubUseSmithyTypes(codegenContext.runtimeConfig, codegenContext.model, rustCrate) } override fun supportsCodegenContext(clazz: Class): Boolean =