Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf

import java.util.concurrent.TimeUnit

import com.google.protobuf.{ByteString, DynamicMessage, Message, TypeRegistry}
import com.google.protobuf.{BoolValue, ByteString, BytesValue, DoubleValue, DynamicMessage, FloatValue, Int32Value, Int64Value, Message, StringValue, TypeRegistry, UInt32Value, UInt64Value}
import com.google.protobuf.Descriptors._
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
import com.google.protobuf.util.JsonFormat
Expand Down Expand Up @@ -259,12 +259,109 @@ private[sql] class ProtobufDeserializer(
updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds))

case (MESSAGE, StringType)
if protoType.getMessageType.getFullName == "google.protobuf.Any" =>
if protoType.getMessageType.getFullName == "google.protobuf.Any" =>
(updater, ordinal, value) =>
// Convert 'Any' protobuf message to JSON string.
val jsonStr = jsonPrinter.print(value.asInstanceOf[DynamicMessage])
updater.set(ordinal, UTF8String.fromString(jsonStr))

// Handle well known wrapper types. We unpack the value field when the desired
// output type is a primitive (determined by the option in [[ProtobufOptions]])
case (MESSAGE, BooleanType)
if protoType.getMessageType.getFullName == BoolValue.getDescriptor.getFullName =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.setBoolean(ordinal, unwrapped.asInstanceOf[Boolean])
}
case (MESSAGE, IntegerType)
if (protoType.getMessageType.getFullName == Int32Value.getDescriptor.getFullName
|| protoType.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName) =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.setInt(ordinal, unwrapped.asInstanceOf[Int])
}
case (MESSAGE, LongType)
if (protoType.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName) =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.setLong(ordinal, Integer.toUnsignedLong(unwrapped.asInstanceOf[Int]))
}
case (MESSAGE, LongType)
if (protoType.getMessageType.getFullName == Int64Value.getDescriptor.getFullName
|| protoType.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName) =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.setLong(ordinal, unwrapped.asInstanceOf[Long])
}
case (MESSAGE, DecimalType.LongDecimal)
if (protoType.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName) =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
val dec = Decimal.fromString(
UTF8String.fromString(java.lang.Long.toUnsignedString(unwrapped.asInstanceOf[Long])))
updater.setDecimal(ordinal, dec)
}
case (MESSAGE, StringType)
if protoType.getMessageType.getFullName == StringValue.getDescriptor.getFullName =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.set(ordinal, UTF8String.fromString(unwrapped.asInstanceOf[String]))
}
case (MESSAGE, BinaryType)
if protoType.getMessageType.getFullName == BytesValue.getDescriptor.getFullName =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.set(ordinal, unwrapped.asInstanceOf[ByteString].toByteArray)
}
case (MESSAGE, FloatType)
if protoType.getMessageType.getFullName == FloatValue.getDescriptor.getFullName =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.setFloat(ordinal, unwrapped.asInstanceOf[Float])
}
case (MESSAGE, DoubleType)
if protoType.getMessageType.getFullName == DoubleValue.getDescriptor.getFullName =>
(updater, ordinal, value) =>
val dm = value.asInstanceOf[DynamicMessage]
val unwrapped = getFieldValue(dm, dm.getDescriptorForType.getFields.get(0))
if (unwrapped == null) {
updater.setNullAt(ordinal)
} else {
updater.setDouble(ordinal, unwrapped.asInstanceOf[Double])
}

case (MESSAGE, st: StructType) =>
val writeRecord = getRecordWriter(
protoType.getMessageType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf

import scala.jdk.CollectionConverters._

import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat}
import com.google.protobuf.{BoolValue, ByteString, BytesValue, DoubleValue, Duration, DynamicMessage, FloatValue, Int32Value, Int64Value, StringValue, Timestamp, UInt32Value, UInt64Value, WireFormat}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

Expand Down Expand Up @@ -173,6 +173,52 @@ private[sql] class ProtobufSerializer(
java.util.Arrays.asList(result: _*)
}

