From 651538a3f256afa3100721f3083d77d742aabdbc Mon Sep 17 00:00:00 2001 From: Fahad Zubair Date: Thu, 17 Oct 2024 11:41:36 +0100 Subject: [PATCH] Allow unnamed enum implementation to override `FromStr` and `FromForStr` methods --- .../smithy/generators/ClientEnumGenerator.kt | 31 ++++++ .../core/smithy/generators/EnumGenerator.kt | 36 ++---- .../smithy/generators/EnumGeneratorTest.kt | 10 ++ .../smithy/generators/ServerEnumGenerator.kt | 104 ++++++++++++------ .../codegen/server/smithy/ConstraintsTest.kt | 86 +++++++++++++++ 5 files changed, 208 insertions(+), 59 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt index 77a5731b62..f45c19d74b 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/ClientEnumGenerator.kt @@ -79,6 +79,37 @@ data class InfallibleEnumType( ) } + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + rustTemplate( + """ + impl #{From} for ${context.enumName} where T: #{AsRef} { + fn from(s: T) -> Self { + ${context.enumName}(s.as_ref().to_owned()) + } + } + """, + *preludeScope, + ) + } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // Add an infallible FromStr implementation for uniformity + rustTemplate( + """ + impl ::std::str::FromStr for ${context.enumName} { + type Err = ::std::convert::Infallible; + + fn from_str(s: &str) -> #{Result}::Err> { + #{Ok}(${context.enumName}::from(s)) + } + } + """, + *preludeScope, + ) + } + override fun additionalEnumImpls(context: EnumGeneratorContext): Writable = writable { // `try_parse` isn't needed for unnamed enums diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt index 1f05ad7aab..d1eac1488e 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGenerator.kt @@ -59,6 +59,12 @@ abstract class EnumType { /** Returns a writable that implements `FromStr` for the enum */ abstract fun implFromStr(context: EnumGeneratorContext): Writable + /** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */ + abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable + + /** Returns a writable that implements `FromStr` for the unnamed enum */ + abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable + /** Optionally adds additional documentation to the `enum` docs */ open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {} @@ -237,32 +243,10 @@ open class EnumGenerator( rust("&self.0") }, ) - - // Add an infallible FromStr implementation for uniformity - rustTemplate( - """ - impl ::std::str::FromStr for ${context.enumName} { - type Err = ::std::convert::Infallible; - - fn from_str(s: &str) -> #{Result}::Err> { - #{Ok}(${context.enumName}::from(s)) - } - } - """, - *preludeScope, - ) - - rustTemplate( - """ - impl #{From} for ${context.enumName} where T: #{AsRef} { - fn from(s: T) -> Self { - ${context.enumName}(s.as_ref().to_owned()) - } - } - - """, - *preludeScope, - ) + // impl From for Blah { ... } + enumType.implFromForStrForUnnamedEnum(context)(this) + // impl FromStr for Blah { ... } + enumType.implFromStrForUnnamedEnum(context)(this) } private fun RustWriter.renderEnum() { diff --git a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt index 0e2b10788f..0528c2e364 100644 --- a/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt +++ b/codegen-core/src/test/kotlin/software/amazon/smithy/rust/codegen/core/smithy/generators/EnumGeneratorTest.kt @@ -494,6 +494,16 @@ class EnumGeneratorTest { // intentional no-op } + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // intentional no-op + } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + writable { + // intentional no-op + } + override fun additionalEnumMembers(context: EnumGeneratorContext): Writable = writable { rust("// additional enum members") diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt index 5bc2218ad1..a3cc269692 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerEnumGenerator.kt @@ -5,10 +5,9 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter import software.amazon.smithy.rust.codegen.core.rustlang.Writable import software.amazon.smithy.rust.codegen.core.rustlang.rust -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock -import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType @@ -39,16 +38,14 @@ open class ConstrainedEnum( } private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape) private val constraintViolationName = constraintViolationSymbol.name - private val codegenScope = - arrayOf( - "String" to RuntimeType.String, - ) - override fun implFromForStr(context: EnumGeneratorContext): Writable = - writable { - withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) { - rustTemplate( - """ + private fun generateConstraintViolation( + context: EnumGeneratorContext, + generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit, + ) = writable { + withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) { + rustTemplate( + """ ##[derive(Debug, PartialEq)] pub struct $constraintViolationName(pub(crate) #{String}); @@ -60,47 +57,86 @@ open class ConstrainedEnum( impl #{Error} for $constraintViolationName {} """, - *codegenScope, - "Error" to RuntimeType.StdError, - "Display" to RuntimeType.Display, - ) + *preludeScope, + "Error" to RuntimeType.StdError, + "Display" to RuntimeType.Display, + ) - if (shape.isReachableFromOperationInput()) { - rustTemplate( - """ + if (shape.isReachableFromOperationInput()) { + rustTemplate( + """ impl $constraintViolationName { #{EnumShapeConstraintViolationImplBlock:W} } """, - "EnumShapeConstraintViolationImplBlock" to - validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( - context.enumTrait, - ), - ) - } + "EnumShapeConstraintViolationImplBlock" to + validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock( + context.enumTrait, + ), + ) } - rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) { - rust("type Error = #T;", constraintViolationSymbol) - rustBlockTemplate("fn try_from(s: &str) -> #{Result}>::Error>", *preludeScope) { - rustBlock("match s") { - context.sortedMembers.forEach { member -> - rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") + } + + generateTryFromStrAndString(context) + } + + override fun implFromForStr(context: EnumGeneratorContext): Writable = + generateConstraintViolation(context) { + rustTemplate( + """ + impl #{TryFrom}<&str> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: &str) -> #{Result}>::Error> { + match s { + #{MatchArms} + _ => Err(#{ConstraintViolation}(s.to_owned())) } - rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol) } } - } + impl #{TryFrom}<#{String}> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: #{String}) -> #{Result}>::Error> { + s.as_str().try_into() + } + } + """, + *preludeScope, + "ConstraintViolation" to constraintViolationSymbol, + "MatchArms" to + writable { + context.sortedMembers.forEach { member -> + rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),") + } + }, + ) + } + + override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable = + generateConstraintViolation(context) { rustTemplate( """ + impl #{TryFrom}<&str> for ${context.enumName} { + type Error = #{ConstraintViolation}; + fn try_from(s: &str) -> #{Result}>::Error> { + s.to_owned().try_into() + } + } impl #{TryFrom}<#{String}> for ${context.enumName} { type Error = #{ConstraintViolation}; fn try_from(s: #{String}) -> #{Result}>::Error> { - s.as_str().try_into() + match s.as_str() { + #{Values} => Ok(Self(s)), + _ => Err(#{ConstraintViolation}(s)) + } } } """, *preludeScope, "ConstraintViolation" to constraintViolationSymbol, + "Values" to + writable { + rust(context.sortedMembers.joinToString(" | ") { it.value.dq() }) + }, ) } @@ -118,6 +154,8 @@ open class ConstrainedEnum( "ConstraintViolation" to constraintViolationSymbol, ) } + + override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context) } class ServerEnumGenerator( diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt index 31123ef811..1fb71f8315 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ConstraintsTest.kt @@ -25,8 +25,10 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.AbstractTrait import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.protocol.traits.Rpcv2CborTrait +import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel import software.amazon.smithy.rust.codegen.core.util.lookup +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider import java.io.File @@ -219,4 +221,88 @@ class ConstraintsTest { structWithInnerDefault.canReachConstrainedShape(model, symbolProvider) shouldBe false primitiveBoolean.isDirectlyConstrained(symbolProvider) shouldBe false } + + @Test + fun `unnamed enum should have ConstraintViolation`() { + val model = + """ + namespace test + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service SampleService { + operations: [SampleOp] + } + + @http(uri: "/dailySummary", method: "POST") + operation SampleOp { + input := { + day: WeeklySummary + } + errors: [ValidationException] + } + + structure WeeklySummary { + day: DayOfWeek, + } + + @enum([ + { value: "MONDAY" }, + { value: "TUESDAY" } + ]) + string DayOfWeek + """.asSmithyModel(smithyVersion = "2") + + // Simply compiling the crate is sufficient as a test. + serverIntegrationTest( + model, + IntegrationTestParams( + service = "test#SampleService", + ), + ) { _, _ -> + } + } + + @Test + fun `named enum should have ConstraintViolation`() { + val model = + """ + namespace test + use aws.protocols#restJson1 + use smithy.framework#ValidationException + + @restJson1 + service SampleService { + operations: [SampleOp] + } + + @http(uri: "/dailySummary", method: "POST") + operation SampleOp { + input := { + day: WeeklySummary + } + errors: [ValidationException] + } + + structure WeeklySummary { + day: DayOfWeek, + } + + @enum([ + { value: "MONDAY", name: "MONDAY" }, + { value: "TUESDAY", name: "TUESDAY" } + ]) + string DayOfWeek + """.asSmithyModel(smithyVersion = "2") + + // Simply compiling the crate is sufficient as a test. + serverIntegrationTest( + model, + IntegrationTestParams( + service = "test#SampleService", + ), + ) { _, _ -> + } + } }