From 1ecc8dfaecce605eb62d2e7334fb1ea30e8c23a2 Mon Sep 17 00:00:00 2001 From: Yoonjae Jeon Date: Tue, 15 Oct 2024 21:56:30 +0900 Subject: [PATCH] sqaush wip make scala2 works Delete some files Add more tests more tests scala3 writer fix polish clean up Add helper function scala3 macro readers fix bug test pass polish polish wip fix remove unused Use ListBuffer to preserve orders polish Keep existing macros for bincompat Keep formatting --- .../upickle/implicits/MacroImplicits.scala | 4 +- .../upickle/implicits/internal/Macros.scala | 7 +- .../upickle/implicits/internal/Macros2.scala | 577 ++++++++++++++++++ .../src-3/upickle/implicits/Readers.scala | 59 +- .../src-3/upickle/implicits/Writers.scala | 4 +- .../src-3/upickle/implicits/macros.scala | 380 +++++++++--- .../upickle/implicits/ObjectContexts.scala | 13 +- .../implicits/src/upickle/implicits/key.scala | 12 + upickle/test/src/upickle/FailureTests.scala | 8 + upickle/test/src/upickle/MacroTests.scala | 102 +++- 10 files changed, 1036 insertions(+), 130 deletions(-) create mode 100644 upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala diff --git a/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala b/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala index 92ac78f93..c83d11ed4 100644 --- a/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala +++ b/upickle/implicits/src-2/upickle/implicits/MacroImplicits.scala @@ -43,7 +43,7 @@ trait MacroImplicits extends MacrosCommon { this: upickle.core.Types => def macroW[T]: Writer[T] = macro MacroImplicits.applyW[T] def macroRW[T]: ReadWriter[T] = macro MacroImplicits.applyRW[ReadWriter[T]] - def macroR0[T, M[_]]: Reader[T] = macro internal.Macros.macroRImpl[T, M] - def macroW0[T, M[_]]: Writer[T] = macro internal.Macros.macroWImpl[T, M] + def macroR0[T, M[_]]: Reader[T] = macro internal.Macros2.macroRImpl[T, M] + def macroW0[T, M[_]]: Writer[T] = macro internal.Macros2.macroWImpl[T, M] } diff --git a/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala b/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala index 52350838f..3236716c3 100644 --- a/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala +++ b/upickle/implicits/src-2/upickle/implicits/internal/Macros.scala @@ -10,6 +10,11 @@ import upickle.implicits.{MacrosCommon, key} import language.higherKinds import language.existentials +/** + * This file is deprecated and remained here for binary compatibility. + * Please use upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala instead. + */ + /** * Implementation of macros used by uPickle to serialize and deserialize * case classes automatically. You probably shouldn't need to use these @@ -177,7 +182,7 @@ object Macros { t.substituteTypes(typeParams, concrete) } else { - val TypeRef(pref, sym, _) = typeOf[Seq[Int]] + val TypeRef(pref, sym, args) = typeOf[Seq[Int]] import compat._ TypeRef(pref, sym, t.asInstanceOf[TypeRef].args) } diff --git a/upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala b/upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala new file mode 100644 index 000000000..0682adb92 --- /dev/null +++ b/upickle/implicits/src-2/upickle/implicits/internal/Macros2.scala @@ -0,0 +1,577 @@ +package upickle.implicits.internal + +import scala.annotation.{nowarn, StaticAnnotation} +import scala.language.experimental.macros +import compat._ + +import acyclic.file +import upickle.core.Annotator +import upickle.implicits.{MacrosCommon, flatten, key} +import language.higherKinds +import language.existentials + +/** + * Implementation of macros used by uPickle to serialize and deserialize + * case classes automatically. You probably shouldn't need to use these + * directly, since they are called implicitly when trying to read/write + * types you don't have a Reader/Writer in scope for. + */ +@nowarn("cat=deprecation") +object Macros2 { + + trait DeriveDefaults[M[_]] { + val c: scala.reflect.macros.blackbox.Context + private def fail(tpe: c.Type, s: String) = c.abort(c.enclosingPosition, s) + + import c.universe._ + private def companionTree(tpe: c.Type): Tree = { + val companionSymbol = tpe.typeSymbol.companionSymbol + + if (companionSymbol == NoSymbol && tpe.typeSymbol.isClass) { + val clsSymbol = tpe.typeSymbol.asClass + val msg = "[error] The companion symbol could not be determined for " + + s"[[${clsSymbol.name}]]. This may be due to a bug in scalac (SI-7567) " + + "that arises when a case class within a function is upickle. As a " + + "workaround, move the declaration to the module-level." + fail(tpe, msg) + } else { + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val pre = tpe.asInstanceOf[symTab.Type].prefix.asInstanceOf[Type] + c.universe.treeBuild.mkAttributedRef(pre, companionSymbol) + } + + } + + /** + * If a super-type is generic, find all the subtypes, but at the same time + * fill in all the generic type parameters that are based on the super-type's + * concrete type + */ + private def fleshedOutSubtypes(tpe: Type) = { + for{ + subtypeSym <- tpe.typeSymbol.asClass.knownDirectSubclasses.filter(!_.toString.contains("")) + if subtypeSym.isType + st = subtypeSym.asType.toType + baseClsArgs = st.baseType(tpe.typeSymbol).asInstanceOf[TypeRef].args + } yield { + tpe match{ + case ExistentialType(_, TypeRef(pre, sym, args)) => + st.substituteTypes(baseClsArgs.map(_.typeSymbol), args) + case ExistentialType(_, _) => st + case TypeRef(pre, sym, args) => + st.substituteTypes(baseClsArgs.map(_.typeSymbol), args) + } + } + } + + private def deriveObject(tpe: c.Type) = { + val mod = tpe.typeSymbol.asClass.module + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val pre = tpe.asInstanceOf[symTab.Type].prefix.asInstanceOf[Type] + val mod2 = c.universe.treeBuild.mkAttributedRef(pre, mod) + + annotate(tpe)(wrapObject(mod2)) + + } + + private[upickle] def mergeTrait(tagKey: Option[String], subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree + + private[upickle] def derive(tpe: c.Type) = { + if (tpe.typeSymbol.asClass.isTrait || (tpe.typeSymbol.asClass.isAbstractClass && !tpe.typeSymbol.isJava)) { + val derived = deriveTrait(tpe) + derived + } + else if (tpe.typeSymbol.isModuleClass) deriveObject(tpe) + else deriveClass(tpe) + } + + private def deriveTrait(tpe: c.Type): c.universe.Tree = { + val clsSymbol = tpe.typeSymbol.asClass + + if (!clsSymbol.isSealed) { + fail(tpe, s"[error] The referenced trait [[${clsSymbol.name}]] must be sealed.") + }else if (clsSymbol.knownDirectSubclasses.filter(!_.toString.contains("")).isEmpty) { + val msg = + s"The referenced trait [[${clsSymbol.name}]] does not have any sub-classes. This may " + + "happen due to a limitation of scalac (SI-7046). To work around this, " + + "try manually specifying the sealed trait picklers as described in " + + "https://com-lihaoyi.github.io/upickle/#ManualSealedTraitPicklers" + fail(tpe, msg) + }else{ + val tagKey = customKey(clsSymbol) + val subTypes = fleshedOutSubtypes(tpe).toSeq.sortBy(_.typeSymbol.fullName) + // println("deriveTrait") + val subDerives = subTypes.map(subCls => q"implicitly[${typeclassFor(subCls)}]") + // println(Console.GREEN + "subDerives " + Console.RESET + subDrivess) + val merged = mergeTrait(tagKey, subDerives, subTypes, tpe) + merged + } + } + + private[upickle] def typeclass: c.WeakTypeTag[M[_]] + + private def typeclassFor(t: Type) = { + // println("typeclassFor " + weakTypeOf[M[_]](typeclass)) + + weakTypeOf[M[_]](typeclass) match { + case TypeRef(a, b, _) => + import compat._ + TypeRef(a, b, List(t)) + case ExistentialType(_, TypeRef(a, b, _)) => + import compat._ + TypeRef(a, b, List(t)) + case x => + println("Dunno Wad Dis Typeclazz Is " + x) + println(x) + println(x.getClass) + ??? + } + } + + sealed trait Flatten + + object Flatten { + case class Class(companion: Tree, fields: List[Field], varArgs: Boolean) extends Flatten + case object Map extends Flatten + case object None extends Flatten + } + + case class Field( + name: String, + mappedName: String, + tpe: Type, + symbol: Symbol, + defaultValue: Option[Tree], + flatten: Flatten, + ) { + lazy val allFields: List[Field] = { + def loop(field: Field): List[Field] = + field.flatten match { + case Flatten.Class(_, fields, _) => fields.flatMap(loop) + case Flatten.Map => List(field) + case Flatten.None => List(field) + } + loop(this) + } + } + + private def getFields(tpe: c.Type): (c.Tree, List[Field], Boolean) = { + def applyTypeArguments(t: c.Type ): c.Type = { + val typeParams = tpe.typeSymbol.asClass.typeParams + val typeArguments = tpe.normalize.asInstanceOf[TypeRef].args + if (t.typeSymbol != definitions.RepeatedParamClass) { + t.substituteTypes(typeParams, typeArguments) + } else { + val TypeRef(pref, sym, _) = typeOf[Seq[Int]] + internal.typeRef(pref, sym, t.asInstanceOf[TypeRef].args) + } + } + + val companion = companionTree(tpe) + //tickle the companion members -- Not doing this leads to unexpected runtime behavior + //I wonder if there is an SI related to this? + companion.tpe.members.foreach(_ => ()) + tpe.members.find(x => x.isMethod && x.asMethod.isPrimaryConstructor) match { + case None => fail(tpe, "Can't find primary constructor of " + tpe) + case Some(primaryConstructor) => + val params = primaryConstructor.asMethod.paramLists.flatten + val varArgs = params.lastOption.exists(_.typeSignature.typeSymbol == definitions.RepeatedParamClass) + val fields = params.zipWithIndex.map { case (param, i) => + val name = param.name.decodedName.toString + val mappedName = customKey(param).getOrElse(name) + val tpeOfField = applyTypeArguments(param.typeSignature) + val defaultValue = if (param.asTerm.isParamWithDefault) + Some(q"$companion.${TermName("apply$default$" + (i + 1))}") + else + None + val flatten = param.annotations.find(_.tree.tpe =:= typeOf[flatten]) match { + case Some(_) => + if (tpeOfField.typeSymbol == typeOf[collection.immutable.Map[_, _]].typeSymbol) Flatten.Map + else if (tpeOfField.typeSymbol.isClass && tpeOfField.typeSymbol.asClass.isCaseClass) { + val (nestedCompanion, fields, nestedVarArgs) = getFields(tpeOfField) + Flatten.Class(nestedCompanion, fields, nestedVarArgs) + } + else fail(tpeOfField, + s"""Invalid type for flattening: $tpeOfField. + | Flatten only works on case classes and Maps""".stripMargin) + case None => + Flatten.None + } + Field(param.name.toString, mappedName, tpeOfField, param, defaultValue, flatten) + } + (companion, fields, varArgs) + } + } + + private def deriveClass(tpe: c.Type) = { + val (companion, fields, varArgs) = getFields(tpe) + // According to @retronym, this is necessary in order to force the + // default argument `apply$default$n` methods to be synthesized + companion.tpe.member(TermName("apply")).info + + val allFields = fields.flatMap(_.allFields) + validateFlattenAnnotation(allFields) + + val derive = + // Otherwise, reading and writing are kinda identical + wrapCaseN( + companion, + fields, + varArgs, + targetType = tpe, + ) + + annotate(tpe)(derive) + } + + private def validateFlattenAnnotation(fields: List[Field]): Unit = { + if (fields.count(_.flatten == Flatten.Map) > 1) { + fail(NoType, "Only one Map can be annotated with @upickle.implicits.flatten in the same level") + } + if (fields.map(_.mappedName).distinct.length != fields.length) { + fail(NoType, "There are multiple fields with the same key") + } + if (fields.exists(field => field.flatten == Flatten.Map && !(field.tpe <:< typeOf[Map[String, _]]))) { + fail(NoType, "The key type of a Map annotated with @flatten must be String.") + } + } + + /** If there is a sealed base class, annotate the derived tree in the JSON + * representation with a class label. + */ + private def annotate(tpe: c.Type)(derived: c.universe.Tree) = { + val sealedParents = tpe.baseClasses.filter(_.asClass.isSealed) + + if (sealedParents.isEmpty) derived + else { + val tagKey = MacrosCommon.tagKeyFromParents( + tpe.typeSymbol.name.toString, + sealedParents, + customKey, + (_: c.Symbol).name.toString, + fail(tpe, _), + ) + + val sealedClassSymbol: Option[Symbol] = sealedParents.find(_ == tpe.typeSymbol) + val segments = + sealedClassSymbol.toList.map(_.fullName.split('.')) ++ + sealedParents + .flatMap(_.asClass.knownDirectSubclasses) + .map(_.fullName.split('.')) + + + // -1 because even if there is only one subclass, and so no name segments + // are needed to differentiate between them, we want to keep at least + // the rightmost name segment + val identicalSegmentCount = Range(0, segments.map(_.length).max - 1) + .takeWhile(i => segments.map(_.lift(i)).distinct.size == 1) + .length + + val tagValue = customKey(tpe.typeSymbol) + .getOrElse(TypeName(tpe.typeSymbol.fullName).decodedName.toString) + + val shortTagValue = customKey(tpe.typeSymbol) + .getOrElse( + TypeName( + tpe.typeSymbol.fullName.split('.').drop(identicalSegmentCount).mkString(".") + ).decodedName.toString + ) + + val tagKeyExpr = tagKey match { + case Some(v) => q"$v" + case None => q"${c.prefix}.tagName" + } + q"${c.prefix}.annotate($derived, $tagKeyExpr, $tagValue, $shortTagValue)" + } + } + + private def customKey(sym: c.Symbol): Option[String] = { + sym.annotations + .find(_.tpe == typeOf[key]) + .flatMap(_.scalaArgs.headOption) + .map{case Literal(Constant(s)) => s.toString} + } + + private[upickle] def serializeDefaults(sym: c.Symbol): Option[Boolean] = { + sym.annotations + .find(_.tpe == typeOf[upickle.implicits.serializeDefaults]) + .flatMap(_.scalaArgs.headOption) + .map{case Literal(Constant(s)) => s.asInstanceOf[Boolean]} + } + + private[upickle] def wrapObject(obj: Tree): Tree + + private[upickle] def wrapCaseN(companion: Tree, + fields: List[Field], + varargs: Boolean, + targetType: c.Type): Tree + } + + abstract class Reading[M[_]] extends DeriveDefaults[M] { + val c: scala.reflect.macros.blackbox.Context + import c.universe._ + def wrapObject(t: c.Tree) = q"new ${c.prefix}.SingletonReader($t)" + + def wrapCaseN(companion: c.universe.Tree, fields: List[Field], varargs: Boolean, targetType: c.Type): c.universe.Tree = { + val allowUnknownKeysAnnotation = targetType.typeSymbol + .annotations + .find(_.tree.tpe == typeOf[upickle.implicits.allowUnknownKeys]) + .flatMap(_.tree.children.tail.headOption) + .map { case Literal(Constant(b)) => b.asInstanceOf[Boolean] } + + val allFields = fields.flatMap(_.allFields).toArray.filter(_.flatten != Flatten.Map) + val (hasFlattenOnMap, valueTypeOfMap) = fields.flatMap(_.allFields).find(_.flatten == Flatten.Map) match { + case Some(f) => + val TypeRef(_, _, _ :: valueType :: Nil) = f.tpe + (true, valueType) + case None => (false, NoType) + } + val numberOfFields = allFields.length + val (localReaders, aggregates) = allFields.zipWithIndex.map { case (_, idx) => + (TermName(s"localReader$idx"), TermName(s"aggregated$idx")) + }.unzip + + val fieldToId = allFields.zipWithIndex.toMap + def constructClass(companion: c.universe.Tree, fields: List[Field], varargs: Boolean): c.universe.Tree = + q""" + $companion.apply( + ..${ + fields.map { field => + field.flatten match { + case Flatten.Class(c, f, v) => constructClass(c, f, v) + case Flatten.Map => + val termName = TermName(s"aggregatedMap") + q"$termName.toMap" + case Flatten.None => + val idx = fieldToId(field) + val termName = TermName(s"aggregated$idx") + if (field == fields.last && varargs) q"$termName:_*" + else q"$termName" + } + } + } + ) + """ + + q""" + ..${ + for (i <- allFields.indices) + yield q"private[this] lazy val ${localReaders(i)} = implicitly[${c.prefix}.Reader[${allFields(i).tpe}]]" + } + ..${ + if (hasFlattenOnMap) + List( + q"private[this] lazy val localReaderMap = implicitly[${c.prefix}.Reader[$valueTypeOfMap]]", + ) + else Nil + } + new ${c.prefix}.CaseClassReader[$targetType] { + override def visitObject(length: Int, jsonableKeys: Boolean, index: Int) = new ${if (numberOfFields <= 64) tq"_root_.upickle.implicits.CaseObjectContext[$targetType]" else tq"_root_.upickle.implicits.HugeCaseObjectContext[$targetType]"}(${numberOfFields}) { + ..${ + for (i <- allFields.indices) + yield q"private[this] var ${aggregates(i)}: ${allFields(i).tpe} = _" + } + ..${ + if (hasFlattenOnMap) + List( + q"private[this] lazy val aggregatedMap: scala.collection.mutable.ListBuffer[(String, $valueTypeOfMap)] = scala.collection.mutable.ListBuffer.empty", + ) + else Nil + } + + def storeAggregatedValue(currentIndex: Int, v: Any): Unit = currentIndex match { + case ..${ + for (i <- aggregates.indices) + yield cq"$i => ${aggregates(i)} = v.asInstanceOf[${allFields(i).tpe}]" + } + case ..${ + if (hasFlattenOnMap) + List(cq"-1 => aggregatedMap += currentKey -> v.asInstanceOf[$valueTypeOfMap]") + else Nil + } + case _ => throw new java.lang.IndexOutOfBoundsException(currentIndex.toString) + } + + def visitKeyValue(s: Any) = { + storeToMap = false + currentKey = ${c.prefix}.objectAttributeKeyReadMap(s.toString).toString + currentIndex = currentKey match { + case ..${ + for (i <- allFields.indices) + yield cq"${allFields(i).mappedName} => $i" + } + case _ => + ${ + (allowUnknownKeysAnnotation, hasFlattenOnMap) match { + case (_, true) => q"storeToMap = true; -1" + case (None, false) => + q""" + if (${ c.prefix }.allowUnknownKeys) -1 + else throw new _root_.upickle.core.Abort("Unknown Key: " + s.toString) + """ + case (Some(false), false) => q"""throw new _root_.upickle.core.Abort(" Unknown Key: " + s.toString)""" + case (Some(true), false) => q"-1" + } + } + } + } + + def visitEnd(index: Int) = { + ..${ + for(i <- allFields.indices if allFields(i).defaultValue.isDefined) + yield q"this.storeValueIfNotFound($i, ${allFields(i).defaultValue.get})" + } + + // Special-case 64 because java bit shifting ignores any RHS values above 63 + // https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19 + if (${ + if (numberOfFields <= 64) q"this.checkErrorMissingKeys(${if (numberOfFields == 64) -1 else (1L << numberOfFields) - 1})" + else q"this.checkErrorMissingKeys(${numberOfFields})" + }) { + this.errorMissingKeys(${numberOfFields}, ${allFields.map(_.mappedName)}) + } + + ${constructClass(companion, fields, varargs)} + } + + def subVisitor: _root_.upickle.core.Visitor[_, _] = currentIndex match { + case -1 => + ${ + if (hasFlattenOnMap) + q"localReaderMap" + else + q"_root_.upickle.core.NoOpVisitor" + } + case ..${ + for (i <- allFields.indices) + yield cq"$i => ${localReaders(i)} " + } + case _ => throw new java.lang.IndexOutOfBoundsException(currentIndex.toString) + } + } + } + """ + } + + override def mergeTrait(tagKey: Option[String], subtrees: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = { + val tagKeyExpr = tagKey match { + case Some(v) => q"$v" + case None => q"${c.prefix}.tagName" + } + q"${c.prefix}.Reader.merge[$targetType]($tagKeyExpr, ..$subtrees)" + } + } + + abstract class Writing[M[_]] extends DeriveDefaults[M] { + val c: scala.reflect.macros.blackbox.Context + import c.universe._ + def wrapObject(obj: c.Tree) = q"new ${c.prefix}.SingletonWriter($obj)" + + def internal = q"${c.prefix}.Internal" + + def wrapCaseN(companion: c.universe.Tree, fields: List[Field], varargs: Boolean, targetType: c.Type): c.universe.Tree = { + def serDfltVals(field: Field) = { + val b: Option[Boolean] = serializeDefaults(field.symbol).orElse(serializeDefaults(targetType.typeSymbol)) + b match { + case Some(b) => q"${b}" + case None => q"${c.prefix}.serializeDefaults" + } + } + + def write(field: Field, outer: c.universe.Tree): List[c.universe.Tree] = { + val select = Select(outer, TermName(field.name)) + field.flatten match { + case Flatten.Class(_, fields, _) => + fields.flatMap(write(_, select)) + case Flatten.Map => + val TypeRef(_, _, _ :: valueType :: Nil) = field.tpe + q""" + $select.foreach { case (key, value) => + this.writeSnippetMappedName[R, $valueType]( + ctx, + key.toString, + implicitly[${c.prefix}.Writer[$valueType]], + value + ) + } + """ :: Nil + case Flatten.None => + val snippet = + q""" + this.writeSnippetMappedName[R, ${field.tpe}]( + ctx, + ${c.prefix}.objectAttributeKeyWriteMap(${field.mappedName}), + implicitly[${c.prefix}.Writer[${field.tpe}]], + $select + ) + """ + val default = if (field.defaultValue.isEmpty) snippet + else q"""if (${serDfltVals(field)} || $select != ${field.defaultValue.get}) $snippet""" + default :: Nil + } + } + + def getLength(field: Field, outer: c.universe.Tree): List[c.universe.Tree] = { + val select = Select(outer, TermName(field.name)) + field.flatten match { + case Flatten.Class(_, fields, _) => fields.flatMap(getLength(_, select)) + case Flatten.Map => q"${select}.size" :: Nil + case Flatten.None => + ( + if (field.defaultValue.isEmpty) q"1" + else q"""if (${serDfltVals(field)} || ${select} != ${field.defaultValue}.get) 1 else 0""" + ) :: Nil + } + } + + q""" + new ${c.prefix}.CaseClassWriter[$targetType]{ + def length(v: $targetType) = { + ${ + fields.flatMap(getLength(_, q"v")) + .foldLeft[Tree](q"0") { case (prev, next) => q"$prev + $next" } + } + } + override def write0[R](out: _root_.upickle.core.Visitor[_, R], v: $targetType): R = { + if (v == null) out.visitNull(-1) + else { + val ctx = out.visitObject(length(v), true, -1) + ..${fields.flatMap(write(_, q"v"))} + ctx.visitEnd(-1) + } + } + def writeToObject[R](ctx: _root_.upickle.core.ObjVisitor[_, R], + v: $targetType): Unit = { + ..${fields.flatMap(write(_, q"v"))} + } + } + """ + } + + override def mergeTrait(tagKey: Option[String], subtree: Seq[Tree], subtypes: Seq[Type], targetType: c.Type): Tree = { + q"${c.prefix}.Writer.merge[$targetType](..$subtree)" + } + } + def macroRImpl[T, R[_]](c0: scala.reflect.macros.blackbox.Context) + (implicit e1: c0.WeakTypeTag[T], e2: c0.WeakTypeTag[R[_]]): c0.Expr[R[T]] = { + import c0.universe._ + val res = new Reading[R]{ + val c: c0.type = c0 + def typeclass = e2 + }.derive(e1.tpe) +// println(c0.universe.showCode(res)) + c0.Expr[R[T]](res) + } + + def macroWImpl[T, W[_]](c0: scala.reflect.macros.blackbox.Context) + (implicit e1: c0.WeakTypeTag[T], e2: c0.WeakTypeTag[W[_]]): c0.Expr[W[T]] = { + import c0.universe._ + val res = new Writing[W]{ + val c: c0.type = c0 + def typeclass = e2 + }.derive(e1.tpe) +// println(c0.universe.showCode(res)) + c0.Expr[W[T]](res) + } +} + diff --git a/upickle/implicits/src-3/upickle/implicits/Readers.scala b/upickle/implicits/src-3/upickle/implicits/Readers.scala index 7540cbbe8..aa93f4067 100644 --- a/upickle/implicits/src-3/upickle/implicits/Readers.scala +++ b/upickle/implicits/src-3/upickle/implicits/Readers.scala @@ -5,6 +5,7 @@ import deriving.Mirror import scala.util.NotGiven import upickle.core.{Annotator, ObjVisitor, Visitor, Abort, CurrentlyDeriving} import upickle.implicits.BaseCaseObjectContext +import scala.collection.mutable trait ReadersVersionSpecific extends MacrosCommon @@ -15,28 +16,46 @@ trait ReadersVersionSpecific abstract class CaseClassReader3[T](paramCount: Int, missingKeyCount: Long, allowUnknownKeys: Boolean, - construct: Array[Any] => T) extends CaseClassReader[T] { + construct: (Array[Any], scala.collection.mutable.Map[String, Any]) => T) extends CaseClassReader[T] { - def visitors0: Product - lazy val visitors = visitors0 - def fromProduct(p: Product): T + def visitors0: (AnyRef, Array[AnyRef]) + lazy val (visitorMap, visitors) = visitors0 + lazy val hasFlattenOnMap = visitorMap ne null def keyToIndex(x: String): Int def allKeysArray: Array[String] def storeDefaults(x: upickle.implicits.BaseCaseObjectContext): Unit trait ObjectContext extends ObjVisitor[Any, T] with BaseCaseObjectContext{ private val params = new Array[Any](paramCount) - - def storeAggregatedValue(currentIndex: Int, v: Any): Unit = params(currentIndex) = v + private val map = scala.collection.mutable.Map.empty[String, Any] + + def storeAggregatedValue(currentIndex: Int, v: Any): Unit = + if (currentIndex == -1) { + if (storeToMap) { + map(currentKey) = v + } + } else { + params(currentIndex) = v + } def subVisitor: Visitor[_, _] = - if (currentIndex == -1) upickle.core.NoOpVisitor - else visitors.productElement(currentIndex).asInstanceOf[Visitor[_, _]] + if (currentIndex == -1) { + if (hasFlattenOnMap) visitorMap.asInstanceOf[Visitor[_, _]] + else upickle.core.NoOpVisitor + } + else { + visitors(currentIndex).asInstanceOf[Visitor[_, _]] + } def visitKeyValue(v: Any): Unit = - val k = objectAttributeKeyReadMap(v.toString).toString - currentIndex = keyToIndex(k) - if (currentIndex == -1 && !allowUnknownKeys) { - throw new upickle.core.Abort("Unknown Key: " + k.toString) + storeToMap = false + currentKey = objectAttributeKeyReadMap(v.toString).toString + currentIndex = keyToIndex(currentKey) + if (currentIndex == -1) { + if (hasFlattenOnMap) { + storeToMap = true + } else if (!allowUnknownKeys) { + throw new upickle.core.Abort("Unknown Key: " + currentKey.toString) + } } def visitEnd(index: Int): T = @@ -47,7 +66,7 @@ trait ReadersVersionSpecific if (this.checkErrorMissingKeys(missingKeyCount)) this.errorMissingKeys(paramCount, allKeysArray) - construct(params) + construct(params, map) } override def visitObject(length: Int, jsonableKeys: Boolean, @@ -58,16 +77,18 @@ trait ReadersVersionSpecific inline def macroR[T](using m: Mirror.Of[T]): Reader[T] = inline m match { case m: Mirror.ProductOf[T] => + macros.validateFlattenAnnotation[T]() + val paramCount = macros.paramsCount[T] val reader = new CaseClassReader3[T]( - macros.paramsCount[T], - macros.checkErrorMissingKeysCount[T](), + paramCount, + if (paramCount <= 64) if (paramCount == 64) -1 else (1L << paramCount) - 1 + else paramCount, macros.extractIgnoreUnknownKeys[T]().headOption.getOrElse(this.allowUnknownKeys), - params => macros.applyConstructor[T](params) + (params: Array[Any], map :scala.collection.mutable.Map[String ,Any]) => macros.applyConstructor[T](params, map) ){ - override def visitors0 = compiletime.summonAll[Tuple.Map[m.MirroredElemTypes, Reader]] - override def fromProduct(p: Product): T = m.fromProduct(p) + override def visitors0 = macros.allReaders[T, Reader] override def keyToIndex(x: String): Int = macros.keyToIndex[T](x) - override def allKeysArray = macros.fieldLabels[T].map(_._2).toArray + override def allKeysArray = macros.allFieldsMappedName[T].toArray override def storeDefaults(x: upickle.implicits.BaseCaseObjectContext): Unit = macros.storeDefaults[T](x) } diff --git a/upickle/implicits/src-3/upickle/implicits/Writers.scala b/upickle/implicits/src-3/upickle/implicits/Writers.scala index 9cdae09d1..db2791446 100644 --- a/upickle/implicits/src-3/upickle/implicits/Writers.scala +++ b/upickle/implicits/src-3/upickle/implicits/Writers.scala @@ -23,7 +23,7 @@ trait WritersVersionSpecific if (v == null) out.visitNull(-1) else { val ctx = out.visitObject(length(v), true, -1) - macros.writeSnippets[R, T, Tuple.Map[m.MirroredElemTypes, Writer]]( + macros.writeSnippets[R, T, Writer]( outerThis, this, v, @@ -34,7 +34,7 @@ trait WritersVersionSpecific } def writeToObject[R](ctx: _root_.upickle.core.ObjVisitor[_, R], v: T): Unit = - macros.writeSnippets[R, T, Tuple.Map[m.MirroredElemTypes, Writer]]( + macros.writeSnippets[R, T, Writer]( outerThis, this, v, diff --git a/upickle/implicits/src-3/upickle/implicits/macros.scala b/upickle/implicits/src-3/upickle/implicits/macros.scala index 50d50ee9d..0a65eba9c 100644 --- a/upickle/implicits/src-3/upickle/implicits/macros.scala +++ b/upickle/implicits/src-3/upickle/implicits/macros.scala @@ -3,11 +3,13 @@ package upickle.implicits.macros import scala.quoted.{ given, _ } import deriving._, compiletime._ import upickle.implicits.{MacrosCommon, ReadersVersionSpecific} -type IsInt[A <: Int] = A def getDefaultParamsImpl0[T](using Quotes, Type[T]): Map[String, Expr[AnyRef]] = import quotes.reflect._ - val unwrapped = TypeRepr.of[T] match{case AppliedType(p, v) => p case t => t} + val unwrapped = TypeRepr.of[T] match { + case AppliedType(p, v) => p + case t => t + } val sym = unwrapped.typeSymbol if (!sym.isClassDef) Map.empty @@ -60,27 +62,109 @@ def extractIgnoreUnknownKeysImpl[T](using Quotes, Type[T]): Expr[List[Boolean]] .toList ) +def extractFlatten[A](using Quotes)(sym: quotes.reflect.Symbol): Boolean = + import quotes.reflect._ + sym + .annotations + .exists(_.tpe =:= TypeRepr.of[upickle.implicits.flatten]) + inline def paramsCount[T]: Int = ${paramsCountImpl[T]} def paramsCountImpl[T](using Quotes, Type[T]) = { - Expr(fieldLabelsImpl0[T].size) + import quotes.reflect._ + val fields = allFields[T] + val count = fields.filter {case (_, _, _, _, flattenMap) => !flattenMap}.length + Expr(count) +} + +inline def allReaders[T, R[_]]: (AnyRef, Array[AnyRef]) = ${allReadersImpl[T, R]} +def allReadersImpl[T, R[_]](using Quotes, Type[T], Type[R]): Expr[(AnyRef, Array[AnyRef])] = { + import quotes.reflect._ + val fields = allFields[T] + val (readerMap, readers) = fields.partitionMap { case (_, _, tpe, _, isFlattenMap) => + if (isFlattenMap) { + val valueTpe = tpe.typeArgs(1) + val readerTpe = TypeRepr.of[R].appliedTo(valueTpe) + val reader = readerTpe.asType match { + case '[t] => '{summonInline[t].asInstanceOf[AnyRef]} + } + Left(reader) + } + else { + val readerTpe = TypeRepr.of[R].appliedTo(tpe) + val reader = readerTpe.asType match { + case '[t] => '{summonInline[t].asInstanceOf[AnyRef]} + } + Right(reader) + } + } + Expr.ofTuple( + ( + readerMap.headOption.getOrElse('{null}.asInstanceOf[Expr[AnyRef]]), + '{${Expr.ofList(readers)}.toArray}, + ) + ) +} + +inline def allFieldsMappedName[T]: List[String] = ${allFieldsMappedNameImpl[T]} +def allFieldsMappedNameImpl[T](using Quotes, Type[T]): Expr[List[String]] = { + import quotes.reflect._ + Expr(allFields[T].map { case (_, label, _, _, _) => label }) } inline def storeDefaults[T](inline x: upickle.implicits.BaseCaseObjectContext): Unit = ${storeDefaultsImpl[T]('x)} def storeDefaultsImpl[T](x: Expr[upickle.implicits.BaseCaseObjectContext])(using Quotes, Type[T]) = { import quotes.reflect.* - - val statements = fieldLabelsImpl0[T] + val statements = allFields[T] + .filter(!_._5) .zipWithIndex - .map { case ((rawLabel, label), i) => - val defaults = getDefaultParamsImpl0[T] - if (defaults.contains(label)) '{${x}.storeValueIfNotFound(${Expr(i)}, ${defaults(label)})} - else '{} + .map { case ((_, _, _, default, _), i) => + default match { + case Some(defaultValue) => '{${x}.storeValueIfNotFound(${Expr(i)}, ${defaultValue})} + case None => '{} + } } Expr.block(statements, '{}) } -inline def fieldLabels[T]: List[(String, String)] = ${fieldLabelsImpl[T]} +def allFields[T](using Quotes, Type[T]): List[(quotes.reflect.Symbol, String, quotes.reflect.TypeRepr, Option[Expr[Any]], Boolean)] = { + import quotes.reflect._ + + def loop(field: Symbol, label: String, classTypeRepr: TypeRepr, defaults: Map[String, Expr[Object]]): List[(Symbol, String, TypeRepr, Option[Expr[Any]], Boolean)] = { + val flatten = extractFlatten(field) + val substitutedTypeRepr = substituteTypeArgs(classTypeRepr, subsitituted = classTypeRepr.memberType(field)) + val typeSymbol = substitutedTypeRepr.typeSymbol + if (flatten) { + if (isMap(substitutedTypeRepr)) { + (field, label, substitutedTypeRepr, defaults.get(label), true) :: Nil + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + fieldLabelsImpl0[t] + .flatMap { case (rawLabel, label) => + val newDefaults = getDefaultParamsImpl0[t] + val newClassTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, newClassTypeRepr, newDefaults) + } + case _ => + report.errorAndAbort(s"Unsupported type $typeSymbol for flattening") + } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map") + } + else { + (field, label, substitutedTypeRepr, defaults.get(label), false) :: Nil + } + } + + fieldLabelsImpl0[T] + .flatMap{ (rawLabel, label) => + val defaults = getDefaultParamsImpl0[T] + val classTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, classTypeRepr, defaults) + } +} + def fieldLabelsImpl0[T](using Quotes, Type[T]): List[(quotes.reflect.Symbol, String)] = import quotes.reflect._ val fields: List[Symbol] = TypeRepr.of[T].typeSymbol @@ -96,16 +180,14 @@ def fieldLabelsImpl0[T](using Quotes, Type[T]): List[(quotes.reflect.Symbol, Str case None => (sym, sym.name) } -def fieldLabelsImpl[T](using Quotes, Type[T]): Expr[List[(String, String)]] = - Expr.ofList(fieldLabelsImpl0[T].map((a, b) => Expr((a.name, b)))) - inline def keyToIndex[T](inline x: String): Int = ${keyToIndexImpl[T]('x)} def keyToIndexImpl[T](x: Expr[String])(using Quotes, Type[T]): Expr[Int] = { import quotes.reflect.* + val fields = allFields[T].filter { case (_, _, _, _, isFlattenMap) => !isFlattenMap } val z = Match( x.asTerm, - fieldLabelsImpl0[T].map(_._2).zipWithIndex.map{(f, i) => - CaseDef(Literal(StringConstant(f)), None, Literal(IntConstant(i))) + fields.zipWithIndex.map{case ((_, label, _, _, _), i) => + CaseDef(Literal(StringConstant(label)), None, Literal(IntConstant(i))) } ++ Seq( CaseDef(Wildcard(), None, Literal(IntConstant(-1))) ) @@ -126,71 +208,138 @@ def serDfltVals(using quotes: Quotes)(thisOuter: Expr[upickle.core.Types with up case None => '{ ${ thisOuter }.serializeDefaults } } } + def writeLengthImpl[T](thisOuter: Expr[upickle.core.Types with upickle.implicits.MacrosCommon], v: Expr[T]) (using quotes: Quotes, t: Type[T]): Expr[Int] = import quotes.reflect.* + def loop(field: Symbol, label: String, classTypeRepr: TypeRepr, select: Select, defaults: Map[String, Expr[Object]]): List[Expr[Int]] = + val flatten = extractFlatten(field) + if (flatten) { + val subsitituted = substituteTypeArgs(classTypeRepr, subsitituted = classTypeRepr.memberType(field)) + val typeSymbol = subsitituted.typeSymbol + if (isMap(subsitituted)) { + List( + '{${select.asExprOf[Map[_, _]]}.size} + ) + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + fieldLabelsImpl0[t] + .flatMap { case (rawLabel, label) => + val newDefaults = getDefaultParamsImpl0[t] + val newSelect = Select.unique(select, rawLabel.name) + val newClassTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, newClassTypeRepr, newSelect, newDefaults) + } + case _ => + report.errorAndAbort("Unsupported type for flattening") + } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map") + } + else if (!defaults.contains(label)) List('{1}) + else { + val serDflt = serDfltVals(thisOuter, field, classTypeRepr.typeSymbol) + List( + '{if (${serDflt} || ${select.asExprOf[Any]} != ${defaults(label)}) 1 else 0} + ) + } + fieldLabelsImpl0[T] - .map{(rawLabel, label) => + .flatMap { (rawLabel, label) => val defaults = getDefaultParamsImpl0[T] - val select = Select.unique(v.asTerm, rawLabel.name).asExprOf[Any] - - if (!defaults.contains(label)) '{1} - else { - val serDflt = serDfltVals(thisOuter, rawLabel, TypeRepr.of[T].typeSymbol) - '{if (${serDflt} || ${select} != ${defaults(label)}) 1 else 0} - } + val select = Select.unique(v.asTerm, rawLabel.name) + val classTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, classTypeRepr, select, defaults) } .foldLeft('{0}) { case (prev, next) => '{$prev + $next} } -inline def checkErrorMissingKeysCount[T](): Long = - ${checkErrorMissingKeysCountImpl[T]()} - -def checkErrorMissingKeysCountImpl[T]()(using Quotes, Type[T]): Expr[Long] = - import quotes.reflect.* - val paramCount = fieldLabelsImpl0[T].size - if (paramCount <= 64) if (paramCount == 64) Expr(-1) else Expr((1L << paramCount) - 1) - else Expr(paramCount) - -inline def writeSnippets[R, T, WS <: Tuple](inline thisOuter: upickle.core.Types with upickle.implicits.MacrosCommon, +inline def writeSnippets[R, T, W[_]](inline thisOuter: upickle.core.Types with upickle.implicits.MacrosCommon, inline self: upickle.implicits.CaseClassReadWriters#CaseClassWriter[T], inline v: T, inline ctx: _root_.upickle.core.ObjVisitor[_, R]): Unit = - ${writeSnippetsImpl[R, T, WS]('thisOuter, 'self, 'v, 'ctx)} + ${writeSnippetsImpl[R, T, W]('thisOuter, 'self, 'v, 'ctx)} -def writeSnippetsImpl[R, T, WS <: Tuple](thisOuter: Expr[upickle.core.Types with upickle.implicits.MacrosCommon], +def writeSnippetsImpl[R, T, W[_]](thisOuter: Expr[upickle.core.Types with upickle.implicits.MacrosCommon], self: Expr[upickle.implicits.CaseClassReadWriters#CaseClassWriter[T]], v: Expr[T], ctx: Expr[_root_.upickle.core.ObjVisitor[_, R]]) - (using Quotes, Type[T], Type[R], Type[WS]): Expr[Unit] = + (using Quotes, Type[T], Type[R], Type[W]): Expr[Unit] = import quotes.reflect.* - Expr.block( - for (((rawLabel, label), i) <- fieldLabelsImpl0[T].zipWithIndex) yield { - - val tpe0 = TypeRepr.of[T].memberType(rawLabel).asType - tpe0 match - case '[tpe] => - val defaults = getDefaultParamsImpl0[T] - Literal(IntConstant(i)).tpe.asType match - case '[IsInt[index]] => - val select = Select.unique(v.asTerm, rawLabel.name).asExprOf[Any] - val snippet = '{ - ${self}.writeSnippetMappedName[R, tpe]( - ${ctx}, - ${thisOuter}.objectAttributeKeyWriteMap(${Expr(label)}), - summonInline[Tuple.Elem[WS, index]], - ${select}, - ) + def loop(field: Symbol, label: String, classTypeRepr: TypeRepr, select: Select, defaults: Map[String, Expr[Object]]): List[Expr[Any]] = + val flatten = extractFlatten(field) + val fieldTypeRepr = substituteTypeArgs(classTypeRepr, subsitituted = classTypeRepr.memberType(field)) + val typeSymbol = fieldTypeRepr.typeSymbol + if (flatten) { + if (isMap(fieldTypeRepr)) { + val (keyTpe0, valueTpe0) = fieldTypeRepr.typeArgs match { + case key :: value :: Nil => (key, value) + case _ => report.errorAndAbort(s"Unsupported type ${typeSymbol} for flattening", v.asTerm.pos) } - if (!defaults.contains(label)) snippet - else { - val serDflt = serDfltVals(thisOuter, rawLabel, TypeRepr.of[T].typeSymbol) - '{if ($serDflt || ${select} != ${defaults(label)}) $snippet} + val writerTpe0 = TypeRepr.of[W].appliedTo(valueTpe0) + (keyTpe0.asType, valueTpe0.asType, writerTpe0.asType) match { + case ('[keyTpe], '[valueTpe], '[writerTpe])=> + val snippet = '{ + ${select.asExprOf[Map[keyTpe, valueTpe]]}.foreach { (k, v) => + ${self}.writeSnippetMappedName[R, valueTpe]( + ${ctx}, + k.toString, + summonInline[writerTpe], + v, + ) + } + } + List(snippet) + } + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + fieldLabelsImpl0[t] + .flatMap { case (rawLabel, label) => + val newDefaults = getDefaultParamsImpl0[t] + val newSelect = Select.unique(select, rawLabel.name) + val newClassTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, newClassTypeRepr, newSelect, newDefaults) + } + case _ => + report.errorAndAbort("Unsupported type for flattening", v) } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map", v.asTerm.pos) + } + else { + val tpe0 = fieldTypeRepr + val writerTpe0 = TypeRepr.of[W].appliedTo(tpe0) + (tpe0.asType, writerTpe0.asType) match + case ('[tpe], '[writerTpe]) => + val snippet = '{ + ${self}.writeSnippetMappedName[R, tpe]( + ${ctx}, + ${thisOuter}.objectAttributeKeyWriteMap(${Expr(label)}), + summonInline[writerTpe], + ${select.asExprOf[Any]}, + ) + } + List( + if (!defaults.contains(label)) snippet + else { + val serDflt = serDfltVals(thisOuter, field, classTypeRepr.typeSymbol) + '{if ($serDflt || ${select.asExprOf[Any]} != ${defaults(label)}) $snippet} + } + ) + } - }, + Expr.block( + fieldLabelsImpl0[T] + .flatMap { (rawLabel, label) => + val defaults = getDefaultParamsImpl0[T] + val select = Select.unique(v.asTerm, rawLabel.name) + val classTypeRepr = TypeRepr.of[T] + loop(rawLabel, label, classTypeRepr, select, defaults) + }, '{()} ) @@ -221,11 +370,19 @@ def tagKeyImpl[T](using Quotes, Type[T])(thisOuter: Expr[upickle.core.Types with case None => '{${thisOuter}.tagName} } -inline def applyConstructor[T](params: Array[Any]): T = ${ applyConstructorImpl[T]('params) } -def applyConstructorImpl[T](using quotes: Quotes, t0: Type[T])(params: Expr[Array[Any]]): Expr[T] = +def substituteTypeArgs(using Quotes)(tpe: quotes.reflect.TypeRepr, subsitituted: quotes.reflect.TypeRepr): quotes.reflect.TypeRepr = { + import quotes.reflect._ + val constructorSym = tpe.typeSymbol.primaryConstructor + val constructorParamSymss = constructorSym.paramSymss + + val tparams0 = constructorParamSymss.flatten.filter(_.isType) + subsitituted.substituteTypes(tparams0 ,tpe.typeArgs) +} + +inline def applyConstructor[T](params: Array[Any], map: scala.collection.mutable.Map[String, Any]): T = ${ applyConstructorImpl[T]('params, 'map) } +def applyConstructorImpl[T](using quotes: Quotes, t0: Type[T])(params: Expr[Array[Any]], map: Expr[scala.collection.mutable.Map[String, Any]]): Expr[T] = import quotes.reflect._ - def apply(typeApply: Option[List[TypeRepr]]) = { - val tpe = TypeRepr.of[T] + def apply(tpe: TypeRepr, typeArgs: List[TypeRepr], offset: Int): (Term, Int) = { val companion: Symbol = tpe.classSymbol.get.companionModule val constructorSym = tpe.typeSymbol.primaryConstructor val constructorParamSymss = constructorSym.paramSymss @@ -233,39 +390,64 @@ def applyConstructorImpl[T](using quotes: Quotes, t0: Type[T])(params: Expr[Arra val (tparams0, params0) = constructorParamSymss.flatten.partition(_.isType) val constructorTpe = tpe.memberType(constructorSym).widen - val rhs = params0.zipWithIndex.map { - case (sym0, i) => - val lhs = '{$params(${ Expr(i) })} + val (rhs, nextOffset) = params0.foldLeft((List.empty[Term], offset)) { case ((terms, i), sym0) => val tpe0 = constructorTpe.memberType(sym0) - - typeApply.map(tps => tpe0.substituteTypes(tparams0, tps)).getOrElse(tpe0) match { - case AnnotatedType(AppliedType(base, Seq(arg)), x) - if x.tpe =:= defn.RepeatedAnnot.typeRef => - arg.asType match { - case '[t] => - Typed( - lhs.asTerm, - TypeTree.of(using AppliedType(defn.RepeatedParamClass.typeRef, List(arg)).asType) - ) + val appliedTpe = tpe0.substituteTypes(tparams0, typeArgs) + val typeSymbol = appliedTpe.typeSymbol + val flatten = extractFlatten(sym0) + if (flatten) { + if (isMap(appliedTpe)) { + val keyTpe0 = appliedTpe.typeArgs.head + val valueTpe0 = appliedTpe.typeArgs(1) + (keyTpe0.asType, valueTpe0.asType) match { + case ('[keyTpe], '[valueTpe]) => + val typedMap = '{${map}.asInstanceOf[collection.mutable.Map[keyTpe, valueTpe]]}.asTerm + val term = Select.unique(typedMap, "toMap") + (term :: terms, i) } - case tpe => - tpe.asType match { - case '[t] => '{ $lhs.asInstanceOf[t] }.asTerm + } + else if (isCaseClass(typeSymbol)) { + typeSymbol.typeRef.dealias.asType match { + case '[t] => + val newTpe = TypeRepr.of[t] + val (term, nextOffset) = newTpe match { + case t: AppliedType => apply(newTpe, t.args, i) + case t: TypeRef => apply(newTpe, List.empty, i) + case t: TermRef => (Ref(t.classSymbol.get.companionModule), i) + } + (term :: terms, nextOffset) + case _ => + report.errorAndAbort(s"Unsupported type $typeSymbol for flattening") } + } else report.errorAndAbort(s"${typeSymbol} is not a case class or a immutable.Map") + } + else { + val lhs = '{$params(${ Expr(i) })} + val term = appliedTpe match { + case AnnotatedType(AppliedType(base, Seq(arg)), x) if x.tpe =:= defn.RepeatedAnnot.typeRef => + arg.asType match { + case '[t] => + Typed( + lhs.asTerm, + TypeTree.of(using AppliedType(defn.RepeatedParamClass.typeRef, List(arg)).asType) + ) + } + case tpe => + tpe.asType match { + case '[t] => '{ $lhs.asInstanceOf[t] }.asTerm + } + } + (term :: terms, i + 1) } - } - typeApply match{ - case None => Select.overloaded(Ref(companion), "apply", Nil, rhs).asExprOf[T] - case Some(args) => - Select.overloaded(Ref(companion), "apply", args, rhs).asExprOf[T] - } + (Select.overloaded(Ref(companion), "apply", typeArgs, rhs.reverse), nextOffset) } - TypeRepr.of[T] match{ - case t: AppliedType => apply(Some(t.args)) - case t: TypeRef => apply(None) + val tpe = TypeRepr.of[T] + tpe match{ + case t: AppliedType => apply(tpe, t.args, 0)._1.asExprOf[T] + case t: TypeRef => apply(tpe, List.empty, 0)._1.asExprOf[T] case t: TermRef => '{${Ref(t.classSymbol.get.companionModule).asExprOf[Any]}.asInstanceOf[T]} } @@ -389,3 +571,27 @@ def defineEnumVisitorsImpl[T0, T <: Tuple](prefix: Expr[Any], macroX: String)(us Block(allDefs.map(_._1), Ident(allDefs.head._2.termRef)).asExprOf[T0] +inline def validateFlattenAnnotation[T](): Unit = ${ validateFlattenAnnotationImpl[T] } +def validateFlattenAnnotationImpl[T](using Quotes, Type[T]): Expr[Unit] = + import quotes.reflect._ + val fields = allFields[T] + if (fields.count(_._5) > 1) { + report.errorAndAbort("Only one Map can be annotated with @upickle.implicits.flatten in the same level") + } + if (fields.map(_._2).distinct.length != fields.length) { + report.errorAndAbort("There are multiple fields with the same key") + } + if (fields.exists {case (_, _, tpe, _, isFlattenMap) => isFlattenMap && !(tpe.typeArgs.head.dealias =:= TypeRepr.of[String].dealias)}) { + report.errorAndAbort("The key type of a Map annotated with @flatten must be String.") + } + '{()} + +private def isMap(using Quotes)(tpe: quotes.reflect.TypeRepr): Boolean = { + import quotes.reflect._ + tpe.typeSymbol == TypeRepr.of[collection.immutable.Map[_, _]].typeSymbol +} + +private def isCaseClass(using Quotes)(typeSymbol: quotes.reflect.Symbol): Boolean = { + import quotes.reflect._ + typeSymbol.isClassDef && typeSymbol.flags.is(Flags.Case) +} diff --git a/upickle/implicits/src/upickle/implicits/ObjectContexts.scala b/upickle/implicits/src/upickle/implicits/ObjectContexts.scala index cac1c17e6..49f33f225 100644 --- a/upickle/implicits/src/upickle/implicits/ObjectContexts.scala +++ b/upickle/implicits/src/upickle/implicits/ObjectContexts.scala @@ -4,6 +4,9 @@ import upickle.core.ObjVisitor trait BaseCaseObjectContext { + var currentKey = "" + var storeToMap = false + def storeAggregatedValue(currentIndex: Int, v: Any): Unit def visitKey(index: Int) = _root_.upickle.core.StringVisitor @@ -21,10 +24,13 @@ abstract class CaseObjectContext[V](fieldCount: Int) extends ObjVisitor[Any, V] var found = 0L def visitValue(v: Any, index: Int): Unit = { - if (currentIndex != -1 && ((found & (1L << currentIndex)) == 0)) { + if ((currentIndex != -1) && ((found & (1L << currentIndex)) == 0)) { storeAggregatedValue(currentIndex, v) found |= (1L << currentIndex) } + else if (storeToMap) { + storeAggregatedValue(currentIndex, v) + } } def storeValueIfNotFound(i: Int, v: Any) = { @@ -53,10 +59,13 @@ abstract class HugeCaseObjectContext[V](fieldCount: Int) extends ObjVisitor[Any, var found = new Array[Long](fieldCount / 64 + 1) def visitValue(v: Any, index: Int): Unit = { - if (currentIndex != -1 && ((found(currentIndex / 64) & (1L << currentIndex)) == 0)) { + if ((currentIndex != -1) && ((found(currentIndex / 64) & (1L << currentIndex)) == 0)) { storeAggregatedValue(currentIndex, v) found(currentIndex / 64) |= (1L << currentIndex) } + else if (storeToMap) { + storeAggregatedValue(currentIndex, v) + } } def storeValueIfNotFound(i: Int, v: Any) = { diff --git a/upickle/implicits/src/upickle/implicits/key.scala b/upickle/implicits/src/upickle/implicits/key.scala index 36c013d71..2be787810 100644 --- a/upickle/implicits/src/upickle/implicits/key.scala +++ b/upickle/implicits/src/upickle/implicits/key.scala @@ -30,4 +30,16 @@ class serializeDefaults(s: Boolean) extends StaticAnnotation */ class allowUnknownKeys(b: Boolean) extends StaticAnnotation + +/** + * An annotation that, when applied to a field in a case class, flattens the fields of the + * annotated `case class` or `Map` into the parent case class during serialization. + * This means the fields will appear at the same level as the parent case class's fields + * rather than nested under the field name. During deserialization, these fields are + * grouped back into the annotated `case class` or `Map`. + * + * **Limitations**: + * - Only works with `Map` types that are subtypes of `Map[String, _]`. + * - Cannot flatten more than two `Map` instances in a same level. + */ class flatten extends StaticAnnotation diff --git a/upickle/test/src/upickle/FailureTests.scala b/upickle/test/src/upickle/FailureTests.scala index 109781a16..9a70f0d76 100644 --- a/upickle/test/src/upickle/FailureTests.scala +++ b/upickle/test/src/upickle/FailureTests.scala @@ -37,6 +37,11 @@ object WrongTag { } +case class FlattenTwoMaps(@upickle.implicits.flatten map1: Map[String, String], @upickle.implicits.flatten map2: Map[String, String]) +case class ConflictingKeys(i: Int, @upickle.implicits.flatten cm: ConflictingMessage) +case class ConflictingMessage(i: Int) +case class MapWithNoneStringKey(@upickle.implicits.flatten map: Map[ConflictingMessage, String]) + object TaggedCustomSerializer{ sealed trait BooleanOrInt @@ -265,6 +270,9 @@ object FailureTests extends TestSuite { // compileError("""read[Array[Object]]("")""").msg // Make sure this doesn't hang the compiler =/ compileError("implicitly[upickle.default.Reader[Nothing]]") + compileError("upickle.default.macroRW[FlattenTwoMaps]") + compileError("upickle.default.macroRW[ConflictingKeys]") + compileError("upickle.default.macroRW[MapWithNoneStringKey]") } test("expWholeNumbers"){ upickle.default.read[Byte]("0e0") ==> 0.toByte diff --git a/upickle/test/src/upickle/MacroTests.scala b/upickle/test/src/upickle/MacroTests.scala index 2c9736317..9217e15ac 100644 --- a/upickle/test/src/upickle/MacroTests.scala +++ b/upickle/test/src/upickle/MacroTests.scala @@ -145,25 +145,63 @@ object TagName{ implicit val fooRw: TagNamePickler.ReadWriter[Foo] = TagNamePickler.macroRW } -case class Pagination(limit: Int, offset: Int, total: Int) +object Flatten { + case class FlattenTest(i: Int, s: String, @upickle.implicits.flatten n: Nested, @upickle.implicits.flatten n2: Nested2) -object Pagination { - implicit val rw: RW[Pagination] = upickle.default.macroRW -} + object FlattenTest { + implicit val rw: RW[FlattenTest] = upickle.default.macroRW + } -case class Users(Ids: List[Int], @upickle.implicits.flatten pagination: Pagination) + case class Nested(d: Double, @upickle.implicits.flatten m: Map[String, Int]) -object Users { - implicit val rw: RW[Users] = upickle.default.macroRW -} + object Nested { + implicit val rw: RW[Nested] = upickle.default.macroRW + } + + case class Nested2(name: String) + + object Nested2 { + implicit val rw: RW[Nested2] = upickle.default.macroRW + } + + case class FlattenTestWithType[T](i: Int, @upickle.implicits.flatten t: T) + + object FlattenTestWithType { + // implicit def rw[T: RW]: RW[FlattenTestWithType[T]] = upickle.default.macroRW + implicit val rw: RW[FlattenTestWithType[Nested]] = upickle.default.macroRW + } + + case class InnerMost(a: String, b: Int) + + object InnerMost { + implicit val rw: RW[InnerMost] = upickle.default.macroRW + } + + case class Inner(@upickle.implicits.flatten innerMost: InnerMost, c: Boolean) + + object Inner { + implicit val rw: RW[Inner] = upickle.default.macroRW + } -case class PackageManifest( - name: String, - @upickle.implicits.flatten otherStuff: Map[String, ujson.Value] - ) + case class Outer(d: Double, @upickle.implicits.flatten inner: Inner) -object PackageManifest { - implicit val rw: RW[PackageManifest] = upickle.default.macroRW + object Outer { + implicit val rw: RW[Outer] = upickle.default.macroRW + } + + case class HasMap(@upickle.implicits.flatten map: Map[String, String], i: Int) + object HasMap { + implicit val rw: RW[HasMap] = upickle.default.macroRW + } + + case class FlattenWithDefault(i: Int, @upickle.implicits.flatten n: NestedWithDefault) + object FlattenWithDefault { + implicit val rw: RW[FlattenWithDefault] = upickle.default.macroRW + } + case class NestedWithDefault(k: Int = 100, l: String) + object NestedWithDefault { + implicit val rw: RW[NestedWithDefault] = upickle.default.macroRW + } } object MacroTests extends TestSuite { @@ -172,7 +210,7 @@ object MacroTests extends TestSuite { // case class A_(objects: Option[C_]); case class C_(nodes: Option[C_]) // implicitly[Reader[A_]] -// implicitly[upickle.old.Writer[upickle.MixedIn.Obj.ClsB]code] +// implicitly[upickle.old.Writer[upickle.MixedIn.Obj.ClsB]] // println(write(ADTs.ADTc(1, "lol", (1.1, 1.2)))) // implicitly[upickle.old.Writer[ADTs.ADTc]] @@ -904,9 +942,39 @@ object MacroTests extends TestSuite { } test("flatten"){ - val a = Users(List(1, 2, 3), Pagination(10, 20, 30)) - upickle.default.write[Users](a) ==> """{"Ids":[1,2,3],"limit":10,"offset":20,"total":30}""" + import Flatten._ + val a = FlattenTest(10, "test", Nested(3.0, Map("one" -> 1, "two" -> 2)), Nested2("hello")) + rw(a, """{"i":10,"s":"test","d":3,"one":1,"two":2,"name":"hello"}""") + } + + test("flattenTypeParam"){ + import Flatten._ + val a = FlattenTestWithType[Nested](10, Nested(5.0, Map("one" -> 1, "two" -> 2))) + rw(a, """{"i":10,"d":5,"one":1,"two":2}""") } + test("nestedFlatten") { + import Flatten._ + val value = Outer(1.1, Inner(InnerMost("test", 42), true)) + rw(value, """{"d":1.1,"a":"test","b":42,"c":true}""") + } + + test("flattenWithMap") { + import Flatten._ + val value = HasMap(Map("key1" -> "value1", "key2" -> "value2"), 10) + rw(value, """{"key1":"value1","key2":"value2","i":10}""") + } + + test("flattenEmptyMap") { + import Flatten._ + val value = HasMap(Map.empty, 10) + rw(value, """{"i":10}""") + } + + test("flattenWithDefaults") { + import Flatten._ + val value = FlattenWithDefault(10, NestedWithDefault(l = "default")) + rw(value, """{"i":10,"l":"default"}""") + } } }