Skip to content

Commit

Permalink
Distinct DTO parameter names
Browse files Browse the repository at this point in the history
Resolves #1642
  • Loading branch information
blast-hardcheese committed Dec 26, 2023
1 parent e374f5f commit 2bd94c4
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 16 deletions.
1 change: 1 addition & 0 deletions modules/sample/src/main/resources/issues/issue222.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ definitions:
allOf:
- "$ref": "#/definitions/RequestFields"
- type: object
required: [same]
properties:
id:
type: string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassPare
import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations
import dev.guardrail.generators.scala.{ CirceModelGenerator, ScalaLanguage }
import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader }
import dev.guardrail.generators.syntax._
import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName }
import dev.guardrail.terms.framework.FrameworkTerms
import dev.guardrail.terms.protocol.PropertyRequirement
Expand Down Expand Up @@ -1353,7 +1354,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
} else {
None
}
params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value))
params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value)))

terms = params.map(_.term)

Expand Down Expand Up @@ -1382,6 +1383,28 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa

} yield code

private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] =
for {
reduced <- params
.groupBy(_.name)
.toList
.traverse { case (k, xs) =>
implicit val ord: Ordering[PropertyRequirement] = {
case (PropertyRequirement.Required, _) => -1
case (_, PropertyRequirement.Required) => 1
case _ => 0
}
xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match {
case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}")
case x :: Nil => Target.pure((k, x))
case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x))
case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}")
}
}
.map(_.toMap)
names = params.map(_.name).distinct
} yield names.flatMap(n => reduced.get(n))

private def encodeModel(
clsName: String,
dtoPackage: List[String],
Expand All @@ -1390,9 +1413,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
) =
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams = parents.reverse.flatMap(_.params) ++ selfParams
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
(discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
typeName = Type.Name(clsName)
Expand Down Expand Up @@ -1457,9 +1480,9 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = {
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams = parents.reverse.flatMap(_.params) ++ selfParams
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
paramCount = params.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassPare
import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations
import dev.guardrail.generators.scala.{ CirceModelGenerator, CirceRefinedModelGenerator, ScalaLanguage }
import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader }
import dev.guardrail.generators.syntax._
import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName }
import dev.guardrail.terms.framework.FrameworkTerms
import dev.guardrail.terms.protocol._
Expand Down Expand Up @@ -1425,7 +1426,7 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
} else {
None
}
params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value))
params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value)))

terms = params.map(_.term)

Expand Down Expand Up @@ -1454,6 +1455,28 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,

} yield code

private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] =
for {
reduced <- params
.groupBy(_.name)
.toList
.traverse { case (k, xs) =>
implicit val ord: Ordering[PropertyRequirement] = {
case (PropertyRequirement.Required, _) => -1
case (_, PropertyRequirement.Required) => 1
case _ => 0
}
xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match {
case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}")
case x :: Nil => Target.pure((k, x))
case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x))
case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}")
}
}
.map(_.toMap)
names = params.map(_.name).distinct
} yield names.flatMap(n => reduced.get(n))

private def encodeModel(
clsName: String,
dtoPackage: List[String],
Expand All @@ -1462,9 +1485,9 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
) =
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams = parents.reverse.flatMap(_.params) ++ selfParams
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
(discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
typeName = Type.Name(clsName)
Expand Down Expand Up @@ -1529,9 +1552,9 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator,
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = {
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams = parents.reverse.flatMap(_.params) ++ selfParams
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
paramCount = params.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import dev.guardrail.generators.ProtocolGenerator.{ WrapEnumSchema, wrapNumberEn
import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassParent }
import dev.guardrail.generators.scala.{ JacksonModelGenerator, ScalaLanguage }
import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader }
import dev.guardrail.generators.syntax._
import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName }
import dev.guardrail.terms.framework.FrameworkTerms
import dev.guardrail.terms.protocol.PropertyRequirement.{ Optional, RequiredNullable }
Expand Down Expand Up @@ -1240,10 +1241,11 @@ class JacksonProtocolGenerator private extends ProtocolTerms[ScalaLanguage, Targ
NonEmptyList.ofInitLast(supportPackage :+ "Presence", "OptionNonMissingDeserializer")
)
allTerms = selfParams ++ parents.flatMap(_.params)

discriminatorNames = parents.flatMap(_.discriminators).map(_.propertyName).toSet
params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value)))
renderedClass = { // TODO: This logic should be reflowed. The scope and rebindings is due to a refactor where
// code from another dependent class was just copied in here wholesale.
val discriminatorNames = parents.flatMap(_.discriminators).map(_.propertyName).toSet
val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value))
val terms = params.map(_.term)

val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) {
Expand Down Expand Up @@ -1309,6 +1311,28 @@ class JacksonProtocolGenerator private extends ProtocolTerms[ScalaLanguage, Targ
} yield renderedClass
}

private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] =
for {
reduced <- params
.groupBy(_.name)
.toList
.traverse { case (k, xs) =>
implicit val ord: Ordering[PropertyRequirement] = {
case (PropertyRequirement.Required, _) => -1
case (_, PropertyRequirement.Required) => 1
case _ => 0
}
xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match {
case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}")
case x :: Nil => Target.pure((k, x))
case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x))
case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}")
}
}
.map(_.toMap)
names = params.map(_.name).distinct
} yield names.flatMap(n => reduced.get(n))

private def encodeModel(
clsName: String,
dtoPackage: List[String],
Expand Down

0 comments on commit 2bd94c4

Please sign in to comment.