diff --git a/src/main/scala/scalapb/json4s/JsonFormat.scala b/src/main/scala/scalapb/json4s/JsonFormat.scala index 8723049..dd4ad60 100644 --- a/src/main/scala/scalapb/json4s/JsonFormat.scala +++ b/src/main/scala/scalapb/json4s/JsonFormat.scala @@ -461,6 +461,7 @@ class Printer private (config: Printer.PrinterConfig) { object Parser { private final case class ParserConfig( isIgnoringUnknownFields: Boolean, + isIgnoringOverlappingOneofFields: Boolean, mapEntriesAsKeyValuePairs: Boolean, formatRegistry: FormatRegistry, typeRegistry: TypeRegistry @@ -472,6 +473,7 @@ class Parser private (config: Parser.ParserConfig) { this( Parser.ParserConfig( isIgnoringUnknownFields = false, + isIgnoringOverlappingOneofFields = false, mapEntriesAsKeyValuePairs = false, JsonFormat.DefaultRegistry, TypeRegistry.empty @@ -490,6 +492,7 @@ class Parser private (config: Parser.ParserConfig) { this( Parser.ParserConfig( isIgnoringUnknownFields = false, + isIgnoringOverlappingOneofFields = false, mapEntriesAsKeyValuePairs = false, formatRegistry, typeRegistry @@ -499,6 +502,9 @@ class Parser private (config: Parser.ParserConfig) { def ignoringUnknownFields: Parser = new Parser(config.copy(isIgnoringUnknownFields = true)) + def ignoringOverlappingOneofFields: Parser = + new Parser(config.copy(isIgnoringOverlappingOneofFields = true)) + def mapEntriesAsKeyValuePairs: Parser = new Parser(config.copy(mapEntriesAsKeyValuePairs = true)) @@ -593,12 +599,23 @@ class Parser private (config: Parser.ParserConfig) { case None => value match { case JObject(fields) => + val usedOneofs = mutable.Set[OneofDescriptor]() val fieldMap = JsonFormat.MemorizedFieldNameMap(cmp.scalaDescriptor) val valueMapBuilder = Map.newBuilder[FieldDescriptor, PValue] fields.foreach { case (name: String, jValue: JValue) => if (fieldMap.contains(name)) { if (jValue != JNull) { val fd = fieldMap(name) + fd.containingOneof.foreach(o => + if ( + !config.isIgnoringOverlappingOneofFields && !usedOneofs + .add(o) + ) { + throw new JsonFormatException( + s"Overlapping field '${name}' in oneof" + ) + } + ) valueMapBuilder += (fd -> parseValue(fd, jValue)) } } else if ( diff --git a/src/test/scala/scalapb/json4s/JsonFormatSpec.scala b/src/test/scala/scalapb/json4s/JsonFormatSpec.scala index 529578a..41a3fd4 100644 --- a/src/test/scala/scalapb/json4s/JsonFormatSpec.scala +++ b/src/test/scala/scalapb/json4s/JsonFormatSpec.scala @@ -7,6 +7,7 @@ import com.google.protobuf.any.{Any => PBAny} import com.google.protobuf.util.JsonFormat.{TypeRegistry => JavaTypeRegistry} import com.google.protobuf.util.{JsonFormat => JavaJsonFormat} import jsontest.custom_collection.{Guitar, Studio} +import jsontest.oneof.OneOf import jsontest.test._ import jsontest.test3._ import org.json4s.JsonDSL._ @@ -674,6 +675,20 @@ class JsonFormatSpec JsonFormat.parser.fromJsonString[TestAllTypes](javaJson) must be(obj) } + "oneofs" should "fail for overlapping keys if failOnOverlappingOneofKeys" in new DefaultParserContext { + val extraKey = """{"primitive": "", "wrapper": ""}""" + assertFails(extraKey, OneOf) + } + + "oneofs" should "not fail for overlapping keys if ignoreOverlappingOneofKeys" in { + val extraKey = """{"primitive": "", "wrapper": ""}""" + val scalaParser = new Parser().ignoringOverlappingOneofFields + val parsedScala = scalaParser.fromJsonString[OneOf](extraKey)(OneOf) + parsedScala must be( + OneOf(field = OneOf.Field.Primitive("")) + ) + } + "TestProto" should "be TestJsonWithMapEntriesAsKeyValuePairs when converted to Proto with mapEntriesAsKeyValuePairs setting" in { JsonFormat.printer.mapEntriesAsKeyValuePairs.toJson(TestProto) must be( parse(TestJsonWithMapEntriesAsKeyValuePairs)