diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt index 44099cceb1..1816240ef9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/ClientCodegenVisitor.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator +import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientEnumGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceGenerator import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ClientProtocolGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientProtocolLoader @@ -29,7 +30,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator -import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.error.OperationErrorGenerator @@ -43,7 +43,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveSha import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamErrors import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.core.util.CommandFailed -import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.hasTrait import software.amazon.smithy.rust.codegen.core.util.isEventStream import software.amazon.smithy.rust.codegen.core.util.letIf @@ -206,9 +205,9 @@ class ClientCodegenVisitor( * Although raw strings require no code generation, enums are actually `EnumTrait` applied to string shapes. */ override fun stringShape(shape: StringShape) { - shape.getTrait()?.also { enum -> + if (shape.hasTrait()) { rustCrate.useShapeWriter(shape) { - EnumGenerator(model, symbolProvider, this, shape, enum).render() + ClientEnumGenerator(codegenContext, shape).render(this) } } } diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt new file mode 100644 index 0000000000..b4c988dc6b --- /dev/null +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt @@ -0,0 +1,170 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumMemberModel +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumType +import software.amazon.smithy.rust.codegen.core.util.dq + +/** Infallible enums have an `Unknown` variant and can't fail to parse */ +data class InfallibleEnumType( + val unknownVariantModule: RustModule, +) : EnumType() { + companion object { + /** Name of the generated unknown enum member name for enums with named members. */ + const val UnknownVariant = "Unknown" + + /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */ + const val UnknownVariantValue = "UnknownVariantValue" + } + + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + rustTemplate( + """ + impl #{From}<&str> for ${context.enumName} { + fn from(s: &str) -> Self { + match s { + #{matchArms} + } + } + } + """, + "From" to RuntimeType.From, + "matchArms" to writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") + } + rust( + "other => ${context.enumName}::$UnknownVariant(#T(other.to_owned()))", + unknownVariantValue(context), + ) + }, + ) + } + + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + rust( + """ + impl std::str::FromStr for ${context.enumName} { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(${context.enumName}::from(s)) + } + } + """, + ) + } + + override fun additionalDocs(context: EnumGeneratorContext): Writable = writable { + renderForwardCompatibilityNote(context.enumName, context.sortedMembers, UnknownVariant, UnknownVariantValue) + } + + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { + docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") + rust("$UnknownVariant(#T)", unknownVariantValue(context)) + } + + override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable { + rust("${context.enumName}::$UnknownVariant(value) => value.as_str()") + } + + private fun unknownVariantValue(context: EnumGeneratorContext): RuntimeType { + return RuntimeType.forInlineFun(UnknownVariantValue, RustModule.Types) { + docs( + """ + Opaque struct used as inner data for the `Unknown` variant defined in enums in + the crate + + While this is not intended to be used directly, it is marked as `pub` because it is + part of the enums that are public interface. + """.trimIndent(), + ) + context.enumMeta.render(this) + rust("struct $UnknownVariantValue(pub(crate) String);") + rustBlock("impl $UnknownVariantValue") { + // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct. + rustBlock("pub(crate) fn as_str(&self) -> &str") { + rust("&self.0") + } + } + } + } + + /** + * Generate the rustdoc describing how to write a match expression against a generated enum in a + * forward-compatible way. + */ + private fun RustWriter.renderForwardCompatibilityNote( + enumName: String, sortedMembers: List, + unknownVariant: String, unknownVariantValue: String, + ) { + docs( + """ + When writing a match expression against `$enumName`, it is important to ensure + your code is forward-compatible. That is, if a match arm handles a case for a + feature that is supported by the service but has not been represented as an enum + variant in a current version of SDK, your code should continue to work when you + upgrade SDK to a future version in which the enum does include a variant for that + feature. + """.trimIndent(), + ) + docs("") + docs("Here is an example of how you can make a match expression forward-compatible:") + docs("") + docs("```text") + rust("/// ## let ${enumName.lowercase()} = unimplemented!();") + rust("/// match ${enumName.lowercase()} {") + sortedMembers.mapNotNull { it.name() }.forEach { member -> + rust("/// $enumName::${member.name} => { /* ... */ },") + } + rust("""/// other @ _ if other.as_str() == "NewFeature" => { /* handles a case for `NewFeature` */ },""") + rust("/// _ => { /* ... */ },") + rust("/// }") + docs("```") + docs( + """ + The above code demonstrates that when `${enumName.lowercase()}` represents + `NewFeature`, the execution path will lead to the second last match arm, + even though the enum does not contain a variant `$enumName::NewFeature` + in the current version of SDK. The reason is that the variable `other`, + created by the `@` operator, is bound to + `$enumName::$unknownVariant($unknownVariantValue("NewFeature".to_owned()))` + and calling `as_str` on it yields `"NewFeature"`. + This match expression is forward-compatible when executed with a newer + version of SDK where the variant `$enumName::NewFeature` is defined. + Specifically, when `${enumName.lowercase()}` represents `NewFeature`, + the execution path will hit the second last match arm as before by virtue of + calling `as_str` on `$enumName::NewFeature` also yielding `"NewFeature"`. + """.trimIndent(), + ) + docs("") + docs( + """ + Explicitly matching on the `$unknownVariant` variant should + be avoided for two reasons: + - The inner data `$unknownVariantValue` is opaque, and no further information can be extracted. + - It might inadvertently shadow other intended match arms. + """.trimIndent(), + ) + } +} + +class ClientEnumGenerator(codegenContext: CodegenContext, shape: StringShape) : + EnumGenerator(codegenContext.model, codegenContext.symbolProvider, shape, InfallibleEnumType(RustModule.Types)) diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt new file mode 100644 index 0000000000..c366fbe5ee --- /dev/null +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGeneratorTest.kt @@ -0,0 +1,161 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.client.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext +import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.core.util.lookup + +class ClientEnumGeneratorTest { + @Test + fun `matching on enum should be forward-compatible`() { + fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { + val shape = model.lookup(shapeId) + val context = testCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.withModule(RustModule.Model) { + ClientEnumGenerator(context, shape).render(this) + unitTest( + "matching_on_enum_should_be_forward_compatible", + """ + match $enumToMatchOn { + SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), + SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), + other @ _ if other.as_str() == "Variant3" => assert!(true), + _ => assert!(false, "expected `Variant3` but got `_`"), + } + """.trimIndent(), + ) + } + project.compileAndTest() + } + + val modelV1 = """ + namespace test + + @enum([ + { name: "Variant1", value: "Variant1" }, + { name: "Variant2", value: "Variant2" }, + ]) + string SomeEnum + """.asSmithyModel() + val variant3AsUnknown = """SomeEnum::from("Variant3")""" + expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown) + + val modelV2 = """ + namespace test + + @enum([ + { name: "Variant1", value: "Variant1" }, + { name: "Variant2", value: "Variant2" }, + { name: "Variant3", value: "Variant3" }, + ]) + string SomeEnum + """.asSmithyModel() + val variant3AsVariant3 = "SomeEnum::Variant3" + expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3) + } + + @Test + fun `impl debug for non-sensitive enum should implement the derived debug trait`() { + val model = """ + namespace test + @enum([ + { name: "Foo", value: "Foo" }, + { name: "Bar", value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() + + val shape = model.lookup("test#SomeEnum") + val context = testCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.withModule(RustModule.Model) { + ClientEnumGenerator(context, shape).render(this) + unitTest( + "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait", + """ + assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo"); + assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar"); + assert_eq!( + format!("{:?}", SomeEnum::from("Baz")), + "Unknown(UnknownVariantValue(\"Baz\"))" + ); + """, + ) + } + project.compileAndTest() + } + + @Test + fun `it escapes the Unknown variant if the enum has an unknown value in the model`() { + val model = """ + namespace test + @enum([ + { name: "Known", value: "Known" }, + { name: "Unknown", value: "Unknown" }, + { name: "UnknownValue", value: "UnknownValue" }, + ]) + string SomeEnum + """.asSmithyModel() + + val shape = model.lookup("test#SomeEnum") + val context = testCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.withModule(RustModule.Model) { + ClientEnumGenerator(context, shape).render(this) + unitTest( + "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model", + """ + assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); + assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); + assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); + """.trimIndent(), + ) + } + project.compileAndTest() + } + + @Test + fun `generated named enums can roundtrip between string and enum value on the unknown variant`() { + val model = """ + namespace test + @enum([ + { value: "t2.nano", name: "T2_NANO" }, + { value: "t2.micro", name: "T2_MICRO" }, + ]) + string InstanceType + """.asSmithyModel() + + val shape = model.lookup("test#InstanceType") + val context = testCodegenContext(model) + val project = TestWorkspace.testProject(context.symbolProvider) + project.withModule(RustModule.Model) { + rust("##![allow(deprecated)]") + ClientEnumGenerator(context, shape).render(this) + unitTest( + "generated_named_enums_roundtrip", + """ + let instance = InstanceType::T2Micro; + assert_eq!(instance.as_str(), "t2.micro"); + assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); + // round trip unknown variants: + assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned()))); + assert_eq!(InstanceType::from("other").as_str(), "other"); + """, + ) + } + project.compileAndTest() + } +} diff --git a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt index f507ba2d4c..39b0041aa2 100644 --- a/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt +++ b/codegen-client/src/test/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientInstantiatorTest.kt @@ -12,13 +12,11 @@ import software.amazon.smithy.rust.codegen.client.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup internal class ClientInstantiatorTest { @@ -55,7 +53,7 @@ internal class ClientInstantiatorTest { val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + ClientEnumGenerator(codegenContext, shape).render(this) unitTest("generate_named_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) @@ -74,7 +72,7 @@ internal class ClientInstantiatorTest { val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + ClientEnumGenerator(codegenContext, shape).render(this) unitTest("generate_unnamed_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index ec8f3505e9..5e7dc1d2b7 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -12,8 +12,9 @@ import software.amazon.smithy.model.traits.DocumentationTrait import software.amazon.smithy.model.traits.EnumDefinition import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustModule +import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape import software.amazon.smithy.rust.codegen.core.rustlang.docs import software.amazon.smithy.rust.codegen.core.rustlang.documentShape @@ -21,19 +22,54 @@ import software.amazon.smithy.rust.codegen.core.rustlang.escape import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.MaybeRenamed import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata import software.amazon.smithy.rust.codegen.core.util.REDACTION -import software.amazon.smithy.rust.codegen.core.util.doubleQuote import software.amazon.smithy.rust.codegen.core.util.dq +import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.orNull import software.amazon.smithy.rust.codegen.core.util.shouldRedact +data class EnumGeneratorContext( + val enumName: String, + val enumMeta: RustMetadata, + val enumTrait: EnumTrait, + val sortedMembers: List, +) + +/** + * Type of enum to generate + * + * In codegen-core, there are only `Infallible` enums. Server adds additional enum types, which + * is why this class is abstract rather than sealed. + */ +abstract class EnumType { + /** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the enum */ + abstract fun implFromForStr(context: EnumGeneratorContext): Writable + + /** Returns a writable that implements `FromStr` for the enum */ + abstract fun implFromStr(context: EnumGeneratorContext): Writable + + /** Optionally adds additional documentation to the `enum` docs */ + open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {} + + /** Optionally adds additional enum members */ + open fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable {} + + /** Optionally adds match arms to the `as_str` match implementation for named enums */ + open fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable {} + + /** Optionally add more attributes to the enum */ + open fun additionalEnumAttributes(context: EnumGeneratorContext): List = emptyList() + + /** Optionally add more impls to the enum */ + open fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable {} +} + /** Model that wraps [EnumDefinition] to calculate and cache values required to generate the Rust enum source. */ class EnumMemberModel(private val definition: EnumDefinition, private val symbolProvider: RustSymbolProvider) { // Because enum variants always start with an upper case letter, they will never @@ -88,152 +124,134 @@ private fun RustWriter.docWithNote(doc: String?, note: String?) { open class EnumGenerator( private val model: Model, private val symbolProvider: RustSymbolProvider, - private val writer: RustWriter, - protected val shape: StringShape, - protected val enumTrait: EnumTrait, + private val shape: StringShape, + private val enumType: EnumType, ) { - protected val symbol: Symbol = symbolProvider.toSymbol(shape) - protected val enumName: String = symbol.name - protected val meta = symbol.expectRustMetadata() - protected val sortedMembers: List = - enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(it, symbolProvider) } - protected open var target: CodegenTarget = CodegenTarget.CLIENT - companion object { - /** Name of the generated unknown enum member name for enums with named members. */ - const val UnknownVariant = "Unknown" - - /** Name of the opaque struct that is inner data for the generated [UnknownVariant]. */ - const val UnknownVariantValue = "UnknownVariantValue" - /** Name of the function on the enum impl to get a vec of value names */ const val Values = "values" } - open fun render() { + private val enumTrait: EnumTrait = shape.expectTrait() + private val symbol: Symbol = symbolProvider.toSymbol(shape) + private val context = EnumGeneratorContext( + enumName = symbol.name, + enumMeta = symbol.expectRustMetadata(), + enumTrait = enumTrait, + sortedMembers = enumTrait.values.sortedBy { it.value }.map { EnumMemberModel(it, symbolProvider) }, + ) + + fun render(writer: RustWriter) { + enumType.additionalEnumAttributes(context).forEach { attribute -> + attribute.render(writer) + } if (enumTrait.hasNames()) { - // pub enum Blah { V1, V2, .. } - renderEnum() - writer.insertTrailingNewline() - // impl From for Blah { ... } - renderFromForStr() - // impl FromStr for Blah { ... } - renderFromStr() - writer.insertTrailingNewline() - // impl Blah { pub fn as_str(&self) -> &str - implBlock() - writer.rustBlock("impl AsRef for $enumName") { - rustBlock("fn as_ref(&self) -> &str") { - rust("self.as_str()") - } - } + writer.renderNamedEnum() } else { - renderUnnamedEnum() + writer.renderUnnamedEnum() } + enumType.additionalEnumImpls(context)(writer) if (shape.shouldRedact(model)) { - renderDebugImplForSensitiveEnum() + writer.renderDebugImplForSensitiveEnum() } } - private fun renderUnnamedEnum() { - writer.documentShape(shape, model) - writer.deprecatedShape(shape) - meta.render(writer) - writer.write("struct $enumName(String);") - writer.rustBlock("impl $enumName") { - docs("Returns the `&str` value of the enum member.") - rustBlock("pub fn as_str(&self) -> &str") { - rust("&self.0") - } - - docs("Returns all the `&str` representations of the enum members.") - rustBlock("pub const fn $Values() -> &'static [&'static str]") { - withBlock("&[", "]") { - val memberList = sortedMembers.joinToString(", ") { it.value.dq() } - rust(memberList) + private fun RustWriter.renderNamedEnum() { + // pub enum Blah { V1, V2, .. } + renderEnum() + insertTrailingNewline() + // impl From for Blah { ... } + enumType.implFromForStr(context)(this) + // impl FromStr for Blah { ... } + enumType.implFromStr(context)(this) + insertTrailingNewline() + // impl Blah { pub fn as_str(&self) -> &str + implBlock( + asStrImpl = writable { + rustBlock("match self") { + context.sortedMembers.forEach { member -> + rust("""${context.enumName}::${member.derivedName()} => ${member.value.dq()},""") + } + enumType.additionalAsStrMatchArms(context)(this) + } + }, + ) + rust( + """ + impl AsRef for ${context.enumName} { + fn as_ref(&self) -> &str { + self.as_str() } } - } + """, + ) + } - writer.rustBlock("impl #T for $enumName where T: #T", RuntimeType.From, RuntimeType.AsRef) { - rustBlock("fn from(s: T) -> Self") { - rust("$enumName(s.as_ref().to_owned())") + private fun RustWriter.renderUnnamedEnum() { + documentShape(shape, model) + deprecatedShape(shape) + context.enumMeta.render(this) + rust("struct ${context.enumName}(String);") + implBlock( + asStrImpl = writable { + rust("&self.0") + }, + ) + + rustTemplate( + """ + impl #{From} for ${context.enumName} where T: #{AsRef} { + fn from(s: T) -> Self { + ${context.enumName}(s.as_ref().to_owned()) + } } - } + """, + "From" to RuntimeType.From, + "AsRef" to RuntimeType.AsRef, + ) } - private fun renderEnum() { - target.ifClient { - writer.renderForwardCompatibilityNote(enumName, sortedMembers, UnknownVariant, UnknownVariantValue) - } + private fun RustWriter.renderEnum() { + enumType.additionalDocs(context)(this) val renamedWarning = - sortedMembers.mapNotNull { it.name() }.filter { it.renamedFrom != null }.joinToString("\n") { + context.sortedMembers.mapNotNull { it.name() }.filter { it.renamedFrom != null }.joinToString("\n") { val previousName = it.renamedFrom!! - "`$enumName::$previousName` has been renamed to `::${it.name}`." + "`${context.enumName}::$previousName` has been renamed to `::${it.name}`." } - writer.docWithNote( + docWithNote( shape.getTrait()?.value, renamedWarning.ifBlank { null }, ) - writer.deprecatedShape(shape) + deprecatedShape(shape) - meta.render(writer) - writer.rustBlock("enum $enumName") { - sortedMembers.forEach { member -> member.render(writer) } - target.ifClient { - docs("`$UnknownVariant` contains new variants that have been added since this code was generated.") - rust("$UnknownVariant(#T)", unknownVariantValue()) - } + context.enumMeta.render(this) + rustBlock("enum ${context.enumName}") { + context.sortedMembers.forEach { member -> member.render(this) } + enumType.additionalEnumMembers(context)(this) } } - private fun implBlock() { - writer.rustBlock("impl $enumName") { - rust("/// Returns the `&str` value of the enum member.") - rustBlock("pub fn as_str(&self) -> &str") { - rustBlock("match self") { - sortedMembers.forEach { member -> - rust("""$enumName::${member.derivedName()} => ${member.value.dq()},""") - } - - target.ifClient { - rust("$enumName::$UnknownVariant(value) => value.as_str()") - } + private fun RustWriter.implBlock(asStrImpl: Writable) { + rustTemplate( + """ + impl ${context.enumName} { + /// Returns the `&str` value of the enum member. + pub fn as_str(&self) -> &str { + #{asStrImpl:W} } - } - - rust("/// Returns all the `&str` values of the enum members.") - rustBlock("pub const fn $Values() -> &'static [&'static str]") { - withBlock("&[", "]") { - val memberList = sortedMembers.joinToString(", ") { it.value.doubleQuote() } - write(memberList) + /// Returns all the `&str` representations of the enum members. + pub const fn $Values() -> &'static [&'static str] { + &[#{Values:W}] } } - } - } - - private fun unknownVariantValue(): RuntimeType { - return RuntimeType.forInlineFun(UnknownVariantValue, RustModule.Types) { - docs( - """ - Opaque struct used as inner data for the `Unknown` variant defined in enums in - the crate - - While this is not intended to be used directly, it is marked as `pub` because it is - part of the enums that are public interface. - """.trimIndent(), - ) - meta.render(this) - rust("struct $UnknownVariantValue(pub(crate) String);") - rustBlock("impl $UnknownVariantValue") { - // The generated as_str is not pub as we need to prevent users from calling it on this opaque struct. - rustBlock("pub(crate) fn as_str(&self) -> &str") { - rust("&self.0") - } - } - } + """, + "asStrImpl" to asStrImpl, + "Values" to writable { + rust(context.sortedMembers.joinToString(", ") { it.value.dq() }) + }, + ) } /** @@ -241,10 +259,10 @@ open class EnumGenerator( * * It prints the redacted text regardless of the variant it is asked to print. */ - private fun renderDebugImplForSensitiveEnum() { - writer.rustTemplate( + private fun RustWriter.renderDebugImplForSensitiveEnum() { + rustTemplate( """ - impl #{Debug} for $enumName { + impl #{Debug} for ${context.enumName} { fn fmt(&self, f: &mut #{StdFmt}::Formatter<'_>) -> #{StdFmt}::Result { write!(f, $REDACTION) } @@ -254,89 +272,4 @@ open class EnumGenerator( "StdFmt" to RuntimeType.stdFmt, ) } - - protected open fun renderFromForStr() { - writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.From) { - rustBlock("fn from(s: &str) -> Self") { - rustBlock("match s") { - sortedMembers.forEach { member -> - rust("""${member.value.dq()} => $enumName::${member.derivedName()},""") - } - rust("other => $enumName::$UnknownVariant(#T(other.to_owned()))", unknownVariantValue()) - } - } - } - } - - open fun renderFromStr() { - writer.rust( - """ - impl std::str::FromStr for $enumName { - type Err = std::convert::Infallible; - - fn from_str(s: &str) -> std::result::Result { - Ok($enumName::from(s)) - } - } - """, - ) - } -} - -/** - * Generate the rustdoc describing how to write a match expression against a generated enum in a - * forward-compatible way. - */ -private fun RustWriter.renderForwardCompatibilityNote( - enumName: String, sortedMembers: List, - unknownVariant: String, unknownVariantValue: String, -) { - docs( - """ - When writing a match expression against `$enumName`, it is important to ensure - your code is forward-compatible. That is, if a match arm handles a case for a - feature that is supported by the service but has not been represented as an enum - variant in a current version of SDK, your code should continue to work when you - upgrade SDK to a future version in which the enum does include a variant for that - feature. - """.trimIndent(), - ) - docs("") - docs("Here is an example of how you can make a match expression forward-compatible:") - docs("") - docs("```text") - rust("/// ## let ${enumName.lowercase()} = unimplemented!();") - rust("/// match ${enumName.lowercase()} {") - sortedMembers.mapNotNull { it.name() }.forEach { member -> - rust("/// $enumName::${member.name} => { /* ... */ },") - } - rust("""/// other @ _ if other.as_str() == "NewFeature" => { /* handles a case for `NewFeature` */ },""") - rust("/// _ => { /* ... */ },") - rust("/// }") - docs("```") - docs( - """ - The above code demonstrates that when `${enumName.lowercase()}` represents - `NewFeature`, the execution path will lead to the second last match arm, - even though the enum does not contain a variant `$enumName::NewFeature` - in the current version of SDK. The reason is that the variable `other`, - created by the `@` operator, is bound to - `$enumName::$unknownVariant($unknownVariantValue("NewFeature".to_owned()))` - and calling `as_str` on it yields `"NewFeature"`. - This match expression is forward-compatible when executed with a newer - version of SDK where the variant `$enumName::NewFeature` is defined. - Specifically, when `${enumName.lowercase()}` represents `NewFeature`, - the execution path will hit the second last match arm as before by virtue of - calling `as_str` on `$enumName::NewFeature` also yielding `"NewFeature"`. - """.trimIndent(), - ) - docs("") - docs( - """ - Explicitly matching on the `$unknownVariant` variant should - be avoided for two reasons: - - The inner data `$unknownVariantValue` is opaque, and no further information can be extracted. - - It might inadvertently shadow other intended match arms. - """.trimIndent(), - ) } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index a36fbf0b08..e2fb34b461 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain +import io.kotest.matchers.string.shouldNotContain import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import software.amazon.smithy.model.Model @@ -14,7 +15,10 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest @@ -89,6 +93,15 @@ class EnumGeneratorTest { @Nested inner class EnumGeneratorTests { + fun RustWriter.renderEnum( + model: Model, + provider: RustSymbolProvider, + shape: StringShape, + enumType: EnumType = TestEnumType, + ) { + EnumGenerator(model, provider, shape, enumType).render(this) + } + @Test fun `it generates named enums`() { val model = """ @@ -113,22 +126,17 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#InstanceType") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) unitTest( "it_generates_named_enums", """ let instance = InstanceType::T2Micro; assert_eq!(instance.as_str(), "t2.micro"); assert_eq!(InstanceType::from("t2.nano"), InstanceType::T2Nano); - assert_eq!(InstanceType::from("other"), InstanceType::Unknown(crate::types::UnknownVariantValue("other".to_owned()))); - // round trip unknown variants: - assert_eq!(InstanceType::from("other").as_str(), "other"); """, ) val output = toString() @@ -158,12 +166,10 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#FooEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) unitTest( "named_enums_implement_eq_and_hash", """ @@ -193,13 +199,11 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#FooEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) unitTest( "unnamed_enums_implement_eq_and_hash", """ @@ -238,13 +242,11 @@ class EnumGeneratorTest { """.asSmithyModel() val shape = model.lookup("test#FooEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { rust("##![allow(deprecated)]") - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) unitTest( "it_generates_unnamed_enums", """ @@ -257,304 +259,256 @@ class EnumGeneratorTest { } @Test - fun `it escapes the Unknown variant if the enum has an unknown value in the model`() { + fun `it should generate documentation for enums`() { val model = """ namespace test + + /// Some top-level documentation. @enum([ { name: "Known", value: "Known" }, { name: "Unknown", value: "Unknown" }, - { name: "UnknownValue", value: "UnknownValue" }, ]) string SomeEnum """.asSmithyModel() val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "it_escapes_the_unknown_variant_if_the_enum_has_an_unknown_value_in_the_model", + renderEnum(model, provider, shape) + val rendered = toString() + rendered shouldContain """ - assert_eq!(SomeEnum::from("Unknown"), SomeEnum::UnknownValue); - assert_eq!(SomeEnum::from("UnknownValue"), SomeEnum::UnknownValue_); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); - """.trimIndent(), - ) + /// Some top-level documentation. + /// + /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ + """.trimIndent() } project.compileAndTest() } @Test - fun `it should generate documentation for enums`() { + fun `it should generate documentation for unnamed enums`() { val model = """ namespace test /// Some top-level documentation. @enum([ - { name: "Known", value: "Known" }, - { name: "Unknown", value: "Unknown" }, + { value: "One" }, + { value: "Two" }, ]) string SomeEnum """.asSmithyModel() val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) val rendered = toString() rendered shouldContain """ /// Some top-level documentation. - /// - /// _Note: `SomeEnum::Unknown` has been renamed to `::UnknownValue`._ """.trimIndent() } project.compileAndTest() } @Test - fun `it should generate documentation for unnamed enums`() { + fun `it handles variants that clash with Rust reserved words`() { val model = """ namespace test + @enum([ + { name: "Known", value: "Known" }, + { name: "Self", value: "other" }, + ]) + string SomeEnum + """.asSmithyModel() - /// Some top-level documentation. + val shape = model.lookup("test#SomeEnum") + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + renderEnum(model, provider, shape) + unitTest( + "it_handles_variants_that_clash_with_rust_reserved_words", + """assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue);""", + ) + } + project.compileAndTest() + } + + @Test + fun `impl debug for non-sensitive enum should implement the derived debug trait`() { + val model = """ + namespace test @enum([ - { value: "One" }, - { value: "Two" }, + { name: "Foo", value: "Foo" }, + { name: "Bar", value: "Bar" }, ]) string SomeEnum """.asSmithyModel() val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - val rendered = toString() - rendered shouldContain + renderEnum(model, provider, shape) + unitTest( + "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait", """ - /// Some top-level documentation. - """.trimIndent() + assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo"); + assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar"); + """, + ) } project.compileAndTest() } - } - @Test - fun `it handles variants that clash with Rust reserved words`() { - val model = """ - namespace test - @enum([ - { name: "Known", value: "Known" }, - { name: "Self", value: "other" }, - ]) - string SomeEnum - """.asSmithyModel() + @Test + fun `impl debug for sensitive enum should redact text`() { + val model = """ + namespace test + @sensitive + @enum([ + { name: "Foo", value: "Foo" }, + { name: "Bar", value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "it_handles_variants_that_clash_with_rust_reserved_words", - """ - assert_eq!(SomeEnum::from("other"), SomeEnum::SelfValue); - assert_eq!(SomeEnum::from("SomethingNew"), SomeEnum::Unknown(crate::types::UnknownVariantValue("SomethingNew".to_owned()))); - """.trimIndent(), - ) + val shape = model.lookup("test#SomeEnum") + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + renderEnum(model, provider, shape) + unitTest( + "impl_debug_for_sensitive_enum_should_redact_text", + """ + assert_eq!(format!("{:?}", SomeEnum::Foo), $REDACTION); + assert_eq!(format!("{:?}", SomeEnum::Bar), $REDACTION); + """, + ) + } + project.compileAndTest() } - project.compileAndTest() - } - @Test - fun `matching on enum should be forward-compatible`() { - fun expectMatchExpressionCompiles(model: Model, shapeId: String, enumToMatchOn: String) { - val shape = model.lookup(shapeId) - val trait = shape.expectTrait() + @Test + fun `impl debug for non-sensitive unnamed enum should implement the derived debug trait`() { + val model = """ + namespace test + @enum([ + { value: "Foo" }, + { value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() + + val shape = model.lookup("test#SomeEnum") val provider = testSymbolProvider(model) val project = TestWorkspace.testProject(provider) project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() + renderEnum(model, provider, shape) unitTest( - "matching_on_enum_should_be_forward_compatible", + "impl_debug_for_non_sensitive_unnamed_enum_should_implement_the_derived_debug_trait", """ - match $enumToMatchOn { - SomeEnum::Variant1 => assert!(false, "expected `Variant3` but got `Variant1`"), - SomeEnum::Variant2 => assert!(false, "expected `Variant3` but got `Variant2`"), - other @ _ if other.as_str() == "Variant3" => assert!(true), - _ => assert!(false, "expected `Variant3` but got `_`"), + for variant in SomeEnum::values() { + assert_eq!( + format!("{:?}", SomeEnum(variant.to_string())), + format!("SomeEnum(\"{}\")", variant.to_owned()) + ); } - """.trimIndent(), + """, ) } project.compileAndTest() } - val modelV1 = """ - namespace test - - @enum([ - { name: "Variant1", value: "Variant1" }, - { name: "Variant2", value: "Variant2" }, - ]) - string SomeEnum - """.asSmithyModel() - val variant3AsUnknown = """SomeEnum::from("Variant3")""" - expectMatchExpressionCompiles(modelV1, "test#SomeEnum", variant3AsUnknown) - - val modelV2 = """ - namespace test - - @enum([ - { name: "Variant1", value: "Variant1" }, - { name: "Variant2", value: "Variant2" }, - { name: "Variant3", value: "Variant3" }, - ]) - string SomeEnum - """.asSmithyModel() - val variant3AsVariant3 = "SomeEnum::Variant3" - expectMatchExpressionCompiles(modelV2, "test#SomeEnum", variant3AsVariant3) - } - - @Test - fun `impl debug for non-sensitive enum should implement the derived debug trait`() { - val model = """ - namespace test - @enum([ - { name: "Foo", value: "Foo" }, - { name: "Bar", value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() + @Test + fun `impl debug for sensitive unnamed enum should redact text`() { + val model = """ + namespace test + @sensitive + @enum([ + { value: "Foo" }, + { value: "Bar" }, + ]) + string SomeEnum + """.asSmithyModel() - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_non_sensitive_enum_should_implement_the_derived_debug_trait", - """ - assert_eq!(format!("{:?}", SomeEnum::Foo), "Foo"); - assert_eq!(format!("{:?}", SomeEnum::Bar), "Bar"); - assert_eq!( - format!("{:?}", SomeEnum::from("Baz")), - "Unknown(UnknownVariantValue(\"Baz\"))" - ); - """, - ) + val shape = model.lookup("test#SomeEnum") + val provider = testSymbolProvider(model) + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + renderEnum(model, provider, shape) + unitTest( + "impl_debug_for_sensitive_unnamed_enum_should_redact_text", + """ + for variant in SomeEnum::values() { + assert_eq!( + format!("{:?}", SomeEnum(variant.to_string())), + $REDACTION + ); + } + """, + ) + } + project.compileAndTest() } - project.compileAndTest() - } - @Test - fun `impl debug for sensitive enum should redact text`() { - val model = """ - namespace test - @sensitive - @enum([ - { name: "Foo", value: "Foo" }, - { name: "Bar", value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() - - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_sensitive_enum_should_redact_text", - """ - assert_eq!(format!("{:?}", SomeEnum::Foo), $REDACTION); - assert_eq!(format!("{:?}", SomeEnum::Bar), $REDACTION); - """, - ) - } - project.compileAndTest() - } + @Test + fun `it supports other enum types`() { + class CustomizingEnumType : EnumType() { + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + // intentional no-op + } - @Test - fun `impl debug for non-sensitive unnamed enum should implement the derived debug trait`() { - val model = """ - namespace test - @enum([ - { value: "Foo" }, - { value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + // intentional no-op + } - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_non_sensitive_unnamed_enum_should_implement_the_derived_debug_trait", - """ - for variant in SomeEnum::values() { - assert_eq!( - format!("{:?}", SomeEnum(variant.to_string())), - format!("SomeEnum(\"{}\")", variant.to_owned()) - ); + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { + rust("// additional enum members") } - """, - ) - } - project.compileAndTest() - } - @Test - fun `impl debug for sensitive unnamed enum should redact text`() { - val model = """ - namespace test - @sensitive - @enum([ - { value: "Foo" }, - { value: "Bar" }, - ]) - string SomeEnum - """.asSmithyModel() + override fun additionalAsStrMatchArms(context: EnumGeneratorContext): Writable = writable { + rust("// additional as_str match arm") + } - val shape = model.lookup("test#SomeEnum") - val trait = shape.expectTrait() - val provider = testSymbolProvider(model) - val project = TestWorkspace.testProject(provider) - project.withModule(RustModule.Model) { - val generator = EnumGenerator(model, provider, this, shape, trait) - generator.render() - unitTest( - "impl_debug_for_sensitive_unnamed_enum_should_redact_text", - """ - for variant in SomeEnum::values() { - assert_eq!( - format!("{:?}", SomeEnum(variant.to_string())), - $REDACTION - ); + override fun additionalDocs(context: EnumGeneratorContext): Writable = writable { + rust("// additional docs") } - """, - ) + } + + val model = """ + namespace test + @enum([ + { name: "Known", value: "Known" }, + { name: "Self", value: "other" }, + ]) + string SomeEnum + """.asSmithyModel() + val shape = model.lookup("test#SomeEnum") + + val provider = testSymbolProvider(model) + val output = RustWriter.root().apply { + renderEnum(model, provider, shape, CustomizingEnumType()) + }.toString() + + // Since we didn't use the Infallible EnumType, there should be no Unknown variant + output shouldNotContain "Unknown" + output shouldNotContain "unknown" + output shouldNotContain "impl From" + output shouldNotContain "impl FromStr" + output shouldContain "// additional enum members" + output shouldContain "// additional as_str match arm" + output shouldContain "// additional docs" + + val project = TestWorkspace.testProject(provider) + project.withModule(RustModule.Model) { + renderEnum(model, provider, shape, CustomizingEnumType()) + } + project.compileAndTest() } - project.compileAndTest() } } diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt new file mode 100644 index 0000000000..e8ea12c0db --- /dev/null +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/TestEnumType.kt @@ -0,0 +1,49 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.rust.codegen.core.smithy.generators + +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.util.dq + +object TestEnumType : EnumType() { + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + rustTemplate( + """ + impl #{From}<&str> for ${context.enumName} { + fn from(s: &str) -> Self { + match s { + #{matchArms} + } + } + } + """, + "From" to RuntimeType.From, + "matchArms" to writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => ${context.enumName}::${member.derivedName()},") + } + rust("_ => panic!()") + }, + ) + } + + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + rust( + """ + impl std::str::FromStr for ${context.enumName} { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + Ok(${context.enumName}::from(s)) + } + } + """, + ) + } +} diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt index 9bd71a7a4e..36bb07b9a4 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/JsonParserGeneratorTest.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver @@ -26,7 +27,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -199,7 +199,7 @@ class JsonParserGeneratorTest { model.lookup("test#EmptyStruct").renderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) } project.withModule(RustModule.public("output")) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt index 50fb343d6d..d9ce8c5828 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/XmlBindingTraitParserGeneratorTest.kt @@ -12,6 +12,7 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbolFn import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer @@ -24,7 +25,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -200,7 +200,7 @@ internal class XmlBindingTraitParserGeneratorTest { model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) } project.withModule(RustModule.public("output")) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt index f8ef938bea..fe8160504a 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/AwsQuerySerializerGeneratorTest.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer @@ -22,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -137,7 +137,7 @@ class AwsQuerySerializerGeneratorTest { model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice"), renderUnknownVariant = generateUnknownVariant).render() val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) } project.withModule(RustModule.public("input")) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt index b3a21898ee..67ebaff51f 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/Ec2QuerySerializerGeneratorTest.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer @@ -21,7 +22,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -130,7 +130,7 @@ class Ec2QuerySerializerGeneratorTest { model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) } project.withModule(RustModule.public("input")) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt index bf3fb604da..5aea8c4c8a 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/JsonSerializerGeneratorTest.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes @@ -24,7 +25,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -146,7 +146,7 @@ class JsonSerializerGeneratorTest { model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) } project.withModule(RustModule.public("input")) { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt index f8f9aafa7b..b2040b8527 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/XmlBindingTraitSerializerGeneratorTest.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.TestEnumType import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpTraitHttpBindingResolver import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolContentTypes @@ -23,7 +24,6 @@ import software.amazon.smithy.rust.codegen.core.testutil.renderWithModelBuilder import software.amazon.smithy.rust.codegen.core.testutil.testCodegenContext import software.amazon.smithy.rust.codegen.core.testutil.testSymbolProvider import software.amazon.smithy.rust.codegen.core.testutil.unitTest -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.lookup @@ -150,7 +150,7 @@ internal class XmlBindingTraitSerializerGeneratorTest { model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, this) UnionGenerator(model, symbolProvider, this, model.lookup("test#Choice")).render() val enum = model.lookup("test#FooEnum") - EnumGenerator(model, symbolProvider, this, enum, enum.expectTrait()).render() + EnumGenerator(model, symbolProvider, enum, TestEnumType).render(this) } project.withModule(RustModule.public("input")) { diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt index d689298ca1..27442578af 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCodegenVisitor.kt @@ -17,7 +17,6 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.rust.codegen.core.rustlang.RustModule -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig @@ -131,8 +130,8 @@ class PythonServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = - PythonServerEnumGenerator(codegenContext, writer, shape, validationExceptionConversionGenerator) + fun pythonServerEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = + PythonServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::pythonServerEnumGeneratorFactory) } diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt index 55a4a083a1..ac12cc0df3 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerEnumGenerator.kt @@ -7,16 +7,17 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.rust.codegen.core.rustlang.Attribute -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext import software.amazon.smithy.rust.codegen.core.util.dq import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext -import software.amazon.smithy.rust.codegen.server.smithy.generators.ServerEnumGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConstrainedEnum import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator /** @@ -24,30 +25,21 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationEx * This class generates enums definitions, implements the `PyClass` trait and adds * some utility functions like `__str__()` and `__repr__()`. */ -class PythonServerEnumGenerator( +class PythonConstrainedEnum( codegenContext: ServerCodegenContext, - private val writer: RustWriter, shape: StringShape, validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, -) : ServerEnumGenerator(codegenContext, writer, shape, validationExceptionConversionGenerator) { - +) : ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator) { private val pyO3 = PythonServerCargoDependency.PyO3.toType() - override fun render() { - renderPyClass() - super.render() - renderPyO3Methods() - } - - private fun renderPyClass() { - Attribute(pyO3.resolve("pyclass")).render(writer) - } + override fun additionalEnumAttributes(context: EnumGeneratorContext): List = + listOf(Attribute(pyO3.resolve("pyclass"))) - private fun renderPyO3Methods() { - Attribute(pyO3.resolve("pymethods")).render(writer) - writer.rustTemplate( + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { + Attribute(pyO3.resolve("pymethods")).render(this) + rustTemplate( """ - impl $enumName { + impl ${context.enumName} { #{name_method:W} ##[getter] pub fn value(&self) -> &str { @@ -61,11 +53,11 @@ class PythonServerEnumGenerator( } } """, - "name_method" to renderPyEnumName(), + "name_method" to pyEnumName(context), ) } - private fun renderPyEnumName(): Writable = + private fun pyEnumName(context: EnumGeneratorContext): Writable = writable { rustBlock( """ @@ -74,11 +66,22 @@ class PythonServerEnumGenerator( """, ) { rustBlock("match self") { - sortedMembers.forEach { member -> + context.sortedMembers.forEach { member -> val memberName = member.name()?.name - rust("""$enumName::$memberName => ${memberName?.dq()},""") + rust("""${context.enumName}::$memberName => ${memberName?.dq()},""") } } } } } + +class PythonServerEnumGenerator( + codegenContext: ServerCodegenContext, + shape: StringShape, + validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) : EnumGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + PythonConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), +) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index def3e4a611..a5a13820e9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -39,6 +39,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.ModelsModule import software.amazon.smithy.rust.codegen.core.smithy.RustCrate import software.amazon.smithy.rust.codegen.core.smithy.SymbolVisitorConfig import software.amazon.smithy.rust.codegen.core.smithy.UnconstrainedModule +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.error.eventStreamErrorSymbol @@ -404,8 +405,8 @@ open class ServerCodegenVisitor( * Although raw strings require no code generation, enums are actually [EnumTrait] applied to string shapes. */ override fun stringShape(shape: StringShape) { - fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) = - ServerEnumGenerator(codegenContext, writer, shape, validationExceptionConversionGenerator) + fun serverEnumGeneratorFactory(codegenContext: ServerCodegenContext, shape: StringShape) = + ServerEnumGenerator(codegenContext, shape, validationExceptionConversionGenerator) stringShape(shape, ::serverEnumGeneratorFactory) } @@ -424,12 +425,12 @@ open class ServerCodegenVisitor( protected fun stringShape( shape: StringShape, - enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, writer: RustWriter, shape: StringShape) -> ServerEnumGenerator, + enumShapeGeneratorFactory: (codegenContext: ServerCodegenContext, shape: StringShape) -> EnumGenerator, ) { if (shape.hasTrait()) { logger.info("[rust-server-codegen] Generating an enum $shape") rustCrate.useShapeWriter(shape) { - enumShapeGeneratorFactory(codegenContext, this, shape).render() + enumShapeGeneratorFactory(codegenContext, shape).render(this) ConstrainedTraitForEnumGenerator(model, codegenContext.symbolProvider, this, shape).render() } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 82c040f66b..88ca5e4fef 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -5,28 +5,26 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.StringShape -import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGeneratorContext +import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumType import software.amazon.smithy.rust.codegen.core.smithy.module import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.server.smithy.PubCrateConstraintViolationSymbolProvider import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.traits.isReachableFromOperationInput -open class ServerEnumGenerator( - val codegenContext: ServerCodegenContext, - private val writer: RustWriter, - shape: StringShape, +open class ConstrainedEnum( + codegenContext: ServerCodegenContext, + private val shape: StringShape, private val validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, -) : EnumGenerator(codegenContext.model, codegenContext.symbolProvider, writer, shape, shape.expectTrait()) { - override var target: CodegenTarget = CodegenTarget.SERVER - +) : EnumType() { private val publicConstrainedTypes = codegenContext.settings.codegenConfig.publicConstrainedTypes private val constraintViolationSymbolProvider = with(codegenContext.constraintViolationSymbolProvider) { @@ -42,8 +40,8 @@ open class ServerEnumGenerator( "String" to RuntimeType.String, ) - override fun renderFromForStr() { - writer.withInlineModule(constraintViolationSymbol.module()) { + override fun implFromForStr(context: EnumGeneratorContext): Writable = writable { + withInlineModule(constraintViolationSymbol.module()) { rustTemplate( """ ##[derive(Debug, PartialEq)] @@ -59,25 +57,27 @@ open class ServerEnumGenerator( #{EnumShapeConstraintViolationImplBlock:W} } """, - "EnumShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(enumTrait), + "EnumShapeConstraintViolationImplBlock" to validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( + context.enumTrait, + ), ) } } - writer.rustBlock("impl #T<&str> for $enumName", RuntimeType.TryFrom) { + rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) { rust("type Error = #T;", constraintViolationSymbol) rustBlock("fn try_from(s: &str) -> Result>::Error>", RuntimeType.TryFrom) { rustBlock("match s") { - sortedMembers.forEach { member -> - rust("${member.value.dq()} => Ok($enumName::${member.derivedName()}),") + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") } rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } } - writer.rustTemplate( + rustTemplate( """ - impl #{TryFrom}<#{String}> for $enumName { - type Error = #{UnknownVariantSymbol}; + impl #{TryFrom}<#{String}> for ${context.enumName} { + type Error = #{ConstraintViolation}; fn try_from(s: #{String}) -> std::result::Result>::Error> { s.as_str().try_into() } @@ -85,21 +85,32 @@ open class ServerEnumGenerator( """, "String" to RuntimeType.String, "TryFrom" to RuntimeType.TryFrom, - "UnknownVariantSymbol" to constraintViolationSymbol, + "ConstraintViolation" to constraintViolationSymbol, ) } - override fun renderFromStr() { - writer.rustTemplate( + override fun implFromStr(context: EnumGeneratorContext): Writable = writable { + rustTemplate( """ - impl std::str::FromStr for $enumName { - type Err = #{UnknownVariantSymbol}; + impl std::str::FromStr for ${context.enumName} { + type Err = #{ConstraintViolation}; fn from_str(s: &str) -> std::result::Result::Err> { Self::try_from(s) } } """, - "UnknownVariantSymbol" to constraintViolationSymbol, + "ConstraintViolation" to constraintViolationSymbol, ) } } + +class ServerEnumGenerator( + codegenContext: ServerCodegenContext, + shape: StringShape, + validationExceptionConversionGenerator: ValidationExceptionConversionGenerator, +) : EnumGenerator( + codegenContext.model, + codegenContext.symbolProvider, + shape, + enumType = ConstrainedEnum(codegenContext, shape, validationExceptionConversionGenerator), +) diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt index dc18daf855..ef29d6a6e8 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerBuilderDefaultValuesTest.kt @@ -185,10 +185,9 @@ class ServerBuilderDefaultValuesTest { ServerEnumGenerator( codegenContext, - writer, model.lookup("com.test#Language"), SmithyValidationExceptionConversionGenerator(codegenContext), - ).render() + ).render(writer) StructureGenerator(model, symbolProvider, writer, struct).render() } @@ -204,10 +203,9 @@ class ServerBuilderDefaultValuesTest { ServerEnumGenerator( codegenContext, - writer, model.lookup("com.test#Language"), SmithyValidationExceptionConversionGenerator(codegenContext), - ).render() + ).render(writer) StructureGenerator(model, symbolProvider, writer, struct).render() } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt index abef4485aa..bffb82bb40 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGeneratorTest.kt @@ -43,10 +43,9 @@ class ServerEnumGeneratorTest { fun `it generates TryFrom, FromStr and errors for enums`() { ServerEnumGenerator( codegenContext, - writer, shape, SmithyValidationExceptionConversionGenerator(codegenContext), - ).render() + ).render(writer) writer.compileAndTest( """ use std::str::FromStr; @@ -61,10 +60,9 @@ class ServerEnumGeneratorTest { fun `it generates enums without the unknown variant`() { ServerEnumGenerator( codegenContext, - writer, shape, SmithyValidationExceptionConversionGenerator(codegenContext), - ).render() + ).render(writer) writer.compileAndTest( """ // Check no `Unknown` variant. @@ -81,10 +79,9 @@ class ServerEnumGeneratorTest { fun `it generates enums without non_exhaustive`() { ServerEnumGenerator( codegenContext, - writer, shape, SmithyValidationExceptionConversionGenerator(codegenContext), - ).render() + ).render(writer) writer.toString() shouldNotContain "#[non_exhaustive]" } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt index 96ea000cdb..851a44e00b 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerInstantiatorTest.kt @@ -13,7 +13,6 @@ import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.rust.codegen.core.rustlang.RustModule import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.withBlock -import software.amazon.smithy.rust.codegen.core.smithy.generators.EnumGenerator import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer import software.amazon.smithy.rust.codegen.core.testutil.TestWorkspace @@ -21,8 +20,8 @@ import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.testutil.compileAndTest import software.amazon.smithy.rust.codegen.core.testutil.unitTest import software.amazon.smithy.rust.codegen.core.util.dq -import software.amazon.smithy.rust.codegen.core.util.expectTrait import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.customizations.SmithyValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverRenderWithModelBuilder import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext @@ -191,7 +190,11 @@ class ServerInstantiatorTest { val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(this) unitTest("generate_named_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data) @@ -210,7 +213,11 @@ class ServerInstantiatorTest { val project = TestWorkspace.testProject() project.withModule(RustModule.Model) { - EnumGenerator(model, symbolProvider, this, shape, shape.expectTrait()).render() + ServerEnumGenerator( + codegenContext, + shape, + SmithyValidationExceptionConversionGenerator(codegenContext), + ).render(this) unitTest("generate_unnamed_enums") { withBlock("let result = ", ";") { sut.render(this, shape, data)