diff --git a/apispec-model/src/main/scala/sttp/apispec/Schema.scala b/apispec-model/src/main/scala/sttp/apispec/Schema.scala index 174ef15..3bb8c4a 100644 --- a/apispec-model/src/main/scala/sttp/apispec/Schema.scala +++ b/apispec-model/src/main/scala/sttp/apispec/Schema.scala @@ -162,7 +162,9 @@ object Schema { Schema($ref = Some(s"$prefix${$ref}")) } -sealed abstract class SchemaType(val value: String) +sealed abstract class SchemaType(val value: String) { + override def toString: String = value +} object SchemaType { case object Boolean extends SchemaType("boolean") case object Object extends SchemaType("object") diff --git a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala index 66ef2c7..506c7eb 100644 --- a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala +++ b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaComparator.scala @@ -70,7 +70,7 @@ class SchemaComparator( case AnySchema.Anything => Schema.Empty case AnySchema.Nothing => Schema.Nothing case s: Schema => deannotate(s) match { - case s@LocalRefSchema(name) => + case s@ReferenceSchema(LocalRef(name)) => def noSchema: Nothing = throw new NoSuchElementException(s"could not resolve schema reference ${s.$ref.get}") @@ -86,13 +86,9 @@ class SchemaComparator( } /** Matches a schema that is a pure reference to one of the component schemas */ - private object LocalRefSchema { + private object ReferenceSchema { def unapply(schema: Schema): Option[String] = - schema.$ref - .filter(ref => schema == Schema($ref = Some(ref))) - .collect { - case LocalRef(name) => name - } + schema.$ref.filter(ref => schema == Schema($ref = Some(ref))) } // strip fields which do not affect schema comparison @@ -122,10 +118,6 @@ class SchemaComparator( checkStringLengthBounds(writerSchema, readerSchema).toList ++ checkPattern(writerSchema, readerSchema).toList - } else if (isNullableSchema(readerSchema)) { - val ws = if (isNullableSchema(writerSchema)) writerSchema.anyOf.head else writerSchema - compare(ws, readerSchema.anyOf.head) - } else if (isProductSchema(writerSchema) && isProductSchema(readerSchema)) { // Even though the default value for `additionalProperties` is an empty schema that accepts everything, // we assume a Nothing schema for the writer because it's extremely unlikely in practice that a value for a @@ -145,24 +137,17 @@ class SchemaComparator( checkDependentRequired(writerSchema, readerSchema) ++ propIssues - } else if (isCoproductSchema(writerSchema) && isCoproductSchema(readerSchema) && - // if readerSchema does not have a discriminator, we fall back to GeneralSchemaMismatch - // TODO: support comparison of untagged unions - readerSchema.discriminator.nonEmpty - ) { + } else if (isDiscriminatedUnionSchema(writerSchema) && isDiscriminatedUnionSchema(readerSchema)) { val writerMapping = discriminatorMapping(writerSchema) val readerMapping = discriminatorMapping(readerSchema) - val variantIssues: List[SchemaCompatibilityIssue] = (writerMapping, readerMapping) match { - case (Some(wm), Some(rm)) => - (wm.keySet intersect rm.keySet).toList.flatMap { tag => - compare(wm(tag), rm(tag)) match { - case Nil => None - case issues => Some(IncompatibleDiscriminatorCase(tag, issues)) - } + val variantIssues: List[SchemaCompatibilityIssue] = + (writerMapping.keySet intersect readerMapping.keySet).toList.flatMap { tag => + compare(writerMapping(tag), readerMapping(tag)) match { + case Nil => None + case issues => Some(IncompatibleDiscriminatorCase(tag, issues)) } - case _ => Nil - } + } checkType(writerSchema, readerSchema).toList ++ checkDiscriminatorProp(writerSchema, readerSchema).toList ++ @@ -198,6 +183,29 @@ class SchemaComparator( checkPropertyNames(writerSchema, readerSchema).toList ++ checkMinMaxProperties(writerSchema, readerSchema).toList + } else if (isUnionSchema(writerSchema)) { + val variants = writerSchema.oneOf ++ writerSchema.anyOf + variants.zipWithIndex.flatMap { + case (variant, idx) => compare(variant, readerSchema) match { + case Nil => None + case issues => Some(IncompatibleUnionVariant(idx, issues)) + } + } + } else if (isUnionSchema(readerSchema)) { + val variants = readerSchema.oneOf ++ readerSchema.anyOf + + @tailrec def alternatives( + variants: List[SchemaLike], + acc: List[List[SchemaCompatibilityIssue]] + ): Option[AlternativeIssues] = variants match { + case Nil => Some(AlternativeIssues(acc.reverse)) + case variant :: tail => compare(writerSchema, variant) match { + case Nil => None + case issues => alternatives(tail, issues :: acc) + } + } + + alternatives(variants, Nil).toList } else if (readerSchema == Schema.Nothing) { List(NoValuesAllowed(writerSchema)) @@ -209,16 +217,6 @@ class SchemaComparator( List(GeneralSchemaMismatch(writerSchema, readerSchema)) } - /** - * Matches a schema that uses `anyOf` to make another, base schema nullable. - * Most of the time, the base schema is a reference. - */ - private def isNullableSchema(s: Schema): Boolean = - s == Schema(anyOf = s.anyOf) && (s.anyOf match { - case List(_, Schema.Null) => true - case _ => false - }) - /** Checks if schema is for a _primitive_ value, i.e. a string, boolean, number or null */ private def isPrimitiveSchema(s: Schema): Boolean = s.`type`.exists { types => @@ -282,32 +280,34 @@ class SchemaComparator( dependentRequired = s.dependentRequired, ) - // coproduct schema is a schema with `oneOf` or `anyOf` of pure references, with an optional discriminator object - private def isCoproductSchema(s: Schema): Boolean = + private def isUnionSchema(s: Schema): Boolean = // exactly one of `oneOf` and `anyOf` should be non-empty - (s.oneOf.nonEmpty != s.anyOf.nonEmpty) && - (s.oneOf ++ s.anyOf).forall { - case LocalRefSchema(_) => true - case _ => false - } && s == Schema( + (s.oneOf.nonEmpty != s.anyOf.nonEmpty) && s == Schema( oneOf = s.oneOf, anyOf = s.anyOf, discriminator = s.discriminator, ) - private def discriminatorMapping(schema: Schema): Option[ListMap[String, Schema]] = - schema.discriminator.map { disc => - // schema name -> overridden discriminator value - val reverseMapping = disc.mapping.getOrElse(ListMap.empty).collect { - case (discValue, LocalRef(name)) => name -> discValue - } - // assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf - ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect { - case s@LocalRefSchema(name) => - val discValue = reverseMapping.getOrElse(name, name) - discValue -> s - } + private def isDiscriminatedUnionSchema(s: Schema): Boolean = + s.discriminator.nonEmpty && isUnionSchema(s) + + /** + * Returns a mapping from discriminator values to schemas associated with them. + * All the schemas in the result are references (local or non-local). + */ + private def discriminatorMapping(schema: Schema): ListMap[String, Schema] = { + // schema reference -> overridden discriminator value + val explicitDiscValueByRef = schema.discriminator.flatMap(_.mapping).getOrElse(ListMap.empty).map(_.swap) + // assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf + ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect { + case s@ReferenceSchema(ref) => + val discValue = explicitDiscValueByRef.getOrElse(ref, ref match { + case LocalRef(name) => name + case _ => throw new NoSuchElementException(s"no discriminator value specified for non-local reference $ref") + }) + discValue -> s } + } private def getTypes(schema: Schema): Option[List[SchemaType]] = schema match { case Schema.Empty => Some(SchemaType.Values) @@ -482,29 +482,21 @@ class SchemaComparator( private def checkDiscriminatorProp( writerSchema: Schema, readerSchema: Schema - ): Option[DiscriminatorPropertyMismatch] = { - val writerDiscProp = writerSchema.discriminator.map(_.propertyName) - val readerDiscProp = readerSchema.discriminator.map(_.propertyName) - (writerDiscProp, readerDiscProp) match { - case (None, Some(readerProp)) => - Some(DiscriminatorPropertyMismatch(None, readerProp)) - case (Some(writerProp), Some(readerProp)) if writerProp != readerProp => - Some(DiscriminatorPropertyMismatch(writerDiscProp, readerProp)) - case _ => - None - } - } + ): Option[DiscriminatorPropertyMismatch] = for { + writerDiscProp <- writerSchema.discriminator.map(_.propertyName) + readerDiscProp <- readerSchema.discriminator.map(_.propertyName) + if writerDiscProp != readerDiscProp + } yield DiscriminatorPropertyMismatch(writerDiscProp, readerDiscProp) private def checkDiscriminatorValues( - writerMapping: Option[ListMap[String, SchemaLike]], - readerMapping: Option[ListMap[String, SchemaLike]], - ): Option[UnsupportedDiscriminatorValues] = - for { - wm <- writerMapping - rm <- readerMapping - unsupportedValues = wm.keySet -- rm.keySet - if unsupportedValues.nonEmpty - } yield UnsupportedDiscriminatorValues(unsupportedValues.toList) + writerMapping: ListMap[String, SchemaLike], + readerMapping: ListMap[String, SchemaLike], + ): Option[UnsupportedDiscriminatorValues] = { + val unsupportedValues = writerMapping.keySet -- readerMapping.keySet + if (unsupportedValues.nonEmpty) + Some(UnsupportedDiscriminatorValues(unsupportedValues.toList)) + else None + } private def checkPropertyNames( writerSchema: Schema, diff --git a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaCompatibilityIssue.scala b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaCompatibilityIssue.scala index 2f448fc..0b2868f 100644 --- a/apispec-model/src/main/scala/sttp/apispec/validation/SchemaCompatibilityIssue.scala +++ b/apispec-model/src/main/scala/sttp/apispec/validation/SchemaCompatibilityIssue.scala @@ -1,11 +1,11 @@ package sttp.apispec.validation -import sttp.apispec.{ExampleSingleValue, ExampleValue, Pattern, Schema, SchemaType} +import sttp.apispec.{ExampleValue, Pattern, Schema, SchemaType} -sealed abstract class SchemaCompatibilityIssue extends Product { +sealed abstract class SchemaCompatibilityIssue { def description: String - override def toString: String = s"$productPrefix($description)" + override def toString: String = description protected def pluralize(what: String, coll: Seq[Any]): String = if (coll.lengthCompare(1) == 0) s"$what ${coll.head}" @@ -13,6 +13,11 @@ sealed abstract class SchemaCompatibilityIssue extends Product { protected def form(coll: Seq[Any], singular: String, plural: String): String = if (coll.size == 1) singular else plural + + protected def issuesRepr(issues: List[SchemaCompatibilityIssue]): String = + issues.iterator + .map(i => s"- ${i.description.replace("\n", "\n ")}") // indent + .mkString("\n") } /** @@ -140,13 +145,11 @@ case class MissingDependentRequiredProperties( } case class DiscriminatorPropertyMismatch( - writerDiscriminator: Option[String], + writerDiscriminator: String, readerDiscriminator: String ) extends SchemaCompatibilityIssue { - def description: String = { - val writerDiscriminatorRepr = writerDiscriminator.fold("")(wd => s", as opposed to $wd") - s"target schema requires discriminator property $readerDiscriminator$writerDiscriminatorRepr" - } + def description: String = + s"target schema discriminator property $readerDiscriminator, as opposed to $writerDiscriminator" } case class UnsupportedDiscriminatorValues( @@ -161,11 +164,6 @@ case class UnsupportedDiscriminatorValues( */ sealed abstract class SubschemaCompatibilityIssue extends SchemaCompatibilityIssue { def subschemaIssues: List[SchemaCompatibilityIssue] - - protected def issuesRepr: String = - subschemaIssues.iterator - .map(i => s"- ${i.description.replace("\n", "\n ")}") - .mkString("\n") } case class IncompatibleProperty( @@ -173,7 +171,7 @@ case class IncompatibleProperty( subschemaIssues: List[SchemaCompatibilityIssue] ) extends SubschemaCompatibilityIssue { def description: String = - s"incompatible schema for property $property:\n$issuesRepr" + s"incompatible schema for property $property:\n${issuesRepr(subschemaIssues)}" } case class IncompatibleDiscriminatorCase( @@ -181,28 +179,28 @@ case class IncompatibleDiscriminatorCase( subschemaIssues: List[SchemaCompatibilityIssue] ) extends SubschemaCompatibilityIssue { def description: String = - s"incompatible schema for discriminator value $discriminatorValue:\n$issuesRepr" + s"incompatible schema for discriminator value $discriminatorValue:\n${issuesRepr(subschemaIssues)}" } case class IncompatibleAdditionalProperties( subschemaIssues: List[SchemaCompatibilityIssue] ) extends SubschemaCompatibilityIssue { def description: String = - s"incompatible schema for additional properties:\n$issuesRepr" + s"incompatible schema for additional properties:\n${issuesRepr(subschemaIssues)}" } case class IncompatiblePropertyNames( subschemaIssues: List[SchemaCompatibilityIssue] ) extends SubschemaCompatibilityIssue { override def description: String = - s"incompatible schema for property names:\n$issuesRepr" + s"incompatible schema for property names:\n${issuesRepr(subschemaIssues)}" } case class IncompatibleItems( subschemaIssues: List[SchemaCompatibilityIssue] ) extends SubschemaCompatibilityIssue { def description: String = - s"incompatible schema for items:\n$issuesRepr" + s"incompatible schema for items:\n${issuesRepr(subschemaIssues)}" } case class IncompatiblePrefixItem( @@ -210,5 +208,30 @@ case class IncompatiblePrefixItem( subschemaIssues: List[SchemaCompatibilityIssue] ) extends SubschemaCompatibilityIssue { def description: String = - s"incompatible schema for prefix item at index $index:\n$issuesRepr" + s"incompatible schema for prefix item at index $index:\n${issuesRepr(subschemaIssues)}" +} + +case class IncompatibleUnionVariant( + index: Int, + subschemaIssues: List[SchemaCompatibilityIssue] +) extends SubschemaCompatibilityIssue { + def description: String = + s"incompatible anyOf/oneOf variant at index $index:\n${issuesRepr(subschemaIssues)}" +} + +/** + * An issue raised when a schema is not compatible with any of the alternatives in a target union schema. + * + * @param alternatives a list of non-empty lists of issues, where each list corresponds to one of the alternatives + */ +case class AlternativeIssues( + alternatives: List[List[SchemaCompatibilityIssue]] +) extends SchemaCompatibilityIssue { + override def description: String = { + val alternativesReprs = alternatives.zipWithIndex.map { + case (issues, idx) => + s"for alternative $idx:\n${issuesRepr(issues)}" + } + s"schema is not compatible with any of the alternatives in oneOf/anyOf:\n${alternativesReprs.mkString("\n")}" + } } diff --git a/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala b/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala index 58c77c6..5289b0f 100644 --- a/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala +++ b/apispec-model/src/test/scala/sttp/apispec/validation/SchemaComparatorTest.scala @@ -190,10 +190,10 @@ class SchemaComparatorTest extends AnyFunSuite { TypeMismatch(List(SchemaType.Null), List(SchemaType.String)) )) assert(compare(opaqueSchema.nullable, opaqueSchema) == List( - GeneralSchemaMismatch(opaqueSchema.nullable, opaqueSchema) //TODO better issue? + IncompatibleUnionVariant(1, List(GeneralSchemaMismatch(Schema.Null, opaqueSchema))) )) assert(compare(ref("String").nullable, ref("String")) == List( - GeneralSchemaMismatch(ref("String").nullable, stringSchema) //TODO better issue? + IncompatibleUnionVariant(1, List(TypeMismatch(List(SchemaType.Null), List(SchemaType.String)))) )) } @@ -455,7 +455,7 @@ class SchemaComparatorTest extends AnyFunSuite { discriminator = Some(Discriminator("type", None)) ), ) == List( - DiscriminatorPropertyMismatch(Some("kind"), "type") + DiscriminatorPropertyMismatch("kind", "type") )) } @@ -628,4 +628,27 @@ class SchemaComparatorTest extends AnyFunSuite { )) ) } + + test("compatible untagged union schemas") { + assert(compare( + Schema(anyOf = List(stringSchema, integerSchema)), + Schema(anyOf = List(stringSchema, numberSchema, booleanSchema)), + ) == Nil) + } + + test("incompatible untagged union schemas") { + assert(compare( + Schema(anyOf = List(stringSchema, numberSchema, booleanSchema)), + Schema(anyOf = List(stringSchema, integerSchema)), + ) == List( + IncompatibleUnionVariant(1, List(AlternativeIssues(List( + List(TypeMismatch(List(SchemaType.Number), List(SchemaType.String))), + List(TypeMismatch(List(SchemaType.Number), List(SchemaType.Integer))) + )))), + IncompatibleUnionVariant(2, List(AlternativeIssues(List( + List(TypeMismatch(List(SchemaType.Boolean), List(SchemaType.String))), + List(TypeMismatch(List(SchemaType.Boolean), List(SchemaType.Integer))) + )))) + )) + } }