Skip to content

Commit

Permalink
Enforce constraints for unnamed enums (#3884)
Browse files Browse the repository at this point in the history
### Enforces Constraints for Unnamed Enums

This PR addresses the issue where, on the server side, unnamed enums
were incorrectly treated as infallible during deserialization, allowing
any string value to be converted without validation. The solution
introduces a `ConstraintViolation` and `TryFrom` implementation for
unnamed enums, ensuring that deserialized values conform to the enum
variants defined in the Smithy model.

The following is an example of an unnamed enum:

```smithy
@enum([
    { value: "MONDAY" },
    { value: "TUESDAY" }
])
string UnnamedDayOfWeek
```

On the server side the following type is generated for the Smithy shape:

```rust
pub struct UnnamedDayOfWeek(String);

impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek {
    type Error = crate::model::unnamed_day_of_week::ConstraintViolation;

    fn try_from(
        s: ::std::string::String,
    ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<::std::string::String>>::Error>
    {
        match s.as_str() {
            "MONDAY" | "TUESDAY" => Ok(Self(s)),
            _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)),
        }
    }
}
```

This change prevents invalid values from being deserialized into unnamed
enums and raises appropriate constraint violations when necessary.

There is one difference between the Rust code generated for
`TryFrom<String>` for named enums versus unnamed enums. The
implementation for unnamed enums passes the ownership of the `String`
parameter to the generated structure, and the implementation for
`TryFrom<&str>` delegates to `TryFrom<String>`.

```rust
impl ::std::convert::TryFrom<::std::string::String> for UnnamedDayOfWeek {
    type Error = crate::model::unnamed_day_of_week::ConstraintViolation;
    fn try_from(
        s: ::std::string::String,
    ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<::std::string::String>>::Error>
    {
        match s.as_str() {
            "MONDAY" | "TUESDAY" => Ok(Self(s)),
            _ => Err(crate::model::unnamed_day_of_week::ConstraintViolation(s)),
        }
    }
}

impl ::std::convert::TryFrom<&str> for UnnamedDayOfWeek {
    type Error = crate::model::unnamed_day_of_week::ConstraintViolation;
    fn try_from(
        s: &str,
    ) -> ::std::result::Result<Self, <Self as ::std::convert::TryFrom<&str>>::Error> {
        s.to_owned().try_into()
    }
}
```

On the client side, the behaviour is unchanged, and the client does not
validate for backward compatibility reasons. An [existing
test](https://github.com/smithy-lang/smithy-rs/pull/3884/files#diff-021ec60146cfe231105d21a7389f2dffcd546595964fbb3f0684ebf068325e48R82)
has been modified to ensure this.

```rust
#[test]
fn generate_unnamed_enums() {
    let result = "t2.nano"
        .parse::<crate::types::UnnamedEnum>()
        .expect("static value validated to member");
    assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));
    let result = "not-a-valid-variant"
        .parse::<crate::types::UnnamedEnum>()
        .expect("static value validated to member");
    assert_eq!(result, UnnamedEnum("not-a-valid-variant".to_owned()));
}
```

Fixes issue #3880

---------

Co-authored-by: Fahad Zubair <fahadzub@amazon.com>
  • Loading branch information
drganjoo and Fahad Zubair authored Nov 14, 2024
1 parent 8cf9ebd commit df77d5f
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 62 deletions.
18 changes: 18 additions & 0 deletions .changelog/4329788.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
applies_to: ["server"]
authors: ["drganjoo"]
references: ["smithy-rs#3880"]
breaking: true
new_feature: false
bug_fix: true
---
Unnamed enums now validate assigned values and will raise a `ConstraintViolation` if an unknown variant is set.

The following is an example of an unnamed enum:
```smithy
@enum([
{ value: "MONDAY" },
{ value: "TUESDAY" }
])
string UnnamedDayOfWeek
```
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ data class InfallibleEnumType(
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}

override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
writable {
// `try_parse` isn't needed for unnamed enums
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ internal class ClientInstantiatorTest {
val shape = model.lookup<StringShape>("com.test#UnnamedEnum")
val sut = ClientInstantiator(codegenContext)
val data = Node.parse("t2.nano".dq())
// The client SDK should accept unknown variants as valid.
val notValidVariant = Node.parse("not-a-valid-variant".dq())

val project = TestWorkspace.testProject(symbolProvider)
project.moduleFor(shape) {
Expand All @@ -77,7 +79,11 @@ internal class ClientInstantiatorTest {
withBlock("let result = ", ";") {
sut.render(this, shape, data)
}
rust("""assert_eq!(result, UnnamedEnum("t2.nano".to_owned()));""")
rust("""assert_eq!(result, UnnamedEnum("$data".to_owned()));""")
withBlock("let result = ", ";") {
sut.render(this, shape, notValidVariant)
}
rust("""assert_eq!(result, UnnamedEnum("$notValidVariant".to_owned()));""")
}
}
project.compileAndTest()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ abstract class EnumType {
/** Returns a writable that implements `FromStr` for the enum */
abstract fun implFromStr(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */
abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `FromStr` for the unnamed enum */
abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Optionally adds additional documentation to the `enum` docs */
open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}

Expand Down Expand Up @@ -237,32 +243,10 @@ open class EnumGenerator(
rust("&self.0")
},
)

// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)

rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
// impl From<str> for Blah { ... }
enumType.implFromForStrForUnnamedEnum(context)(this)
// impl FromStr for Blah { ... }
enumType.implFromStrForUnnamedEnum(context)(this)
}

private fun RustWriter.renderEnum() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ class EnumGeneratorTest {
// intentional no-op
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
writable {
rust("// additional enum members")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.dq

object TestEnumType : EnumType() {
Expand Down Expand Up @@ -49,4 +50,35 @@ object TestEnumType : EnumType() {
""",
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;
fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand Down Expand Up @@ -39,16 +38,14 @@ open class ConstrainedEnum(
}
private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
private val constraintViolationName = constraintViolationSymbol.name
private val codegenScope =
arrayOf(
"String" to RuntimeType.String,
)

override fun implFromForStr(context: EnumGeneratorContext): Writable =
writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
private fun generateConstraintViolation(
context: EnumGeneratorContext,
generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit,
) = writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
##[derive(Debug, PartialEq)]
pub struct $constraintViolationName(pub(crate) #{String});
Expand All @@ -60,47 +57,86 @@ open class ConstrainedEnum(
impl #{Error} for $constraintViolationName {}
""",
*codegenScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)
*preludeScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)

if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
impl $constraintViolationName {
#{EnumShapeConstraintViolationImplBlock:W}
}
""",
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) {
rust("type Error = #T;", constraintViolationSymbol)
rustBlockTemplate("fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error>", *preludeScope) {
rustBlock("match s") {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}

generateTryFromStrAndString(context)
}

override fun implFromForStr(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
match s {
#{MatchArms}
_ => Err(#{ConstraintViolation}(s.to_owned()))
}
rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol)
}
}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"MatchArms" to
writable {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}
},
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
s.to_owned().try_into()
}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
match s.as_str() {
#{Values} => Ok(Self(s)),
_ => Err(#{ConstraintViolation}(s))
}
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"Values" to
writable {
rust(context.sortedMembers.joinToString(" | ") { it.value.dq() })
},
)
}

Expand All @@ -118,6 +154,8 @@ open class ConstrainedEnum(
"ConstraintViolation" to constraintViolationSymbol,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context)
}

class ServerEnumGenerator(
Expand Down
Loading

0 comments on commit df77d5f

Please sign in to comment.