Skip to content

Commit

Permalink
./gradlew ktlintFormat
Browse files Browse the repository at this point in the history
  • Loading branch information
david-perez committed Jun 28, 2024
1 parent c31ed44 commit 830885b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ abstract class ProtocolTestGenerator {

/** The entry point to render the protocol tests, invoked by the code generators. */
fun render(writer: RustWriter) {
val allTests = allMatchingTestCases().flatMap {
fixBrokenTestCase(it)
}
val allTests =
allMatchingTestCases().flatMap {
fixBrokenTestCase(it)
}
if (allTests.isEmpty()) {
return
}
Expand All @@ -104,50 +105,55 @@ abstract class ProtocolTestGenerator {
* If the test is broken, we synthesize it in two versions: the original broken test with a `#[should_panic]`
* attribute, so get alerted if the test now passes, and the fixed version, which should pass.
*/
private fun fixBrokenTestCase(it: TestCase): List<TestCase> = if (!it.isBroken()) {
listOf(it)
} else {
assert(it.expectFail())

val brokenTest = it.findInBroken()!!
var fixed = brokenTest.fixIt(it)

val intro = "The hotfix function for broken test case ${it.kind} ${it.id}"
val moreInfo =
"""This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}].
|We are tracking things here: [${brokenTest.trackedIn.joinToString()}].""".trimMargin()

// Something must change...
if (it == fixed) {
PANIC(
"""$intro did not make any modifications. It is likely that the test case was
private fun fixBrokenTestCase(it: TestCase): List<TestCase> =
if (!it.isBroken()) {
listOf(it)
} else {
assert(it.expectFail())

val brokenTest = it.findInBroken()!!
var fixed = brokenTest.fixIt(it)

val intro = "The hotfix function for broken test case ${it.kind} ${it.id}"
val moreInfo =
"""This test case was identified to be broken in at least these Smithy versions: [${brokenTest.inAtLeast.joinToString()}].
|We are tracking things here: [${brokenTest.trackedIn.joinToString()}].
""".trimMargin()

// Something must change...
if (it == fixed) {
PANIC(
"""$intro did not make any modifications. It is likely that the test case was
|fixed upstream, and you're now updating the Smithy version; in this case, remove the hotfix
|function, as the test is no longer broken.
|$moreInfo""".trimMargin(),
)
}
|$moreInfo
""".trimMargin(),
)
}

// ... but the hotfix function is not allowed to change the test case kind...
if (it.kind != fixed.kind) {
PANIC(
"""$intro changed the test case kind. This is not allowed.
|$moreInfo""".trimMargin(),
)
}
// ... but the hotfix function is not allowed to change the test case kind...
if (it.kind != fixed.kind) {
PANIC(
"""$intro changed the test case kind. This is not allowed.
|$moreInfo
""".trimMargin(),
)
}

// ... nor its id.
if (it.id != fixed.id) {
PANIC(
"""$intro changed the test case id. This is not allowed.
|$moreInfo""".trimMargin(),
)
}
// ... nor its id.
if (it.id != fixed.id) {
PANIC(
"""$intro changed the test case id. This is not allowed.
|$moreInfo
""".trimMargin(),
)
}

// The latter is because we're going to generate the fixed version with an identifiable suffix.
fixed = fixed.suffixIdWith("_hotfixed")
// The latter is because we're going to generate the fixed version with an identifiable suffix.
fixed = fixed.suffixIdWith("_hotfixed")

listOf(it, fixed)
}
listOf(it, fixed)
}

/** Implementors should describe how to render the test cases. **/
abstract fun RustWriter.renderAllTestCases(allTests: List<TestCase>)
Expand All @@ -161,23 +167,25 @@ abstract class ProtocolTestGenerator {
this.filter { testCase -> runOnly.contains(testCase.id) }
}

private fun TestCase.toFailingTest(): FailingTest = when (this) {
is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id)
is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id)
is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id)
}
private fun TestCase.toFailingTest(): FailingTest =
when (this) {
is TestCase.MalformedRequestTest -> FailingTest.MalformedRequestTest(serviceShapeId.toString(), this.id)
is TestCase.RequestTest -> FailingTest.RequestTest(serviceShapeId.toString(), this.id)
is TestCase.ResponseTest -> FailingTest.ResponseTest(serviceShapeId.toString(), this.id)
}

/** Do we expect this test case to fail? */
private fun TestCase.expectFail(): Boolean = this.isBroken() || expectFail.contains(this.toFailingTest())

/** Is this test case broken? */
private fun TestCase.isBroken(): Boolean = this.findInBroken() != null

