diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala index a46baf5137995..45f3419edf9c6 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -193,6 +193,11 @@ private[sql] class ProtobufDeserializer( case (INT, ShortType) => (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short]) + case (INT, LongType) => + (updater, ordinal, value) => + updater.setLong( + ordinal, + Integer.toUnsignedLong(value.asInstanceOf[Int])) case ( MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated => @@ -201,6 +206,13 @@ private[sql] class ProtobufDeserializer( case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) + case (LONG, DecimalType.LongDecimal) => + (updater, ordinal, value) => + updater.setDecimal( + ordinal, + Decimal.fromString( + UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long])))) + case (FLOAT, FloatType) => (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float]) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala index 4684934a56583..432f948a90290 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf import scala.jdk.CollectionConverters._ -import com.google.protobuf.{Duration, DynamicMessage, Timestamp} +import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ @@ -91,8 +91,17 @@ private[sql] class ProtobufSerializer( (getter, ordinal) => { getter.getInt(ordinal) } + case (LongType, INT) if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT32 => + (getter, ordinal) => { + getter.getLong(ordinal).toInt + } case (LongType, LONG) => (getter, ordinal) => getter.getLong(ordinal) + case (DecimalType(), LONG) + if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 => + (getter, ordinal) => { + getter.getDecimal(ordinal, 20, 0).toUnscaledLong + } case (FloatType, FLOAT) => (getter, ordinal) => getter.getFloat(ordinal) case (DoubleType, DOUBLE) => diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index dfdef1f0ec357..f08dfabb606bc 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -168,6 +168,18 @@ private[sql] class ProtobufOptions( // instead of string, so use caution if changing existing parsing logic. val enumsAsInts: Boolean = parameters.getOrElse("enums.as.ints", false.toString).toBoolean + + // Protobuf supports unsigned integer types uint32 and uint64. By default this library + // will serialize them as the signed IntegerType and LongType respectively. For very + // large unsigned values this can cause overflow, causing these numbers + // to be represented as negative (above 2^31 for uint32 + // and above 2^63 for uint64). + // + // Enabling this option will upcast unsigned integers into a larger type, + // i.e. LongType for uint32 and Decimal(20, 0) for uint64 so their representation + // can contain large unsigned values without overflow. + val upcastUnsignedInts: Boolean = + parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean } private[sql] object ProtobufOptions { diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index aa3ac998a746b..083d1dac081df 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils import scala.jdk.CollectionConverters._ import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import com.google.protobuf.WireFormat import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -67,9 +68,22 @@ object SchemaConverters extends Logging { existingRecordNames: Map[String, Int], protobufOptions: ProtobufOptions): Option[StructField] = { import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + val dataType = fd.getJavaType match { - case INT => Some(IntegerType) - case LONG => Some(LongType) + // When the protobuf type is unsigned and upcastUnsignedIntegers has been set, + // use a larger type (LongType and Decimal(20,0) for uint32 and uint64). + case INT => + if (fd.getLiteType == WireFormat.FieldType.UINT32 && protobufOptions.upcastUnsignedInts) { + Some(LongType) + } else { + Some(IntegerType) + } + case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64 + && protobufOptions.upcastUnsignedInts) { + Some(DecimalType.LongDecimal) + } else { + Some(LongType) + } case FLOAT => Some(FloatType) case DOUBLE => Some(DoubleType) case BOOLEAN => Some(BooleanType) diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 9c3975979848d..67f6568107e64 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -1600,6 +1600,52 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } + test("test unsigned integer types") { + // Test that we correctly handle unsigned integer parsing. + // We're using Integer/Long's `MIN_VALUE` as it has a 1 in the sign bit. + val sample = spark.range(1).select( + lit( + SimpleMessage + .newBuilder() + .setUint32Value(Integer.MIN_VALUE) + .setUint64Value(Long.MinValue) + .build() + .toByteArray + ).as("raw_proto")) + + val expectedWithoutFlag = spark.range(1).select( + lit(Integer.MIN_VALUE).as("uint32_value"), + lit(Long.MinValue).as("uint64_value") + ) + + val expectedWithFlag = spark.range(1).select( + lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"), + lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value") + ) + + checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) => + List( + Map.empty[String, String], + Map("upcast.unsigned.ints" -> "false")).foreach(opts => { + checkAnswer( + sample.select( + from_protobuf_wrapper($"raw_proto", name, descFilePathOpt, opts).as("proto")) + .select("proto.uint32_value", "proto.uint64_value"), + expectedWithoutFlag) + }) + + checkAnswer( + sample.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("upcast.unsigned.ints" -> "true")).as("proto")) + .select("proto.uint32_value", "proto.uint64_value"), + expectedWithFlag) + } + } + def testFromProtobufWithOptions( df: DataFrame,