Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compliancetests fixes improvements #680

Merged
merged 16 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,11 @@ lazy val complianceTests = projectMatrix
Dependencies.Pprint.core.value
)
},
moduleName := {
if (virtualAxes.value.contains(CatsEffect2Axis))
moduleName.value + "-ce2"
else moduleName.value
},
Test / smithySpecs := Seq(
(ThisBuild / baseDirectory).value / "sampleSpecs" / "test.smithy"
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@ private[internals] object assert {
private def isJson(bodyMediaType: Option[String]) =
bodyMediaType.exists(_.equalsIgnoreCase("application/json"))

private def jsonEql(a: String, b: String): ComplianceResult = {
(parse(a), parse(b)) match {
case (Right(a), Right(b)) if a == b => success
case (Left(a), Left(b)) => fail(s"Both JSONs are invalid: $a, $b")
case (Left(a), _) => fail(s"First JSON is invalid: $a")
case (_, Left(b)) => fail(s"Second JSON is invalid: $b")
case (Right(a), Right(b)) => fail(s"JSONs are not equal: $a, $b")
private def jsonEql(expected: String, actual: String): ComplianceResult = {
(expected.isEmpty, actual.isEmpty) match {
case (true, true) => success
case (true, false) => fail(s"Expected empty body, but got $actual")
case (false, true) => fail(s"Expected $expected, but got empty body")
case (false, false) =>
(parse(expected), parse(actual)) match {
case (Right(a), Right(b)) if a == b => success
case (Left(a), Left(b)) => fail(s"Both JSONs are invalid: $a, $b")
case (Left(a), _) =>
fail(s"Expected JSON is invalid: $expected \n Error $a ")
case (_, Left(b)) =>
fail(s"Actual JSON is invalid: $actual \n Error $b")
case (Right(a), Right(b)) =>
fail(s"JSONs are not equal: expected json: $a \n actual json: $b")
}
}
}

Expand All @@ -51,18 +60,6 @@ private[internals] object assert {
}
}

def bodyEql[A](
expected: A,
actual: A,
bodyMediaType: Option[String]
): ComplianceResult = {
if (isJson(bodyMediaType)) {
jsonEql(expected.toString, actual.toString)
} else {
eql(expected, actual)
}
}

def bodyEql(
expected: String,
actual: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
package smithy4s.compliancetests
package internals

import java.nio.charset.StandardCharsets

import cats.implicits._
import org.http4s.headers.`Content-Type`
import org.http4s.HttpApp
import org.http4s.Request
import org.http4s.Response
import org.http4s.Status
import org.http4s.Uri
import org.typelevel.ci.CIString
import smithy.test._
import smithy4s.compliancetests.ComplianceTest.ComplianceResult
import smithy4s.http.CodecAPI
Expand All @@ -37,7 +34,7 @@ import smithy4s.Service
import scala.concurrent.duration._
import smithy4s.http.HttpMediaType
import org.http4s.MediaType
import org.http4s.Header
import org.http4s.Headers

private[compliancetests] class ClientHttpComplianceTestCase[
F[_],
Expand Down Expand Up @@ -70,23 +67,12 @@ private[compliancetests] class ClientHttpComplianceTestCase[
.withPath(
Uri.Path.unsafeFromString(testCase.uri)
)
.withQueryParams(
testCase.queryParams.combineAll.map {
_.split("=", 2) match {
case Array(k, v) =>
(
k,
Uri.decode(
toDecode = v,
charset = StandardCharsets.UTF_8,
plusIsSpace = true
)
)
}
}.toMap
.withMultiValueQueryParams(
parseQueryParams(testCase.queryParams)
)

val uriAssert = assert.eql(expectedUri, request.uri)
val uriAssert =
assert.eql(expectedUri.renderString, request.uri.renderString)
val methodAssert = assert.eql(
testCase.method.toLowerCase(),
request.method.name.toLowerCase()
Expand All @@ -107,7 +93,7 @@ private[compliancetests] class ClientHttpComplianceTestCase[
): ComplianceTest[F] = {
type R[I_, E_, O_, SE_, SO_] = F[O_]

val revisedSchema = mapAllTimestampsToEpoch(endpoint.input)
val revisedSchema = mapAllTimestampsToEpoch(endpoint.input.clearHints)
val inputFromDocument = Document.Decoder.fromSchema(revisedSchema)
ComplianceTest[F](
name = endpoint.id.toString + "(client|request): " + testCase.id,
Expand Down Expand Up @@ -165,7 +151,7 @@ private[compliancetests] class ClientHttpComplianceTestCase[
ComplianceTest[F](
name = endpoint.id.toString + "(client|response): " + testCase.id,
run = {
val revisedSchema = mapAllTimestampsToEpoch(endpoint.output)
val revisedSchema = mapAllTimestampsToEpoch(endpoint.output.clearHints)
val buildResult: Either[Document => F[Throwable], Document => F[O]] = {
errorSchema
.toLeft {
Expand Down Expand Up @@ -200,21 +186,15 @@ private[compliancetests] class ClientHttpComplianceTestCase[
.through(utf8Encode)
}
.getOrElse(fs2.Stream.empty)
val headers: Seq[Header.ToRaw] =
testCase.headers.toList
.flatMap(_.toList)
.map { case (key, value) =>
Header.Raw(CIString(key), value)
}
.map(Header.ToRaw.rawToRaw)
.toSeq

val headers = Headers(
`Content-Type`(MediaType.unsafeParse(mediaType.value))
) ++ parseHeaders(testCase.headers)

req.body.compile.drain.as(
Response[F](status)
.withBodyStream(body)
.putHeaders(headers: _*)
.putHeaders(
`Content-Type`(MediaType.unsafeParse(mediaType.value))
)
.withHeaders(headers)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
package smithy4s.compliancetests
package internals

import java.nio.charset.StandardCharsets

import cats.implicits._
import org.http4s._
import org.http4s.headers.`Content-Type`
Expand Down Expand Up @@ -52,14 +50,9 @@ private[compliancetests] class ServerHttpComplianceTestCase[
testCase: HttpRequestTestCase
): Request[F] = {
val expectedHeaders =
List(
testCase.headers.map(h =>
Headers(h.toList.map(a => a: Header.ToRaw): _*)
),
testCase.bodyMediaType.map(mt =>
Headers(`Content-Type`(MediaType.unsafeParse(mt)))
)
).foldMap(_.combineAll)
testCase.bodyMediaType
.map(mt => Headers(`Content-Type`(MediaType.unsafeParse(mt))))
.getOrElse(Headers.empty) ++ parseHeaders(testCase.headers)

val expectedMethod = Method
.fromString(testCase.method)
Expand All @@ -69,20 +62,8 @@ private[compliancetests] class ServerHttpComplianceTestCase[
.withPath(
Uri.Path.unsafeFromString(testCase.uri).addEndsWithSlash
)
.withQueryParams(
testCase.queryParams.combineAll.map {
_.split("=", 2) match {
case Array(k, v) =>
(
k,
Uri.decode(
toDecode = v,
charset = StandardCharsets.UTF_8,
plusIsSpace = true
)
)
}
}.toMap
.withMultiValueQueryParams(
parseQueryParams(testCase.queryParams)
)

val body =
Expand All @@ -103,7 +84,7 @@ private[compliancetests] class ServerHttpComplianceTestCase[
testCase: HttpRequestTestCase
): ComplianceTest[F] = {

val revisedSchema = mapAllTimestampsToEpoch(endpoint.input)
val revisedSchema = mapAllTimestampsToEpoch(endpoint.input.clearHints)
val inputFromDocument = Document.Decoder.fromSchema(revisedSchema)
ComplianceTest[F](
name = endpoint.id.toString + "(server|request): " + testCase.id,
Expand Down Expand Up @@ -159,7 +140,7 @@ private[compliancetests] class ServerHttpComplianceTestCase[
errorSchema
.toLeft {
val outputDecoder = Document.Decoder.fromSchema(
mapAllTimestampsToEpoch(endpoint.output)
mapAllTimestampsToEpoch(endpoint.output.clearHints)
)
(doc: Document) =>
outputDecoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,79 @@
* limitations under the License.
*/

package smithy4s.compliancetests
package smithy4s
package compliancetests

import smithy4s.Hints
import smithy4s.schema.Schema
import org.http4s.{Header, Headers, Uri}
import cats.implicits._

import java.nio.charset.StandardCharsets
import scala.collection.immutable.ListMap

package object internals {

// Due to AWS's usage of integer as the canonical representation of a Timestamp in smithy , we need to provide the decoder with instructions to use a Long instead.
// therefore the timestamp type is switched to type epochSeconds: Long
// This is just a workaround thats limited to testing scenarios
def mapAllTimestampsToEpoch[A](schema: Schema[A]): Schema[A] = {
private[compliancetests] def mapAllTimestampsToEpoch[A](
schema: Schema[A]
): Schema[A] = {
schema.transformHintsTransitively(h =>
h.++(Hints(smithy.api.TimestampFormat.EPOCH_SECONDS.widen))
)
}

private[compliancetests] implicit class SchemaOps[A](val schema: Schema[A])
extends AnyVal {
def clearHints: Schema[A] =
schema.transformHintsTransitively(_ => Hints.empty)
}

private def splitQuery(queryString: String): (String, String) = {
queryString.split("=", 2) match {
daddykotex marked this conversation as resolved.
Show resolved Hide resolved
case Array(k, v) =>
(
k,
Uri.decode(
toDecode = v,
charset = StandardCharsets.UTF_8,
plusIsSpace = true
)
)
case Array(k) => (k, "")
}
}

private[compliancetests] def parseQueryParams(
queryParams: Option[List[String]]
): ListMap[String, List[String]] = {
queryParams.combineAll
.map(splitQuery)
.foldLeft[ListMap[String, List[String]]](ListMap.empty) {
case (acc, (k, v)) =>
daddykotex marked this conversation as resolved.
Show resolved Hide resolved
acc.get(k) match {
case Some(value) => acc + (k -> (value :+ v))
case None => acc + (k -> List(v))
}
}
}

private[compliancetests] def parseHeaders(
maybeHeaders: Option[Map[String, String]]
): Headers =
maybeHeaders.fold(Headers.empty)(h =>
Headers(h.toList.flatMap(parseSingleHeader).map(a => a: Header.ToRaw): _*)
)

// regex for comma not between quotes as quotes can be used to escape commas in headers
private val commaNotBetweenQuotes = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"

private def parseSingleHeader(
kv: (String, String)
): List[(String, String)] = {
kv match {
case (k, v) => v.split(commaNotBetweenQuotes).toList.map((k, _))
}

}
}