// Handle serializing primitives back into well known wrapper types.
case (BooleanType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == BoolValue.getDescriptor.getFullName =>
(getter, ordinal) =>
BoolValue.of(getter.getBoolean(ordinal))

case (IntegerType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == Int32Value.getDescriptor.getFullName =>
(getter, ordinal) =>
Int32Value.of(getter.getInt(ordinal))

case (IntegerType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName =>
(getter, ordinal) =>
UInt32Value.of(getter.getInt(ordinal))

case (LongType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == Int64Value.getDescriptor.getFullName =>
(getter, ordinal) =>
Int64Value.of(getter.getLong(ordinal))

case (LongType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName =>
(getter, ordinal) =>
UInt64Value.of(getter.getLong(ordinal))

case (StringType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == StringValue.getDescriptor.getFullName =>
(getter, ordinal) =>
StringValue.of(getter.getUTF8String(ordinal).toString)

case (BinaryType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == BytesValue.getDescriptor.getFullName =>
(getter, ordinal) =>
BytesValue.of(ByteString.copyFrom(getter.getBinary(ordinal)))

case (FloatType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == FloatValue.getDescriptor.getFullName =>
(getter, ordinal) =>
FloatValue.of(getter.getFloat(ordinal))

case (DoubleType, MESSAGE)
if fieldDescriptor.getMessageType.getFullName == DoubleValue.getDescriptor.getFullName =>
(getter, ordinal) =>
DoubleValue.of(getter.getDouble(ordinal))

case (st: StructType, MESSAGE) =>
val structConverter =
newStructConverter(st, fieldDescriptor.getMessageType, catalystPath, protoPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,33 @@ private[sql] class ProtobufOptions(
// can contain large unsigned values without overflow.
val upcastUnsignedInts: Boolean =
parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean

// Whether to unwrap the struct representation for well known primitve wrapper types when
// deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value,
// google.protobuf.Int64Value, etc.) will get deserialized as structs. We allow the option to
// deserialize them as their respective primitives.
Comment on lines +186 to +187
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update this to clarify that emit.default.values does not apply here? I.e. an Int32Value value field would be null if unset, even if emit.default.values is set to true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure! just added a comment with an example

// https://protobuf.dev/reference/protobuf/google.protobuf/
//
// For example, given a message like:
// ```
// syntax = "proto3";
// message Example {
// google.protobuf.Int32Value int_val = 1;
// }
// ```
//
// The message Example(Int32Value(1)) would be deserialized by default as
// {int_val: {value: 5}}
//
// However, with this option set, it would be deserialized as
// {int_val: 5}
//
// NOTE: With `emit.default.values`, we won't fill in the default primitive value during
// this unwrapping; this behavior preserves as much information as possible.
// Concretely, the behavior with emit defaults and this option set is:
// nil => nil, Int32Value(0) => 0, Int32Value(100) => 100.
val unwrapWellKnownTypes: Boolean =
parameters.getOrElse("unwrap.primitive.wrapper.types", false.toString).toBoolean
}

private[sql] object ProtobufOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.protobuf.utils

import scala.jdk.CollectionConverters._

import com.google.protobuf.{BoolValue, BytesValue, DoubleValue, FloatValue, Int32Value, Int64Value, StringValue, UInt32Value, UInt64Value}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.WireFormat

Expand Down Expand Up @@ -103,8 +104,46 @@ object SchemaConverters extends Logging {
fd.getMessageType.getFields.get(1).getName.equals("nanos")) =>
Some(TimestampType)
case MESSAGE if protobufOptions.convertAnyFieldsToJson &&
fd.getMessageType.getFullName == "google.protobuf.Any" =>
fd.getMessageType.getFullName == "google.protobuf.Any" =>
Some(StringType) // Any protobuf will be parsed and converted to json string.

// Unwrap well known primitive wrapper types if the option has been set.
case MESSAGE if fd.getMessageType.getFullName == BoolValue.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(BooleanType)
case MESSAGE if fd.getMessageType.getFullName == Int32Value.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(IntegerType)
case MESSAGE if fd.getMessageType.getFullName == UInt32Value.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
if (protobufOptions.upcastUnsignedInts) {
Some(LongType)
} else {
Some(IntegerType)
}
case MESSAGE if fd.getMessageType.getFullName == Int64Value.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(LongType)
case MESSAGE if fd.getMessageType.getFullName == UInt64Value.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
if (protobufOptions.upcastUnsignedInts) {
Some(DecimalType.LongDecimal)
} else {
Some(LongType)
}
case MESSAGE if fd.getMessageType.getFullName == StringValue.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(StringType)
case MESSAGE if fd.getMessageType.getFullName == BytesValue.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(BinaryType)
case MESSAGE if fd.getMessageType.getFullName == FloatValue.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(FloatType)
case MESSAGE if fd.getMessageType.getFullName == DoubleValue.getDescriptor.getFullName
&& protobufOptions.unwrapWellKnownTypes =>
Some(DoubleType)

case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry =>
var keyType: Option[DataType] = None
var valueType: Option[DataType] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import "timestamp.proto";
import "duration.proto";
import "basicmessage.proto";
import "google/protobuf/any.proto";
import "google/protobuf/wrappers.proto";

option java_outer_classname = "SimpleMessageProtos";

Expand Down Expand Up @@ -324,3 +325,19 @@ message Proto3AllTypes {
}
map<string, string> map = 13;
}

message WellKnownWrapperTypes {
google.protobuf.BoolValue bool_val = 1;
google.protobuf.Int32Value int32_val = 2;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add test where int32_val is not set. What should the Spark struct contain:

  • When emit.default.values is false (default)
  • When emit.default.values is true
    Please comment on the expected behavior.

Another similar test with int32_val.value set to 0 (default value).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, great suggestion. I added that test!

basically the internal value field should operate the same as any primitive, and when unwrapping it should unwrap that value correctly. I added a test that confirms every configuration of emit.default.values and unwrap.primitive.wrapper.types

google.protobuf.UInt32Value uint32_val = 3;
google.protobuf.Int64Value int64_val = 4;
google.protobuf.UInt64Value uint64_val = 5;
google.protobuf.StringValue string_val = 6;
google.protobuf.BytesValue bytes_val = 7;
google.protobuf.FloatValue float_val = 8;
google.protobuf.DoubleValue double_val = 9;

// Sample repeated and map types
repeated google.protobuf.Int32Value int32_list = 10;
map<int32, google.protobuf.StringValue> wkt_map = 11;
}
Loading