Skip to content

Commit

Permalink
Ignore transient fields of generic records (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jan 28, 2024
1 parent 6a0e914 commit fbb9f4f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.util.Try
import zio.Console._
import zio._
import zio.schema.CaseSet._
import zio.schema.annotation.transientField
import zio.schema.meta.MetaSchema
import zio.schema.{ CaseSet, DeriveSchema, DynamicValue, DynamicValueGen, Schema, SchemaGen, StandardType, TypeId }
import zio.stream.{ ZSink, ZStream }
Expand Down Expand Up @@ -124,7 +125,12 @@ object ProtobufCodecSpec extends ZIOSpecDefault {
test("records with arity greater than 22") {
for {
ed <- encodeAndDecodeNS(schemaHighArityRecord, HighArity())
} yield assert(ed)(equalTo(HighArity()))
} yield assertTrue(ed == HighArity())
},
test("records with arity greater than 22 and transient field") {
for {
ed <- encodeAndDecodeNS(schemaHighArityRecordTransient, HighArityTransient(f24 = 10))
} yield assertTrue(ed == HighArityTransient())
},
test("integer") {
check(Gen.int) { value =>
Expand Down Expand Up @@ -1037,9 +1043,38 @@ object ProtobufCodecSpec extends ZIOSpecDefault {
f23: Int = 23,
f24: Int = 24
)
case class HighArityTransient(
f1: Int = 1,
f2: Int = 2,
f3: Int = 3,
f4: Int = 4,
f5: Int = 5,
f6: Int = 6,
f7: Int = 7,
f8: Int = 8,
f9: Int = 9,
f10: Int = 10,
f11: Int = 11,
f12: Int = 12,
f13: Int = 13,
f14: Int = 14,
f15: Int = 15,
f16: Int = 16,
f17: Int = 17,
f18: Int = 18,
f19: Int = 19,
f20: Int = 20,
f21: Int = 21,
f22: Int = 22,
f23: Int = 23,
@transientField
f24: Int = 24
)

lazy val schemaHighArityRecord: Schema[HighArity] = DeriveSchema.gen[HighArity]

lazy val schemaHighArityRecordTransient: Schema[HighArityTransient] = DeriveSchema.gen[HighArityTransient]

lazy val schemaOneOf: Schema[OneOf] = DeriveSchema.gen[OneOf]

case class MyRecord(age: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,43 +143,38 @@ trait MutableSchemaBasedValueProcessor[Target, Context] {
}

def fields(s: Schema.Record[_], record: Any, fs: Schema.Field[_, _]*): Unit = {
val nonTransientFields = fs.filter {
case Schema.Field(_, _, annotations, _, _, _)
if annotations.collectFirst { case a: transientField => a }.isDefined =>
false
case _ => true
}
val values = ChunkBuilder.make[Target](nonTransientFields.size)

def processNext(index: Int, remaining: List[Schema.Field[_, _]]): Unit =
remaining match {
case next :: _ =>
currentSchema = next.schema
currentValue = next.asInstanceOf[Schema.Field[Any, Any]].get(record)
pushContext(contextForRecordField(contextStack.head, index, next))
push(processField(index, remaining, _))
case Nil =>
finishWith(
processRecord(
contextStack.head,
s,
nonTransientFields.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) {
case (lm, pair) =>
lm.updated(pair._1, pair._2)
}
)
val nonTransientFields = fs.filterNot(_.annotations.exists(_.isInstanceOf[transientField]))
val values = ChunkBuilder.make[Target](nonTransientFields.size)

def processNext(index: Int, remaining: Seq[Schema.Field[_, _]]): Unit =
if (remaining.isEmpty) {
finishWith(
processRecord(
contextStack.head,
s,
nonTransientFields.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) {
case (lm, pair) =>
lm.updated(pair._1, pair._2)
}
)
)
} else {
val next = remaining.head
currentSchema = next.schema
currentValue = next.asInstanceOf[Schema.Field[Any, Any]].get(record)
pushContext(contextForRecordField(contextStack.head, index, next))
push(processField(index, remaining, _))
}

def processField(index: Int, currentStructure: List[Schema.Field[_, _]], fieldResult: Target): Unit = {
def processField(index: Int, currentStructure: Seq[Schema.Field[_, _]], fieldResult: Target): Unit = {
contextStack = contextStack.tail
values += fieldResult
val remaining = currentStructure.tail
processNext(index + 1, remaining)
}

startProcessingRecord(contextStack.head, s)
processNext(0, nonTransientFields.toList)
processNext(0, nonTransientFields)
}

def enumCases(s: Schema.Enum[_], cs: Schema.Case[_, _]*): Unit = {
Expand Down Expand Up @@ -223,33 +218,33 @@ trait MutableSchemaBasedValueProcessor[Target, Context] {
finishWith(processPrimitive(currentContext, currentValue, p.asInstanceOf[StandardType[Any]]))

case s @ Schema.GenericRecord(_, structure, _) =>
val map = currentValue.asInstanceOf[ListMap[String, _]]
val structureChunk = structure.toChunk
val values = ChunkBuilder.make[Target](structureChunk.size)

def processNext(index: Int, remaining: List[Schema.Field[ListMap[String, _], _]]): Unit =
remaining match {
case next :: _ =>
currentSchema = next.schema
currentValue = map(next.name)
pushContext(contextForRecordField(currentContext, index, next))
push(processField(index, remaining, _))
case Nil =>
finishWith(
processRecord(
currentContext,
s,
structureChunk.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) {
case (lm, pair) =>
lm.updated(pair._1, pair._2)
}
)
val map = currentValue.asInstanceOf[ListMap[String, _]]
val nonTransientFields = structure.toChunk.filterNot(_.annotations.exists(_.isInstanceOf[transientField]))
val values = ChunkBuilder.make[Target](nonTransientFields.size)

def processNext(index: Int, remaining: Seq[Schema.Field[ListMap[String, _], _]]): Unit =
if (remaining.isEmpty) {
finishWith(
processRecord(
currentContext,
s,
nonTransientFields.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) {
case (lm, pair) =>
lm.updated(pair._1, pair._2)
}
)
)
} else {
val next = remaining.head
currentSchema = next.schema
currentValue = map(next.name)
pushContext(contextForRecordField(currentContext, index, next))
push(processField(index, remaining, _))
}

def processField(
index: Int,
currentStructure: List[Schema.Field[ListMap[String, _], _]],
currentStructure: Seq[Schema.Field[ListMap[String, _], _]],
fieldResult: Target
): Unit = {
contextStack = contextStack.tail
Expand All @@ -259,7 +254,7 @@ trait MutableSchemaBasedValueProcessor[Target, Context] {
}

startProcessingRecord(currentContext, s)
processNext(0, structureChunk.toList)
processNext(0, nonTransientFields)

case s @ Schema.Enum1(_, case1, _) =>
enumCases(s, case1)
Expand Down

0 comments on commit fbb9f4f

Please sign in to comment.