From db0df8010cc8791a2425330845a66531139a892c Mon Sep 17 00:00:00 2001 From: susliko <1istoobig@gmail.com> Date: Tue, 16 May 2023 20:04:05 +0300 Subject: [PATCH 1/3] Format scala 3 files --- .scalafmt.conf | 10 +- .../ru/tinkoff/phobos/derivation/common.scala | 121 ++-- .../tinkoff/phobos/derivation/decoder.scala | 598 ++++++++++-------- .../tinkoff/phobos/derivation/encoder.scala | 199 +++--- .../phobos/derivation/semiauto/package.scala | 19 +- .../ru/tinkoff/phobos/DerivationTest.scala | 6 +- 6 files changed, 547 insertions(+), 406 deletions(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index ed328ad..ecd53fc 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -15,4 +15,12 @@ newlines.penalizeSingleSelectMultiArgList = false binPack.parentConstructors = true includeCurlyBraceInSelectChains = false -trailingCommas = always \ No newline at end of file +trailingCommas = always +fileOverride { + "glob:**/modules/core/src/test/scala-3/**" { + runner.dialect = scala3 + } + "glob:**/modules/core/src/main/scala-3/**" { + runner.dialect = scala3 + } +} diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala index 128c391..5803a93 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala @@ -21,23 +21,25 @@ object common { } private[derivation] final class ProductTypeField(using val quotes: Quotes)( - val localName: String, - val xmlName: Expr[String], // Name of element or attribute - val namespaceUri: Expr[Option[String]], - val typeRepr: quotes.reflect.TypeRepr, - val category: FieldCategory, + val localName: String, + val xmlName: Expr[String], // Name of element or attribute + val namespaceUri: Expr[Option[String]], + val typeRepr: quotes.reflect.TypeRepr, + val category: FieldCategory, ) private[derivation] final class SumTypeChild(using val quotes: Quotes)( - val xmlName: Expr[String], // Value of discriminator - val typeRepr: quotes.reflect.TypeRepr, + val xmlName: Expr[String], // Value of discriminator + val typeRepr: quotes.reflect.TypeRepr, ) - private[derivation] def extractProductTypeFields[T: Type](config: Expr[ElementCodecConfig])(using Quotes): List[ProductTypeField] = { + private[derivation] def extractProductTypeFields[T: Type]( + config: Expr[ElementCodecConfig], + )(using Quotes): List[ProductTypeField] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol + val classSymbol = classTypeRepr.typeSymbol // Extracting first non-type parameter list. Size of this parameter list must be equal to size of .caseFields val constructorFields = classSymbol.primaryConstructor.paramSymss.filterNot(_.exists(_.isType)).head @@ -47,7 +49,11 @@ object common { val fieldXmlName = extractFieldXmlName(config, classSymbol, fieldSymbol, fieldAnnotations, fieldCategory) val fieldNamespace = extractFeildNamespace(config, classSymbol, fieldSymbol, fieldAnnotations, fieldCategory) ProductTypeField(using quotes)( - fieldSymbol.name, fieldXmlName, fieldNamespace, classTypeRepr.memberType(fieldSymbol), fieldCategory + fieldSymbol.name, + fieldXmlName, + fieldNamespace, + classTypeRepr.memberType(fieldSymbol), + fieldCategory, ) } val textCount = fields.count(_.category == FieldCategory.text) @@ -57,19 +63,21 @@ object common { s""" |Product type cannot have more than one field with @text annotation. |Product type '${classSymbol.name}' has $textCount - |""".stripMargin + |""".stripMargin, ) if (defaultCount > 1) report.throwError( s""" |Product type cannot have more than one field with @default annotation. |Product type '${classSymbol.name}' has $defaultCount - |""".stripMargin + |""".stripMargin, ) fields } - private[derivation] def extractSumTypeChildren[T: Type](config: Expr[ElementCodecConfig])(using Quotes): List[SumTypeChild] = { + private[derivation] def extractSumTypeChildren[T: Type]( + config: Expr[ElementCodecConfig], + )(using Quotes): List[SumTypeChild] = { import quotes.reflect.* val traitTypeRepr = TypeRepr.of[T] val traitSymbol = traitTypeRepr.typeSymbol @@ -81,45 +89,44 @@ object common { } private def extractFieldCategory(using Quotes)( - classSymbol: quotes.reflect.Symbol, - fieldSymbol: quotes.reflect.Symbol, - fieldAnnotations: List[Expr[Any]] + classSymbol: quotes.reflect.Symbol, + fieldSymbol: quotes.reflect.Symbol, + fieldAnnotations: List[Expr[Any]], ): FieldCategory = { import quotes.reflect.* - fieldAnnotations - .collect { - case '{attr()} => FieldCategory.attribute - case '{text()} => FieldCategory.text - case '{default()} => FieldCategory.default - } match { + fieldAnnotations.collect { + case '{ attr() } => FieldCategory.attribute + case '{ text() } => FieldCategory.text + case '{ default() } => FieldCategory.default + } match { case Nil => FieldCategory.element case List(category) => category case categories => val categoryAnnotations = categories.collect { case FieldCategory.attribute => "@attr" - case FieldCategory.text => "@text" - case FieldCategory.default => "@default" + case FieldCategory.text => "@text" + case FieldCategory.default => "@default" }.mkString(", ") report.throwError( s""" |Product type field cannot have more than one category annotation (@attr, @text or @default). |Field '${fieldSymbol.name}' in product type '${classSymbol.name}' has ${categories.size}: $categoryAnnotations - |""".stripMargin + |""".stripMargin, ) } } private def extractFieldXmlName(using Quotes)( - config: Expr[ElementCodecConfig], - classSymbol: quotes.reflect.Symbol, - fieldSymbol: quotes.reflect.Symbol, - fieldAnnotations: List[Expr[Any]], - fieldCategory: FieldCategory, + config: Expr[ElementCodecConfig], + classSymbol: quotes.reflect.Symbol, + fieldSymbol: quotes.reflect.Symbol, + fieldAnnotations: List[Expr[Any]], + fieldCategory: FieldCategory, ): Expr[String] = { import quotes.reflect.* - (fieldAnnotations.collect {case '{renamed($a)} => a } match { + (fieldAnnotations.collect { case '{ renamed($a) } => a } match { case Nil => None case List(name) => Some(name) case names => @@ -128,56 +135,56 @@ object common { s""" |Product type field cannot have more than one @renamed annotation. |Field '${fieldSymbol.name}' in product type '${classSymbol.name}' has ${names.size}: $renamedAnnotations - |""".stripMargin + |""".stripMargin, ) }).getOrElse(fieldCategory match { - case FieldCategory.element => '{${config}.transformElementNames(${Expr(fieldSymbol.name)})} - case FieldCategory.attribute => '{${config}.transformAttributeNames(${Expr(fieldSymbol.name)})} + case FieldCategory.element => '{ ${ config }.transformElementNames(${ Expr(fieldSymbol.name) }) } + case FieldCategory.attribute => '{ ${ config }.transformAttributeNames(${ Expr(fieldSymbol.name) }) } case _ => Expr(fieldSymbol.name) }) } private def extractFeildNamespace(using Quotes)( - config: Expr[ElementCodecConfig], - classSymbol: quotes.reflect.Symbol, - fieldSymbol: quotes.reflect.Symbol, - fieldAnnotations: List[Expr[Any]], - fieldCategory: FieldCategory, + config: Expr[ElementCodecConfig], + classSymbol: quotes.reflect.Symbol, + fieldSymbol: quotes.reflect.Symbol, + fieldAnnotations: List[Expr[Any]], + fieldCategory: FieldCategory, ): Expr[Option[String]] = { import quotes.reflect.* - fieldAnnotations.collect { - case '{xmlns($namespace: b)} => '{Some(summonInline[Namespace[b]].getNamespace)} + fieldAnnotations.collect { case '{ xmlns($namespace: b) } => + '{ Some(summonInline[Namespace[b]].getNamespace) } } match { - case Nil => fieldCategory match { - case FieldCategory.element => '{${config}.elementsDefaultNamespace} - case FieldCategory.attribute => '{${config}.attributesDefaultNamespace} - case _ => '{None} - } + case Nil => + fieldCategory match { + case FieldCategory.element => '{ ${ config }.elementsDefaultNamespace } + case FieldCategory.attribute => '{ ${ config }.attributesDefaultNamespace } + case _ => '{ None } + } case List(namespace) => namespace case namespaces => val xmlnsAnnotations = - fieldAnnotations - .collect { - case '{xmlns($namespace)} => s"@xmlns(${namespace.asTerm.show})" - } + fieldAnnotations.collect { case '{ xmlns($namespace) } => + s"@xmlns(${namespace.asTerm.show})" + } .mkString(", ") report.throwError( s""" |Product type field cannot have more than one @xmlns annotation. |Field '${fieldSymbol.name}' in product type '${classSymbol.name}' has ${namespaces.size}: $xmlnsAnnotations - |""".stripMargin + |""".stripMargin, ) } } private def extractChildXmlName(using Quotes)( - config: Expr[ElementCodecConfig], - traitSymbol: quotes.reflect.Symbol, - childSymbol: quotes.reflect.Symbol, + config: Expr[ElementCodecConfig], + traitSymbol: quotes.reflect.Symbol, + childSymbol: quotes.reflect.Symbol, ): Expr[String] = { import quotes.reflect.* - childSymbol.annotations.map(_.asExpr).collect { case '{discriminator($a)} => a } match { - case Nil => '{$config.transformConstructorNames(${Expr(childSymbol.name)})} + childSymbol.annotations.map(_.asExpr).collect { case '{ discriminator($a) } => a } match { + case Nil => '{ $config.transformConstructorNames(${ Expr(childSymbol.name) }) } case List(name) => name case names => val discriminatorAnnotations = names.map(name => s"@discriminator(${name.show})").mkString(", ") @@ -185,7 +192,7 @@ object common { s""" |Sum type child cannot have more than one @discriminator annotation. |Child '${childSymbol.name}' of sum type '${traitSymbol.name}' has ${names.size}: $discriminatorAnnotations - |""".stripMargin + |""".stripMargin, ) } } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala index 70c570c..59bf031 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala @@ -20,20 +20,20 @@ import scala.quoted.* object decoder { inline def deriveElementDecoder[T]( - inline config: ElementCodecConfig + inline config: ElementCodecConfig, ): ElementDecoder[T] = - ${deriveElementDecoderImpl('{config})} + ${ deriveElementDecoderImpl('{ config }) } inline def deriveXmlDecoder[T]( - inline localName: String, - inline namespace: Option[String], - inline config: ElementCodecConfig + inline localName: String, + inline namespace: Option[String], + inline config: ElementCodecConfig, ): XmlDecoder[T] = - ${deriveXmlDecoderImpl('{localName}, '{namespace}, '{config})} + ${ deriveXmlDecoderImpl('{ localName }, '{ namespace }, '{ config }) } def deriveElementDecoderImpl[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementDecoder[T]] = { import quotes.reflect.* - val tpe = TypeRepr.of[T] + val tpe = TypeRepr.of[T] val typeSymbol = tpe.typeSymbol if (typeSymbol.flags.is(Flags.Case)) { deriveProduct[T](config) @@ -45,11 +45,11 @@ object decoder { } def deriveXmlDecoderImpl[T: Type]( - localName: Expr[String], - namespace: Expr[Option[String]], - config: Expr[ElementCodecConfig], + localName: Expr[String], + namespace: Expr[Option[String]], + config: Expr[ElementCodecConfig], )(using Quotes): Expr[XmlDecoder[T]] = - '{XmlDecoder.fromElementDecoder[T]($localName, $namespace)(${deriveElementDecoderImpl(config)})} + '{ XmlDecoder.fromElementDecoder[T]($localName, $namespace)(${ deriveElementDecoderImpl(config) }) } // PRODUCT @@ -62,35 +62,41 @@ object decoder { } private def decodeAttributes(using Quotes)( - groups: Map[FieldCategory, List[ProductTypeField]], - c: Expr[Cursor], - currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], + groups: Map[FieldCategory, List[ProductTypeField]], + c: Expr[Cursor], + currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], ): Expr[List[Unit]] = { Expr.ofList( groups.getOrElse(FieldCategory.attribute, Nil).map { field => field.typeRepr.asType match { - case '[t] => '{ - val attribute = summonInline[AttributeDecoder[t]] - .decodeAsAttribute($c, ${field.xmlName}, ${field.namespaceUri}) - ${currentFieldStates}.update(${Expr(field.localName)}, attribute) - } + case '[t] => + '{ + val attribute = summonInline[AttributeDecoder[t]] + .decodeAsAttribute($c, ${ field.xmlName }, ${ field.namespaceUri }) + ${ currentFieldStates }.update(${ Expr(field.localName) }, attribute) + } } - } + }, ) } private def decodeText(using Quotes)( - groups: Map[FieldCategory, List[ProductTypeField]], - c: Expr[Cursor], - currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], + groups: Map[FieldCategory, List[ProductTypeField]], + c: Expr[Cursor], + currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], ): Expr[Unit] = { - groups.get(FieldCategory.text).flatMap(_.headOption) - .fold('{()}){ text => + groups + .get(FieldCategory.text) + .flatMap(_.headOption) + .fold('{ () }) { text => text.typeRepr.asType match { case '[t] => '{ - val res = $currentFieldStates.getOrElse(${Expr(text.localName)}, summonInline[TextDecoder[t]]).asInstanceOf[TextDecoder[t]].decodeAsText($c) - $currentFieldStates.update(${Expr(text.localName)}, res) + val res = $currentFieldStates + .getOrElse(${ Expr(text.localName) }, summonInline[TextDecoder[t]]) + .asInstanceOf[TextDecoder[t]] + .decodeAsText($c) + $currentFieldStates.update(${ Expr(text.localName) }, res) } } } @@ -99,170 +105,199 @@ object decoder { // Used twice. Should be used once? // Replace ` match` with ` switch` ? private def decodeElementCases[T: Type](using Quotes)( - elements: List[ProductTypeField], - go: Expr[DecoderState => ElementDecoder[T]], - c: Expr[Cursor], - currentFieldStates: Expr[mutable.AnyRefMap[String, Any]] + elements: List[ProductTypeField], + go: Expr[DecoderState => ElementDecoder[T]], + c: Expr[Cursor], + currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], ): List[quotes.reflect.CaseDef] = { import quotes.reflect.* - elements.map{ element => + elements.map { element => val symbol = Symbol.newBind(Symbol.spliceOwner, "x", Flags.EmptyFlags, TypeRepr.of[String]) - val eq = symbol.memberMethod("==").head - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), Some(Apply(Select(Ref(symbol), eq), List(element.xmlName.asTerm))), - (element.typeRepr.asType match { - case '[t] => - '{ - val res = ${currentFieldStates} - .getOrElse(${Expr(element.localName)}, summonInline[ElementDecoder[t]]) - .asInstanceOf[ElementDecoder[t]] - .decodeAsElement($c, ${element.xmlName}, ${element.namespaceUri}.orElse($c.getScopeDefaultNamespace)) - ${currentFieldStates}.update(${Expr(element.localName)}, res) - if (res.isCompleted) { - res.result(${element.xmlName} :: $c.history) match { - case Right(_) => $go(DecoderState.DecodingSelf) - case Left(error) => new ElementDecoder.FailedDecoder[T](error) - } - } else { - $go(DecoderState.DecodingElement(${element.xmlName})) + val eq = symbol.memberMethod("==").head + CaseDef( + Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), + Some(Apply(Select(Ref(symbol), eq), List(element.xmlName.asTerm))), + (element.typeRepr.asType match { + case '[t] => + '{ + val res = ${ currentFieldStates } + .getOrElse(${ Expr(element.localName) }, summonInline[ElementDecoder[t]]) + .asInstanceOf[ElementDecoder[t]] + .decodeAsElement( + $c, + ${ element.xmlName }, + ${ element.namespaceUri }.orElse($c.getScopeDefaultNamespace), + ) + ${ currentFieldStates }.update(${ Expr(element.localName) }, res) + if (res.isCompleted) { + res.result(${ element.xmlName } :: $c.history) match { + case Right(_) => $go(DecoderState.DecodingSelf) + case Left(error) => new ElementDecoder.FailedDecoder[T](error) } + } else { + $go(DecoderState.DecodingElement(${ element.xmlName })) } - }).asTerm - ) - } + } + }).asTerm, + ) } + } private def decodeStartElement[T: Type](using Quotes)( - groups: Map[FieldCategory, List[ProductTypeField]], - go: Expr[DecoderState => ElementDecoder[T]], - c: Expr[Cursor], - currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], + groups: Map[FieldCategory, List[ProductTypeField]], + go: Expr[DecoderState => ElementDecoder[T]], + c: Expr[Cursor], + currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], ) = { import quotes.reflect.* val decodeElements = decodeElementCases[T](groups.getOrElse(FieldCategory.element, Nil), go, c, currentFieldStates) val decodeDefault = - groups.get(FieldCategory.default).flatMap(_.headOption).fold{ - val symbol = Symbol.newBind(Symbol.spliceOwner, "_", Flags.EmptyFlags, TypeRepr.of[String]) - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), None, ('{ - val state = DecoderState.IgnoringElement($c.getLocalName, Option($c.getNamespaceURI).filter(_.nonEmpty), 0) - $c.next() - $go(state) - }).asTerm) - }{ default => - val symbol = Symbol.newBind(Symbol.spliceOwner, "_", Flags.EmptyFlags, TypeRepr.of[String]) - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), None, - default.typeRepr.asType match { - case '[t] => - '{ - val name = $c.getLocalName - val namespace = Option($c.getNamespaceURI) - val res = $currentFieldStates - .getOrElse(${Expr(default.localName)}, summonInline[ElementDecoder[t]]) - .asInstanceOf[ElementDecoder[t]] - .decodeAsElement($c, name, namespace.orElse($c.getScopeDefaultNamespace)) - $currentFieldStates.update(${Expr(default.localName)}, res) - if (res.isCompleted) { - res.result(name :: $c.history) match { - case Right(_) => $go(DecoderState.DecodingSelf) - case Left(error) => new ElementDecoder.FailedDecoder[T](error) + groups + .get(FieldCategory.default) + .flatMap(_.headOption) + .fold { + val symbol = Symbol.newBind(Symbol.spliceOwner, "_", Flags.EmptyFlags, TypeRepr.of[String]) + CaseDef( + Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), + None, + ('{ + val state = + DecoderState.IgnoringElement($c.getLocalName, Option($c.getNamespaceURI).filter(_.nonEmpty), 0) + $c.next() + $go(state) + }).asTerm, + ) + } { default => + val symbol = Symbol.newBind(Symbol.spliceOwner, "_", Flags.EmptyFlags, TypeRepr.of[String]) + CaseDef( + Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), + None, + default.typeRepr.asType match { + case '[t] => + '{ + val name = $c.getLocalName + val namespace = Option($c.getNamespaceURI) + val res = $currentFieldStates + .getOrElse(${ Expr(default.localName) }, summonInline[ElementDecoder[t]]) + .asInstanceOf[ElementDecoder[t]] + .decodeAsElement($c, name, namespace.orElse($c.getScopeDefaultNamespace)) + $currentFieldStates.update(${ Expr(default.localName) }, res) + if (res.isCompleted) { + res.result(name :: $c.history) match { + case Right(_) => $go(DecoderState.DecodingSelf) + case Left(error) => new ElementDecoder.FailedDecoder[T](error) + } + } else { + $go(DecoderState.IgnoringElement(name, namespace, 0)) } - } else { - $go(DecoderState.IgnoringElement(name, namespace, 0)) - } - }.asTerm - } - ) - } - Match('{$c.getLocalName}.asTerm, decodeElements :+ decodeDefault).asExprOf[ElementDecoder[T]] + }.asTerm + }, + ) + } + Match('{ $c.getLocalName }.asTerm, decodeElements :+ decodeDefault).asExprOf[ElementDecoder[T]] } private def decodeEndElement[T: Type](using Quotes)( - fields: List[ProductTypeField], - go: Expr[DecoderState => ElementDecoder[T]], - c: Expr[Cursor], - localName: Expr[String], - currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], - config: Expr[ElementCodecConfig], + fields: List[ProductTypeField], + go: Expr[DecoderState => ElementDecoder[T]], + c: Expr[Cursor], + localName: Expr[String], + currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], + config: Expr[ElementCodecConfig], ) = { import quotes.reflect.* - '{$c.getLocalName match { - case name if name == $localName => - ${ - val classTypeRepr = TypeRepr.of[T] - val primaryConstructor = Select(New(TypeTree.of[T]), classTypeRepr.typeSymbol.primaryConstructor) - val primaryConstructorTypeApplied = classTypeRepr match { - case AppliedType(_, params) => TypeApply(primaryConstructor, params.map(Inferred.apply)) - case _ => primaryConstructor - } - fields - .foldLeft[List[Term] => Expr[Either[DecodingError, T]]]{ - terms => '{Right(${Apply(primaryConstructorTypeApplied, terms).asExprOf[T]})} - } { (acc, field) => params => - field.typeRepr.asType match { - case '[t] => - val fSymbol = Symbol.newMethod( + '{ + $c.getLocalName match { + case name if name == $localName => + ${ + val classTypeRepr = TypeRepr.of[T] + val primaryConstructor = Select(New(TypeTree.of[T]), classTypeRepr.typeSymbol.primaryConstructor) + val primaryConstructorTypeApplied = classTypeRepr match { + case AppliedType(_, params) => TypeApply(primaryConstructor, params.map(Inferred.apply)) + case _ => primaryConstructor + } + fields + .foldLeft[List[Term] => Expr[Either[DecodingError, T]]] { terms => + '{ Right(${ Apply(primaryConstructorTypeApplied, terms).asExprOf[T] }) } + } { (acc, field) => params => + field.typeRepr.asType match { + case '[t] => + val fSymbol = Symbol.newMethod( Symbol.spliceOwner, "anonfun", MethodType(List(field.localName))( - _ => List(TypeRepr.of[t]), - _ => TypeRepr.of[Either[DecodingError, T]]) - ) - val f = Block( - List(DefDef(fSymbol, _.headOption.flatMap(_.headOption).map { param => - acc(param.asExprOf[t].asTerm :: params).asTerm.changeOwner(fSymbol) - } - )), - Closure(Ref(fSymbol), Some(TypeRepr.of[t => Either[DecodingError, T]])), - ) - field.category match { - case FieldCategory.element | FieldCategory.default => - '{ - $currentFieldStates - .getOrElse(${Expr(field.localName)}, summonInline[ElementDecoder[t]]) - .asInstanceOf[ElementDecoder[t]] - .result($c.history) - .flatMap { ${f.asExprOf[t => Either[DecodingError, T]]} } - } - case FieldCategory.attribute => - '{ - $currentFieldStates - .getOrElse( - ${Expr(field.localName)}, - Left(DecodingError(s"Attribute '${${field.xmlName}}' is missing or invalid", $c.history, None)) - ) - .asInstanceOf[Either[DecodingError, t]] - .flatMap { ${f.asExprOf[t => Either[DecodingError, T]]} } - } - case FieldCategory.text => - '{ - $currentFieldStates - .getOrElse( - ${Expr(field.localName)}, - summonInline[TextDecoder[t]], - ) - .asInstanceOf[TextDecoder[t]] - .result($c.history) - .flatMap { ${f.asExprOf[t => Either[DecodingError, T]]} } - } - } - } - }(Nil) - } match { - case Right(result) => - $c.next() - $config.scopeDefaultNamespace.foreach(_ => $c.unsetScopeDefaultNamespace()) - new ConstDecoder[T](result) - case Left(error) => - new FailedDecoder[T](error) - } - case _ => - $c.next() - $go(DecoderState.DecodingSelf) - }} + _ => List(TypeRepr.of[t]), + _ => TypeRepr.of[Either[DecodingError, T]], + ), + ) + val f = Block( + List( + DefDef( + fSymbol, + _.headOption.flatMap(_.headOption).map { param => + acc(param.asExprOf[t].asTerm :: params).asTerm.changeOwner(fSymbol) + }, + ), + ), + Closure(Ref(fSymbol), Some(TypeRepr.of[t => Either[DecodingError, T]])), + ) + field.category match { + case FieldCategory.element | FieldCategory.default => + '{ + $currentFieldStates + .getOrElse(${ Expr(field.localName) }, summonInline[ElementDecoder[t]]) + .asInstanceOf[ElementDecoder[t]] + .result($c.history) + .flatMap { ${ f.asExprOf[t => Either[DecodingError, T]] } } + } + case FieldCategory.attribute => + '{ + $currentFieldStates + .getOrElse( + ${ Expr(field.localName) }, + Left( + DecodingError( + s"Attribute '${${ field.xmlName }}' is missing or invalid", + $c.history, + None, + ), + ), + ) + .asInstanceOf[Either[DecodingError, t]] + .flatMap { ${ f.asExprOf[t => Either[DecodingError, T]] } } + } + case FieldCategory.text => + '{ + $currentFieldStates + .getOrElse( + ${ Expr(field.localName) }, + summonInline[TextDecoder[t]], + ) + .asInstanceOf[TextDecoder[t]] + .result($c.history) + .flatMap { ${ f.asExprOf[t => Either[DecodingError, T]] } } + } + } + } + }(Nil) + } match { + case Right(result) => + $c.next() + $config.scopeDefaultNamespace.foreach(_ => $c.unsetScopeDefaultNamespace()) + new ConstDecoder[T](result) + case Left(error) => + new FailedDecoder[T](error) + } + case _ => + $c.next() + $go(DecoderState.DecodingSelf) + } + } } private def decodingElement[T: Type](using Quotes)( - groups: Map[FieldCategory, List[ProductTypeField]], + groups: Map[FieldCategory, List[ProductTypeField]], go: Expr[DecoderState => ElementDecoder[T]], c: Expr[Cursor], currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], @@ -271,59 +306,71 @@ object decoder { import quotes.reflect.* val default = { val symbol = Symbol.newBind(Symbol.spliceOwner, "unknown", Flags.EmptyFlags, TypeRepr.of[String]) - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), None, '{ - new ElementDecoder.FailedDecoder[T]( - $c.error( - s"Illegal decoder state: DecodingElement(${${Ref(symbol).asExprOf[String]}}). It's a library bug. Please report it" + CaseDef( + Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), + None, + '{ + new ElementDecoder.FailedDecoder[T]( + $c.error( + s"Illegal decoder state: DecodingElement(${${ Ref(symbol).asExprOf[String] }}). It's a library bug. Please report it", + ), ) - ) - }.asTerm) + }.asTerm, + ) } - Match(name.asTerm, decodeElementCases[T](groups.getOrElse(FieldCategory.element, Nil), go, c, currentFieldStates) :+ default).asExprOf[ElementDecoder[T]] + Match( + name.asTerm, + decodeElementCases[T](groups.getOrElse(FieldCategory.element, Nil), go, c, currentFieldStates) :+ default, + ).asExprOf[ElementDecoder[T]] } private def ignoringElement[T: Type](using Quotes)( - groups: Map[FieldCategory, List[ProductTypeField]], - go: Expr[DecoderState => ElementDecoder[T]], - c: Expr[Cursor], - currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], - state: Expr[IgnoringElement], + groups: Map[FieldCategory, List[ProductTypeField]], + go: Expr[DecoderState => ElementDecoder[T]], + c: Expr[Cursor], + currentFieldStates: Expr[mutable.AnyRefMap[String, Any]], + state: Expr[IgnoringElement], ) = { import quotes.reflect.* - groups.getOrElse(FieldCategory.default, Nil).headOption.fold( - '{ - if ($c.isEndElement && $c.getLocalName == $state.name && $c.getNamespaceURI == $state.namespace.getOrElse("")) { - $c.next() - if ($state.depth == 0) { - $go(DecoderState.DecodingSelf) + groups + .getOrElse(FieldCategory.default, Nil) + .headOption + .fold('{ + if ($c.isEndElement && $c.getLocalName == $state.name && $c.getNamespaceURI == $state.namespace.getOrElse("")) { + $c.next() + if ($state.depth == 0) { + $go(DecoderState.DecodingSelf) + } else { + $go($state.copy(depth = $state.depth - 1)) + } + } else if ( + $c.isStartElement && $c.getLocalName == $state.name && $c.getNamespaceURI == $state.namespace.getOrElse("") + ) { + $c.next() + $go($state.copy(depth = $state.depth + 1)) } else { - $go($state.copy(depth = $state.depth - 1)) + $c.next() + $go($state) } - } else if ($c.isStartElement && $c.getLocalName == $state.name && $c.getNamespaceURI == $state.namespace.getOrElse("")) { - $c.next() - $go($state.copy(depth = $state.depth + 1)) - } else { - $c.next() - $go($state) - }}){ default => + }) { default => // Looks similar with code in decodeStartElement default.typeRepr.asType match { case '[t] => - '{ - val res = $currentFieldStates - .getOrElse(${Expr(default.localName)}, summonInline[ElementDecoder[t]]) + '{ + val res = $currentFieldStates + .getOrElse(${ Expr(default.localName) }, summonInline[ElementDecoder[t]]) .asInstanceOf[ElementDecoder[t]] .decodeAsElement($c, $state.name, $state.namespace.orElse($c.getScopeDefaultNamespace)) - $currentFieldStates.update(${Expr(default.localName)}, res) - if (res.isCompleted) { - res.result($state.name :: $c.history) match { - case Right(_) => $go(DecoderState.DecodingSelf) - case Left(error) => new ElementDecoder.FailedDecoder[T](error) - } - } else { - $go($state) - } - } + $currentFieldStates.update(${ Expr(default.localName) }, res) + if (res.isCompleted) { + res.result($state.name :: $c.history) match { + case Right(_) => $go(DecoderState.DecodingSelf) + case Left(error) => new ElementDecoder.FailedDecoder[T](error) + } + } else { + $go($state) + } + } } } } @@ -331,9 +378,9 @@ object decoder { private def deriveProduct[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementDecoder[T]] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol - val fields = extractProductTypeFields[T](config) - val groups = fields.groupBy(_.category) + val classSymbol = classTypeRepr.typeSymbol + val fields = extractProductTypeFields[T](config) + val groups = fields.groupBy(_.category) '{ // Generate case class instead of untyped map? class TDecoder(state: DecoderState, fieldStates: Map[String, Any]) extends ElementDecoder[T] { @@ -344,38 +391,40 @@ object decoder { if (c.getEventType == AsyncXMLStreamReader.EVENT_INCOMPLETE) { c.next() TDecoder(currentState, Map.from(currentFieldStates)) - } else currentState match { - case DecoderState.New => - if (c.isStartElement) { - val newNamespaceUri = - if (c.getScopeDefaultNamespace == namespaceUri) $config.scopeDefaultNamespace - else $config.scopeDefaultNamespace.orElse(namespaceUri) - $config.scopeDefaultNamespace.foreach(c.setScopeDefaultNamespace) - ElementDecoder.errorIfWrongName[T](c, localName, newNamespaceUri.orElse(c.getScopeDefaultNamespace)) match { - case None => - ${decodeAttributes(groups, 'c, 'currentFieldStates)} - c.next() - go(DecoderState.DecodingSelf) - case Some(error) => error + } else + currentState match { + case DecoderState.New => + if (c.isStartElement) { + val newNamespaceUri = + if (c.getScopeDefaultNamespace == namespaceUri) $config.scopeDefaultNamespace + else $config.scopeDefaultNamespace.orElse(namespaceUri) + $config.scopeDefaultNamespace.foreach(c.setScopeDefaultNamespace) + ElementDecoder + .errorIfWrongName[T](c, localName, newNamespaceUri.orElse(c.getScopeDefaultNamespace)) match { + case None => + ${ decodeAttributes(groups, 'c, 'currentFieldStates) } + c.next() + go(DecoderState.DecodingSelf) + case Some(error) => error + } + } else { + ElementDecoder.FailedDecoder[T](c.error("Illegal state: not START_ELEMENT")) } - } else { - ElementDecoder.FailedDecoder[T](c.error("Illegal state: not START_ELEMENT")) - } - case DecoderState.DecodingSelf => - ${decodeText(groups, 'c, 'currentFieldStates)} - if (c.isStartElement) { - ${decodeStartElement[T](groups, 'go, 'c, 'currentFieldStates)} - } else if (c.isEndElement) { - ${decodeEndElement[T](fields, 'go, 'c, 'localName, 'currentFieldStates, config)} - } else { - c.next() - go(DecoderState.DecodingSelf) - } - case DecoderState.DecodingElement(name) => - ${decodingElement(groups, 'go, 'c, 'currentFieldStates, 'name)} - case state: DecoderState.IgnoringElement => - ${ignoringElement[T](groups, 'go, 'c, 'currentFieldStates, 'state)} - } + case DecoderState.DecodingSelf => + ${ decodeText(groups, 'c, 'currentFieldStates) } + if (c.isStartElement) { + ${ decodeStartElement[T](groups, 'go, 'c, 'currentFieldStates) } + } else if (c.isEndElement) { + ${ decodeEndElement[T](fields, 'go, 'c, 'localName, 'currentFieldStates, config) } + } else { + c.next() + go(DecoderState.DecodingSelf) + } + case DecoderState.DecodingElement(name) => + ${ decodingElement(groups, 'go, 'c, 'currentFieldStates, 'name) } + case state: DecoderState.IgnoringElement => + ${ ignoringElement[T](groups, 'go, 'c, 'currentFieldStates, 'state) } + } } go(state) } @@ -398,38 +447,65 @@ object decoder { val discriminator = if ($config.useElementNameAsDiscriminator) { Right(c.getLocalName) } else { - ElementDecoder.errorIfWrongName[T](c, localName, namespaceUri) + ElementDecoder + .errorIfWrongName[T](c, localName, namespaceUri) .map(Left.apply) - .getOrElse{ - val discriminatorIdx = c.getAttributeIndex($config.discriminatorNamespace.getOrElse(null), $config.discriminatorLocalName) + .getOrElse { + val discriminatorIdx = + c.getAttributeIndex($config.discriminatorNamespace.getOrElse(null), $config.discriminatorLocalName) if (discriminatorIdx > -1) { Right(c.getAttributeValue(discriminatorIdx)) } else { Left( new FailedDecoder[T]( - c.error(s"No type discriminator '${$config.discriminatorNamespace.fold("")(_ + ":")}${$config.discriminatorLocalName}' found") - ) + c.error( + s"No type discriminator '${$config.discriminatorNamespace.fold("")(_ + ":")}${$config.discriminatorLocalName}' found", + ), + ), ) } } } - discriminator.fold(identity, d => ${Match('d.asTerm, extractSumTypeChildren[T](config).map { child => - val symbol = Symbol.newBind(Symbol.spliceOwner, "x", Flags.EmptyFlags, TypeRepr.of[String]) - val eq = symbol.memberMethod("==").head - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), Some(Apply(Select(Ref(symbol), eq), List(child.xmlName.asTerm))), child.typeRepr.asType match { - case '[t] => '{ - summonInline[ElementDecoder[t]] - .decodeAsElement(c, c.getLocalName, Option(c.getNamespaceURI).filter(_.nonEmpty).orElse(c.getScopeDefaultNamespace)) - .map(_.asInstanceOf[T]) - }.asTerm - }) - } :+ { - val symbol = Symbol.newBind(Symbol.spliceOwner, "unknown", Flags.EmptyFlags, TypeRepr.of[String]) - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), None, - '{new FailedDecoder[T](c.error(s"Unknown type discriminator value: '${${Ref(symbol).asExprOf[String]}}'"))}.asTerm - ) - }).asExprOf[ElementDecoder[T]] - }) + discriminator.fold( + identity, + d => + ${ + Match( + 'd.asTerm, + extractSumTypeChildren[T](config).map { child => + val symbol = Symbol.newBind(Symbol.spliceOwner, "x", Flags.EmptyFlags, TypeRepr.of[String]) + val eq = symbol.memberMethod("==").head + CaseDef( + Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), + Some(Apply(Select(Ref(symbol), eq), List(child.xmlName.asTerm))), + child.typeRepr.asType match { + case '[t] => + '{ + summonInline[ElementDecoder[t]] + .decodeAsElement( + c, + c.getLocalName, + Option(c.getNamespaceURI).filter(_.nonEmpty).orElse(c.getScopeDefaultNamespace), + ) + .map(_.asInstanceOf[T]) + }.asTerm + }, + ) + } :+ { + val symbol = Symbol.newBind(Symbol.spliceOwner, "unknown", Flags.EmptyFlags, TypeRepr.of[String]) + CaseDef( + Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), + None, + '{ + new FailedDecoder[T]( + c.error(s"Unknown type discriminator value: '${${ Ref(symbol).asExprOf[String] }}'"), + ) + }.asTerm, + ) + }, + ).asExprOf[ElementDecoder[T]] + }, + ) } } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala index 3095044..8e6dda0 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala @@ -14,22 +14,22 @@ import scala.quoted.* object encoder { inline def deriveElementEncoder[T]( - inline config: ElementCodecConfig + inline config: ElementCodecConfig, ): ElementEncoder[T] = - ${deriveElementEncoderImpl('{config})} + ${ deriveElementEncoderImpl('{ config }) } inline def deriveXmlEncoder[T]( - inline localName: String, - inline namespace: Option[String], - inline preferredNamespacePrefix: Option[String], - inline config: ElementCodecConfig + inline localName: String, + inline namespace: Option[String], + inline preferredNamespacePrefix: Option[String], + inline config: ElementCodecConfig, ): XmlEncoder[T] = - ${deriveXmlEncoderImpl('{localName}, '{namespace}, '{preferredNamespacePrefix}, '{config})} + ${ deriveXmlEncoderImpl('{ localName }, '{ namespace }, '{ preferredNamespacePrefix }, '{ config }) } def deriveElementEncoderImpl[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementEncoder[T]] = { import quotes.reflect.* - val tpe = TypeRepr.of[T] + val tpe = TypeRepr.of[T] val typeSymbol = tpe.typeSymbol if (typeSymbol.flags.is(Flags.Case)) { deriveProduct(config) @@ -41,130 +41,152 @@ object encoder { } def deriveXmlEncoderImpl[T: Type]( - localName: Expr[String], - namespace: Expr[Option[String]], - preferredNamespacePrefix: Expr[Option[String]], - config: Expr[ElementCodecConfig], + localName: Expr[String], + namespace: Expr[Option[String]], + preferredNamespacePrefix: Expr[Option[String]], + config: Expr[ElementCodecConfig], )(using Quotes): Expr[XmlEncoder[T]] = - '{XmlEncoder.fromElementEncoder[T]($localName, $namespace, $preferredNamespacePrefix)(${deriveElementEncoderImpl(config)})} + '{ + XmlEncoder.fromElementEncoder[T]($localName, $namespace, $preferredNamespacePrefix)(${ + deriveElementEncoderImpl(config) + }) + } // PRODUCT private def encodeAttributes[T: Type](using Quotes)( - fields: List[ProductTypeField], - sw: Expr[PhobosStreamWriter], - a: Expr[T] + fields: List[ProductTypeField], + sw: Expr[PhobosStreamWriter], + a: Expr[T], ): Expr[List[Unit]] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol - Expr.ofList(fields.map{ field => + val classSymbol = classTypeRepr.typeSymbol + Expr.ofList(fields.map { field => field.typeRepr.asType match { - case '[t] => '{ - summonInline[AttributeEncoder[t]].encodeAsAttribute( - ${Select(a.asTerm, classSymbol.declaredField(field.localName)).asExprOf[t]}, - $sw, - ${field.xmlName}, - ${field.namespaceUri} - ) - } + case '[t] => + '{ + summonInline[AttributeEncoder[t]].encodeAsAttribute( + ${ Select(a.asTerm, classSymbol.declaredField(field.localName)).asExprOf[t] }, + $sw, + ${ field.xmlName }, + ${ field.namespaceUri }, + ) + } } }) } private def encodeText[T: Type](using Quotes)( - fields: List[ProductTypeField], - sw: Expr[PhobosStreamWriter], - a: Expr[T] + fields: List[ProductTypeField], + sw: Expr[PhobosStreamWriter], + a: Expr[T], ): Expr[List[Unit]] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol + val classSymbol = classTypeRepr.typeSymbol Expr.ofList(fields.map { field => field.typeRepr.asType match { - case '[t] => '{ - summonInline[TextEncoder[t]] - .encodeAsText(${Select(a.asTerm, classSymbol.declaredField(field.localName)).asExprOf[t]}, $sw) - } + case '[t] => + '{ + summonInline[TextEncoder[t]] + .encodeAsText(${ Select(a.asTerm, classSymbol.declaredField(field.localName)).asExprOf[t] }, $sw) + } } }) } private def encodeElements[T: Type](using Quotes)( - fields: List[ProductTypeField], - sw: Expr[PhobosStreamWriter], - a: Expr[T] + fields: List[ProductTypeField], + sw: Expr[PhobosStreamWriter], + a: Expr[T], ): Expr[List[Unit]] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol + val classSymbol = classTypeRepr.typeSymbol Expr.ofList(fields.map { field => field.typeRepr.asType match { - case '[t] => '{ - summonInline[ElementEncoder[t]].encodeAsElement( - ${Select(a.asTerm, classSymbol.declaredField(field.localName)).asExprOf[t]}, - $sw, - ${field.xmlName}, - ${field.namespaceUri}, - ) - } + case '[t] => + '{ + summonInline[ElementEncoder[t]].encodeAsElement( + ${ Select(a.asTerm, classSymbol.declaredField(field.localName)).asExprOf[t] }, + $sw, + ${ field.xmlName }, + ${ field.namespaceUri }, + ) + } } }) } - private def deriveProduct[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementEncoder[T]] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol - val fields = extractProductTypeFields[T](config) + val classSymbol = classTypeRepr.typeSymbol + val fields = extractProductTypeFields[T](config) val groups = fields.groupBy(_.category) - '{new ElementEncoder[T]{ - def encodeAsElement(a: T, sw: PhobosStreamWriter, localName: String, namespaceUri: Option[String], preferredNamespacePrefix: Option[String]): Unit = { - namespaceUri.fold(sw.writeStartElement(localName))(ns => sw.writeStartElement(preferredNamespacePrefix.orNull, localName, ns)) - $config.scopeDefaultNamespace.foreach { uri => - sw.writeAttribute("xmlns", uri) - } - $config.defineNamespaces.foreach { - case (uri, Some(prefix)) => - if (sw.getNamespaceContext.getPrefix(uri) == null) sw.writeNamespace(prefix, uri) - case (uri, None) => - if (sw.getNamespaceContext.getPrefix(uri) == null) sw.writeNamespace(uri) - } + '{ + new ElementEncoder[T] { + def encodeAsElement( + a: T, + sw: PhobosStreamWriter, + localName: String, + namespaceUri: Option[String], + preferredNamespacePrefix: Option[String], + ): Unit = { + namespaceUri.fold(sw.writeStartElement(localName))(ns => + sw.writeStartElement(preferredNamespacePrefix.orNull, localName, ns), + ) + $config.scopeDefaultNamespace.foreach { uri => + sw.writeAttribute("xmlns", uri) + } + $config.defineNamespaces.foreach { + case (uri, Some(prefix)) => + if (sw.getNamespaceContext.getPrefix(uri) == null) sw.writeNamespace(prefix, uri) + case (uri, None) => + if (sw.getNamespaceContext.getPrefix(uri) == null) sw.writeNamespace(uri) + } - ${encodeAttributes[T](groups.getOrElse(FieldCategory.attribute, Nil), 'sw, 'a)} - ${encodeText[T](groups.getOrElse(FieldCategory.text, Nil), 'sw, 'a)} - ${encodeElements[T]((groups.getOrElse(FieldCategory.element, Nil) ::: groups.getOrElse(FieldCategory.default, Nil)), 'sw, 'a)} + ${ encodeAttributes[T](groups.getOrElse(FieldCategory.attribute, Nil), 'sw, 'a) } + ${ encodeText[T](groups.getOrElse(FieldCategory.text, Nil), 'sw, 'a) } + ${ + encodeElements[T]( + (groups.getOrElse(FieldCategory.element, Nil) ::: groups.getOrElse(FieldCategory.default, Nil)), + 'sw, + 'a, + ) + } - sw.writeEndElement() + sw.writeEndElement() + } } - }} + } } // SUM private def encodeChild[T: Type](using Quotes)( - config: Expr[ElementCodecConfig], - child: SumTypeChild, - childValue: Expr[T], - sw: Expr[PhobosStreamWriter], - localName: Expr[String], - namespaceUri: Expr[Option[String]], - preferredNamespacePrefix: Expr[Option[String]], + config: Expr[ElementCodecConfig], + child: SumTypeChild, + childValue: Expr[T], + sw: Expr[PhobosStreamWriter], + localName: Expr[String], + namespaceUri: Expr[Option[String]], + preferredNamespacePrefix: Expr[Option[String]], ): Expr[Unit] = { import quotes.reflect.* '{ val instance = summonInline[ElementEncoder[T]] if ($config.useElementNameAsDiscriminator) { - instance.encodeAsElement(${childValue}, $sw, ${child.xmlName}, $namespaceUri, $preferredNamespacePrefix) + instance.encodeAsElement(${ childValue }, $sw, ${ child.xmlName }, $namespaceUri, $preferredNamespacePrefix) } else { - $sw.memorizeDiscriminator($config.discriminatorNamespace, $config.discriminatorLocalName, ${child.xmlName}) - instance.encodeAsElement(${childValue}, $sw, $localName, $namespaceUri, $preferredNamespacePrefix) + $sw.memorizeDiscriminator($config.discriminatorNamespace, $config.discriminatorLocalName, ${ child.xmlName }) + instance.encodeAsElement(${ childValue }, $sw, $localName, $namespaceUri, $preferredNamespacePrefix) } } } @@ -174,24 +196,39 @@ object encoder { '{ new ElementEncoder[T] { - def encodeAsElement(a: T, sw: PhobosStreamWriter, localName: String, namespaceUri: Option[String], preferredNamespacePrefix: Option[String]): Unit = { + def encodeAsElement( + a: T, + sw: PhobosStreamWriter, + localName: String, + namespaceUri: Option[String], + preferredNamespacePrefix: Option[String], + ): Unit = { ${ val alternatives = extractSumTypeChildren[T](config).map { child => child.typeRepr.asType match { case '[t] => val childValueSymbol = Symbol.newBind(Symbol.spliceOwner, "child", Flags.EmptyFlags, TypeRepr.of[t]) - val encode = encodeChild(config, child, Ref(childValueSymbol).asExprOf[t], 'sw, 'localName, 'namespaceUri, 'preferredNamespacePrefix) + val encode = encodeChild( + config, + child, + Ref(childValueSymbol).asExprOf[t], + 'sw, + 'localName, + 'namespaceUri, + 'preferredNamespacePrefix, + ) CaseDef(Bind(childValueSymbol, Typed(Ref(childValueSymbol), TypeTree.of[t])), None, encode.asTerm) } } Match( - '{a}.asTerm, + '{ a }.asTerm, // Scala 3.0 reports false positive match may not be exhaustive warning if (util.Properties.versionNumberString.startsWith("3.0")) alternatives :+ { val symbol = Symbol.newBind(Symbol.spliceOwner, "_", Flags.EmptyFlags, TypeRepr.of[T]) - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[T])), None, '{()}.asTerm) - } else alternatives + CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[T])), None, '{ () }.asTerm) + } + else alternatives, ).asExprOf[Unit] } } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala index 19124ea..1887dd5 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala @@ -17,8 +17,17 @@ package object semiauto { inline def deriveXmlEncoderConfigured[T](localName: String, config: ElementCodecConfig): XmlEncoder[T] = encoder.deriveXmlEncoder[T](localName, None, None, config) inline def deriveXmlEncoder[T, NS: Namespace](localName: String, ns: NS): XmlEncoder[T] = - encoder.deriveXmlEncoder[T](localName, Some(Namespace[NS].getNamespace), Namespace[NS].getPreferredPrefix, ElementCodecConfig.default) - inline def deriveXmlEncoderConfigured[T, NS: Namespace](localName: String, ns: NS, config: ElementCodecConfig): XmlEncoder[T] = + encoder.deriveXmlEncoder[T]( + localName, + Some(Namespace[NS].getNamespace), + Namespace[NS].getPreferredPrefix, + ElementCodecConfig.default, + ) + inline def deriveXmlEncoderConfigured[T, NS: Namespace]( + localName: String, + ns: NS, + config: ElementCodecConfig, + ): XmlEncoder[T] = encoder.deriveXmlEncoder[T](localName, Some(Namespace[NS].getNamespace), Namespace[NS].getPreferredPrefix, config) inline def deriveElementDecoder[T]: ElementDecoder[T] = @@ -31,6 +40,10 @@ package object semiauto { decoder.deriveXmlDecoder[T](localName, None, config) inline def deriveXmlDecoder[T, NS: Namespace](localName: String, ns: NS): XmlDecoder[T] = decoder.deriveXmlDecoder[T](localName, Some(Namespace[NS].getNamespace), ElementCodecConfig.default) - inline def deriveXmlDecoderConfigured[T, NS: Namespace](localName: String, ns: NS, config: ElementCodecConfig): XmlDecoder[T] = + inline def deriveXmlDecoderConfigured[T, NS: Namespace]( + localName: String, + ns: NS, + config: ElementCodecConfig, + ): XmlDecoder[T] = decoder.deriveXmlDecoder[T](localName, Some(Namespace[NS].getNamespace), config) } diff --git a/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala b/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala index 8d16c9c..312ecd4 100644 --- a/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala +++ b/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala @@ -13,7 +13,7 @@ class DerivationTest extends AnyWordSpec with Matchers { case class Bar(d: String, foo: Foo, e: Char) derives ElementEncoder given XmlEncoder[Bar] = XmlEncoder.fromElementEncoder("bar") - val bar = Bar("d value", Foo(1, "b value", 3.0), 'e') + val bar = Bar("d value", Foo(1, "b value", 3.0), 'e') val string = """ | | d value @@ -37,7 +37,7 @@ class DerivationTest extends AnyWordSpec with Matchers { case class Bar(d: String, foo: Foo, e: Char) derives ElementDecoder given XmlDecoder[Bar] = XmlDecoder.fromElementDecoder("bar") - val bar = Bar("d value", Foo(1, "b value", 3.0), 'e') + val bar = Bar("d value", Foo(1, "b value", 3.0), 'e') val string = """ | | d value @@ -55,4 +55,4 @@ class DerivationTest extends AnyWordSpec with Matchers { } } -} \ No newline at end of file +} From b66fa641047f7e072e6ec07425fb98309eaa5e96 Mon Sep 17 00:00:00 2001 From: susliko <1istoobig@gmail.com> Date: Tue, 16 May 2023 22:59:26 +0300 Subject: [PATCH 2/3] Fix EncodingError message --- .../main/scala/ru/tinkoff/phobos/encoding/EncodingError.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/core/src/main/scala/ru/tinkoff/phobos/encoding/EncodingError.scala b/modules/core/src/main/scala/ru/tinkoff/phobos/encoding/EncodingError.scala index 2710f60..7592908 100644 --- a/modules/core/src/main/scala/ru/tinkoff/phobos/encoding/EncodingError.scala +++ b/modules/core/src/main/scala/ru/tinkoff/phobos/encoding/EncodingError.scala @@ -1,5 +1,5 @@ package ru.tinkoff.phobos.encoding case class EncodingError(text: String, cause: Option[Throwable] = None) extends Exception(text, cause.orNull) { - override def getMessage: String = s"Error while decoding XML: $text" + override def getMessage: String = s"Error while encoding XML: $text" } From 6ee898af562e5fab9bf1240d970bb1de610ba15f Mon Sep 17 00:00:00 2001 From: susliko <1istoobig@gmail.com> Date: Tue, 16 May 2023 22:59:56 +0300 Subject: [PATCH 3/3] Derive codecs for enums and sealed traits --- .../phobos/decoding/DerivedElement.scala | 5 + .../phobos/derivation/LazySummon.scala | 12 ++ .../ru/tinkoff/phobos/derivation/common.scala | 58 ++++-- .../tinkoff/phobos/derivation/decoder.scala | 185 ++++++++---------- .../tinkoff/phobos/derivation/encoder.scala | 136 ++++--------- .../phobos/derivation/semiauto/package.scala | 2 - .../phobos/encoding/DerivedElement.scala | 5 + .../ru/tinkoff/phobos/DerivationTest.scala | 128 +++++++++++- .../phobos/DecoderDerivationTest.scala | 9 +- .../phobos/EncoderDerivationTest.scala | 8 +- 10 files changed, 314 insertions(+), 234 deletions(-) create mode 100644 modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/LazySummon.scala diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/decoding/DerivedElement.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/decoding/DerivedElement.scala index e385725..1f59edf 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/decoding/DerivedElement.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/decoding/DerivedElement.scala @@ -2,8 +2,13 @@ package ru.tinkoff.phobos.decoding import ru.tinkoff.phobos.configured.ElementCodecConfig import ru.tinkoff.phobos.derivation.decoder +import ru.tinkoff.phobos.derivation.LazySummon +import scala.deriving.Mirror private[decoding] trait DerivedElement { inline def derived[T]: ElementDecoder[T] = decoder.deriveElementDecoder[T](ElementCodecConfig.default) + + inline given [T](using mirror: Mirror.Of[T]): LazySummon[ElementDecoder, T] = new: + def instance = decoder.deriveElementDecoder[T](ElementCodecConfig.default) } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/LazySummon.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/LazySummon.scala new file mode 100644 index 0000000..ea414d6 --- /dev/null +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/LazySummon.scala @@ -0,0 +1,12 @@ +package ru.tinkoff.phobos.derivation + +/** Defining givens of such type in companion objects of ElementEncoder and ElementDecoder allows to summon instances of + * these typeclasses for every child of a sum type (sealed trait or enum), e.g. like this: + * {{{ + * summonAll[Tuple.Map[m.MirroredElemTypes, [t] =>> LazySummon[TC, t]]] + * }}} + * while safeguards against automatical derivation for all types without explicit `derives` clause or + * `deriveElementEncoder`/`deriveElementDecoder` calls. + */ +trait LazySummon[TC[_], A]: + def instance: TC[A] diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala index 5803a93..7830330 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala @@ -7,6 +7,8 @@ import ru.tinkoff.phobos.syntax.* import scala.quoted.* import scala.compiletime.* import scala.annotation.nowarn +import scala.deriving.Mirror +import scala.reflect.TypeTest @nowarn("msg=Use errorAndAbort") object common { @@ -28,11 +30,18 @@ object common { val category: FieldCategory, ) - private[derivation] final class SumTypeChild(using val quotes: Quotes)( - val xmlName: Expr[String], // Value of discriminator - val typeRepr: quotes.reflect.TypeRepr, + private[derivation] final class SumTypeChild[TC[_], Base]( + val xmlName: String, + val lazyTC: LazySummon[TC, Base], + val typeTest: TypeTest[Base, ?], ) + extension [TC[_], Base](children: List[SumTypeChild[TC, Base]]) + def byInstance[T](i: Base): Option[SumTypeChild[TC, Base]] = + children.find(_.typeTest.unapply(i).isDefined) + + def byXmlName(n: String): Option[SumTypeChild[TC, Base]] = children.find(_.xmlName == n) + private[derivation] def extractProductTypeFields[T: Type]( config: Expr[ElementCodecConfig], )(using Quotes): List[ProductTypeField] = { @@ -75,17 +84,34 @@ object common { fields } - private[derivation] def extractSumTypeChildren[T: Type]( + inline def extractSumTypeChild[TC[_], T]( + inline config: ElementCodecConfig, + )(using m: Mirror.SumOf[T]): List[SumTypeChild[TC, T]] = { + type Children = m.MirroredElemTypes + val typeTests = summonAll[Tuple.Map[Children, [t] =>> TypeTest[T, t]]].toList.map(_.asInstanceOf[TypeTest[T, ?]]) + val lazyTCs = + summonAll[Tuple.Map[Children, [t] =>> LazySummon[TC, t]]].toList.map(_.asInstanceOf[LazySummon[TC, T]]) + val xmlNames = extractSumXmlNames[T](config) + + typeTests.zip(lazyTCs).zip(xmlNames).map { case ((typeTest, lazyTC), xmlName) => + new SumTypeChild(xmlName, lazyTC, typeTest) + } + } + + private[derivation] inline def extractSumXmlNames[T](inline config: ElementCodecConfig): List[String] = + ${ extractSumXmlNamesImpl[T]('config) } + + private[derivation] def extractSumXmlNamesImpl[T: Type]( config: Expr[ElementCodecConfig], - )(using Quotes): List[SumTypeChild] = { + )(using q: Quotes): Expr[List[String]] = { import quotes.reflect.* val traitTypeRepr = TypeRepr.of[T] val traitSymbol = traitTypeRepr.typeSymbol - traitSymbol.children.map { childSymbol => - val xmlName = extractChildXmlName(config, traitSymbol, childSymbol) - SumTypeChild(using quotes)(xmlName, TypeIdent(childSymbol).tpe) - } + val names = Varargs(traitSymbol.children.map { childInfosymbol => + extractChildXmlName(using q)(config, traitSymbol, childInfosymbol) + }) + '{ List($names: _*) } } private def extractFieldCategory(using Quotes)( @@ -180,20 +206,26 @@ object common { private def extractChildXmlName(using Quotes)( config: Expr[ElementCodecConfig], traitSymbol: quotes.reflect.Symbol, - childSymbol: quotes.reflect.Symbol, + childInfosymbol: quotes.reflect.Symbol, ): Expr[String] = { import quotes.reflect.* - childSymbol.annotations.map(_.asExpr).collect { case '{ discriminator($a) } => a } match { - case Nil => '{ $config.transformConstructorNames(${ Expr(childSymbol.name) }) } + childInfosymbol.annotations.map(_.asExpr).collect { case '{ discriminator($a) } => a } match { + case Nil => '{ $config.transformConstructorNames(${ Expr(childInfosymbol.name) }) } case List(name) => name case names => val discriminatorAnnotations = names.map(name => s"@discriminator(${name.show})").mkString(", ") report.throwError( s""" |Sum type child cannot have more than one @discriminator annotation. - |Child '${childSymbol.name}' of sum type '${traitSymbol.name}' has ${names.size}: $discriminatorAnnotations + |Child '${childInfosymbol.name}' of sum type '${traitSymbol.name}' has ${names.size}: $discriminatorAnnotations |""".stripMargin, ) } } + + inline def showType[T <: AnyKind]: String = ${ showTypeMacro[T] } + + private def showTypeMacro[T <: AnyKind: Type](using q: Quotes): Expr[String] = + import q.reflect.* + Expr(TypeRepr.of[T].dealias.widen.show) } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala index 59bf031..5e0436e 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/decoder.scala @@ -14,6 +14,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.compiletime.* import scala.quoted.* +import scala.deriving.Mirror @nowarn("msg=Use errorAndAbort") @nowarn("msg=Use methodMember") @@ -22,34 +23,20 @@ object decoder { inline def deriveElementDecoder[T]( inline config: ElementCodecConfig, ): ElementDecoder[T] = - ${ deriveElementDecoderImpl('{ config }) } + summonFrom { + case _: Mirror.ProductOf[T] => deriveProduct(config) + case _: Mirror.SumOf[T] => + val childInfos = extractSumTypeChild[ElementDecoder, T](config) + deriveSum(config, childInfos) + case _ => error(s"${showType[T]} is not a sum type or product type") + } inline def deriveXmlDecoder[T]( inline localName: String, inline namespace: Option[String], inline config: ElementCodecConfig, ): XmlDecoder[T] = - ${ deriveXmlDecoderImpl('{ localName }, '{ namespace }, '{ config }) } - - def deriveElementDecoderImpl[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementDecoder[T]] = { - import quotes.reflect.* - val tpe = TypeRepr.of[T] - val typeSymbol = tpe.typeSymbol - if (typeSymbol.flags.is(Flags.Case)) { - deriveProduct[T](config) - } else if (typeSymbol.flags.is(Flags.Sealed)) { - deriveSum[T](config) - } else { - report.throwError(s"${typeSymbol} is not a sum type or product type") - } - } - - def deriveXmlDecoderImpl[T: Type]( - localName: Expr[String], - namespace: Expr[Option[String]], - config: Expr[ElementCodecConfig], - )(using Quotes): Expr[XmlDecoder[T]] = - '{ XmlDecoder.fromElementDecoder[T]($localName, $namespace)(${ deriveElementDecoderImpl(config) }) } + XmlDecoder.fromElementDecoder[T](localName, namespace)(deriveElementDecoder(config)) // PRODUCT @@ -210,16 +197,23 @@ object decoder { '{ $c.getLocalName match { case name if name == $localName => - ${ - val classTypeRepr = TypeRepr.of[T] - val primaryConstructor = Select(New(TypeTree.of[T]), classTypeRepr.typeSymbol.primaryConstructor) - val primaryConstructorTypeApplied = classTypeRepr match { - case AppliedType(_, params) => TypeApply(primaryConstructor, params.map(Inferred.apply)) - case _ => primaryConstructor + val decodingResult: Either[DecodingError, T] = ${ + def appliedConstructor(constructorParams: List[Term]): Term = { + val classTypeRepr = TypeRepr.of[T] + val primaryConstructor = Select(New(TypeTree.of[T]), classTypeRepr.typeSymbol.primaryConstructor) + classTypeRepr match { + case AppliedType(_, params) => + Apply(TypeApply(primaryConstructor, params.map(Inferred.apply)), constructorParams) + case TermRef(typeRepr, name) => + Ref(classTypeRepr.termSymbol) + case _ => + Apply(primaryConstructor, constructorParams) + } } + fields .foldLeft[List[Term] => Expr[Either[DecodingError, T]]] { terms => - '{ Right(${ Apply(primaryConstructorTypeApplied, terms).asExprOf[T] }) } + '{ Right(${ appliedConstructor(terms).asExprOf[T] }) } } { (acc, field) => params => field.typeRepr.asType match { case '[t] => @@ -281,14 +275,15 @@ object decoder { } } }(Nil) - } match { - case Right(result) => + } + decodingResult.fold( + new FailedDecoder[T](_), + result => { $c.next() $config.scopeDefaultNamespace.foreach(_ => $c.unsetScopeDefaultNamespace()) new ConstDecoder[T](result) - case Left(error) => - new FailedDecoder[T](error) - } + }, + ) case _ => $c.next() $go(DecoderState.DecodingSelf) @@ -375,7 +370,10 @@ object decoder { } } - private def deriveProduct[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementDecoder[T]] = { + private inline def deriveProduct[T](inline config: ElementCodecConfig): ElementDecoder[T] = + ${ deriveProductImpl[T]('config) } + + private def deriveProductImpl[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementDecoder[T]] = { import quotes.reflect.* val classTypeRepr = TypeRepr.of[T] val classSymbol = classTypeRepr.typeSymbol @@ -436,84 +434,59 @@ object decoder { } } - private def deriveSum[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementDecoder[T]] = { - import quotes.reflect.* - '{ - new ElementDecoder[T] { - def decodeAsElement(c: Cursor, localName: String, namespaceUri: Option[String]): ElementDecoder[T] = { - if (c.getEventType == AsyncXMLStreamReader.EVENT_INCOMPLETE) { - this + private inline def deriveSum[T]( + inline config: ElementCodecConfig, + inline childInfos: List[SumTypeChild[ElementDecoder, T]], + ): ElementDecoder[T] = { + new ElementDecoder[T] { + def decodeAsElement(c: Cursor, localName: String, namespaceUri: Option[String]): ElementDecoder[T] = { + if (c.getEventType == AsyncXMLStreamReader.EVENT_INCOMPLETE) { + this + } else { + val discriminator = if (config.useElementNameAsDiscriminator) { + Right(c.getLocalName) } else { - val discriminator = if ($config.useElementNameAsDiscriminator) { - Right(c.getLocalName) - } else { - ElementDecoder - .errorIfWrongName[T](c, localName, namespaceUri) - .map(Left.apply) - .getOrElse { - val discriminatorIdx = - c.getAttributeIndex($config.discriminatorNamespace.getOrElse(null), $config.discriminatorLocalName) - if (discriminatorIdx > -1) { - Right(c.getAttributeValue(discriminatorIdx)) - } else { - Left( - new FailedDecoder[T]( - c.error( - s"No type discriminator '${$config.discriminatorNamespace.fold("")(_ + ":")}${$config.discriminatorLocalName}' found", - ), + ElementDecoder + .errorIfWrongName[T](c, localName, namespaceUri) + .map(Left.apply) + .getOrElse { + val discriminatorIdx = + c.getAttributeIndex(config.discriminatorNamespace.getOrElse(null), config.discriminatorLocalName) + if (discriminatorIdx > -1) { + Right(c.getAttributeValue(discriminatorIdx)) + } else { + Left( + new FailedDecoder[T]( + c.error( + s"No type discriminator '${config.discriminatorNamespace.fold("")(_ + ":")}${config.discriminatorLocalName}' found", ), - ) - } + ), + ) } - } - discriminator.fold( - identity, - d => - ${ - Match( - 'd.asTerm, - extractSumTypeChildren[T](config).map { child => - val symbol = Symbol.newBind(Symbol.spliceOwner, "x", Flags.EmptyFlags, TypeRepr.of[String]) - val eq = symbol.memberMethod("==").head - CaseDef( - Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), - Some(Apply(Select(Ref(symbol), eq), List(child.xmlName.asTerm))), - child.typeRepr.asType match { - case '[t] => - '{ - summonInline[ElementDecoder[t]] - .decodeAsElement( - c, - c.getLocalName, - Option(c.getNamespaceURI).filter(_.nonEmpty).orElse(c.getScopeDefaultNamespace), - ) - .map(_.asInstanceOf[T]) - }.asTerm - }, - ) - } :+ { - val symbol = Symbol.newBind(Symbol.spliceOwner, "unknown", Flags.EmptyFlags, TypeRepr.of[String]) - CaseDef( - Bind(symbol, Typed(Ref(symbol), TypeTree.of[String])), - None, - '{ - new FailedDecoder[T]( - c.error(s"Unknown type discriminator value: '${${ Ref(symbol).asExprOf[String] }}'"), - ) - }.asTerm, - ) - }, - ).asExprOf[ElementDecoder[T]] - }, - ) + } } + discriminator.fold( + identity, + d => { + childInfos.byXmlName(d) match { + case Some(childInfo) => + childInfo.lazyTC.instance.decodeAsElement( + c, + c.getLocalName, + Option(c.getNamespaceURI).filter(_.nonEmpty).orElse(c.getScopeDefaultNamespace), + ) + case None => + new FailedDecoder[T](c.error(s"Unknown type discriminator value: '$d'")) + } + }, + ) } + } - val isCompleted: Boolean = false + val isCompleted: Boolean = false - def result(history: => List[String]): Either[DecodingError, T] = - Left(ElementDecoder.decodingNotCompleteError(history)) - } + def result(history: => List[String]): Either[DecodingError, T] = + Left(ElementDecoder.decodingNotCompleteError(history)) } } } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala index 8e6dda0..8897595 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/encoder.scala @@ -9,6 +9,7 @@ import ru.tinkoff.phobos.derivation.common.* import scala.annotation.nowarn import scala.compiletime.* import scala.quoted.* +import scala.deriving.Mirror @nowarn("msg=Use errorAndAbort") object encoder { @@ -16,7 +17,13 @@ object encoder { inline def deriveElementEncoder[T]( inline config: ElementCodecConfig, ): ElementEncoder[T] = - ${ deriveElementEncoderImpl('{ config }) } + summonFrom { + case _: Mirror.ProductOf[T] => deriveProduct(config) + case _: Mirror.SumOf[T] => + val childInfos = extractSumTypeChild[ElementEncoder, T](config) + deriveSum(config, childInfos) + case _ => error(s"${showType[T]} is not a sum type or product type") + } inline def deriveXmlEncoder[T]( inline localName: String, @@ -24,33 +31,7 @@ object encoder { inline preferredNamespacePrefix: Option[String], inline config: ElementCodecConfig, ): XmlEncoder[T] = - ${ deriveXmlEncoderImpl('{ localName }, '{ namespace }, '{ preferredNamespacePrefix }, '{ config }) } - - def deriveElementEncoderImpl[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementEncoder[T]] = { - import quotes.reflect.* - - val tpe = TypeRepr.of[T] - val typeSymbol = tpe.typeSymbol - if (typeSymbol.flags.is(Flags.Case)) { - deriveProduct(config) - } else if (typeSymbol.flags.is(Flags.Sealed)) { - deriveSum(config) - } else { - report.throwError(s"${typeSymbol} is not a sum type or product type") - } - } - - def deriveXmlEncoderImpl[T: Type]( - localName: Expr[String], - namespace: Expr[Option[String]], - preferredNamespacePrefix: Expr[Option[String]], - config: Expr[ElementCodecConfig], - )(using Quotes): Expr[XmlEncoder[T]] = - '{ - XmlEncoder.fromElementEncoder[T]($localName, $namespace, $preferredNamespacePrefix)(${ - deriveElementEncoderImpl(config) - }) - } + XmlEncoder.fromElementEncoder[T](localName, namespace, preferredNamespacePrefix)(deriveElementEncoder(config)) // PRODUCT @@ -121,12 +102,11 @@ object encoder { }) } - private def deriveProduct[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementEncoder[T]] = { - import quotes.reflect.* - val classTypeRepr = TypeRepr.of[T] - val classSymbol = classTypeRepr.typeSymbol - val fields = extractProductTypeFields[T](config) + inline def deriveProduct[T](inline config: ElementCodecConfig): ElementEncoder[T] = + ${ deriveProductImpl[T]('config) } + private def deriveProductImpl[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementEncoder[T]] = { + val fields = extractProductTypeFields[T](config) val groups = fields.groupBy(_.category) '{ @@ -169,69 +149,35 @@ object encoder { // SUM - private def encodeChild[T: Type](using Quotes)( - config: Expr[ElementCodecConfig], - child: SumTypeChild, - childValue: Expr[T], - sw: Expr[PhobosStreamWriter], - localName: Expr[String], - namespaceUri: Expr[Option[String]], - preferredNamespacePrefix: Expr[Option[String]], - ): Expr[Unit] = { - import quotes.reflect.* - - '{ - val instance = summonInline[ElementEncoder[T]] - if ($config.useElementNameAsDiscriminator) { - instance.encodeAsElement(${ childValue }, $sw, ${ child.xmlName }, $namespaceUri, $preferredNamespacePrefix) - } else { - $sw.memorizeDiscriminator($config.discriminatorNamespace, $config.discriminatorLocalName, ${ child.xmlName }) - instance.encodeAsElement(${ childValue }, $sw, $localName, $namespaceUri, $preferredNamespacePrefix) - } - } - } - - private def deriveSum[T: Type](config: Expr[ElementCodecConfig])(using Quotes): Expr[ElementEncoder[T]] = { - import quotes.reflect.* - - '{ - new ElementEncoder[T] { - def encodeAsElement( - a: T, - sw: PhobosStreamWriter, - localName: String, - namespaceUri: Option[String], - preferredNamespacePrefix: Option[String], - ): Unit = { - ${ - val alternatives = - extractSumTypeChildren[T](config).map { child => - child.typeRepr.asType match { - case '[t] => - val childValueSymbol = Symbol.newBind(Symbol.spliceOwner, "child", Flags.EmptyFlags, TypeRepr.of[t]) - val encode = encodeChild( - config, - child, - Ref(childValueSymbol).asExprOf[t], - 'sw, - 'localName, - 'namespaceUri, - 'preferredNamespacePrefix, - ) - CaseDef(Bind(childValueSymbol, Typed(Ref(childValueSymbol), TypeTree.of[t])), None, encode.asTerm) - } - } - Match( - '{ a }.asTerm, - // Scala 3.0 reports false positive match may not be exhaustive warning - if (util.Properties.versionNumberString.startsWith("3.0")) alternatives :+ { - val symbol = Symbol.newBind(Symbol.spliceOwner, "_", Flags.EmptyFlags, TypeRepr.of[T]) - CaseDef(Bind(symbol, Typed(Ref(symbol), TypeTree.of[T])), None, '{ () }.asTerm) - } - else alternatives, - ).asExprOf[Unit] + inline def deriveSum[T]( + inline config: ElementCodecConfig, + inline childInfos: List[SumTypeChild[ElementEncoder, T]], + ): ElementEncoder[T] = { + new ElementEncoder[T] { + def encodeAsElement( + t: T, + sw: PhobosStreamWriter, + localName: String, + namespaceUri: Option[String], + preferredNamespacePrefix: Option[String], + ): Unit = { + val childInfo = childInfos + .byInstance(t) + .getOrElse(throw EncodingError(s"Looks like an error in derivation: no TypeTest was positive for $t")) + val discr = + if (config.useElementNameAsDiscriminator) childInfo.xmlName + else { + sw.memorizeDiscriminator( + config.discriminatorNamespace, + config.discriminatorLocalName, + childInfo.xmlName, + ) + localName } - } + + childInfo.lazyTC.instance + .asInstanceOf[ElementEncoder[T]] + .encodeAsElement(t, sw, discr, namespaceUri, preferredNamespacePrefix) } } } diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala index 1887dd5..8d83aa0 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/semiauto/package.scala @@ -5,8 +5,6 @@ import ru.tinkoff.phobos.configured.ElementCodecConfig import ru.tinkoff.phobos.decoding.{ElementDecoder, XmlDecoder} import ru.tinkoff.phobos.encoding.{ElementEncoder, XmlEncoder} -import scala.deriving.Mirror - package object semiauto { inline def deriveElementEncoder[T]: ElementEncoder[T] = encoder.deriveElementEncoder[T](ElementCodecConfig.default) diff --git a/modules/core/src/main/scala-3/ru/tinkoff/phobos/encoding/DerivedElement.scala b/modules/core/src/main/scala-3/ru/tinkoff/phobos/encoding/DerivedElement.scala index 21c9738..af39885 100644 --- a/modules/core/src/main/scala-3/ru/tinkoff/phobos/encoding/DerivedElement.scala +++ b/modules/core/src/main/scala-3/ru/tinkoff/phobos/encoding/DerivedElement.scala @@ -2,8 +2,13 @@ package ru.tinkoff.phobos.encoding import ru.tinkoff.phobos.configured.ElementCodecConfig import ru.tinkoff.phobos.derivation.encoder +import scala.deriving.Mirror +import ru.tinkoff.phobos.derivation.LazySummon private[encoding] trait DerivedElement { inline def derived[T]: ElementEncoder[T] = encoder.deriveElementEncoder[T](ElementCodecConfig.default) + + inline given [T](using mirror: Mirror.Of[T]): LazySummon[ElementEncoder, T] = new: + def instance = encoder.deriveElementEncoder[T](ElementCodecConfig.default) } diff --git a/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala b/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala index 312ecd4..b7c6814 100644 --- a/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala +++ b/modules/core/src/test/scala-3/ru/tinkoff/phobos/DerivationTest.scala @@ -5,12 +5,24 @@ import org.scalatest.wordspec.AnyWordSpec import ru.tinkoff.phobos.decoding._ import ru.tinkoff.phobos.encoding._ import ru.tinkoff.phobos.testString._ +import ru.tinkoff.phobos.syntax.discriminator +import ru.tinkoff.phobos.syntax.text +import ru.tinkoff.phobos.syntax.attr +import ru.tinkoff.phobos.SealedClasses.Animal.animalDecoder +import ru.tinkoff.phobos.derivation.LazySummon +import scala.reflect.TypeTest +import ru.tinkoff.phobos.configured.ElementCodecConfig +import ru.tinkoff.phobos.derivation.common.extractSumTypeChild +import ru.tinkoff.phobos.derivation.semiauto.deriveXmlEncoder +import ru.tinkoff.phobos.derivation.encoder.deriveElementEncoder +import scala.deriving.Mirror +import scala.annotation.nowarn class DerivationTest extends AnyWordSpec with Matchers { + import DerivationTest.* + "ElementEncoder.derived" should { - "derive simple encoders" in { - case class Foo(a: Int, b: String, c: Double) derives ElementEncoder - case class Bar(d: String, foo: Foo, e: Char) derives ElementEncoder + "derive for products" in { given XmlEncoder[Bar] = XmlEncoder.fromElementEncoder("bar") val bar = Bar("d value", Foo(1, "b value", 3.0), 'e') @@ -29,12 +41,54 @@ class DerivationTest extends AnyWordSpec with Matchers { val encoded = XmlEncoder[Bar].encode(bar) assert(encoded == Right(string)) } + + "derive for sealed traits" in { + given XmlEncoder[Wild] = XmlEncoder.fromElementEncoder("Wild") + assert(XmlEncoder[Wild].encode(Wild.Tiger) == Right(""" + | + | + """.stripMargin.minimized)) + assert(XmlEncoder[Wild].encode(Wild.Wolf("Coyote")) == Right(""" + | + | Coyote + """.stripMargin.minimized)) + } + + "derive for enums" in { + given XmlEncoder[Domestic] = XmlEncoder.fromElementEncoder("Domestic") + assert(XmlEncoder[Domestic].encode(Domestic.Cat) == Right(""" + | + | + """.stripMargin.minimized)) + assert(XmlEncoder[Domestic].encode(Domestic.Dog("Pug")) == Right(""" + | + | Pug + """.stripMargin.minimized)) + } + + "derive for products with sealed traits and enums" in { + given XmlEncoder[Nature] = XmlEncoder.fromElementEncoder("Nature") + + val res = XmlEncoder[Nature].encode(Nature(Wild.Tiger, Domestic.Dog("Pug"))) + assert( + res == Right( + """ + | + | + | + | Pug + | + """.stripMargin.minimized, + ), + ) + } } + import scala.compiletime.* + import scala.deriving.* + "ElementDecoder.derived" should { - "derive simple decoders" in { - case class Foo(a: Int, b: String, c: Double) derives ElementDecoder - case class Bar(d: String, foo: Foo, e: Char) derives ElementDecoder + "derive for products" in { given XmlDecoder[Bar] = XmlDecoder.fromElementDecoder("bar") val bar = Bar("d value", Foo(1, "b value", 3.0), 'e') @@ -53,6 +107,68 @@ class DerivationTest extends AnyWordSpec with Matchers { val decoded = XmlDecoder[Bar].decode(string) assert(decoded == Right(bar)) } + + "derive for sealed traits" in { + given XmlDecoder[Wild] = XmlDecoder.fromElementDecoder("Wild") + + val tigerString = """ + | + """.stripMargin + + val wolfString = """ + | Coyote + """.stripMargin + + assert(XmlDecoder[Wild].decode(tigerString) == Right(Wild.Tiger)) + assert(XmlDecoder[Wild].decode(wolfString) == Right(Wild.Wolf("Coyote"))) + } + + "derive for enums" in { + given XmlDecoder[Domestic] = XmlDecoder.fromElementDecoder("Domestic") + + val catString = """ + | + """.stripMargin + + val dogString = """ + | Pug + """.stripMargin + + assert(XmlDecoder[Domestic].decode(catString) == Right(Domestic.Cat)) + assert(XmlDecoder[Domestic].decode(dogString) == Right(Domestic.Dog("Pug"))) + } + + "derive for products with sealed traits and enums" in { + given XmlDecoder[Nature] = XmlDecoder.fromElementDecoder("Nature") + + val natureString = """ + | + | + | + | Pug + | + """.stripMargin.minimized + + val nature = Nature(Wild.Tiger, Domestic.Dog("Pug")) + assert(XmlDecoder[Nature].decode(natureString) == Right(nature)) + } + } +} + +object DerivationTest { + case class Foo(a: Int, b: String, c: Double) derives ElementEncoder, ElementDecoder + case class Bar(d: String, foo: Foo, e: Char) derives ElementEncoder, ElementDecoder + + sealed trait Wild derives ElementEncoder, ElementDecoder + object Wild { + @discriminator("cat") case object Tiger extends Wild + @discriminator("dog") case class Wolf(@text breed: String) extends Wild + } + + enum Domestic derives ElementEncoder, ElementDecoder { + @discriminator("tiger") case Cat + @discriminator("wolf") case Dog(@text breed: String) } + case class Nature(wild: Wild, domestic: Domestic) derives ElementEncoder, ElementDecoder } diff --git a/modules/core/src/test/scala/ru/tinkoff/phobos/DecoderDerivationTest.scala b/modules/core/src/test/scala/ru/tinkoff/phobos/DecoderDerivationTest.scala index 7972a6d..4cb296e 100644 --- a/modules/core/src/test/scala/ru/tinkoff/phobos/DecoderDerivationTest.scala +++ b/modules/core/src/test/scala/ru/tinkoff/phobos/DecoderDerivationTest.scala @@ -83,13 +83,9 @@ class DecoderDerivationTest extends AnyWordSpec with Matchers { TextDecoder.stringDecoder.map(_ => -42.0) case class Foo(@attr bar: Int, @text baz: Double) - object Foo { - implicit val fooDecoder: ElementDecoder[Foo] = deriveElementDecoder - } + implicit val fooDecoder: ElementDecoder[Foo] = deriveElementDecoder case class Qux(str: String, foo: Foo) - object Qux { - implicit val quxDecoder: XmlDecoder[Qux] = deriveXmlDecoder("qux") - } + implicit val quxDecoder: XmlDecoder[Qux] = deriveXmlDecoder("qux") val qux = Qux("constant", Foo(24, -42.0)) val string = @@ -1700,3 +1696,4 @@ class DecoderDerivationTest extends AnyWordSpec with Matchers { } } } + diff --git a/modules/core/src/test/scala/ru/tinkoff/phobos/EncoderDerivationTest.scala b/modules/core/src/test/scala/ru/tinkoff/phobos/EncoderDerivationTest.scala index 572d1fa..cf864b6 100644 --- a/modules/core/src/test/scala/ru/tinkoff/phobos/EncoderDerivationTest.scala +++ b/modules/core/src/test/scala/ru/tinkoff/phobos/EncoderDerivationTest.scala @@ -75,13 +75,9 @@ class EncoderDerivationTest extends AnyWordSpec with Matchers { TextEncoder.stringEncoder.contramap(_ => "text") case class Foo(@attr bar: Int, @text baz: Double) - object Foo { - implicit val fooEncoder: ElementEncoder[Foo] = deriveElementEncoder - } + implicit val fooEncoder: ElementEncoder[Foo] = deriveElementEncoder case class Qux(str: String, foo: Foo) - object Qux { - implicit val quxEncoder: XmlEncoder[Qux] = deriveXmlEncoder[Qux]("qux") - } + implicit val quxEncoder: XmlEncoder[Qux] = deriveXmlEncoder[Qux]("qux") val qux = Qux("42", Foo(42, 12.2)) val xml = XmlEncoder[Qux].encode(qux)