Skip to content

Commit

Permalink
Merge pull request #179 from seveneves/feature-#172/allow-comparson-j…
Browse files Browse the repository at this point in the history
…son-schema

Support comparison of Schemas generated by `TapirSchemaToJsonSchema`
  • Loading branch information
adamw authored Aug 7, 2024
2 parents fb2d238 + 02d61a1 commit 08eaa08
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,25 @@ import scala.annotation.tailrec
import scala.collection.immutable.ListMap
import scala.collection.mutable

object SchemaComparator {
final val RefPrefix = "#/components/schemas/"
}

/**
* Utility for comparing schemas for compatibility.
* See [[compare]] for more details.
*
* Since this class contains a cache of comparison results,
* it is meant to be reused between multiple schema comparisons.
*
* @param writerNamedSchemas named schemas which may be referred to by the writer schema
* @param readerNamedSchemas named schemas which may be referred to by the reader schema
* @param writerSchemaResolver can resolve named schemas which may be referred to by the writer schema
* @param readerSchemaResolver can resolve named schemas which may be referred to by the reader schema
*/
class SchemaComparator(
writerNamedSchemas: Map[String, Schema],
readerNamedSchemas: Map[String, Schema]
writerSchemaResolver: SchemaResolver,
readerSchemaResolver: SchemaResolver
) {

import SchemaComparator._
def this(
writerNamedSchemas: Map[String, Schema],
readerNamedSchemas: Map[String, Schema]
) = this(SchemaResolver(writerNamedSchemas), SchemaResolver(readerNamedSchemas))

private val issuesCache = new mutable.HashMap[(Schema, Schema), List[SchemaCompatibilityIssue]]
private val identicalityCache = new mutable.HashMap[(Schema, Schema), Boolean]
Expand Down Expand Up @@ -65,54 +64,13 @@ class SchemaComparator(
* @return a list of incompatibilities between the schemas
*/
def compare(writerSchema: SchemaLike, readerSchema: SchemaLike): List[SchemaCompatibilityIssue] = {
val normalizedWriterSchema = normalize(writerSchema, writerNamedSchemas)
val normalizedReaderSchema = normalize(readerSchema, readerNamedSchemas)
val normalizedWriterSchema = writerSchemaResolver.resolveAndNormalize(writerSchema)
val normalizedReaderSchema = readerSchemaResolver.resolveAndNormalize(readerSchema)
computeCached(issuesCache, (normalizedWriterSchema, normalizedReaderSchema), Nil) {
compareNormalized(normalizedWriterSchema, normalizedReaderSchema)
}
}

// translate AnySchema to Schema, remove annotations and resolve references
@tailrec private def normalize(schema: SchemaLike, named: Map[String, Schema]): Schema = schema match {
case AnySchema.Anything => Schema.Empty
case AnySchema.Nothing => Schema.Nothing
case s: Schema => deannotate(s) match {
case s@ReferenceSchema(LocalRef(name)) =>
def noSchema: Nothing =
throw new NoSuchElementException(s"could not resolve schema reference ${s.$ref.get}")

normalize(named.getOrElse(name, noSchema), named)
case s => s
}
}

private object LocalRef {
def unapply(ref: String): Option[String] =
if (ref.startsWith(RefPrefix)) Some(ref.stripPrefix(RefPrefix))
else None
}

/** Matches a schema that is a pure reference to one of the component schemas */
private object ReferenceSchema {
def unapply(schema: Schema): Option[String] =
schema.$ref.filter(ref => schema == Schema($ref = Some(ref)))
}

// strip fields which do not affect schema comparison
private def deannotate(schema: Schema): Schema =
schema.copy(
$comment = None,
title = None,
description = None,
default = None,
deprecated = None,
readOnly = None,
writeOnly = None,
examples = None,
externalDocs = None,
extensions = ListMap.empty
)

private def compareNormalized(writerSchema: Schema, readerSchema: Schema): List[SchemaCompatibilityIssue] =
if (writerSchema == Schema.Nothing || readerSchema == Schema.Empty) {
Nil
Expand Down Expand Up @@ -145,8 +103,8 @@ class SchemaComparator(
propIssues

} else if (isDiscriminatedUnionSchema(writerSchema) && isDiscriminatedUnionSchema(readerSchema)) {
val writerMapping = discriminatorMapping(writerSchema)
val readerMapping = discriminatorMapping(readerSchema)
val writerMapping = writerSchemaResolver.discriminatorMapping(writerSchema)
val readerMapping = readerSchemaResolver.discriminatorMapping(readerSchema)

val variantIssues: List[SchemaCompatibilityIssue] =
(writerMapping.keySet intersect readerMapping.keySet).toList.flatMap { tag =>
Expand Down Expand Up @@ -234,7 +192,7 @@ class SchemaComparator(
private def identical(writerSchema: Schema, readerSchema: Schema): Boolean =
(writerSchema eq readerSchema) || computeCached(identicalityCache, (writerSchema, readerSchema), true) {
def identicalSubschema(writerSubschema: SchemaLike, readerSubschema: SchemaLike): Boolean =
identical(normalize(writerSubschema, writerNamedSchemas), normalize(readerSubschema, readerNamedSchemas))
identical(writerSchemaResolver.resolveAndNormalize(writerSubschema), readerSchemaResolver.resolveAndNormalize(readerSubschema))

def identicalSubschemaMap[K](writerSubschemas: ListMap[K, SchemaLike], readerSubschemas: ListMap[K, SchemaLike]): Boolean =
(writerSubschemas.keySet ++ readerSubschemas.keySet).forall { key =>
Expand Down Expand Up @@ -358,24 +316,6 @@ class SchemaComparator(
private def isDiscriminatedUnionSchema(s: Schema): Boolean =
s.discriminator.nonEmpty && isUnionSchema(s)

/**
* Returns a mapping from discriminator values to schemas associated with them.
* All the schemas in the result are references (local or non-local).
*/
private def discriminatorMapping(schema: Schema): ListMap[String, Schema] = {
// schema reference -> overridden discriminator value
val explicitDiscValueByRef = schema.discriminator.flatMap(_.mapping).getOrElse(ListMap.empty).map(_.swap)
// assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf
ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect {
case s@ReferenceSchema(ref) =>
val discValue = explicitDiscValueByRef.getOrElse(ref, ref match {
case LocalRef(name) => name
case _ => throw new NoSuchElementException(s"no discriminator value specified for non-local reference $ref")
})
discValue -> s
}
}

private def getTypes(schema: Schema): Option[List[SchemaType]] = schema match {
case Schema.Empty => Some(SchemaType.Values)
case Schema.Nothing => Some(Nil)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package sttp.apispec.validation

import sttp.apispec.{AnySchema, Schema, SchemaLike}

import scala.annotation.tailrec
import scala.collection.immutable.ListMap

class SchemaResolver(schemas: Map[String, Schema]) {

def discriminatorMapping(schema: Schema): ListMap[String, Schema] = {
// schema reference -> overridden discriminator value
val explicitDiscValueByRef = schema.discriminator.flatMap(_.mapping).getOrElse(ListMap.empty).map(_.swap)
// assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf
ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect { case s @ ReferenceSchema(ref) =>
val discValue = explicitDiscValueByRef.getOrElse(
ref,
ref match {
case SchemaResolver.Reference(name) => name
case _ => throw new NoSuchElementException(s"no discriminator value specified for non-local reference $ref")
}
)
discValue -> s
}
}

@tailrec final def resolveAndNormalize(schema: SchemaLike): Schema = schema match {
case AnySchema.Anything => Schema.Empty
case AnySchema.Nothing => Schema.Nothing
case s @ ReferenceSchema(SchemaResolver.Reference(name)) =>
resolveAndNormalize(
schemas.getOrElse(name, throw new NoSuchElementException(s"could not resolve schema reference ${s.$ref.get}"))
)
case s: Schema => normalize(s)
}

private object ReferenceSchema {
def unapply(schema: Schema): Option[String] =
schema.$ref.filter(ref => schema == Schema($ref = Some(ref)))
}

private def normalize(schema: Schema): Schema =
schema.copy(
$comment = None,
$defs = None,
$schema = None,
title = None,
description = None,
default = None,
deprecated = None,
readOnly = None,
writeOnly = None,
examples = None,
externalDocs = None,
extensions = ListMap.empty
)
}

object SchemaResolver {
val ComponentsRefPrefix = "#/components/schemas/"

val DefsRefPrefix = "#/$defs/"

private val Reference = new References(ComponentsRefPrefix, DefsRefPrefix)

private class References(prefix: String*) {
def unapply(ref: String): Option[String] = prefix.flatMap { p =>
Option(ref).filter(_.startsWith(p)).map(_.stripPrefix(p))
}.headOption
}

def apply(schemas: Map[String, Schema]): SchemaResolver = new SchemaResolver(schemas)

def apply(schema: Schema): SchemaResolver = new SchemaResolver(
schema.$defs.getOrElse(Map.empty).collect { case (name, s: Schema) => name -> s }
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package sttp.apispec.validation

class ComponentsSchemaComparatorTest extends SchemaComparatorTest(SchemaResolver.ComponentsRefPrefix)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package sttp.apispec.validation

class DefsSchemaComparatorTest extends SchemaComparatorTest(SchemaResolver.DefsRefPrefix)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package sttp.apispec.validation

import org.scalatest.funsuite.AnyFunSuite
import sttp.apispec._
import sttp.apispec.validation.SchemaComparator.RefPrefix

import scala.collection.immutable.ListMap

class SchemaComparatorTest extends AnyFunSuite {
abstract class SchemaComparatorTest(referencePrefix: String) extends AnyFunSuite {

private val stringSchema = Schema(SchemaType.String)
private val integerSchema = Schema(SchemaType.Integer)
private val numberSchema = Schema(SchemaType.Number)
Expand Down Expand Up @@ -108,10 +108,11 @@ class SchemaComparatorTest extends AnyFunSuite {
)

private def ref(name: String): Schema =
Schema.referenceTo(SchemaComparator.RefPrefix, name)
Schema.referenceTo(referencePrefix, name)

private def compare(writerSchema: Schema, readerSchema: Schema): List[SchemaCompatibilityIssue] =
new SchemaComparator(writerSchemas, readerSchemas).compare(writerSchema, readerSchema)
new SchemaComparator(writerSchemas, readerSchemas)
.compare(writerSchema, readerSchema)

test("ignoring annotations") {
assert(compare(
Expand Down Expand Up @@ -481,11 +482,11 @@ class SchemaComparatorTest extends AnyFunSuite {
assert(compare(
Schema(
oneOf = List(ref("Foo"), ref("Bar")),
discriminator = Some(Discriminator("type", Some(ListMap("WFoo" -> s"${RefPrefix}Foo"))))
discriminator = Some(Discriminator("type", Some(ListMap("WFoo" -> s"${referencePrefix}Foo"))))
),
Schema(
oneOf = List(ref("Foo"), ref("Bar"), ref("Baz")),
discriminator = Some(Discriminator("type", Some(ListMap("RBar" -> s"${RefPrefix}Bar"))))
discriminator = Some(Discriminator("type", Some(ListMap("RBar" -> s"${referencePrefix}Bar"))))
),
) == List(
UnsupportedDiscriminatorValues(List("WFoo", "Bar"))
Expand All @@ -496,7 +497,7 @@ class SchemaComparatorTest extends AnyFunSuite {
assert(compare(
Schema(
oneOf = List(ref("Foo"), ref("Bar")),
discriminator = Some(Discriminator("type", Some(ListMap("Baz" -> s"${RefPrefix}Bar"))))
discriminator = Some(Discriminator("type", Some(ListMap("Baz" -> s"${referencePrefix}Bar"))))
),
Schema(
oneOf = List(ref("Foo"), ref("Baz")),
Expand Down

0 comments on commit 08eaa08

Please sign in to comment.