diff --git a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala index d931c5b..deefb55 100644 --- a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala +++ b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala @@ -6,10 +6,6 @@ import scala.annotation.tailrec import scala.collection.immutable.ListMap import scala.collection.mutable -object SchemaComparator { - final val RefPrefix = "#/components/schemas/" -} - /** * Utility for comparing schemas for compatibility. * See [[compare]] for more details. @@ -17,15 +13,18 @@ object SchemaComparator { * Since this class contains a cache of comparison results, * it is meant to be reused between multiple schema comparisons. * - * @param writerNamedSchemas named schemas which may be referred to by the writer schema - * @param readerNamedSchemas named schemas which may be referred to by the reader schema + * @param writerSchemaResolver can resolve named schemas which may be referred to by the writer schema + * @param readerSchemaResolver can resolve named schemas which may be referred to by the reader schema */ class SchemaComparator( - writerNamedSchemas: Map[String, Schema], - readerNamedSchemas: Map[String, Schema] + writerSchemaResolver: SchemaResolver, + readerSchemaResolver: SchemaResolver ) { - import SchemaComparator._ + def this( + writerNamedSchemas: Map[String, Schema], + readerNamedSchemas: Map[String, Schema] + ) = this(SchemaResolver(writerNamedSchemas), SchemaResolver(readerNamedSchemas)) private val issuesCache = new mutable.HashMap[(Schema, Schema), List[SchemaCompatibilityIssue]] private val identicalityCache = new mutable.HashMap[(Schema, Schema), Boolean] @@ -65,54 +64,13 @@ class SchemaComparator( * @return a list of incompatibilities between the schemas */ def compare(writerSchema: SchemaLike, readerSchema: SchemaLike): List[SchemaCompatibilityIssue] = { - val normalizedWriterSchema = normalize(writerSchema, writerNamedSchemas) - val normalizedReaderSchema = normalize(readerSchema, readerNamedSchemas) + val normalizedWriterSchema = writerSchemaResolver.resolveAndNormalize(writerSchema) + val normalizedReaderSchema = readerSchemaResolver.resolveAndNormalize(readerSchema) computeCached(issuesCache, (normalizedWriterSchema, normalizedReaderSchema), Nil) { compareNormalized(normalizedWriterSchema, normalizedReaderSchema) } } - // translate AnySchema to Schema, remove annotations and resolve references - @tailrec private def normalize(schema: SchemaLike, named: Map[String, Schema]): Schema = schema match { - case AnySchema.Anything => Schema.Empty - case AnySchema.Nothing => Schema.Nothing - case s: Schema => deannotate(s) match { - case s@ReferenceSchema(LocalRef(name)) => - def noSchema: Nothing = - throw new NoSuchElementException(s"could not resolve schema reference ${s.$ref.get}") - - normalize(named.getOrElse(name, noSchema), named) - case s => s - } - } - - private object LocalRef { - def unapply(ref: String): Option[String] = - if (ref.startsWith(RefPrefix)) Some(ref.stripPrefix(RefPrefix)) - else None - } - - /** Matches a schema that is a pure reference to one of the component schemas */ - private object ReferenceSchema { - def unapply(schema: Schema): Option[String] = - schema.$ref.filter(ref => schema == Schema($ref = Some(ref))) - } - - // strip fields which do not affect schema comparison - private def deannotate(schema: Schema): Schema = - schema.copy( - $comment = None, - title = None, - description = None, - default = None, - deprecated = None, - readOnly = None, - writeOnly = None, - examples = None, - externalDocs = None, - extensions = ListMap.empty - ) - private def compareNormalized(writerSchema: Schema, readerSchema: Schema): List[SchemaCompatibilityIssue] = if (writerSchema == Schema.Nothing || readerSchema == Schema.Empty) { Nil @@ -145,8 +103,8 @@ class SchemaComparator( propIssues } else if (isDiscriminatedUnionSchema(writerSchema) && isDiscriminatedUnionSchema(readerSchema)) { - val writerMapping = discriminatorMapping(writerSchema) - val readerMapping = discriminatorMapping(readerSchema) + val writerMapping = writerSchemaResolver.discriminatorMapping(writerSchema) + val readerMapping = readerSchemaResolver.discriminatorMapping(readerSchema) val variantIssues: List[SchemaCompatibilityIssue] = (writerMapping.keySet intersect readerMapping.keySet).toList.flatMap { tag => @@ -234,7 +192,7 @@ class SchemaComparator( private def identical(writerSchema: Schema, readerSchema: Schema): Boolean = (writerSchema eq readerSchema) || computeCached(identicalityCache, (writerSchema, readerSchema), true) { def identicalSubschema(writerSubschema: SchemaLike, readerSubschema: SchemaLike): Boolean = - identical(normalize(writerSubschema, writerNamedSchemas), normalize(readerSubschema, readerNamedSchemas)) + identical(writerSchemaResolver.resolveAndNormalize(writerSubschema), readerSchemaResolver.resolveAndNormalize(readerSubschema)) def identicalSubschemaMap[K](writerSubschemas: ListMap[K, SchemaLike], readerSubschemas: ListMap[K, SchemaLike]): Boolean = (writerSubschemas.keySet ++ readerSubschemas.keySet).forall { key => @@ -358,24 +316,6 @@ class SchemaComparator( private def isDiscriminatedUnionSchema(s: Schema): Boolean = s.discriminator.nonEmpty && isUnionSchema(s) - /** - * Returns a mapping from discriminator values to schemas associated with them. - * All the schemas in the result are references (local or non-local). - */ - private def discriminatorMapping(schema: Schema): ListMap[String, Schema] = { - // schema reference -> overridden discriminator value - val explicitDiscValueByRef = schema.discriminator.flatMap(_.mapping).getOrElse(ListMap.empty).map(_.swap) - // assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf - ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect { - case s@ReferenceSchema(ref) => - val discValue = explicitDiscValueByRef.getOrElse(ref, ref match { - case LocalRef(name) => name - case _ => throw new NoSuchElementException(s"no discriminator value specified for non-local reference $ref") - }) - discValue -> s - } - } - private def getTypes(schema: Schema): Option[List[SchemaType]] = schema match { case Schema.Empty => Some(SchemaType.Values) case Schema.Nothing => Some(Nil) diff --git a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaResolver.scala b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaResolver.scala new file mode 100644 index 0000000..f071170 --- /dev/null +++ b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaResolver.scala @@ -0,0 +1,76 @@ +package sttp.apispec.validation + +import sttp.apispec.{AnySchema, Schema, SchemaLike} + +import scala.annotation.tailrec +import scala.collection.immutable.ListMap + +class SchemaResolver(schemas: Map[String, Schema]) { + + def discriminatorMapping(schema: Schema): ListMap[String, Schema] = { + // schema reference -> overridden discriminator value + val explicitDiscValueByRef = schema.discriminator.flatMap(_.mapping).getOrElse(ListMap.empty).map(_.swap) + // assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf + ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect { case s @ ReferenceSchema(ref) => + val discValue = explicitDiscValueByRef.getOrElse( + ref, + ref match { + case SchemaResolver.Reference(name) => name + case _ => throw new NoSuchElementException(s"no discriminator value specified for non-local reference $ref") + } + ) + discValue -> s + } + } + + @tailrec final def resolveAndNormalize(schema: SchemaLike): Schema = schema match { + case AnySchema.Anything => Schema.Empty + case AnySchema.Nothing => Schema.Nothing + case s @ ReferenceSchema(SchemaResolver.Reference(name)) => + resolveAndNormalize( + schemas.getOrElse(name, throw new NoSuchElementException(s"could not resolve schema reference ${s.$ref.get}")) + ) + case s: Schema => normalize(s) + } + + private object ReferenceSchema { + def unapply(schema: Schema): Option[String] = + schema.$ref.filter(ref => schema == Schema($ref = Some(ref))) + } + + private def normalize(schema: Schema): Schema = + schema.copy( + $comment = None, + $defs = None, + $schema = None, + title = None, + description = None, + default = None, + deprecated = None, + readOnly = None, + writeOnly = None, + examples = None, + externalDocs = None, + extensions = ListMap.empty + ) +} + +object SchemaResolver { + val ComponentsRefPrefix = "#/components/schemas/" + + val DefsRefPrefix = "#/$defs/" + + private val Reference = new References(ComponentsRefPrefix, DefsRefPrefix) + + private class References(prefix: String*) { + def unapply(ref: String): Option[String] = prefix.flatMap { p => + Option(ref).filter(_.startsWith(p)).map(_.stripPrefix(p)) + }.headOption + } + + def apply(schemas: Map[String, Schema]): SchemaResolver = new SchemaResolver(schemas) + + def apply(schema: Schema): SchemaResolver = new SchemaResolver( + schema.$defs.getOrElse(Map.empty).collect { case (name, s: Schema) => name -> s } + ) +} diff --git a/apispec-model/src/test/scala/sttp/apispec/validation/ComponentsSchemaComparatorTest.scala b/apispec-model/src/test/scala/sttp/apispec/validation/ComponentsSchemaComparatorTest.scala new file mode 100644 index 0000000..0ba8b17 --- /dev/null +++ b/apispec-model/src/test/scala/sttp/apispec/validation/ComponentsSchemaComparatorTest.scala @@ -0,0 +1,3 @@ +package sttp.apispec.validation + +class ComponentsSchemaComparatorTest extends SchemaComparatorTest(SchemaResolver.ComponentsRefPrefix) diff --git a/apispec-model/src/test/scala/sttp/apispec/validation/DefsSchemaComparatorTest.scala b/apispec-model/src/test/scala/sttp/apispec/validation/DefsSchemaComparatorTest.scala new file mode 100644 index 0000000..6357389 --- /dev/null +++ b/apispec-model/src/test/scala/sttp/apispec/validation/DefsSchemaComparatorTest.scala @@ -0,0 +1,3 @@ +package sttp.apispec.validation + +class DefsSchemaComparatorTest extends SchemaComparatorTest(SchemaResolver.DefsRefPrefix) diff --git a/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala b/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala index 912eac0..8c0394d 100644 --- a/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala +++ b/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala @@ -2,11 +2,11 @@ package sttp.apispec.validation import org.scalatest.funsuite.AnyFunSuite import sttp.apispec._ -import sttp.apispec.validation.SchemaComparator.RefPrefix import scala.collection.immutable.ListMap -class SchemaComparatorTest extends AnyFunSuite { +abstract class SchemaComparatorTest(referencePrefix: String) extends AnyFunSuite { + private val stringSchema = Schema(SchemaType.String) private val integerSchema = Schema(SchemaType.Integer) private val numberSchema = Schema(SchemaType.Number) @@ -108,10 +108,11 @@ class SchemaComparatorTest extends AnyFunSuite { ) private def ref(name: String): Schema = - Schema.referenceTo(SchemaComparator.RefPrefix, name) + Schema.referenceTo(referencePrefix, name) private def compare(writerSchema: Schema, readerSchema: Schema): List[SchemaCompatibilityIssue] = - new SchemaComparator(writerSchemas, readerSchemas).compare(writerSchema, readerSchema) + new SchemaComparator(writerSchemas, readerSchemas) + .compare(writerSchema, readerSchema) test("ignoring annotations") { assert(compare( @@ -481,11 +482,11 @@ class SchemaComparatorTest extends AnyFunSuite { assert(compare( Schema( oneOf = List(ref("Foo"), ref("Bar")), - discriminator = Some(Discriminator("type", Some(ListMap("WFoo" -> s"${RefPrefix}Foo")))) + discriminator = Some(Discriminator("type", Some(ListMap("WFoo" -> s"${referencePrefix}Foo")))) ), Schema( oneOf = List(ref("Foo"), ref("Bar"), ref("Baz")), - discriminator = Some(Discriminator("type", Some(ListMap("RBar" -> s"${RefPrefix}Bar")))) + discriminator = Some(Discriminator("type", Some(ListMap("RBar" -> s"${referencePrefix}Bar")))) ), ) == List( UnsupportedDiscriminatorValues(List("WFoo", "Bar")) @@ -496,7 +497,7 @@ class SchemaComparatorTest extends AnyFunSuite { assert(compare( Schema( oneOf = List(ref("Foo"), ref("Bar")), - discriminator = Some(Discriminator("type", Some(ListMap("Baz" -> s"${RefPrefix}Bar")))) + discriminator = Some(Discriminator("type", Some(ListMap("Baz" -> s"${referencePrefix}Bar")))) ), Schema( oneOf = List(ref("Foo"), ref("Baz")),