Skip to content

Commit

Permalink
added optionalField annotation (#410)
Browse files Browse the repository at this point in the history
* added optionalField annotation with implementations for Json and Thrift codecs

* formatted imports

* more formatting

* formatted
  • Loading branch information
devsprint authored Nov 5, 2022
1 parent 2513315 commit ebed2bd
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package zio.schema
import scala.annotation.Annotation

import zio.Chunk
import zio.schema.annotation.fieldName
import zio.schema.annotation.{ fieldName, optionalField }
import zio.test._

object DeriveSchemaSpec extends ZIOSpecDefault {
Expand Down Expand Up @@ -243,6 +243,12 @@ object DeriveSchemaSpec extends ZIOSpecDefault {
implicit val schema: Schema[ContainsSchema] = DeriveSchema.gen[ContainsSchema]
}

case class OptionalField(@optionalField name: String, age: Int)

object OptionalField {
implicit val schema: Schema[OptionalField] = DeriveSchema.gen[OptionalField]
}

override def spec: Spec[Environment, Any] = suite("DeriveSchemaSpec")(
suite("Derivation")(
test("correctly derives case class 0") {
Expand Down Expand Up @@ -416,6 +422,30 @@ object DeriveSchemaSpec extends ZIOSpecDefault {
)
}
assert(derived)(hasSameSchema(expected))
},
test("correctly derives optional fields when optional annotation is present") {
val derived: Schema[OptionalField] = Schema[OptionalField]
val expected: Schema[OptionalField] = {
Schema.CaseClass2(
TypeId.parse("zio.schema.DeriveSchemaSpec.OptionalField"),
field1 = Schema.Field(
"name",
Schema.Primitive(StandardType.StringType),
Chunk(optionalField()),
get = _.name,
set = (a, b: String) => a.copy(name = b)
),
field2 = Schema.Field(
"age",
Schema.Primitive(StandardType.IntType),
Chunk.empty,
get = _.age,
set = (a, b: Int) => a.copy(age = b)
),
OptionalField.apply
)
}
assert(derived)(hasSameSchema(expected))
}
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import zio.json.{
JsonFieldEncoder
}
import zio.schema._
import zio.schema.annotation.optionalField
import zio.schema.codec.BinaryCodec._
import zio.stream.ZPipeline
import zio.{ Chunk, ChunkBuilder, NonEmptyChunk, ZIO }
Expand Down Expand Up @@ -847,8 +848,14 @@ object JsonCodec extends BinaryCodec {

var i = 0
while (i < len) {
if (buffer(i) == null)
buffer(i) = schemaDecoder(schemas(i)).unsafeDecodeMissing(spans(i) :: trace)
if (buffer(i) == null) {
val optionalAnnotation = fields(i).annotations.collectFirst { case a: optionalField => a }
if (optionalAnnotation.isDefined)
buffer(i) = schemas(i).defaultValue.toOption.get
else
buffer(i) = schemaDecoder(schemas(i)).unsafeDecodeMissing(spans(i) :: trace)

}
i += 1
}
buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import zio.json.JsonDecoder.JsonError
import zio.json.{ DeriveJsonEncoder, JsonEncoder }
import zio.schema.CaseSet._
import zio.schema._
import zio.schema.annotation.optionalField
import zio.schema.codec.JsonCodec.JsonEncoder.charSequenceToByteChunk
import zio.stream.ZStream
import zio.test.Assertion._
Expand Down Expand Up @@ -179,6 +180,13 @@ object JsonCodecSpec extends ZIOSpecDefault {
suite("case class")(
test("case object") {
assertDecodes(schemaObject, Singleton, charSequenceToByteChunk("{}"))
},
test("optional") {
assertDecodes(
optionalSearchRequestSchema,
OptionalSearchRequest("test", 0, 10, Schema[String].defaultValue.getOrElse("")),
charSequenceToByteChunk("""{"query":"test","pageNumber":0,"resultPerPage":10}""")
)
}
)
)
Expand Down Expand Up @@ -450,6 +458,11 @@ object JsonCodecSpec extends ZIOSpecDefault {
assertEncodesThenDecodes(searchRequestSchema, value)
}
},
test("optional") {
check(optionalSearchRequestGen) { value =>
assertEncodesThenDecodes(optionalSearchRequestSchema, value)
}
},
test("object") {
assertEncodesThenDecodes(schemaObject, Singleton)
}
Expand Down Expand Up @@ -759,6 +772,17 @@ object JsonCodecSpec extends ZIOSpecDefault {
implicit val encoder: JsonEncoder[SearchRequest] = DeriveJsonEncoder.gen[SearchRequest]
}

case class OptionalSearchRequest(
query: String,
pageNumber: Int,
resultPerPage: Int,
@optionalField nextPage: String
)

object OptionalSearchRequest {
implicit val encoder: JsonEncoder[OptionalSearchRequest] = DeriveJsonEncoder.gen[OptionalSearchRequest]
}

private val searchRequestGen: Gen[Sized, SearchRequest] =
for {
query <- Gen.string
Expand All @@ -767,8 +791,18 @@ object JsonCodecSpec extends ZIOSpecDefault {
nextPage <- Gen.option(Gen.asciiString)
} yield SearchRequest(query, pageNumber, results, nextPage)

private val optionalSearchRequestGen: Gen[Sized, OptionalSearchRequest] =
for {
query <- Gen.string
pageNumber <- Gen.int(Int.MinValue, Int.MaxValue)
results <- Gen.int(Int.MinValue, Int.MaxValue)
nextPage <- Gen.asciiString
} yield OptionalSearchRequest(query, pageNumber, results, nextPage)

val searchRequestSchema: Schema[SearchRequest] = DeriveSchema.gen[SearchRequest]

val optionalSearchRequestSchema: Schema[OptionalSearchRequest] = DeriveSchema.gen[OptionalSearchRequest]

val recordSchema: Schema[ListMap[String, _]] = Schema.record(
TypeId.Structural,
Schema.Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import scala.util.{ Failure, Success, Try }
import org.apache.thrift.protocol._

import zio.schema._
import zio.schema.annotation.optionalField
import zio.schema.codec.BinaryCodec._
import zio.schema.codec.ThriftCodec.Thrift.{
bigDecimalStructure,
Expand Down Expand Up @@ -863,8 +864,8 @@ object ThriftCodec extends BinaryCodec {
def addFields(values: ListMap[Short, Any], index: Int): Result[Array[Any]] =
if (index >= fields.size) Right(buffer)
else {
val Schema.Field(label, schema, _, _, _, _) = fields(index)
val rawValue = values.get((index + 1).toShort)
val Schema.Field(label, schema, annotations, _, _, _) = fields(index)
val rawValue = values.get((index + 1).toShort)
rawValue match {
case Some(value) =>
buffer.update(index, value)
Expand All @@ -874,7 +875,12 @@ object ThriftCodec extends BinaryCodec {
case Some(value) =>
buffer.update(index, value)
addFields(values, index + 1)
case None => fail(path :+ label, "Missing value")
case None =>
val optionalFieldAnnotation = annotations.collectFirst({ case a: optionalField => a })
if (optionalFieldAnnotation.isDefined) {
buffer.update(index, schema.defaultValue.toOption.get)
addFields(values, index + 1)
} else fail(path :+ label, "Missing value")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.thrift.TSerializable
import org.apache.thrift.protocol.{ TBinaryProtocol, TField, TType }

import zio.schema.CaseSet.caseOf
import zio.schema.annotation.optionalField
import zio.schema.codec.{ generated => g }
import zio.schema.{ CaseSet, DeriveSchema, DynamicValue, DynamicValueGen, Schema, SchemaGen, StandardType, TypeId }
import zio.stream.{ ZSink, ZStream }
Expand Down Expand Up @@ -771,6 +772,16 @@ object ThriftCodecSpec extends ZIOSpecDefault {
encoded <- write(new g.EnumValue(g.Color.BLUE))
ed <- decodeNS(schemaEnumValue, encoded)
} yield assert(ed)(equalTo(EnumValue(2)))
},
test("decode case class with optionalField") {
for {
bytes <- writeManually { p =>
p.writeFieldBegin(new TField("name", TType.STRING, 1))
p.writeString("Dan")
p.writeFieldStop()
}
d <- decodeNS(PersonWithOptionalField.schema, bytes)
} yield assert(d)(equalTo(PersonWithOptionalField("Dan", 0)))
}
),
suite("Should fail to decode")(
Expand Down Expand Up @@ -1018,6 +1029,12 @@ object ThriftCodecSpec extends ZIOSpecDefault {

lazy val sequenceOfSumSchema: Schema[SequenceOfSum] = DeriveSchema.gen[SequenceOfSum]

case class PersonWithOptionalField(name: String, @optionalField age: Int)

object PersonWithOptionalField {
implicit val schema: Schema[PersonWithOptionalField] = DeriveSchema.gen[PersonWithOptionalField]
}

def toHex(chunk: Chunk[Byte]): String =
chunk.toArray.map("%02X".format(_)).mkString

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package zio.schema.annotation

final case class optionalField() extends scala.annotation.StaticAnnotation

0 comments on commit ebed2bd

Please sign in to comment.