diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala index 45f3419edf9c6..fa6567ae2aa59 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf import java.util.concurrent.TimeUnit -import com.google.protobuf.{ByteString, DynamicMessage, Message, TypeRegistry} +import com.google.protobuf.{BoolValue, ByteString, BytesValue, DoubleValue, DynamicMessage, FloatValue, Int32Value, Int64Value, Message, StringValue, TypeRegistry, UInt32Value, UInt64Value} import com.google.protobuf.Descriptors._ import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ import com.google.protobuf.util.JsonFormat @@ -259,12 +259,109 @@ private[sql] class ProtobufDeserializer( updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds)) case (MESSAGE, StringType) - if protoType.getMessageType.getFullName == "google.protobuf.Any" => + if protoType.getMessageType.getFullName == "google.protobuf.Any" => (updater, ordinal, value) => // Convert 'Any' protobuf message to JSON string. val jsonStr = jsonPrinter.print(value.asInstanceOf[DynamicMessage]) updater.set(ordinal, UTF8String.fromString(jsonStr)) + // Handle well known wrapper types. We unpack the value field when the desired + // output type is a primitive (determined by the option in [[ProtobufOptions]]) + case (MESSAGE, BooleanType) + if protoType.getMessageType.getFullName == BoolValue.getDescriptor.getFullName => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.setBoolean(ordinal, unwrapped.asInstanceOf[Boolean]) + } + case (MESSAGE, IntegerType) + if (protoType.getMessageType.getFullName == Int32Value.getDescriptor.getFullName + || protoType.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName) => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.setInt(ordinal, unwrapped.asInstanceOf[Int]) + } + case (MESSAGE, LongType) + if (protoType.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName) => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.setLong(ordinal, Integer.toUnsignedLong(unwrapped.asInstanceOf[Int])) + } + case (MESSAGE, LongType) + if (protoType.getMessageType.getFullName == Int64Value.getDescriptor.getFullName + || protoType.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName) => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.setLong(ordinal, unwrapped.asInstanceOf[Long]) + } + case (MESSAGE, DecimalType.LongDecimal) + if (protoType.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName) => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + val dec = Decimal.fromString( + UTF8String.fromString(java.lang.Long.toUnsignedString(unwrapped.asInstanceOf[Long]))) + updater.setDecimal(ordinal, dec) + } + case (MESSAGE, StringType) + if protoType.getMessageType.getFullName == StringValue.getDescriptor.getFullName => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.set(ordinal, UTF8String.fromString(unwrapped.asInstanceOf[String])) + } + case (MESSAGE, BinaryType) + if protoType.getMessageType.getFullName == BytesValue.getDescriptor.getFullName => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.set(ordinal, unwrapped.asInstanceOf[ByteString].toByteArray) + } + case (MESSAGE, FloatType) + if protoType.getMessageType.getFullName == FloatValue.getDescriptor.getFullName => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.setFloat(ordinal, unwrapped.asInstanceOf[Float]) + } + case (MESSAGE, DoubleType) + if protoType.getMessageType.getFullName == DoubleValue.getDescriptor.getFullName => + (updater, ordinal, value) => + val dm = value.asInstanceOf[DynamicMessage] + val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0)) + if (unwrapped == null) { + updater.setNullAt(ordinal) + } else { + updater.setDouble(ordinal, unwrapped.asInstanceOf[Double]) + } + case (MESSAGE, st: StructType) => val writeRecord = getRecordWriter( protoType.getMessageType, diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala index 432f948a90290..1c64e70755d5c 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf import scala.jdk.CollectionConverters._ -import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat} +import com.google.protobuf.{BoolValue, ByteString, BytesValue, DoubleValue, Duration, DynamicMessage, FloatValue, Int32Value, Int64Value, StringValue, Timestamp, UInt32Value, UInt64Value, WireFormat} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ @@ -173,6 +173,52 @@ private[sql] class ProtobufSerializer( java.util.Arrays.asList(result: _*) } + // Handle serializing primitives back into well known wrapper types. + case (BooleanType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == BoolValue.getDescriptor.getFullName => + (getter, ordinal) => + BoolValue.of(getter.getBoolean(ordinal)) + + case (IntegerType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == Int32Value.getDescriptor.getFullName => + (getter, ordinal) => + Int32Value.of(getter.getInt(ordinal)) + + case (IntegerType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName => + (getter, ordinal) => + UInt32Value.of(getter.getInt(ordinal)) + + case (LongType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == Int64Value.getDescriptor.getFullName => + (getter, ordinal) => + Int64Value.of(getter.getLong(ordinal)) + + case (LongType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName => + (getter, ordinal) => + UInt64Value.of(getter.getLong(ordinal)) + + case (StringType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == StringValue.getDescriptor.getFullName => + (getter, ordinal) => + StringValue.of(getter.getUTF8String(ordinal).toString) + + case (BinaryType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == BytesValue.getDescriptor.getFullName => + (getter, ordinal) => + BytesValue.of(ByteString.copyFrom(getter.getBinary(ordinal))) + + case (FloatType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == FloatValue.getDescriptor.getFullName => + (getter, ordinal) => + FloatValue.of(getter.getFloat(ordinal)) + + case (DoubleType, MESSAGE) + if fieldDescriptor.getMessageType.getFullName == DoubleValue.getDescriptor.getFullName => + (getter, ordinal) => + DoubleValue.of(getter.getDouble(ordinal)) + case (st: StructType, MESSAGE) => val structConverter = newStructConverter(st, fieldDescriptor.getMessageType, catalystPath, protoPath) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index f08dfabb606bc..5f8c42df365a8 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -180,6 +180,33 @@ private[sql] class ProtobufOptions( // can contain large unsigned values without overflow. val upcastUnsignedInts: Boolean = parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean + + // Whether to unwrap the struct representation for well known primitve wrapper types when + // deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value, + // google.protobuf.Int64Value, etc.) will get deserialized as structs. We allow the option to + // deserialize them as their respective primitives. + // https://protobuf.dev/reference/protobuf/google.protobuf/ + // + // For example, given a message like: + // ``` + // syntax = "proto3"; + // message Example { + // google.protobuf.Int32Value int_val = 1; + // } + // ``` + // + // The message Example(Int32Value(1)) would be deserialized by default as + // {int_val: {value: 5}} + // + // However, with this option set, it would be deserialized as + // {int_val: 5} + // + // NOTE: With `emit.default.values`, we won't fill in the default primitive value during + // this unwrapping; this behavior preserves as much information as possible. + // Concretely, the behavior with emit defaults and this option set is: + // nil => nil, Int32Value(0) => 0, Int32Value(100) => 100. + val unwrapWellKnownTypes: Boolean = + parameters.getOrElse("unwrap.primitive.wrapper.types", false.toString).toBoolean } private[sql] object ProtobufOptions { diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index 083d1dac081df..b35aa153aaa19 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.protobuf.utils import scala.jdk.CollectionConverters._ +import com.google.protobuf.{BoolValue, BytesValue, DoubleValue, FloatValue, Int32Value, Int64Value, StringValue, UInt32Value, UInt64Value} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.WireFormat @@ -103,8 +104,46 @@ object SchemaConverters extends Logging { fd.getMessageType.getFields.get(1).getName.equals("nanos")) => Some(TimestampType) case MESSAGE if protobufOptions.convertAnyFieldsToJson && - fd.getMessageType.getFullName == "google.protobuf.Any" => + fd.getMessageType.getFullName == "google.protobuf.Any" => Some(StringType) // Any protobuf will be parsed and converted to json string. + + // Unwrap well known primitive wrapper types if the option has been set. + case MESSAGE if fd.getMessageType.getFullName == BoolValue.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(BooleanType) + case MESSAGE if fd.getMessageType.getFullName == Int32Value.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(IntegerType) + case MESSAGE if fd.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + if (protobufOptions.upcastUnsignedInts) { + Some(LongType) + } else { + Some(IntegerType) + } + case MESSAGE if fd.getMessageType.getFullName == Int64Value.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(LongType) + case MESSAGE if fd.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + if (protobufOptions.upcastUnsignedInts) { + Some(DecimalType.LongDecimal) + } else { + Some(LongType) + } + case MESSAGE if fd.getMessageType.getFullName == StringValue.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(StringType) + case MESSAGE if fd.getMessageType.getFullName == BytesValue.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(BinaryType) + case MESSAGE if fd.getMessageType.getFullName == FloatValue.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(FloatType) + case MESSAGE if fd.getMessageType.getFullName == DoubleValue.getDescriptor.getFullName + && protobufOptions.unwrapWellKnownTypes => + Some(DoubleType) + case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry => var keyType: Option[DataType] = None var valueType: Option[DataType] = None diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto index 78b75d08fb56a..a643e91158eb7 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -27,6 +27,7 @@ import "timestamp.proto"; import "duration.proto"; import "basicmessage.proto"; import "google/protobuf/any.proto"; +import "google/protobuf/wrappers.proto"; option java_outer_classname = "SimpleMessageProtos"; @@ -324,3 +325,19 @@ message Proto3AllTypes { } map map = 13; } + +message WellKnownWrapperTypes { + google.protobuf.BoolValue bool_val = 1; + google.protobuf.Int32Value int32_val = 2; + google.protobuf.UInt32Value uint32_val = 3; + google.protobuf.Int64Value int64_val = 4; + google.protobuf.UInt64Value uint64_val = 5; + google.protobuf.StringValue string_val = 6; + google.protobuf.BytesValue bytes_val = 7; + google.protobuf.FloatValue float_val = 8; + google.protobuf.DoubleValue double_val = 9; + + // Sample repeated and map types + repeated google.protobuf.Int32Value int32_list = 10; + map wkt_map = 11; +} diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 67f6568107e64..5e9e737151fe4 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -21,12 +21,12 @@ import java.time.Duration import scala.jdk.CollectionConverters._ -import com.google.protobuf.{Any => AnyProto, ByteString, DynamicMessage} +import com.google.protobuf.{Any => AnyProto, BoolValue, ByteString, BytesValue, DoubleValue, DynamicMessage, FloatValue, Int32Value, Int64Value, StringValue, UInt32Value, UInt64Value} import org.json4s.StringInput import org.json4s.jackson.JsonMethods import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} -import org.apache.spark.sql.functions.{lit, struct, typedLit} +import org.apache.spark.sql.functions.{array, lit, map, struct, typedLit} import org.apache.spark.sql.protobuf.protos.Proto2Messages.Proto2AllTypes import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos._ import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum @@ -1647,6 +1647,281 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } + test("well known types deserialization and round trip") { + val message = spark.range(1).select( + lit(WellKnownWrapperTypes + .newBuilder() + .setBoolVal(BoolValue.of(true)) + .setInt32Val(Int32Value.of(100)) + .setUint32Val(UInt32Value.of(200)) + .setInt64Val(Int64Value.of(300)) + .setUint64Val(UInt64Value.of(400)) + .setStringVal(StringValue.of("string")) + .setBytesVal(BytesValue.of(ByteString.copyFromUtf8("bytes"))) + .setFloatVal(FloatValue.of(1.23f)) + .setDoubleVal(DoubleValue.of(4.56)) + .addInt32List(Int32Value.of(1)) + .addInt32List(Int32Value.of(2)) + .putWktMap(1, StringValue.of("mapval")) + .build().toByteArray + ).as("raw_proto")) + + // By default, well known wrapper types should come out as structs. + val expectedWithoutFlag = spark.range(1).select( + struct( + struct(lit(true) as ("value")).as("bool_val"), + struct(lit(100).as("value")).as("int32_val"), + struct(lit(200).as("value")).as("uint32_val"), + struct(lit(300).as("value")).as("int64_val"), + struct(lit(400).as("value")).as("uint64_val"), + struct(lit("string").as("value")).as("string_val"), + struct(lit("bytes".getBytes).as("value")).as("bytes_val"), + struct(lit(1.23f).as("value")).as("float_val"), + struct(lit(4.56).as("value")).as("double_val"), + array(struct(lit(1).as("value")), struct(lit(2).as("value"))).as("int32_list"), + map(lit(1), struct(lit("mapval").as("value"))).as("wkt_map") + ).as("proto") + ) + + // With the flag set, ensure that well known wrapper types get deserialized as primitives. + val expectedWithFlag = spark.range(1).select( + struct( + lit(true).as("bool_val"), + lit(100).as("int32_val"), + lit(200).as("uint32_val"), + lit(300).as("int64_val"), + lit(400).as("uint64_val"), + lit("string").as("string_val"), + lit("bytes".getBytes).as("bytes_val"), + lit(1.23f).as("float_val"), + lit(4.56).as("double_val"), + typedLit(List(1, 2)).as("int32_list"), + typedLit(Map(1 -> "mapval")).as("wkt_map") + ).as("proto") + ) + + checkWithFileAndClassName("WellKnownWrapperTypes") { case (name, descFilePathOpt) => + // With the option as false, ensure that deserialization works, and the + // value can be round-tripped. + List(Map.empty[String, String], Map("unwrap.primitive.wrapper.types" -> "false")) + .foreach(opts => { + val parsed = message.select(from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + opts).as("parsed")) + checkAnswer(parsed, expectedWithoutFlag) + + // Verify that round-tripping gives us the same parsed representation. + val reserialized = parsed.select( + to_protobuf_wrapper($"parsed", name, descFilePathOpt).as("reserialized")) + val reparsed = reserialized.select( + from_protobuf_wrapper($"reserialized", name, descFilePathOpt, opts).as("reparsed")) + checkAnswer(parsed, reparsed) + }) + + // Without the option not set or set as false, ensure that the deserialization is as + // expected and that round-tripping works. + val opt = Map("unwrap.primitive.wrapper.types" -> "true") + val parsed = message.select(from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + opt).as("parsed")) + checkAnswer(parsed, expectedWithFlag) + + val reserialized = parsed.select( + to_protobuf_wrapper($"parsed", name, descFilePathOpt).as("reserialized")) + val reparsed = reserialized.select( + from_protobuf_wrapper($"reserialized", name, descFilePathOpt, opt).as("reparsed")) + checkAnswer(parsed, reparsed) + } + } + + test("test well known wrappers with emit defaults") { + // Test that the emit defaults behavior and unwrap primitives behavior work correctly together. + // We'll go through when a well known wrapper type is not set, set to zero, or set to nonzero + // and show the behavior of deserialization under every combination of the + // "unwrap.primitive.wrapper.types" and "emit.default.values" flags. + + // Setup test data for the three cases of unset, explicitly zero, and non-zero. + val unset = spark.range(1).select( + lit( + WellKnownWrapperTypes.newBuilder().build().toByteArray + ).as("raw_proto")) + + val explicitZero = spark.range(1).select( + lit( + WellKnownWrapperTypes.newBuilder().setInt32Val(Int32Value.of(0)).build().toByteArray + ).as("raw_proto")) + + val explicitNonzero = spark.range(1).select( + lit( + WellKnownWrapperTypes.newBuilder().setInt32Val(Int32Value.of(100)).build().toByteArray + ).as("raw_proto")) + + val expectedEmpty = spark.range(1).select(lit(null).as("int32_val")) + + // For all combinations of unwrap / emitDefaults, check that we get back the expected values. + checkWithFileAndClassName("WellKnownWrapperTypes") { case (name, descFilePathOpt) => + for { + unwrap <- Seq("true", "false") + defaults <- Seq("true", "false") + } { + // For unset values, we'll always get back null. + checkAnswer( + unset.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("unwrap.primitive.wrapper.types" -> unwrap, "emit.default.values" -> defaults) + ).as("proto") + ).select("proto.int32_val"), + expectedEmpty + ) + + // For explicit zero, we should get back null or zero based on emit default values. + val parsedExplicitZero = + explicitZero.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("unwrap.primitive.wrapper.types" -> unwrap, "emit.default.values" -> defaults) + ).as("proto") + ).select("proto.int32_val") + + if (unwrap == "false") { + if (defaults == "false") { + checkAnswer( + parsedExplicitZero, + spark.range(1).select( + struct(lit(null).as("value")) + ).as("int32_val") + ) + } else { + checkAnswer( + parsedExplicitZero, + spark.range(1).select( + struct(lit(0).as("value")) + ).as("int32_val") + ) + } + } else { + if (defaults == "false") { + checkAnswer(parsedExplicitZero, expectedEmpty) + } else { + checkAnswer(parsedExplicitZero, Seq((0)).toDF("int32_val")) + } + } + + // For nonzero, we should get back the number or wrapped version regardless + // of the value of emit defaults. + val parsedNonzero = + explicitNonzero.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("unwrap.primitive.wrapper.types" -> unwrap, "emit.default.values" -> defaults) + ).as("proto") + ).select("proto.int32_val") + + if (unwrap == "true") { + checkAnswer(parsedNonzero, Seq((100)).toDF("int32_val")) + } else { + checkAnswer( + parsedNonzero, + spark.range(1).select( + struct(lit(100).as("value")).as("int32_val") + ) + ) + } + } + } + } + + test("test well known wrappers with upcast ints") { + // Test that the unwrap primitives behavior and upcast uint64 work correctly together. + // We'll check the deserialization behavior under every combination of the + // "unwrap.primitive.wrapper.types" and "emit.default.values" flags. + + // Set up an example df with negative integer/long values, which have 1 in the largest bit. + // When interpreted as unsigned instead, they'd be large numbers. + // Integer.MIN_VALUE + 4 = 1000 <28 0s> 0100 = -2147483644 = 2147483652 + // Long.MinValue + 4 = 1000 <60 0s> 0100 = -9223372036854775804 = 9223372036854775812 + val originalInt = Integer.MIN_VALUE + 4 + val originalLong = Long.MinValue + 4 + val unsignedInt = 2147483652L + val unsignedLong = BigDecimal("9223372036854775812") + + val unsigned = spark.range(1).select( + lit( + WellKnownWrapperTypes.newBuilder() + .setUint32Val(UInt32Value.of(originalInt)) + .setUint64Val(UInt64Value.of(originalLong)) + .build().toByteArray + ).as("raw_proto")) + + + // For every combination of unwrap/upcast, check that we get the correct values back. + checkWithFileAndClassName("WellKnownWrapperTypes") { case (name, descFilePathOpt) => + for { + unwrap <- Seq("true", "false") + upcast <- Seq("true", "false") + } { + val parsed = unsigned.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("unwrap.primitive.wrapper.types" -> unwrap, "upcast.unsigned.ints" -> upcast) + ).as("proto") + ).select("proto.uint32_val", "proto.uint64_val") + + if (unwrap == "false") { + if (upcast == "false") { + // unwrap=false, upcast=false should give back negative numbers in struct format. + checkAnswer( + parsed, + spark.range(1).select( + struct(lit(originalInt).as("value")).as("uint32_value"), + struct(lit(originalLong).as("value")).as("uint64_value") + ).as("int32_val") + ) + } else { + // unwrap=false, upcast=true should give back large positive numbers in struct format. + checkAnswer( + parsed, + spark.range(1).select( + struct(lit(unsignedInt).as("value")).as("uint32_value"), + struct(lit(unsignedLong).as("value")).as("uint64_value") + ).as("int32_val") + ) + } + } + else { + // unwrap=true, upcast=false should give back negative primitives. + if (upcast == "false") { + checkAnswer(parsed, + spark.range(1).select( + lit(originalInt).as("uint32_value"), + lit(originalLong).as("uint64_value") + )) + } else { + // unwrap=true, upcast=true should give back large positive primitives. + checkAnswer(parsed, + spark.range(1).select( + lit(unsignedInt).as("uint32_value"), + lit(unsignedLong).as("uint64_value") + )) + } + } + } + } + } + def testFromProtobufWithOptions( df: DataFrame, expectedDf: DataFrame,