diff --git a/modules/sample/src/main/resources/issues/issue222.yaml b/modules/sample/src/main/resources/issues/issue222.yaml index f538ac44c..39f357b4f 100644 --- a/modules/sample/src/main/resources/issues/issue222.yaml +++ b/modules/sample/src/main/resources/issues/issue222.yaml @@ -15,6 +15,7 @@ definitions: allOf: - "$ref": "#/definitions/RequestFields" - type: object + required: [same] properties: id: type: string diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala index eb88e252b..0594e1be1 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala @@ -19,6 +19,7 @@ import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassPare import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations import dev.guardrail.generators.scala.{ CirceModelGenerator, ScalaLanguage } import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.syntax._ import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName } import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.PropertyRequirement @@ -1353,7 +1354,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa } else { None } - params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) + params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value))) terms = params.map(_.term) @@ -1382,6 +1383,28 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa } yield code + private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] = + for { + reduced <- params + .groupBy(_.name) + .toList + .traverse { case (k, xs) => + implicit val ord: Ordering[PropertyRequirement] = { + case (PropertyRequirement.Required, _) => -1 + case (_, PropertyRequirement.Required) => 1 + case _ => 0 + } + xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match { + case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}") + case x :: Nil => Target.pure((k, x)) + case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x)) + case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}") + } + } + .map(_.toMap) + names = params.map(_.name).distinct + } yield names.flatMap(n => reduced.get(n)) + private def encodeModel( clsName: String, dtoPackage: List[String], @@ -1390,9 +1413,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa ) = for { () <- Target.pure(()) - discriminators = parents.flatMap(_.discriminators) - discriminatorNames = discriminators.map(_.propertyName).toSet - allParams = parents.reverse.flatMap(_.params) ++ selfParams + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams) (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList typeName = Type.Name(clsName) @@ -1457,9 +1480,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = { for { () <- Target.pure(()) - discriminators = parents.flatMap(_.discriminators) - discriminatorNames = discriminators.map(_.propertyName).toSet - allParams = parents.reverse.flatMap(_.params) ++ selfParams + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams) params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) paramCount = params.length diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala index 0994fae3d..0153df8d9 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala @@ -20,6 +20,7 @@ import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassPare import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations import dev.guardrail.generators.scala.{ CirceModelGenerator, CirceRefinedModelGenerator, ScalaLanguage } import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.syntax._ import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName } import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol._ @@ -1425,7 +1426,7 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, } else { None } - params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) + params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value))) terms = params.map(_.term) @@ -1454,6 +1455,28 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, } yield code + private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] = + for { + reduced <- params + .groupBy(_.name) + .toList + .traverse { case (k, xs) => + implicit val ord: Ordering[PropertyRequirement] = { + case (PropertyRequirement.Required, _) => -1 + case (_, PropertyRequirement.Required) => 1 + case _ => 0 + } + xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match { + case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}") + case x :: Nil => Target.pure((k, x)) + case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x)) + case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}") + } + } + .map(_.toMap) + names = params.map(_.name).distinct + } yield names.flatMap(n => reduced.get(n)) + private def encodeModel( clsName: String, dtoPackage: List[String], @@ -1462,9 +1485,9 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, ) = for { () <- Target.pure(()) - discriminators = parents.flatMap(_.discriminators) - discriminatorNames = discriminators.map(_.propertyName).toSet - allParams = parents.reverse.flatMap(_.params) ++ selfParams + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams) (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList typeName = Type.Name(clsName) @@ -1529,9 +1552,9 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = { for { () <- Target.pure(()) - discriminators = parents.flatMap(_.discriminators) - discriminatorNames = discriminators.map(_.propertyName).toSet - allParams = parents.reverse.flatMap(_.params) ++ selfParams + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams) params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) paramCount = params.length diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala index 23079b8c9..3f05f2e95 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/jackson/JacksonProtocolGenerator.scala @@ -17,6 +17,7 @@ import dev.guardrail.generators.ProtocolGenerator.{ WrapEnumSchema, wrapNumberEn import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassParent } import dev.guardrail.generators.scala.{ JacksonModelGenerator, ScalaLanguage } import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader } +import dev.guardrail.generators.syntax._ import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName } import dev.guardrail.terms.framework.FrameworkTerms import dev.guardrail.terms.protocol.PropertyRequirement.{ Optional, RequiredNullable } @@ -1240,11 +1241,12 @@ class JacksonProtocolGenerator private extends ProtocolTerms[ScalaLanguage, Targ NonEmptyList.ofInitLast(supportPackage :+ "Presence", "OptionNonMissingDeserializer") ) allTerms = selfParams ++ parents.flatMap(_.params) + + discriminatorNames = parents.flatMap(_.discriminators).map(_.propertyName).toSet + params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value))) renderedClass = { // TODO: This logic should be reflowed. The scope and rebindings is due to a refactor where // code from another dependent class was just copied in here wholesale. - val discriminatorNames = parents.flatMap(_.discriminators).map(_.propertyName).toSet - val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) - val terms = params.map(_.term) + val terms = params.map(_.term) val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) { def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { @@ -1309,6 +1311,28 @@ class JacksonProtocolGenerator private extends ProtocolTerms[ScalaLanguage, Targ } yield renderedClass } + private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] = + for { + reduced <- params + .groupBy(_.name) + .toList + .traverse { case (k, xs) => + implicit val ord: Ordering[PropertyRequirement] = { + case (PropertyRequirement.Required, _) => -1 + case (_, PropertyRequirement.Required) => 1 + case _ => 0 + } + xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match { + case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}") + case x :: Nil => Target.pure((k, x)) + case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x)) + case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}") + } + } + .map(_.toMap) + names = params.map(_.name).distinct + } yield names.flatMap(n => reduced.get(n)) + private def encodeModel( clsName: String, dtoPackage: List[String],