From 08e05c6389d4ed630f49aa9458f2d37dc5ec2f37 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 25 Dec 2023 22:55:30 -0800 Subject: [PATCH] Reflow to be a for comprehension --- .../scala/circe/CirceProtocolGenerator.scala | 173 +++++++++--------- .../circe/CirceRefinedProtocolGenerator.scala | 173 +++++++++--------- 2 files changed, 178 insertions(+), 168 deletions(-) diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala index ed7d3ea81..eb88e252b 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceProtocolGenerator.scala @@ -1342,107 +1342,111 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa supportPackage: List[String], selfParams: List[ProtocolParameter[ScalaLanguage]], parents: List[SuperClass[ScalaLanguage]] = Nil - ) = { - val discriminators = parents.flatMap(_.discriminators) - val discriminatorNames = discriminators.map(_.propertyName).toSet - val parentOpt = if (parents.exists(s => s.discriminators.nonEmpty)) { - parents.headOption - } else { - None - } - val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) + ) = + for { + () <- Target.pure(()) + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + parentOpt = + if (parents.exists(s => s.discriminators.nonEmpty)) { + parents.headOption + } else { + None + } + params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) - val terms = params.map(_.term) + terms = params.map(_.term) - val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) { - def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { - case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()" - case _ => Lit.String("[redacted]") - } + toStringMethod = + if (params.exists(_.dataRedaction != DataVisible)) { + def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { + case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()" + case _ => Lit.String("[redacted]") + } - val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(","))) + val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(","))) - List[Defn.Def]( - q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}" - ) - } else { - List.empty[Defn.Def] - } + List[Defn.Def]( + q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}" + ) + } else { + List.empty[Defn.Def] + } - val code = parentOpt - .fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent => - q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a => - init"${Type.Name(a)}(...$Nil)" - )} { ..$toStringMethod }""" - ) + code = parentOpt + .fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent => + q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a => + init"${Type.Name(a)}(...$Nil)" + )} { ..$toStringMethod }""" + ) - Target.pure(code) - } + } yield code private def encodeModel( clsName: String, dtoPackage: List[String], selfParams: List[ProtocolParameter[ScalaLanguage]], parents: List[SuperClass[ScalaLanguage]] = Nil - ) = { - val discriminators = parents.flatMap(_.discriminators) - val discriminatorNames = discriminators.map(_.propertyName).toSet - val allParams = parents.reverse.flatMap(_.params) ++ selfParams - val (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) - val readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList - val typeName = Type.Name(clsName) - val encVal = { - def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) = - q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))""" - - def encodeRequired(param: ProtocolParameter[ScalaLanguage]) = - q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)""" - - def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = { - val name = Lit.String(param.name.value) - q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))" - } + ) = + for { + () <- Target.pure(()) + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams = parents.reverse.flatMap(_.params) ++ selfParams + (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) + readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList + typeName = Type.Name(clsName) + encVal = { + def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) = + q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))""" + + def encodeRequired(param: ProtocolParameter[ScalaLanguage]) = + q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)""" + + def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = { + val name = Lit.String(param.name.value) + q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))" + } - val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param => - val name = Lit.String(param.name.value) - param.propertyRequirement match { - case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy => - Right(encodeRequired(param)) - case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable => - Left(encodeOptional(param)) - case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => - Left(encodeOptional(param)) - case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) => - Right(encodeRequired(param)) - case PropertyRequirement.Configured(PropertyRequirement.Optional, _) => - Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""") + val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param => + val name = Lit.String(param.name.value) + param.propertyRequirement match { + case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy => + Right(encodeRequired(param)) + case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable => + Left(encodeOptional(param)) + case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => + Left(encodeOptional(param)) + case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) => + Right(encodeRequired(param)) + case PropertyRequirement.Configured(PropertyRequirement.Optional, _) => + Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""") + } } - } - val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName)) - val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})" - val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) => - q"$acc ++ $field" - } + val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName)) + val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})" + val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) => + q"$acc ++ $field" + } - q""" + q""" ${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields)) """ - } - val (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys => - ( - List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"), - encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))" - ) - } + } + (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys => + ( + List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"), + encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))" + ) + } - Target.pure(Option(q""" + } yield Option(q""" implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = { ..${readOnlyDefn}; ${readOnlyFilter(encVal)} } - """)) - } + """) private def decodeModel( clsName: String, @@ -1451,13 +1455,14 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa selfParams: List[ProtocolParameter[ScalaLanguage]], parents: List[SuperClass[ScalaLanguage]] = Nil )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = { - val discriminators = parents.flatMap(_.discriminators) - val discriminatorNames = discriminators.map(_.propertyName).toSet - val allParams = parents.reverse.flatMap(_.params) ++ selfParams - val params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) - val needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) - val paramCount = params.length for { + () <- Target.pure(()) + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams = parents.reverse.flatMap(_.params) ++ selfParams + params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) + needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) + paramCount = params.length presence <- Lt.selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence")) decVal <- if (paramCount == 0) { diff --git a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala index 10622f624..0994fae3d 100644 --- a/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala +++ b/modules/scala-support/src/main/scala/dev/guardrail/generators/scala/circe/CirceRefinedProtocolGenerator.scala @@ -1414,107 +1414,111 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, supportPackage: List[String], selfParams: List[ProtocolParameter[ScalaLanguage]], parents: List[SuperClass[ScalaLanguage]] = Nil - ) = { - val discriminators = parents.flatMap(_.discriminators) - val discriminatorNames = discriminators.map(_.propertyName).toSet - val parentOpt = if (parents.exists(s => s.discriminators.nonEmpty)) { - parents.headOption - } else { - None - } - val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) + ) = + for { + () <- Target.pure(()) + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + parentOpt = + if (parents.exists(s => s.discriminators.nonEmpty)) { + parents.headOption + } else { + None + } + params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value)) - val terms = params.map(_.term) + terms = params.map(_.term) - val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) { - def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { - case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()" - case _ => Lit.String("[redacted]") - } + toStringMethod = + if (params.exists(_.dataRedaction != DataVisible)) { + def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match { + case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()" + case _ => Lit.String("[redacted]") + } - val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(","))) + val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(","))) - List[Defn.Def]( - q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}" - ) - } else { - List.empty[Defn.Def] - } + List[Defn.Def]( + q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}" + ) + } else { + List.empty[Defn.Def] + } - val code = parentOpt - .fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent => - q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a => - init"${Type.Name(a)}(...$Nil)" - )} { ..$toStringMethod }""" - ) + code = parentOpt + .fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent => + q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a => + init"${Type.Name(a)}(...$Nil)" + )} { ..$toStringMethod }""" + ) - Target.pure(code) - } + } yield code private def encodeModel( clsName: String, dtoPackage: List[String], selfParams: List[ProtocolParameter[ScalaLanguage]], parents: List[SuperClass[ScalaLanguage]] = Nil - ) = { - val discriminators = parents.flatMap(_.discriminators) - val discriminatorNames = discriminators.map(_.propertyName).toSet - val allParams = parents.reverse.flatMap(_.params) ++ selfParams - val (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) - val readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList - val typeName = Type.Name(clsName) - val encVal = { - def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) = - q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))""" - - def encodeRequired(param: ProtocolParameter[ScalaLanguage]) = - q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)""" - - def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = { - val name = Lit.String(param.name.value) - q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))" - } + ) = + for { + () <- Target.pure(()) + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams = parents.reverse.flatMap(_.params) ++ selfParams + (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value)) + readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList + typeName = Type.Name(clsName) + encVal = { + def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) = + q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))""" + + def encodeRequired(param: ProtocolParameter[ScalaLanguage]) = + q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)""" + + def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = { + val name = Lit.String(param.name.value) + q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))" + } - val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param => - val name = Lit.String(param.name.value) - param.propertyRequirement match { - case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy => - Right(encodeRequired(param)) - case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable => - Left(encodeOptional(param)) - case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => - Left(encodeOptional(param)) - case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) => - Right(encodeRequired(param)) - case PropertyRequirement.Configured(PropertyRequirement.Optional, _) => - Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""") + val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param => + val name = Lit.String(param.name.value) + param.propertyRequirement match { + case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy => + Right(encodeRequired(param)) + case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable => + Left(encodeOptional(param)) + case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) => + Left(encodeOptional(param)) + case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) => + Right(encodeRequired(param)) + case PropertyRequirement.Configured(PropertyRequirement.Optional, _) => + Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""") + } } - } - val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName)) - val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})" - val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) => - q"$acc ++ $field" - } + val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName)) + val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})" + val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) => + q"$acc ++ $field" + } - q""" + q""" ${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields)) """ - } - val (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys => - ( - List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"), - encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))" - ) - } + } + (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys => + ( + List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"), + encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))" + ) + } - Target.pure(Option(q""" + } yield Option(q""" implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = { ..${readOnlyDefn}; ${readOnlyFilter(encVal)} } - """)) - } + """) private def decodeModel( clsName: String, @@ -1523,13 +1527,14 @@ class CirceRefinedProtocolGenerator private (circeVersion: CirceModelGenerator, selfParams: List[ProtocolParameter[ScalaLanguage]], parents: List[SuperClass[ScalaLanguage]] = Nil )(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = { - val discriminators = parents.flatMap(_.discriminators) - val discriminatorNames = discriminators.map(_.propertyName).toSet - val allParams = parents.reverse.flatMap(_.params) ++ selfParams - val params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) - val needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) - val paramCount = params.length for { + () <- Target.pure(()) + discriminators = parents.flatMap(_.discriminators) + discriminatorNames = discriminators.map(_.propertyName).toSet + allParams = parents.reverse.flatMap(_.params) ++ selfParams + params = allParams.filterNot(param => discriminatorNames.contains(param.name.value)) + needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull) + paramCount = params.length presence <- Lt.selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence")) decVal <- if (paramCount == 0) {