Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ private[sql] class ProtobufDeserializer(
case (INT, ShortType) =>
(updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short])

case (INT, LongType) =>
(updater, ordinal, value) =>
updater.setLong(
ordinal,
Integer.toUnsignedLong(value.asInstanceOf[Int]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I wouldn't add this. It would be problematic when Spark has unsigned types. For the same reason, Parquet also doesn't support unsigned physical types for Spark.

Copy link
Contributor Author

@justaparth justaparth Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be problematic when Spark has unsigned types. For the same reason, Parquet also doesn't support unsigned physical types for Spark.

hey, i'm not sure if i follow; do you mind explaining what you mean by this?

My goal here is to add an option allowing unsigned 32 and 64 bit integers coming from protobuf to be represented in a type that can contain them without overflow. I mention this in the description, but I actually modeled my code off of how the parquet code today is written, which i believe is doing this same thing by default:

https://github.com/justaparth/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala#L243-L270

https://github.com/justaparth/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala#L345-L351

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it was added at here #31921. My memory was old.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am less certain about over all Spark context, but for from_protobuf() this looks fine so far (since it is only enabled with an option).

case (
MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
Expand All @@ -201,6 +206,13 @@ private[sql] class ProtobufDeserializer(
case (LONG, LongType) =>
(updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long])

case (LONG, DecimalType.LongDecimal) =>
(updater, ordinal, value) =>
updater.setDecimal(
ordinal,
Decimal.fromString(
UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long]))))

case (FLOAT, FloatType) =>
(updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf

import scala.jdk.CollectionConverters._

import com.google.protobuf.{Duration, DynamicMessage, Timestamp}
import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

Expand Down Expand Up @@ -91,8 +91,17 @@ private[sql] class ProtobufSerializer(
(getter, ordinal) => {
getter.getInt(ordinal)
}
case (LongType, INT) if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT32 =>
(getter, ordinal) => {
getter.getLong(ordinal).toInt
}
case (LongType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)
case (DecimalType(), LONG)
if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 =>
(getter, ordinal) => {
getter.getDecimal(ordinal, 20, 0).toUnscaledLong
}
case (FloatType, FLOAT) =>
(getter, ordinal) => getter.getFloat(ordinal)
case (DoubleType, DOUBLE) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ private[sql] class ProtobufOptions(
// instead of string, so use caution if changing existing parsing logic.
val enumsAsInts: Boolean =
parameters.getOrElse("enums.as.ints", false.toString).toBoolean

// Protobuf supports unsigned integer types uint32 and uint64. By default this library
// will serialize them as the signed IntegerType and LongType respectively. For very
// large unsigned values this can cause overflow, causing these numbers
// to be represented as negative (above 2^31 for uint32
// and above 2^63 for uint64).
//
// Enabling this option will upcast unsigned integers into a larger type,
// i.e. LongType for uint32 and Decimal(20, 0) for uint64 so their representation
// can contain large unsigned values without overflow.
val upcastUnsignedInts: Boolean =
parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean
}

private[sql] object ProtobufOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils
import scala.jdk.CollectionConverters._

import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.WireFormat

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -67,9 +68,22 @@ object SchemaConverters extends Logging {
existingRecordNames: Map[String, Int],
protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

val dataType = fd.getJavaType match {
case INT => Some(IntegerType)
case LONG => Some(LongType)
// When the protobuf type is unsigned and upcastUnsignedIntegers has been set,
// use a larger type (LongType and Decimal(20,0) for uint32 and uint64).
case INT =>
if (fd.getLiteType == WireFormat.FieldType.UINT32 && protobufOptions.upcastUnsignedInts) {
Some(LongType)
} else {
Some(IntegerType)
}
case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64
&& protobufOptions.upcastUnsignedInts) {
Some(DecimalType.LongDecimal)
} else {
Some(LongType)
}
case FLOAT => Some(FloatType)
case DOUBLE => Some(DoubleType)
case BOOLEAN => Some(BooleanType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,52 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
}

test("test unsigned integer types") {
// Test that we correctly handle unsigned integer parsing.
// We're using Integer/Long's `MIN_VALUE` as it has a 1 in the sign bit.
val sample = spark.range(1).select(
lit(
SimpleMessage
.newBuilder()
.setUint32Value(Integer.MIN_VALUE)
.setUint64Value(Long.MinValue)
.build()
.toByteArray
).as("raw_proto"))

val expectedWithoutFlag = spark.range(1).select(
lit(Integer.MIN_VALUE).as("uint32_value"),
lit(Long.MinValue).as("uint64_value")
)

val expectedWithFlag = spark.range(1).select(
lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"),
lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value")
)

checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) =>
List(
Map.empty[String, String],
Map("upcast.unsigned.ints" -> "false")).foreach(opts => {
checkAnswer(
sample.select(
from_protobuf_wrapper($"raw_proto", name, descFilePathOpt, opts).as("proto"))
.select("proto.uint32_value", "proto.uint64_value"),
expectedWithoutFlag)
})

checkAnswer(
sample.select(
from_protobuf_wrapper(
$"raw_proto",
name,
descFilePathOpt,
Map("upcast.unsigned.ints" -> "true")).as("proto"))
.select("proto.uint32_value", "proto.uint64_value"),
expectedWithFlag)
}
}


def testFromProtobufWithOptions(
df: DataFrame,
Expand Down