private fun TestCase.findInBroken(): BrokenTest? = brokenTests.find { brokenTest ->
(this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) ||
(this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) ||
(this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id)
}
private fun TestCase.findInBroken(): BrokenTest? =
brokenTests.find { brokenTest ->
(this is TestCase.RequestTest && brokenTest is BrokenTest.RequestTest && this.id == brokenTest.id) ||
(this is TestCase.ResponseTest && brokenTest is BrokenTest.ResponseTest && this.id == brokenTest.id) ||
(this is TestCase.MalformedRequestTest && brokenTest is BrokenTest.MalformedRequestTest && this.id == brokenTest.id)
}

fun requestTestCases(): List<TestCase> {
val requestTests =
Expand Down Expand Up @@ -355,14 +363,13 @@ abstract class ProtocolTestGenerator {
sealed class BrokenTest(
open val serviceShapeId: String,
open val id: String,

/** A non-exhaustive set of Smithy versions where the test was found to be broken. */
open val inAtLeast: Set<String>,
/**
* GitHub URLs related to the test brokenness, like a GitHub issue in Smithy where we reported the test was broken,
* or a PR where we fixed it.
**/
open val trackedIn: Set<String>
open val trackedIn: Set<String>,
) {
data class RequestTest(
override val serviceShapeId: String,
Expand Down Expand Up @@ -411,8 +418,10 @@ object ServiceShapeId {
sealed class FailingTest(open val serviceShapeId: String, open val id: String) {
data class RequestTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)

data class ResponseTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)

data class MalformedRequestTest(override val serviceShapeId: String, override val id: String) :
FailingTest(serviceShapeId, id)
}
Expand Down Expand Up @@ -455,19 +464,22 @@ sealed class TestCase {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is MalformedRequestTest) return false
return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation && this.testCase.request.toNode()
.equals(other.testCase.request.toNode()) && this.testCase.response.toNode()
.equals(other.testCase.response.toNode())
return this.protocol == other.protocol && this.id == other.id && this.documentation == other.documentation &&
this.testCase.request.toNode()
.equals(other.testCase.request.toNode()) &&
this.testCase.response.toNode()
.equals(other.testCase.response.toNode())
}

override fun hashCode(): Int = testCase.hashCode()
}

fun suffixIdWith(suffix: String): TestCase = when (this) {
is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix))
is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix))
is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape)
}
fun suffixIdWith(suffix: String): TestCase =
when (this) {
is RequestTest -> RequestTest(this.testCase.suffixIdWith(suffix))
is MalformedRequestTest -> MalformedRequestTest(this.testCase.suffixIdWith(suffix))
is ResponseTest -> ResponseTest(this.testCase.suffixIdWith(suffix), this.targetShape)
}

private fun HttpRequestTestCase.suffixIdWith(suffix: String): HttpRequestTestCase =
this.toBuilder().id(this.id + suffix).build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,13 @@ class ServerProtocolTestGenerator(
"RestJsonMalformedPatternReDOSString",
howToFixItFn = ::fixRestJsonMalformedPatternReDOSString,
inAtLeast = setOf("1.26.2", "1.49.0"),
trackedIn = setOf(
// TODO(https://github.com/awslabs/smithy/issues/1506)
"https://github.com/awslabs/smithy/issues/1506",
// TODO(https://github.com/smithy-lang/smithy/pull/2340)
"https://github.com/smithy-lang/smithy/pull/2340",
),
trackedIn =
setOf(
// TODO(https://github.com/awslabs/smithy/issues/1506)
"https://github.com/awslabs/smithy/issues/1506",
// TODO(https://github.com/smithy-lang/smithy/pull/2340)
"https://github.com/smithy-lang/smithy/pull/2340",
),
),
)

Expand All @@ -188,7 +189,9 @@ class ServerProtocolTestGenerator(
"S3PreservesEmbeddedDotSegmentInUriLabel",
)

private fun fixRestJsonMalformedPatternReDOSString(testCase: TestCase.MalformedRequestTest): TestCase.MalformedRequestTest {
private fun fixRestJsonMalformedPatternReDOSString(
testCase: TestCase.MalformedRequestTest,
): TestCase.MalformedRequestTest {
val brokenResponse = testCase.testCase.response
val brokenBody = brokenResponse.body.get()
val fixedBody =
Expand Down Expand Up @@ -321,7 +324,7 @@ class ServerProtocolTestGenerator(

if (!protocolSupport.responseSerialization || (
!protocolSupport.errorSerialization && shape.hasTrait<ErrorTrait>()
)
)
) {
rust("/* test case disabled for this protocol (not yet supported) */")
return
Expand Down

0 comments on commit 830885b

Please sign in to comment.