Skip to content

Commit

Permalink
support for comparing untagged union schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
ghik committed Apr 24, 2024
1 parent 078f9a1 commit d2f97cc
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 98 deletions.
4 changes: 3 additions & 1 deletion apispec-model/src/main/scala/sttp/apispec/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ object Schema {
Schema($ref = Some(s"$prefix${$ref}"))
}

sealed abstract class SchemaType(val value: String)
sealed abstract class SchemaType(val value: String) {
override def toString: String = value
}
object SchemaType {
case object Boolean extends SchemaType("boolean")
case object Object extends SchemaType("object")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SchemaComparator(
case AnySchema.Anything => Schema.Empty
case AnySchema.Nothing => Schema.Nothing
case s: Schema => deannotate(s) match {
case s@LocalRefSchema(name) =>
case s@ReferenceSchema(LocalRef(name)) =>
def noSchema: Nothing =
throw new NoSuchElementException(s"could not resolve schema reference ${s.$ref.get}")

Expand All @@ -86,13 +86,9 @@ class SchemaComparator(
}

/** Matches a schema that is a pure reference to one of the component schemas */
private object LocalRefSchema {
private object ReferenceSchema {
def unapply(schema: Schema): Option[String] =
schema.$ref
.filter(ref => schema == Schema($ref = Some(ref)))
.collect {
case LocalRef(name) => name
}
schema.$ref.filter(ref => schema == Schema($ref = Some(ref)))
}

// strip fields which do not affect schema comparison
Expand Down Expand Up @@ -122,10 +118,6 @@ class SchemaComparator(
checkStringLengthBounds(writerSchema, readerSchema).toList ++
checkPattern(writerSchema, readerSchema).toList

} else if (isNullableSchema(readerSchema)) {
val ws = if (isNullableSchema(writerSchema)) writerSchema.anyOf.head else writerSchema
compare(ws, readerSchema.anyOf.head)

} else if (isProductSchema(writerSchema) && isProductSchema(readerSchema)) {
// Even though the default value for `additionalProperties` is an empty schema that accepts everything,
// we assume a Nothing schema for the writer because it's extremely unlikely in practice that a value for a
Expand All @@ -145,24 +137,17 @@ class SchemaComparator(
checkDependentRequired(writerSchema, readerSchema) ++
propIssues

} else if (isCoproductSchema(writerSchema) && isCoproductSchema(readerSchema) &&
// if readerSchema does not have a discriminator, we fall back to GeneralSchemaMismatch
// TODO: support comparison of untagged unions
readerSchema.discriminator.nonEmpty
) {
} else if (isDiscriminatedUnionSchema(writerSchema) && isDiscriminatedUnionSchema(readerSchema)) {
val writerMapping = discriminatorMapping(writerSchema)
val readerMapping = discriminatorMapping(readerSchema)

val variantIssues: List[SchemaCompatibilityIssue] = (writerMapping, readerMapping) match {
case (Some(wm), Some(rm)) =>
(wm.keySet intersect rm.keySet).toList.flatMap { tag =>
compare(wm(tag), rm(tag)) match {
case Nil => None
case issues => Some(IncompatibleDiscriminatorCase(tag, issues))
}
val variantIssues: List[SchemaCompatibilityIssue] =
(writerMapping.keySet intersect readerMapping.keySet).toList.flatMap { tag =>
compare(writerMapping(tag), readerMapping(tag)) match {
case Nil => None
case issues => Some(IncompatibleDiscriminatorCase(tag, issues))
}
case _ => Nil
}
}

checkType(writerSchema, readerSchema).toList ++
checkDiscriminatorProp(writerSchema, readerSchema).toList ++
Expand Down Expand Up @@ -198,6 +183,29 @@ class SchemaComparator(
checkPropertyNames(writerSchema, readerSchema).toList ++
checkMinMaxProperties(writerSchema, readerSchema).toList

} else if (isUnionSchema(writerSchema)) {
val variants = writerSchema.oneOf ++ writerSchema.anyOf
variants.zipWithIndex.flatMap {
case (variant, idx) => compare(variant, readerSchema) match {
case Nil => None
case issues => Some(IncompatibleUnionVariant(idx, issues))
}
}
} else if (isUnionSchema(readerSchema)) {
val variants = readerSchema.oneOf ++ readerSchema.anyOf

@tailrec def alternatives(
variants: List[SchemaLike],
acc: List[List[SchemaCompatibilityIssue]]
): Option[AlternativeIssues] = variants match {
case Nil => Some(AlternativeIssues(acc.reverse))
case variant :: tail => compare(writerSchema, variant) match {
case Nil => None
case issues => alternatives(tail, issues :: acc)
}
}

alternatives(variants, Nil).toList
} else if (readerSchema == Schema.Nothing) {
List(NoValuesAllowed(writerSchema))

Expand All @@ -209,16 +217,6 @@ class SchemaComparator(
List(GeneralSchemaMismatch(writerSchema, readerSchema))
}

/**
* Matches a schema that uses `anyOf` to make another, base schema nullable.
* Most of the time, the base schema is a reference.
*/
private def isNullableSchema(s: Schema): Boolean =
s == Schema(anyOf = s.anyOf) && (s.anyOf match {
case List(_, Schema.Null) => true
case _ => false
})

/** Checks if schema is for a _primitive_ value, i.e. a string, boolean, number or null */
private def isPrimitiveSchema(s: Schema): Boolean =
s.`type`.exists { types =>
Expand Down Expand Up @@ -282,32 +280,34 @@ class SchemaComparator(
dependentRequired = s.dependentRequired,
)

// coproduct schema is a schema with `oneOf` or `anyOf` of pure references, with an optional discriminator object
private def isCoproductSchema(s: Schema): Boolean =
private def isUnionSchema(s: Schema): Boolean =
// exactly one of `oneOf` and `anyOf` should be non-empty
(s.oneOf.nonEmpty != s.anyOf.nonEmpty) &&
(s.oneOf ++ s.anyOf).forall {
case LocalRefSchema(_) => true
case _ => false
} && s == Schema(
(s.oneOf.nonEmpty != s.anyOf.nonEmpty) && s == Schema(
oneOf = s.oneOf,
anyOf = s.anyOf,
discriminator = s.discriminator,
)

private def discriminatorMapping(schema: Schema): Option[ListMap[String, Schema]] =
schema.discriminator.map { disc =>
// schema name -> overridden discriminator value
val reverseMapping = disc.mapping.getOrElse(ListMap.empty).collect {
case (discValue, LocalRef(name)) => name -> discValue
}
// assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf
ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect {
case s@LocalRefSchema(name) =>
val discValue = reverseMapping.getOrElse(name, name)
discValue -> s
}
private def isDiscriminatedUnionSchema(s: Schema): Boolean =
s.discriminator.nonEmpty && isUnionSchema(s)

/**
* Returns a mapping from discriminator values to schemas associated with them.
* All the schemas in the result are references (local or non-local).
*/
private def discriminatorMapping(schema: Schema): ListMap[String, Schema] = {
// schema reference -> overridden discriminator value
val explicitDiscValueByRef = schema.discriminator.flatMap(_.mapping).getOrElse(ListMap.empty).map(_.swap)
// assuming that schema is valid and every reference in disc.mapping is also an element of oneOf/anyOf
ListMap.empty ++ (schema.oneOf ++ schema.anyOf).collect {
case s@ReferenceSchema(ref) =>
val discValue = explicitDiscValueByRef.getOrElse(ref, ref match {
case LocalRef(name) => name
case _ => throw new NoSuchElementException(s"no discriminator value specified for non-local reference $ref")
})
discValue -> s
}
}

private def getTypes(schema: Schema): Option[List[SchemaType]] = schema match {
case Schema.Empty => Some(SchemaType.Values)
Expand Down Expand Up @@ -482,29 +482,21 @@ class SchemaComparator(
private def checkDiscriminatorProp(
writerSchema: Schema,
readerSchema: Schema
): Option[DiscriminatorPropertyMismatch] = {
val writerDiscProp = writerSchema.discriminator.map(_.propertyName)
val readerDiscProp = readerSchema.discriminator.map(_.propertyName)
(writerDiscProp, readerDiscProp) match {
case (None, Some(readerProp)) =>
Some(DiscriminatorPropertyMismatch(None, readerProp))
case (Some(writerProp), Some(readerProp)) if writerProp != readerProp =>
Some(DiscriminatorPropertyMismatch(writerDiscProp, readerProp))
case _ =>
None
}
}
): Option[DiscriminatorPropertyMismatch] = for {
writerDiscProp <- writerSchema.discriminator.map(_.propertyName)
readerDiscProp <- readerSchema.discriminator.map(_.propertyName)
if writerDiscProp != readerDiscProp
} yield DiscriminatorPropertyMismatch(writerDiscProp, readerDiscProp)

private def checkDiscriminatorValues(
writerMapping: Option[ListMap[String, SchemaLike]],
readerMapping: Option[ListMap[String, SchemaLike]],
): Option[UnsupportedDiscriminatorValues] =
for {
wm <- writerMapping
rm <- readerMapping
unsupportedValues = wm.keySet -- rm.keySet
if unsupportedValues.nonEmpty
} yield UnsupportedDiscriminatorValues(unsupportedValues.toList)
writerMapping: ListMap[String, SchemaLike],
readerMapping: ListMap[String, SchemaLike],
): Option[UnsupportedDiscriminatorValues] = {
val unsupportedValues = writerMapping.keySet -- readerMapping.keySet
if (unsupportedValues.nonEmpty)
Some(UnsupportedDiscriminatorValues(unsupportedValues.toList))
else None
}

private def checkPropertyNames(
writerSchema: Schema,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package sttp.apispec.validation

import sttp.apispec.{ExampleSingleValue, ExampleValue, Pattern, Schema, SchemaType}
import sttp.apispec.{ExampleValue, Pattern, Schema, SchemaType}

sealed abstract class SchemaCompatibilityIssue extends Product {
sealed abstract class SchemaCompatibilityIssue {
def description: String

override def toString: String = s"$productPrefix($description)"
override def toString: String = description

protected def pluralize(what: String, coll: Seq[Any]): String =
if (coll.lengthCompare(1) == 0) s"$what ${coll.head}"
else s"${what}s ${coll.mkString(", ")}"

protected def form(coll: Seq[Any], singular: String, plural: String): String =
if (coll.size == 1) singular else plural

protected def issuesRepr(issues: List[SchemaCompatibilityIssue]): String =
issues.iterator
.map(i => s"- ${i.description.replace("\n", "\n ")}") // indent
.mkString("\n")
}

/**
Expand Down Expand Up @@ -140,13 +145,11 @@ case class MissingDependentRequiredProperties(
}

case class DiscriminatorPropertyMismatch(
writerDiscriminator: Option[String],
writerDiscriminator: String,
readerDiscriminator: String
) extends SchemaCompatibilityIssue {
def description: String = {
val writerDiscriminatorRepr = writerDiscriminator.fold("")(wd => s", as opposed to $wd")
s"target schema requires discriminator property $readerDiscriminator$writerDiscriminatorRepr"
}
def description: String =
s"target schema discriminator property $readerDiscriminator, as opposed to $writerDiscriminator"
}

case class UnsupportedDiscriminatorValues(
Expand All @@ -161,54 +164,74 @@ case class UnsupportedDiscriminatorValues(
*/
sealed abstract class SubschemaCompatibilityIssue extends SchemaCompatibilityIssue {
def subschemaIssues: List[SchemaCompatibilityIssue]

protected def issuesRepr: String =
subschemaIssues.iterator
.map(i => s"- ${i.description.replace("\n", "\n ")}")
.mkString("\n")
}

case class IncompatibleProperty(
property: String,
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
def description: String =
s"incompatible schema for property $property:\n$issuesRepr"
s"incompatible schema for property $property:\n${issuesRepr(subschemaIssues)}"
}

case class IncompatibleDiscriminatorCase(
discriminatorValue: String,
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
def description: String =
s"incompatible schema for discriminator value $discriminatorValue:\n$issuesRepr"
s"incompatible schema for discriminator value $discriminatorValue:\n${issuesRepr(subschemaIssues)}"
}

case class IncompatibleAdditionalProperties(
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
def description: String =
s"incompatible schema for additional properties:\n$issuesRepr"
s"incompatible schema for additional properties:\n${issuesRepr(subschemaIssues)}"
}

case class IncompatiblePropertyNames(
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
override def description: String =
s"incompatible schema for property names:\n$issuesRepr"
s"incompatible schema for property names:\n${issuesRepr(subschemaIssues)}"
}

case class IncompatibleItems(
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
def description: String =
s"incompatible schema for items:\n$issuesRepr"
s"incompatible schema for items:\n${issuesRepr(subschemaIssues)}"
}

case class IncompatiblePrefixItem(
index: Int,
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
def description: String =
s"incompatible schema for prefix item at index $index:\n$issuesRepr"
s"incompatible schema for prefix item at index $index:\n${issuesRepr(subschemaIssues)}"
}

case class IncompatibleUnionVariant(
index: Int,
subschemaIssues: List[SchemaCompatibilityIssue]
) extends SubschemaCompatibilityIssue {
def description: String =
s"incompatible anyOf/oneOf variant at index $index:\n${issuesRepr(subschemaIssues)}"
}

/**
* An issue raised when a schema is not compatible with any of the alternatives in a target union schema.
*
* @param alternatives a list of non-empty lists of issues, where each list corresponds to one of the alternatives
*/
case class AlternativeIssues(
alternatives: List[List[SchemaCompatibilityIssue]]
) extends SchemaCompatibilityIssue {
override def description: String = {
val alternativesReprs = alternatives.zipWithIndex.map {
case (issues, idx) =>
s"for alternative $idx:\n${issuesRepr(issues)}"
}
s"schema is not compatible with any of the alternatives in oneOf/anyOf:\n${alternativesReprs.mkString("\n")}"
}
}
Loading

0 comments on commit d2f97cc

Please sign in to comment.