Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve broken protocol test generation #3726

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.BrokenTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.FailingTest
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ProtocolTestGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.ServiceShapeId.AWS_JSON_10
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCase
import software.amazon.smithy.rust.codegen.core.smithy.generators.protocol.TestCaseKind
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
Expand Down Expand Up @@ -70,9 +70,9 @@ class ClientProtocolTestGenerator(
private val ExpectFail =
setOf<FailingTest>(
// Failing because we don't serialize default values if they match the default.
FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", TestCaseKind.Request),
FailingTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", TestCaseKind.Request),
FailingTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput", TestCaseKind.Request),
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse"),
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults"),
FailingTest.RequestTest(AWS_JSON_10, "AwsJson10ClientPopulatesDefaultValuesInInput"),
)
}

Expand All @@ -84,6 +84,8 @@ class ClientProtocolTestGenerator(
get() = emptySet()
override val disabledTests: Set<String>
get() = emptySet()
override val brokenTests: Set<BrokenTest>
get() = emptySet()

override val logger: Logger = Logger.getLogger(javaClass.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.orNull
Expand All @@ -51,9 +52,17 @@ abstract class ProtocolTestGenerator {
/**
* We expect these tests to fail due to shortcomings in our implementation.
* They will _fail_ if they pass, so we will discover and remove them if we fix them by accident.
**/
*/
abstract val expectFail: Set<FailingTest>

/**
* We expect these tests to fail because their definitions are broken.
* We map from a failing test to a "hotfix" function that can mutate the test in-memory and return a fixed version of it.
* The tests will _fail_ if they pass, so we will discover and remove the hotfix if we're updating to a newer
* version of Smithy where the test was fixed upstream.
*/
abstract val brokenTests: Set<BrokenTest>

/** Only generate these tests; useful to temporarily set and shorten development cycles */
abstract val runOnly: Set<String>

Expand All @@ -63,18 +72,23 @@ abstract class ProtocolTestGenerator {
*/
abstract val disabledTests: Set<String>

private val serviceShapeId: ShapeId
get() = codegenContext.serviceShape.id

/** The Rust module in which we should generate the protocol tests for [operationShape]. */
private fun protocolTestsModule(): RustModule.LeafModule {
val operationName = codegenContext.symbolProvider.toSymbol(operationShape).name
val testModuleName = "${operationName.toSnakeCase()}_test"
val additionalAttributes =
listOf(Attribute(allow("unreachable_code", "unused_variables")))
val additionalAttributes = listOf(Attribute(allow("unreachable_code", "unused_variables")))
return RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)
}

/** The entry point to render the protocol tests, invoked by the code generators. */
fun render(writer: RustWriter) {
val allTests = allMatchingTestCases().fixBroken()
val allTests =
allMatchingTestCases().flatMap {
fixBrokenTestCase(it)
}
if (allTests.isEmpty()) {
return
}
Expand All @@ -84,15 +98,65 @@ abstract class ProtocolTestGenerator {
}
}

/** Implementors should describe how to render the test cases. **/
abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)

/**
* This function applies a "fix function" to each broken test before we synthesize it.
* Broken tests are those whose definitions in the `awslabs/smithy` repository are wrong.
* We try to contribute fixes upstream to pare down this function to the identity function.
* This function applies a "hotfix function" to a broken test case before we synthesize it.
* Broken tests are those whose definitions in the `smithy-lang/smithy` repository are wrong.
* We try to contribute fixes upstream to pare down the list of broken tests.
* If the test is broken, we synthesize it in two versions: the original broken test with a `#[should_panic]`
* attribute, so get alerted if the test now passes, and the fixed version, which should pass.
*/
open fun List<TestCase>.fixBroken(): List<TestCase> = this
private fun fixBrokenTestCase(it: TestCase): List<TestCase> =
if (!it.isBroken()) {
listOf(it)
} else {
assert(it.expectFail())

val brokenTest = it.findInBroken()!!
var fixed = brokenTest.fixIt(it)

val intro = "The hotfix function for broken test case ${it.kind} ${it.id}"
val moreInfo =
"""This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}].
|We are tracking things here: [${brokenTest.trackedIn.joinToString()}].
""".trimMargin()

// Something must change...
if (it == fixed) {
PANIC(
"""$intro did not make any modifications. It is likely that the test case was
|fixed upstream, and you're now updating the Smithy version; in this case, remove the hotfix
|function, as the test is no longer broken.
|$moreInfo
""".trimMargin(),
)
}

// ... but the hotfix function is not allowed to change the test case kind...
if (it.kind != fixed.kind) {
PANIC(
"""$intro changed the test case kind. This is not allowed.
|$moreInfo
""".trimMargin(),
)
}

// ... nor its id.
if (it.id != fixed.id) {
PANIC(
"""$intro changed the test case id. This is not allowed.
|$moreInfo
""".trimMargin(),
)
}

// The latter is because we're going to generate the fixed version with an identifiable suffix.
fixed = fixed.suffixIdWith("_hotfixed")

listOf(it, fixed)
}

/** Implementors should describe how to render the test cases. **/
abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)

/** Filter out test cases that are disabled or don't match the service protocol. */
private fun List<TestCase>.filterMatching(): List<TestCase> =
Expand All @@ -103,11 +167,25 @@ abstract class ProtocolTestGenerator {
this.filter { testCase -> runOnly.contains(testCase.id) }
}

/** Do we expect this [testCase] to fail? */
private fun expectFail(testCase: TestCase): Boolean =
expectFail.find {
it.id == testCase.id && it.kind == testCase.kind && it.service == codegenContext.serviceShape.id.toString()
} != null
private fun TestCase.toFailingTest(): FailingTest =
when (this) {
is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id)
is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id)
is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id)
}

