Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compliancetests fixes improvements #680

Merged
merged 16 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,11 @@ lazy val complianceTests = projectMatrix
Dependencies.Pprint.core.value
)
},
moduleName := {
if (virtualAxes.value.contains(CatsEffect2Axis))
moduleName.value + "-ce2"
else moduleName.value
},
Test / smithySpecs := Seq(
(ThisBuild / baseDirectory).value / "sampleSpecs" / "test.smithy"
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@ private[internals] object assert {
private def isJson(bodyMediaType: Option[String]) =
bodyMediaType.exists(_.equalsIgnoreCase("application/json"))

private def jsonEql(a: String, b: String): ComplianceResult = {
(parse(a), parse(b)) match {
case (Right(a), Right(b)) if a == b => success
case (Left(a), Left(b)) => fail(s"Both JSONs are invalid: $a, $b")
case (Left(a), _) => fail(s"First JSON is invalid: $a")
case (_, Left(b)) => fail(s"Second JSON is invalid: $b")
case (Right(a), Right(b)) => fail(s"JSONs are not equal: $a, $b")
private def jsonEql(expected: String, actual: String): ComplianceResult = {
(expected.isEmpty, actual.isEmpty) match {
case (true, true) => success
case (true, false) => fail(s"Expected empty body, but got $actual")
case (false, true) => fail(s"Expected $expected, but got empty body")
case (false, false) =>
(parse(expected), parse(actual)) match {
case (Right(a), Right(b)) if a == b => success
case (Left(a), Left(b)) => fail(s"Both JSONs are invalid: $a, $b")
case (Left(a), _) =>
fail(s"Expected JSON is invalid: $expected \n Error $a ")
case (_, Left(b)) =>
fail(s"Actual JSON is invalid: $actual \n Error $b")
case (Right(a), Right(b)) =>
fail(s"JSONs are not equal: expected json: $a \n actual json: $b")
}
}
}

Expand All @@ -51,18 +60,6 @@ private[internals] object assert {
}
}

def bodyEql[A](
expected: A,
actual: A,
bodyMediaType: Option[String]
): ComplianceResult = {
if (isJson(bodyMediaType)) {
jsonEql(expected.toString, actual.toString)
} else {
eql(expected, actual)
}
}

def bodyEql(
expected: String,
actual: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
package smithy4s.compliancetests
package internals

import java.nio.charset.StandardCharsets

import cats.implicits._
import org.http4s.headers.`Content-Type`
import org.http4s.HttpApp
import org.http4s.Request
import org.http4s.Response
import org.http4s.Status
import org.http4s.Uri
import org.typelevel.ci.CIString
import smithy.test._
import smithy4s.compliancetests.ComplianceTest.ComplianceResult
import smithy4s.http.CodecAPI
Expand All @@ -37,7 +34,7 @@ import smithy4s.Service
import scala.concurrent.duration._
import smithy4s.http.HttpMediaType
import org.http4s.MediaType
import org.http4s.Header
import org.http4s.Headers

private[compliancetests] class ClientHttpComplianceTestCase[
F[_],
Expand Down Expand Up @@ -70,20 +67,8 @@ private[compliancetests] class ClientHttpComplianceTestCase[
.withPath(
Uri.Path.unsafeFromString(testCase.uri)
)
.withQueryParams(
testCase.queryParams.combineAll.map {
_.split("=", 2) match {
case Array(k, v) =>
(
k,
Uri.decode(
toDecode = v,
charset = StandardCharsets.UTF_8,
plusIsSpace = true
)
)
}
}.toMap
.withMultiValueQueryParams(
parseQueryParams(testCase.queryParams)
)

val uriAssert = assert.eql(expectedUri, request.uri)
Expand Down Expand Up @@ -200,21 +185,20 @@ private[compliancetests] class ClientHttpComplianceTestCase[
.through(utf8Encode)
}
.getOrElse(fs2.Stream.empty)
val headers: Seq[Header.ToRaw] =
testCase.headers.toList
.flatMap(_.toList)
.map { case (key, value) =>
Header.Raw(CIString(key), value)
}
.map(Header.ToRaw.rawToRaw)
.toSeq

// val headers = extractHeaders(testCase.headers).toList.flatten

val headers = List(
extractHeaders(testCase.headers),
Some(
Headers(`Content-Type`(MediaType.unsafeParse(mediaType.value)))
)
).foldMap(_.combineAll)

req.body.compile.drain.as(
Response[F](status)
.withBodyStream(body)
.putHeaders(headers: _*)
.putHeaders(
`Content-Type`(MediaType.unsafeParse(mediaType.value))
)
.withHeaders(headers)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
package smithy4s.compliancetests
package internals

import java.nio.charset.StandardCharsets

import cats.implicits._
import org.http4s._
import org.http4s.headers.`Content-Type`
Expand Down Expand Up @@ -53,9 +51,7 @@ private[compliancetests] class ServerHttpComplianceTestCase[
): Request[F] = {
val expectedHeaders =
List(
testCase.headers.map(h =>
Headers(h.toList.map(a => a: Header.ToRaw): _*)
),
extractHeaders(testCase.headers),
testCase.bodyMediaType.map(mt =>
Headers(`Content-Type`(MediaType.unsafeParse(mt)))
)
Expand All @@ -69,20 +65,8 @@ private[compliancetests] class ServerHttpComplianceTestCase[
.withPath(
Uri.Path.unsafeFromString(testCase.uri).addEndsWithSlash
)
.withQueryParams(
testCase.queryParams.combineAll.map {
_.split("=", 2) match {
case Array(k, v) =>
(
k,
Uri.decode(
toDecode = v,
charset = StandardCharsets.UTF_8,
plusIsSpace = true
)
)
}
}.toMap
.withMultiValueQueryParams(
parseQueryParams(testCase.queryParams)
)

val body =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
* limitations under the License.
*/

package smithy4s.compliancetests
package smithy4s
package compliancetests

import smithy4s.Hints
import smithy4s.schema.Schema
import org.http4s.{Header, Headers, Uri}
import cats.implicits._

import java.nio.charset.StandardCharsets
import scala.annotation.tailrec
import scala.collection.immutable.ListMap

package object internals {

Expand All @@ -29,4 +34,62 @@ package object internals {
h.++(Hints(smithy.api.TimestampFormat.EPOCH_SECONDS.widen))
)
}

def splitQuery(queryString: String): (String, String) = {
daddykotex marked this conversation as resolved.
Show resolved Hide resolved
queryString.split("=", 2) match {
daddykotex marked this conversation as resolved.
Show resolved Hide resolved
case Array(k, v) =>
(
k,
Uri.decode(
toDecode = v,
charset = StandardCharsets.UTF_8,
plusIsSpace = true
)
)
case Array(k) => (k, "")
}
}

def parseQueryParams(
queryParams: Option[List[String]]
): ListMap[String, List[String]] = {
queryParams.combineAll
.map(splitQuery)
.foldRight[ListMap[String, List[String]]](ListMap.empty) {
yisraelU marked this conversation as resolved.
Show resolved Hide resolved
case ((k, v), acc) =>
acc.get(k) match {
case Some(value) => acc + (k -> (v :: value))
case None => acc + (k -> List(v))
}
}
}

def extractHeaders(
maybeHeaders: Option[Map[String, String]]
): Option[Headers] =
yisraelU marked this conversation as resolved.
Show resolved Hide resolved
maybeHeaders.map(h =>
Headers(h.toList.flatMap(parseSingleHeader).map(a => a: Header.ToRaw): _*)
)

// regex for comma not between quotes
private val commaNotBetweenQuotes = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"

private def parseSingleHeader(
kv: (String, String)
): List[(String, String)] = {
val key = kv._1
val values = kv._2.split(commaNotBetweenQuotes).toList
@tailrec
def loop(
yisraelU marked this conversation as resolved.
Show resolved Hide resolved
rest: List[String],
acc: List[(String, String)]
): List[(String, String)] = {
rest match {
case ::(head, next) => loop(next, (key, head) :: acc)
case Nil => acc
}
}
loop(values, List.empty)
}

}