Skip to content

Commit

Permalink
Allow unnamed enum implementation to override FromStr and `FromForS…
Browse files Browse the repository at this point in the history
…tr` methods
  • Loading branch information
Fahad Zubair committed Oct 17, 2024
1 parent ef07c88 commit 651538a
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ data class InfallibleEnumType(
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}

override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
writable {
// `try_parse` isn't needed for unnamed enums
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ abstract class EnumType {
/** Returns a writable that implements `FromStr` for the enum */
abstract fun implFromStr(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */
abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `FromStr` for the unnamed enum */
abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Optionally adds additional documentation to the `enum` docs */
open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}

Expand Down Expand Up @@ -237,32 +243,10 @@ open class EnumGenerator(
rust("&self.0")
},
)

// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)

rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
// impl From<str> for Blah { ... }
enumType.implFromForStrForUnnamedEnum(context)(this)
// impl FromStr for Blah { ... }
enumType.implFromStrForUnnamedEnum(context)(this)
}

private fun RustWriter.renderEnum() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ class EnumGeneratorTest {
// intentional no-op
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
writable {
rust("// additional enum members")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand Down Expand Up @@ -39,16 +38,14 @@ open class ConstrainedEnum(
}
private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
private val constraintViolationName = constraintViolationSymbol.name
private val codegenScope =
arrayOf(
"String" to RuntimeType.String,
)

override fun implFromForStr(context: EnumGeneratorContext): Writable =
writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
private fun generateConstraintViolation(
context: EnumGeneratorContext,
generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit,
) = writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
##[derive(Debug, PartialEq)]
pub struct $constraintViolationName(pub(crate) #{String});
Expand All @@ -60,47 +57,86 @@ open class ConstrainedEnum(
impl #{Error} for $constraintViolationName {}
""",
*codegenScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)
*preludeScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)

if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
impl $constraintViolationName {
#{EnumShapeConstraintViolationImplBlock:W}
}
""",
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) {
rust("type Error = #T;", constraintViolationSymbol)
rustBlockTemplate("fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error>", *preludeScope) {
rustBlock("match s") {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}

generateTryFromStrAndString(context)
}

override fun implFromForStr(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
match s {
#{MatchArms}
_ => Err(#{ConstraintViolation}(s.to_owned()))
}
rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol)
}
}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"MatchArms" to
writable {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}
},
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
s.to_owned().try_into()
}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
match s.as_str() {
#{Values} => Ok(Self(s)),
_ => Err(#{ConstraintViolation}(s))
}
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"Values" to
writable {
rust(context.sortedMembers.joinToString(" | ") { it.value.dq() })
},
)
}

Expand All @@ -118,6 +154,8 @@ open class ConstrainedEnum(
"ConstraintViolation" to constraintViolationSymbol,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context)
}

class ServerEnumGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.AbstractTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import java.io.File

Expand Down Expand Up @@ -219,4 +221,88 @@ class ConstraintsTest {
structWithInnerDefault.canReachConstrainedShape(model, symbolProvider) shouldBe false
primitiveBoolean.isDirectlyConstrained(symbolProvider) shouldBe false
}

@Test
fun `unnamed enum should have ConstraintViolation`() {
val model =
"""
namespace test
use aws.protocols#restJson1
use smithy.framework#ValidationException
@restJson1
service SampleService {
operations: [SampleOp]
}
@http(uri: "/dailySummary", method: "POST")
operation SampleOp {
input := {
day: WeeklySummary
}
errors: [ValidationException]
}
structure WeeklySummary {
day: DayOfWeek,
}
@enum([
{ value: "MONDAY" },
{ value: "TUESDAY" }
])
string DayOfWeek
""".asSmithyModel(smithyVersion = "2")

// Simply compiling the crate is sufficient as a test.
serverIntegrationTest(
model,
IntegrationTestParams(
service = "test#SampleService",
),
) { _, _ ->
}
}

@Test
fun `named enum should have ConstraintViolation`() {
val model =
"""
namespace test
use aws.protocols#restJson1
use smithy.framework#ValidationException
@restJson1
service SampleService {
operations: [SampleOp]
}
@http(uri: "/dailySummary", method: "POST")
operation SampleOp {
input := {
day: WeeklySummary
}
errors: [ValidationException]
}
structure WeeklySummary {
day: DayOfWeek,
}
@enum([
{ value: "MONDAY", name: "MONDAY" },
{ value: "TUESDAY", name: "TUESDAY" }
])
string DayOfWeek
""".asSmithyModel(smithyVersion = "2")

// Simply compiling the crate is sufficient as a test.
serverIntegrationTest(
model,
IntegrationTestParams(
service = "test#SampleService",
),
) { _, _ ->
}
}
}

0 comments on commit 651538a

Please sign in to comment.