Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ private[sql] class ProtobufDeserializer(
case (INT, ShortType) =>
(updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short])

case (INT, LongType) =>
(updater, ordinal, value) => updater.setLong(
ordinal,
UTF8String.fromString(
java.lang.Integer.toUnsignedString(value.asInstanceOf[Int])).toLongExact)

case (
MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
Expand All @@ -201,6 +207,13 @@ private[sql] class ProtobufDeserializer(
case (LONG, LongType) =>
(updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long])

case (LONG, DecimalType.LongDecimal) =>
(updater, ordinal, value) =>
updater.setDecimal(
ordinal,
Decimal.fromString(
UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long]))))

case (FLOAT, FloatType) =>
(updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf

import scala.collection.JavaConverters._

import com.google.protobuf.{Duration, DynamicMessage, Timestamp}
import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

Expand Down Expand Up @@ -91,8 +91,22 @@ private[sql] class ProtobufSerializer(
(getter, ordinal) => {
getter.getInt(ordinal)
}

// uint32 is represented as Long so convert it back correctly.
case (LongType, INT) if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT32 =>
(getter, ordinal) => {
getter.getLong(ordinal).toInt
}

case (LongType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)

// uint64 is represented as Decimal so convert it back correctly here
case (DecimalType(), LONG)
if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 =>
(getter, ordinal) => {
getter.getDecimal(ordinal, 20, 0).toUnscaledLong // todo make constant
}
case (FloatType, FLOAT) =>
(getter, ordinal) => getter.getFloat(ordinal)
case (DoubleType, DOUBLE) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ private[sql] class ProtobufOptions(
// instead of string, so use caution if changing existing parsing logic.
val enumsAsInts: Boolean =
parameters.getOrElse("enums.as.ints", false.toString).toBoolean

// Protobuf supports unsigned integer types uint32 and uint64. By default this library
// will serialize them as LongType and Decimal(20, 0) so that large unsigned values
// can fit into the resulting type.
//
// Previously, this library used to take uint32 and uint64 into IntegerType and
// LongType respectively. Large unsigned values would overflow as negative numbers,
// since Integer and Long are signed. If you would like to preserve that older behavior,
// you can set this flag.
val unsignedAsSignedPrimitive: Boolean =
parameters.getOrElse("unsigned.as.signed.primitive", false.toString).toBoolean
}

private[sql] object ProtobufOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils
import scala.collection.JavaConverters._

import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.WireFormat

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -67,9 +68,26 @@ object SchemaConverters extends Logging {
existingRecordNames: Map[String, Int],
protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

val unsignedAsPrimitive = protobufOptions.unsignedAsSignedPrimitive

val dataType = fd.getJavaType match {
case INT => Some(IntegerType)
case LONG => Some(LongType)

// Convert uint32 to Long type so that large values do not overflow signed
// integer.
case INT => if (fd.getLiteType == WireFormat.FieldType.UINT32 && !unsignedAsPrimitive) {
Some(LongType)
} else {
Some(IntegerType)
}

// Convert uint64 to Decimal(20,0) so that large values do not overflow
// Long, which is signed.
case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64 && !unsignedAsPrimitive) {
Some(DecimalType.LongDecimal)
} else {
Some(LongType)
}
case FLOAT => Some(FloatType)
case DOUBLE => Some(DoubleType)
case BOOLEAN => Some(BooleanType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,46 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
}

test("test unsigned integer types") {
// The java type for uint32 and uint64 is signed integer and long respectively.
// Let's check that we're converting correctly
val sample = spark.range(1).select(
lit(
SimpleMessage
.newBuilder()
.setUint32Value(Integer.MIN_VALUE)
.setUint64Value(Long.MinValue)
.build()
.toByteArray
).as("raw_proto"))

val expected = spark.range(1).select(
lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using MIN_VALUE as its all 1s in binary and is the largest possible number if its interpreted as "unsigned"

lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value")
)

val expectedWithLegacy = spark.range(1).select(
lit(Integer.MIN_VALUE).as("uint32_value"),
lit(Long.MinValue).as("uint64_value")
)

checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) =>
checkAnswer(
sample.select(
from_protobuf_wrapper($"raw_proto", name, descFilePathOpt).as("proto"))
.select("proto.uint32_value", "proto.uint64_value"),
expected)
checkAnswer(
sample.select(
from_protobuf_wrapper(
$"raw_proto",
name,
descFilePathOpt,
Map("unsigned.as.signed.primitive" -> "true")).as("proto"))
.select("proto.uint32_value", "proto.uint64_value"),
expectedWithLegacy)
}
}

def testFromProtobufWithOptions(
df: DataFrame,
Expand Down