Skip to content

Commit

Permalink
Reflow to be a for comprehension
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese committed Dec 26, 2023
1 parent d46e2ae commit 08e05c6
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1342,107 +1342,111 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
supportPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
) = {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val parentOpt = if (parents.exists(s => s.discriminators.nonEmpty)) {
parents.headOption
} else {
None
}
val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value))
) =
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
parentOpt =
if (parents.exists(s => s.discriminators.nonEmpty)) {
parents.headOption
} else {
None
}
params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value))

val terms = params.map(_.term)
terms = params.map(_.term)

val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) {
def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match {
case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()"
case _ => Lit.String("[redacted]")
}
toStringMethod =
if (params.exists(_.dataRedaction != DataVisible)) {
def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match {
case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()"
case _ => Lit.String("[redacted]")
}

val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(",")))
val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(",")))

List[Defn.Def](
q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}"
)
} else {
List.empty[Defn.Def]
}
List[Defn.Def](
q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}"
)
} else {
List.empty[Defn.Def]
}

val code = parentOpt
.fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent =>
q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a =>
init"${Type.Name(a)}(...$Nil)"
)} { ..$toStringMethod }"""
)
code = parentOpt
.fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent =>
q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a =>
init"${Type.Name(a)}(...$Nil)"
)} { ..$toStringMethod }"""
)

Target.pure(code)
}
} yield code

private def encodeModel(
clsName: String,
dtoPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
) = {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val allParams = parents.reverse.flatMap(_.params) ++ selfParams
val (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
val readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
val typeName = Type.Name(clsName)
val encVal = {
def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) =
q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))"""

def encodeRequired(param: ProtocolParameter[ScalaLanguage]) =
q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)"""

def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = {
val name = Lit.String(param.name.value)
q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))"
}
) =
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams = parents.reverse.flatMap(_.params) ++ selfParams
(discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
typeName = Type.Name(clsName)
encVal = {
def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) =
q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))"""

def encodeRequired(param: ProtocolParameter[ScalaLanguage]) =
q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)"""

def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = {
val name = Lit.String(param.name.value)
q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))"
}

val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param =>
val name = Lit.String(param.name.value)
param.propertyRequirement match {
case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy =>
Right(encodeRequired(param))
case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) =>
Right(encodeRequired(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, _) =>
Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""")
val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param =>
val name = Lit.String(param.name.value)
param.propertyRequirement match {
case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy =>
Right(encodeRequired(param))
case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) =>
Right(encodeRequired(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, _) =>
Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""")
}
}
}

val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName))
val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})"
val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) =>
q"$acc ++ $field"
}
val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName))
val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})"
val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) =>
q"$acc ++ $field"
}

q"""
q"""
${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields))
"""
}
val (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys =>
(
List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"),
encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))"
)
}
}
(readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys =>
(
List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"),
encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))"
)
}

Target.pure(Option(q"""
} yield Option(q"""
implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = {
..${readOnlyDefn};
${readOnlyFilter(encVal)}
}
"""))
}
""")

private def decodeModel(
clsName: String,
Expand All @@ -1451,13 +1455,14 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val allParams = parents.reverse.flatMap(_.params) ++ selfParams
val params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
val needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
val paramCount = params.length
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams = parents.reverse.flatMap(_.params) ++ selfParams
params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
paramCount = params.length
presence <- Lt.selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence"))
decVal <-
if (paramCount == 0) {
Expand Down
Loading

0 comments on commit 08e05c6

Please sign in to comment.