Skip to content

Scala 3. Derivation for enums. Better derivation for sealed traits #283

Merged
merged 3 commits into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,12 @@ newlines.penalizeSingleSelectMultiArgList = false

binPack.parentConstructors = true
includeCurlyBraceInSelectChains = false
trailingCommas = always
trailingCommas = always
fileOverride {
"glob:**/modules/core/src/test/scala-3/**" {
runner.dialect = scala3
}
"glob:**/modules/core/src/main/scala-3/**" {
runner.dialect = scala3
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package ru.tinkoff.phobos.decoding

import ru.tinkoff.phobos.configured.ElementCodecConfig
import ru.tinkoff.phobos.derivation.decoder
import ru.tinkoff.phobos.derivation.LazySummon
import scala.deriving.Mirror

private[decoding] trait DerivedElement {
inline def derived[T]: ElementDecoder[T] =
decoder.deriveElementDecoder[T](ElementCodecConfig.default)

inline given [T](using mirror: Mirror.Of[T]): LazySummon[ElementDecoder, T] = new:
def instance = decoder.deriveElementDecoder[T](ElementCodecConfig.default)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ru.tinkoff.phobos.derivation

/** Defining givens of such type in companion objects of ElementEncoder and ElementDecoder allows to summon instances of
* these typeclasses for every child of a sum type (sealed trait or enum), e.g. like this:
* {{{
* summonAll[Tuple.Map[m.MirroredElemTypes, [t] =>> LazySummon[TC, t]]]
* }}}
* while safeguards against automatical derivation for all types without explicit `derives` clause or
* `deriveElementEncoder`/`deriveElementDecoder` calls.
*/
trait LazySummon[TC[_], A]:
def instance: TC[A]
165 changes: 102 additions & 63 deletions modules/core/src/main/scala-3/ru/tinkoff/phobos/derivation/common.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import ru.tinkoff.phobos.syntax.*
import scala.quoted.*
import scala.compiletime.*
import scala.annotation.nowarn
import scala.deriving.Mirror
import scala.reflect.TypeTest

@nowarn("msg=Use errorAndAbort")
object common {
Expand All @@ -21,23 +23,32 @@ object common {
}

private[derivation] final class ProductTypeField(using val quotes: Quotes)(
val localName: String,
val xmlName: Expr[String], // Name of element or attribute
val namespaceUri: Expr[Option[String]],
val typeRepr: quotes.reflect.TypeRepr,
val category: FieldCategory,
val localName: String,
val xmlName: Expr[String], // Name of element or attribute
val namespaceUri: Expr[Option[String]],
val typeRepr: quotes.reflect.TypeRepr,
val category: FieldCategory,
)

private[derivation] final class SumTypeChild(using val quotes: Quotes)(
val xmlName: Expr[String], // Value of discriminator
val typeRepr: quotes.reflect.TypeRepr,
private[derivation] final class SumTypeChild[TC[_], Base](
val xmlName: String,
val lazyTC: LazySummon[TC, Base],
val typeTest: TypeTest[Base, ?],
)

private[derivation] def extractProductTypeFields[T: Type](config: Expr[ElementCodecConfig])(using Quotes): List[ProductTypeField] = {
extension [TC[_], Base](children: List[SumTypeChild[TC, Base]])
def byInstance[T](i: Base): Option[SumTypeChild[TC, Base]] =
children.find(_.typeTest.unapply(i).isDefined)

def byXmlName(n: String): Option[SumTypeChild[TC, Base]] = children.find(_.xmlName == n)

private[derivation] def extractProductTypeFields[T: Type](
config: Expr[ElementCodecConfig],
)(using Quotes): List[ProductTypeField] = {
import quotes.reflect.*

val classTypeRepr = TypeRepr.of[T]
val classSymbol = classTypeRepr.typeSymbol
val classSymbol = classTypeRepr.typeSymbol

// Extracting first non-type parameter list. Size of this parameter list must be equal to size of .caseFields
val constructorFields = classSymbol.primaryConstructor.paramSymss.filterNot(_.exists(_.isType)).head
Expand All @@ -47,7 +58,11 @@ object common {
val fieldXmlName = extractFieldXmlName(config, classSymbol, fieldSymbol, fieldAnnotations, fieldCategory)
val fieldNamespace = extractFeildNamespace(config, classSymbol, fieldSymbol, fieldAnnotations, fieldCategory)
ProductTypeField(using quotes)(
fieldSymbol.name, fieldXmlName, fieldNamespace, classTypeRepr.memberType(fieldSymbol), fieldCategory
fieldSymbol.name,
fieldXmlName,
fieldNamespace,
classTypeRepr.memberType(fieldSymbol),
fieldCategory,
)
}
val textCount = fields.count(_.category == FieldCategory.text)
Expand All @@ -57,69 +72,87 @@ object common {
s"""
|Product type cannot have more than one field with @text annotation.
|Product type '${classSymbol.name}' has $textCount
|""".stripMargin
|""".stripMargin,
)
if (defaultCount > 1)
report.throwError(
s"""
|Product type cannot have more than one field with @default annotation.
|Product type '${classSymbol.name}' has $defaultCount
|""".stripMargin
|""".stripMargin,
)
fields
}

private[derivation] def extractSumTypeChildren[T: Type](config: Expr[ElementCodecConfig])(using Quotes): List[SumTypeChild] = {
inline def extractSumTypeChild[TC[_], T](
inline config: ElementCodecConfig,
)(using m: Mirror.SumOf[T]): List[SumTypeChild[TC, T]] = {
type Children = m.MirroredElemTypes
val typeTests = summonAll[Tuple.Map[Children, [t] =>> TypeTest[T, t]]].toList.map(_.asInstanceOf[TypeTest[T, ?]])
val lazyTCs =
summonAll[Tuple.Map[Children, [t] =>> LazySummon[TC, t]]].toList.map(_.asInstanceOf[LazySummon[TC, T]])
val xmlNames = extractSumXmlNames[T](config)

typeTests.zip(lazyTCs).zip(xmlNames).map { case ((typeTest, lazyTC), xmlName) =>
new SumTypeChild(xmlName, lazyTC, typeTest)
}
}

private[derivation] inline def extractSumXmlNames[T](inline config: ElementCodecConfig): List[String] =
${ extractSumXmlNamesImpl[T]('config) }

private[derivation] def extractSumXmlNamesImpl[T: Type](
config: Expr[ElementCodecConfig],
)(using q: Quotes): Expr[List[String]] = {
import quotes.reflect.*
val traitTypeRepr = TypeRepr.of[T]
val traitSymbol = traitTypeRepr.typeSymbol

traitSymbol.children.map { childSymbol =>
val xmlName = extractChildXmlName(config, traitSymbol, childSymbol)
SumTypeChild(using quotes)(xmlName, TypeIdent(childSymbol).tpe)
}
val names = Varargs(traitSymbol.children.map { childInfosymbol =>
extractChildXmlName(using q)(config, traitSymbol, childInfosymbol)
})
'{ List($names: _*) }
}

private def extractFieldCategory(using Quotes)(
classSymbol: quotes.reflect.Symbol,
fieldSymbol: quotes.reflect.Symbol,
fieldAnnotations: List[Expr[Any]]
classSymbol: quotes.reflect.Symbol,
fieldSymbol: quotes.reflect.Symbol,
fieldAnnotations: List[Expr[Any]],
): FieldCategory = {
import quotes.reflect.*
fieldAnnotations
.collect {
case '{attr()} => FieldCategory.attribute
case '{text()} => FieldCategory.text
case '{default()} => FieldCategory.default
} match {
fieldAnnotations.collect {
case '{ attr() } => FieldCategory.attribute
case '{ text() } => FieldCategory.text
case '{ default() } => FieldCategory.default
} match {
case Nil => FieldCategory.element
case List(category) => category
case categories =>
val categoryAnnotations =
categories.collect {
case FieldCategory.attribute => "@attr"
case FieldCategory.text => "@text"
case FieldCategory.default => "@default"
case FieldCategory.text => "@text"
case FieldCategory.default => "@default"
}.mkString(", ")

report.throwError(
s"""
|Product type field cannot have more than one category annotation (@attr, @text or @default).
|Field '${fieldSymbol.name}' in product type '${classSymbol.name}' has ${categories.size}: $categoryAnnotations
|""".stripMargin
|""".stripMargin,
)
}
}

private def extractFieldXmlName(using Quotes)(
config: Expr[ElementCodecConfig],
classSymbol: quotes.reflect.Symbol,
fieldSymbol: quotes.reflect.Symbol,
fieldAnnotations: List[Expr[Any]],
fieldCategory: FieldCategory,
config: Expr[ElementCodecConfig],
classSymbol: quotes.reflect.Symbol,
fieldSymbol: quotes.reflect.Symbol,
fieldAnnotations: List[Expr[Any]],
fieldCategory: FieldCategory,
): Expr[String] = {
import quotes.reflect.*
(fieldAnnotations.collect {case '{renamed($a)} => a } match {
(fieldAnnotations.collect { case '{ renamed($a) } => a } match {
case Nil => None
case List(name) => Some(name)
case names =>
Expand All @@ -128,65 +161,71 @@ object common {
s"""
|Product type field cannot have more than one @renamed annotation.
|Field '${fieldSymbol.name}' in product type '${classSymbol.name}' has ${names.size}: $renamedAnnotations
|""".stripMargin
|""".stripMargin,
)
}).getOrElse(fieldCategory match {
case FieldCategory.element => '{${config}.transformElementNames(${Expr(fieldSymbol.name)})}
case FieldCategory.attribute => '{${config}.transformAttributeNames(${Expr(fieldSymbol.name)})}
case FieldCategory.element => '{ ${ config }.transformElementNames(${ Expr(fieldSymbol.name) }) }
case FieldCategory.attribute => '{ ${ config }.transformAttributeNames(${ Expr(fieldSymbol.name) }) }
case _ => Expr(fieldSymbol.name)
})
}

private def extractFeildNamespace(using Quotes)(
config: Expr[ElementCodecConfig],
classSymbol: quotes.reflect.Symbol,
fieldSymbol: quotes.reflect.Symbol,
fieldAnnotations: List[Expr[Any]],
fieldCategory: FieldCategory,
config: Expr[ElementCodecConfig],
classSymbol: quotes.reflect.Symbol,
fieldSymbol: quotes.reflect.Symbol,
fieldAnnotations: List[Expr[Any]],
fieldCategory: FieldCategory,
): Expr[Option[String]] = {
import quotes.reflect.*
fieldAnnotations.collect {
case '{xmlns($namespace: b)} => '{Some(summonInline[Namespace[b]].getNamespace)}
fieldAnnotations.collect { case '{ xmlns($namespace: b) } =>
'{ Some(summonInline[Namespace[b]].getNamespace) }
} match {
case Nil => fieldCategory match {
case FieldCategory.element => '{${config}.elementsDefaultNamespace}
case FieldCategory.attribute => '{${config}.attributesDefaultNamespace}
case _ => '{None}
}
case Nil =>
fieldCategory match {
case FieldCategory.element => '{ ${ config }.elementsDefaultNamespace }
case FieldCategory.attribute => '{ ${ config }.attributesDefaultNamespace }
case _ => '{ None }
}
case List(namespace) => namespace
case namespaces =>
val xmlnsAnnotations =
fieldAnnotations
.collect {
case '{xmlns($namespace)} => s"@xmlns(${namespace.asTerm.show})"
}
fieldAnnotations.collect { case '{ xmlns($namespace) } =>
s"@xmlns(${namespace.asTerm.show})"
}
.mkString(", ")
report.throwError(
s"""
|Product type field cannot have more than one @xmlns annotation.
|Field '${fieldSymbol.name}' in product type '${classSymbol.name}' has ${namespaces.size}: $xmlnsAnnotations
|""".stripMargin
|""".stripMargin,
)
}
}

private def extractChildXmlName(using Quotes)(
config: Expr[ElementCodecConfig],
traitSymbol: quotes.reflect.Symbol,
childSymbol: quotes.reflect.Symbol,
config: Expr[ElementCodecConfig],
traitSymbol: quotes.reflect.Symbol,
childInfosymbol: quotes.reflect.Symbol,
): Expr[String] = {
import quotes.reflect.*
childSymbol.annotations.map(_.asExpr).collect { case '{discriminator($a)} => a } match {
case Nil => '{$config.transformConstructorNames(${Expr(childSymbol.name)})}
childInfosymbol.annotations.map(_.asExpr).collect { case '{ discriminator($a) } => a } match {
case Nil => '{ $config.transformConstructorNames(${ Expr(childInfosymbol.name) }) }
case List(name) => name
case names =>
val discriminatorAnnotations = names.map(name => s"@discriminator(${name.show})").mkString(", ")
report.throwError(
s"""
|Sum type child cannot have more than one @discriminator annotation.
|Child '${childSymbol.name}' of sum type '${traitSymbol.name}' has ${names.size}: $discriminatorAnnotations
|""".stripMargin
|Child '${childInfosymbol.name}' of sum type '${traitSymbol.name}' has ${names.size}: $discriminatorAnnotations
|""".stripMargin,
)
}
}

inline def showType[T <: AnyKind]: String = ${ showTypeMacro[T] }

private def showTypeMacro[T <: AnyKind: Type](using q: Quotes): Expr[String] =
import q.reflect.*
Expr(TypeRepr.of[T].dealias.widen.show)
}
Loading