/** Do we expect this test case to fail? */
private fun TestCase.expectFail(): Boolean = this.isBroken() || expectFail.contains(this.toFailingTest())

/** Is this test case broken? */
private fun TestCase.isBroken(): Boolean = this.findInBroken() != null

private fun TestCase.findInBroken(): BrokenTest? =
brokenTests.find { brokenTest ->
(this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) ||
(this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) ||
(this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id)
Comment on lines +185 to +187
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can factor out this.id == brokenTest.id since it appears in each operand of the OR operator?

}

fun requestTestCases(): List<TestCase> {
val requestTests =
Expand Down Expand Up @@ -160,6 +238,7 @@ abstract class ProtocolTestGenerator {
block: Writable,
) {
if (testCase.documentation != null) {
testModuleWriter.rust("")
testModuleWriter.docs(testCase.documentation!!, templating = false)
}
testModuleWriter.docs("Test ID: ${testCase.id}")
Expand All @@ -171,7 +250,7 @@ abstract class ProtocolTestGenerator {
Attribute.TokioTest.render(testModuleWriter)
Attribute.TracedTest.render(testModuleWriter)

if (expectFail(testCase)) {
if (testCase.expectFail()) {
shouldPanic().render(testModuleWriter)
}
val fnNameSuffix =
Expand Down Expand Up @@ -281,6 +360,51 @@ abstract class ProtocolTestGenerator {
}
}

sealed class BrokenTest(
open val serviceShapeId: String,
open val id: String,
/** A non-exhaustive set of Smithy versions where the test was found to be broken. */
open val inAtLeast: Set<String>,
/**
* GitHub URLs related to the test brokenness, like a GitHub issue in Smithy where we reported the test was broken,
* or a PR where we fixed it.
**/
open val trackedIn: Set<String>,
) {
data class RequestTest(
override val serviceShapeId: String,
override val id: String,
override val inAtLeast: Set<String>,
override val trackedIn: Set<String>,
val howToFixItFn: (TestCase.RequestTest) -> TestCase.RequestTest,
) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

data class ResponseTest(
override val serviceShapeId: String,
override val id: String,
override val inAtLeast: Set<String>,
override val trackedIn: Set<String>,
val howToFixItFn: (TestCase.ResponseTest) -> TestCase.ResponseTest,
) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

data class MalformedRequestTest(
override val serviceShapeId: String,
override val id: String,
override val inAtLeast: Set<String>,
override val trackedIn: Set<String>,
val howToFixItFn: (TestCase.MalformedRequestTest) -> TestCase.MalformedRequestTest,
) : BrokenTest(serviceShapeId, id, inAtLeast, trackedIn)

fun fixIt(testToFix: TestCase): TestCase {
check(testToFix.id == this.id)
return when (this) {
is MalformedRequestTest -> howToFixItFn(testToFix as TestCase.MalformedRequestTest)
is RequestTest -> howToFixItFn(testToFix as TestCase.RequestTest)
is ResponseTest -> howToFixItFn(testToFix as TestCase.ResponseTest)
}
}
}

/**
* Service shape IDs in common protocol test suites defined upstream.
*/
Expand All @@ -291,7 +415,16 @@ object ServiceShapeId {
const val REST_JSON_VALIDATION = "aws.protocoltests.restjson.validation#RestJsonValidation"
}

data class FailingTest(val service: String, val id: String, val kind: TestCaseKind)
sealed class FailingTest(open val serviceShapeId: String, open val id: String) {
data class RequestTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)

data class ResponseTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)

data class MalformedRequestTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)
}

sealed class TestCaseKind {
data object Request : TestCaseKind()
Expand All @@ -302,11 +435,60 @@ sealed class TestCaseKind {
}

sealed class TestCase {
data class RequestTest(val testCase: HttpRequestTestCase) : TestCase()
/*
* The properties of these data classes don't implement `equals()` usefully in Smithy, so we delegate to `equals()`
* of their `Node` representations.
*/

data class RequestTest(val testCase: HttpRequestTestCase) : TestCase() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is RequestTest) return false
return testCase.toNode().equals(other.testCase.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is ResponseTest) return false
return testCase.toNode().equals(other.testCase.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is MalformedRequestTest) return false
return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation &&
this.testCase.request.toNode()
.equals(other.testCase.request.toNode()) &&
this.testCase.response.toNode()
.equals(other.testCase.response.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

fun suffixIdWith(suffix: String): TestCase =
when (this) {
is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix))
is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix))
is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape)
}

private fun HttpRequestTestCase.suffixIdWith(suffix: String): HttpRequestTestCase =
this.toBuilder().id(this.id + suffix).build()

data class ResponseTest(val testCase: HttpResponseTestCase, val targetShape: StructureShape) : TestCase()
private fun HttpResponseTestCase.suffixIdWith(suffix: String): HttpResponseTestCase =
this.toBuilder().id(this.id + suffix).build()

data class MalformedRequestTest(val testCase: HttpMalformedRequestTestCase) : TestCase()
private fun HttpMalformedRequestTestCase.suffixIdWith(suffix: String): HttpMalformedRequestTestCase =
this.toBuilder().id(this.id + suffix).build()

/*
* `HttpRequestTestCase` and `HttpResponseTestCase` both implement `HttpMessageTestCase`, but
Expand Down
Loading