diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala index 954b3de72ca56..4d1040a62c52f 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -193,6 +193,12 @@ private[sql] class ProtobufDeserializer( case (INT, ShortType) => (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short]) + case (INT, LongType) => + (updater, ordinal, value) => updater.setLong( + ordinal, + UTF8String.fromString( + java.lang.Integer.toUnsignedString(value.asInstanceOf[Int])).toLongExact) + case ( MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated => @@ -201,6 +207,13 @@ private[sql] class ProtobufDeserializer( case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) + case (LONG, DecimalType.LongDecimal) => + (updater, ordinal, value) => + updater.setDecimal( + ordinal, + Decimal.fromString( + UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long])))) + case (FLOAT, FloatType) => (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float]) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala index b11284d1f2897..40439a89a227a 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf import scala.collection.JavaConverters._ -import com.google.protobuf.{Duration, DynamicMessage, Timestamp} +import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ @@ -91,8 +91,22 @@ private[sql] class ProtobufSerializer( (getter, ordinal) => { getter.getInt(ordinal) } + + // uint32 is represented as Long so convert it back correctly. + case (LongType, INT) if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT32 => + (getter, ordinal) => { + getter.getLong(ordinal).toInt + } + case (LongType, LONG) => (getter, ordinal) => getter.getLong(ordinal) + + // uint64 is represented as Decimal so convert it back correctly here + case (DecimalType(), LONG) + if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 => + (getter, ordinal) => { + getter.getDecimal(ordinal, 20, 0).toUnscaledLong // todo make constant + } case (FloatType, FLOAT) => (getter, ordinal) => getter.getFloat(ordinal) case (DoubleType, DOUBLE) => diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index dfdef1f0ec357..3850a8a51a412 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -168,6 +168,17 @@ private[sql] class ProtobufOptions( // instead of string, so use caution if changing existing parsing logic. val enumsAsInts: Boolean = parameters.getOrElse("enums.as.ints", false.toString).toBoolean + + // Protobuf supports unsigned integer types uint32 and uint64. By default this library + // will serialize them as LongType and Decimal(20, 0) so that large unsigned values + // can fit into the resulting type. + // + // Previously, this library used to take uint32 and uint64 into IntegerType and + // LongType respectively. Large unsigned values would overflow as negative numbers, + // since Integer and Long are signed. If you would like to preserve that older behavior, + // you can set this flag. + val unsignedAsSignedPrimitive: Boolean = + parameters.getOrElse("unsigned.as.signed.primitive", false.toString).toBoolean } private[sql] object ProtobufOptions { diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index 33b7ef87744e0..50a0c7e2f9258 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils import scala.collection.JavaConverters._ import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import com.google.protobuf.WireFormat import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -67,9 +68,26 @@ object SchemaConverters extends Logging { existingRecordNames: Map[String, Int], protobufOptions: ProtobufOptions): Option[StructField] = { import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + + val unsignedAsPrimitive = protobufOptions.unsignedAsSignedPrimitive + val dataType = fd.getJavaType match { - case INT => Some(IntegerType) - case LONG => Some(LongType) + + // Convert uint32 to Long type so that large values do not overflow signed + // integer. + case INT => if (fd.getLiteType == WireFormat.FieldType.UINT32 && !unsignedAsPrimitive) { + Some(LongType) + } else { + Some(IntegerType) + } + + // Convert uint64 to Decimal(20,0) so that large values do not overflow + // Long, which is signed. + case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64 && !unsignedAsPrimitive) { + Some(DecimalType.LongDecimal) + } else { + Some(LongType) + } case FLOAT => Some(FloatType) case DOUBLE => Some(DoubleType) case BOOLEAN => Some(BooleanType) diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 7e6cf0a3c9689..bc3243f7a4af0 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -1572,6 +1572,46 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } + test("test unsigned integer types") { + // The java type for uint32 and uint64 is signed integer and long respectively. + // Let's check that we're converting correctly + val sample = spark.range(1).select( + lit( + SimpleMessage + .newBuilder() + .setUint32Value(Integer.MIN_VALUE) + .setUint64Value(Long.MinValue) + .build() + .toByteArray + ).as("raw_proto")) + + val expected = spark.range(1).select( + lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"), + lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value") + ) + + val expectedWithLegacy = spark.range(1).select( + lit(Integer.MIN_VALUE).as("uint32_value"), + lit(Long.MinValue).as("uint64_value") + ) + + checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) => + checkAnswer( + sample.select( + from_protobuf_wrapper($"raw_proto", name, descFilePathOpt).as("proto")) + .select("proto.uint32_value", "proto.uint64_value"), + expected) + checkAnswer( + sample.select( + from_protobuf_wrapper( + $"raw_proto", + name, + descFilePathOpt, + Map("unsigned.as.signed.primitive" -> "true")).as("proto")) + .select("proto.uint32_value", "proto.uint64_value"), + expectedWithLegacy) + } + } def testFromProtobufWithOptions( df: DataFrame,