Skip to content

Commit

Permalink
Read scala default values, if no fieldDefaultValue annotation is found
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed May 20, 2023
1 parent 62ace4d commit 8f34167
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package zio.schema

import zio.Chunk

import scala.annotation.nowarn
import scala.reflect.macros.whitebox

import zio.Chunk

object DeriveSchema {
import scala.language.experimental.macros

Expand Down Expand Up @@ -222,26 +222,57 @@ object DeriveSchema {

val typeAnnotations: List[Tree] = collectTypeAnnotations(tpe)

val defaultConstructorValues =
tpe.typeSymbol.asClass.primaryConstructor.asMethod.paramLists.head
.map(_.asTerm)
.zipWithIndex
.flatMap {
case (symbol, i) =>
if (symbol.isParamWithDefault) {
val defaultInit = tpe.companion.member(TermName(s"$$lessinit$$greater$$default$$${i + 1}"))
val defaultApply = tpe.companion.member(TermName(s"apply$$default$$${i + 1}"))
Some(i -> defaultInit)
.filter(_ => defaultInit != NoSymbol)
.orElse(Some(i -> defaultApply))
.filter(_ => defaultApply != NoSymbol)
} else None
}
.toMap

@nowarn
val fieldAnnotations: List[List[Tree]] = //List.fill(arity)(Nil)
tpe.typeSymbol.asClass.primaryConstructor.asMethod.paramLists.headOption.map { symbols =>
symbols
.map(_.annotations.collect {
case annotation if !(annotation.tree.tpe <:< JavaAnnotationTpe) =>
annotation.tree match {
case q"new $annConstructor(..$annotationArgs)" =>
q"new ${annConstructor.tpe.typeSymbol}(..$annotationArgs)"
case q"new $annConstructor()" =>
q"new ${annConstructor.tpe.typeSymbol}()"
case tree =>
c.warning(c.enclosingPosition, s"Unhandled annotation tree $tree")
EmptyTree
symbols.zipWithIndex.map {
case (symbol, i) =>
val annotations = symbol.annotations.collect {
case annotation if !(annotation.tree.tpe <:< JavaAnnotationTpe) =>
annotation.tree match {
case q"new $annConstructor(..$annotationArgs)" =>
q"new ${annConstructor.tpe.typeSymbol}(..$annotationArgs)"
case q"new $annConstructor()" =>
q"new ${annConstructor.tpe.typeSymbol}()"
case tree =>
c.warning(c.enclosingPosition, s"Unhandled annotation tree $tree")
EmptyTree
}
case annotation =>
c.warning(c.enclosingPosition, s"Unhandled annotation ${annotation.tree}")
EmptyTree
}
val hasDefaultAnnotation =
annotations.exists {
case q"new _root_.zio.schema.annotation.fieldDefaultValue(..$args)" => true
case _ => false
}
case annotation =>
c.warning(c.enclosingPosition, s"Unhandled annotation ${annotation.tree}")
EmptyTree
})
.filter(_ != EmptyTree)
if (hasDefaultAnnotation || defaultConstructorValues.get(i).isEmpty) {
annotations
} else {
annotations :+
q"new _root_.zio.schema.annotation.fieldDefaultValue[${symbol.typeSignature}](${defaultConstructorValues(i)})"

}

}.filter(_ != EmptyTree)
}.getOrElse(Nil)

@nowarn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,40 @@ private case class DeriveSchema()(using val ctx: Quotes) extends ReflectionUtils
field.name -> field.annotations.filter(filterAnnotation).map(_.asExpr)
}

private def fromConstructor(from: Symbol): scala.collection.Map[String, List[Expr[Any]]] =
private def defaultValues(from: Symbol): Predef.Map[String, Expr[Any]] =
(1 to from.primaryConstructor.paramSymss.size).toList.map(
i =>
from
.companionClass
.declaredMethod(s"$$lessinit$$greater$$default$$$i")
.headOption
.orElse(
from
.companionClass
.declaredMethod(s"$$apply$$default$$$i")
.headOption
)
.map(Select(Ref(from.companionModule), _).asExpr)
).zip(from.primaryConstructor.paramSymss.flatten.map(_.name)).collect{case (Some(expr), name) => name -> expr}.toMap

private def fromConstructor(from: Symbol): scala.collection.Map[String, List[Expr[Any]]] = {
val defaults = defaultValues(from)
from.primaryConstructor.paramSymss.flatten.map { field =>
field.name -> field.annotations
.filter(filterAnnotation)
.map(_.asExpr.asInstanceOf[Expr[Any]])
field.name -> {
val annos = field.annotations
.filter(filterAnnotation)
.map(_.asExpr.asInstanceOf[Expr[Any]])
val hasDefaultAnnotation =
field.annotations.exists(_.tpe <:< TypeRepr.of[zio.schema.annotation.fieldDefaultValue[_]])
if (hasDefaultAnnotation || defaults.get(field.name).isEmpty) {
annos
} else {
annos :+ '{zio.schema.annotation.fieldDefaultValue(${defaults(field.name)})}.asExprOf[Any]
}
}
}.toMap

}

def deriveEnum[T: Type](mirror: Mirror, stack: Stack)(using Quotes) = {
val selfRefSymbol = Symbol.newVal(Symbol.spliceOwner, s"derivedSchema${stack.size}", TypeRepr.of[Schema[T]], Flags.Lazy, Symbol.noSymbol)
val selfRef = Ref(selfRefSymbol)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package zio.schema

import scala.annotation.nowarn
import scala.reflect.ClassTag

import zio.schema.Deriver.WrappedF
import zio.schema.Schema.Field
import zio.schema.annotation.fieldDefaultValue
import zio.test.{ Spec, TestEnvironment, ZIOSpecDefault, assertTrue }
import zio.{ Chunk, Scope }

import scala.annotation.nowarn
import scala.reflect.ClassTag

object DeriveSpec extends ZIOSpecDefault with VersionSpecificDeriveSpec {
override def spec: Spec[TestEnvironment with Scope, Any] =
suite("Derive")(
Expand Down Expand Up @@ -161,6 +162,30 @@ object DeriveSpec extends ZIOSpecDefault with VersionSpecificDeriveSpec {
assertTrue(refEquals)
}
),
suite("default field values")(
test("use case class default values") {
val capturedSchema = Derive.derive[CapturedSchema, RecordWithDefaultValue](schemaCapturer)
val annotations = capturedSchema.schema
.asInstanceOf[Schema.Record[RecordWithDefaultValue]]
.fields(0)
.annotations
assertTrue(
annotations
.exists(a => a.isInstanceOf[fieldDefaultValue[_]] && a.asInstanceOf[fieldDefaultValue[Int]].value == 42)
)
},
test("prefer field annotations over case class default values") {
val capturedSchema = Derive.derive[CapturedSchema, RecordWithDefaultValue](schemaCapturer)
val annotations = capturedSchema.schema
.asInstanceOf[Schema.Record[RecordWithDefaultValue]]
.fields(1)
.annotations
assertTrue(
annotations
.exists(a => a.isInstanceOf[fieldDefaultValue[_]] && a.asInstanceOf[fieldDefaultValue[Int]].value == 52)
)
},
),
versionSpecificSuite
)

Expand Down Expand Up @@ -273,6 +298,12 @@ object DeriveSpec extends ZIOSpecDefault with VersionSpecificDeriveSpec {
implicit val schema: Schema[RecordWithBigTuple] = DeriveSchema.gen[RecordWithBigTuple]
}

case class RecordWithDefaultValue(int: Int = 42, @fieldDefaultValue(52) int2: Int = 42)

object RecordWithDefaultValue {
implicit val schema: Schema[RecordWithDefaultValue] = DeriveSchema.gen[RecordWithDefaultValue]
}

sealed trait Enum1

object Enum1 {
Expand Down

0 comments on commit 8f34167

Please sign in to comment.