diff --git a/connector/proto/pom.xml b/connector/proto/pom.xml
new file mode 100644
index 0000000000000..12cdae3b12202
--- /dev/null
+++ b/connector/proto/pom.xml
@@ -0,0 +1,119 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.12
+ 3.4.0-SNAPSHOT
+ ../../pom.xml
+
+
+ spark-proto_2.12
+
+ proto
+ 3.21.1
+
+ jar
+ Spark Proto
+ https://spark.apache.org/
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+
+
+
+ org.tukaani
+ xz
+
+
+ com.google.protobuf
+ protobuf-java
+ ${protobuf.version}
+ compile
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+
+
+ com.google.protobuf:*
+
+
+
+
+ com.google.protobuf
+ ${spark.shade.packageName}.spark-proto.protobuf
+
+ com.google.protobuf.**
+
+
+
+
+
+
+
+
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala
new file mode 100644
index 0000000000000..1846d360db13d
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto
+
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.proto.utils.{ProtoUtils, SchemaConverters}
+import org.apache.spark.sql.types.{BinaryType, DataType}
+
+private[proto] case class CatalystDataToProto(
+ child: Expression,
+ descFilePath: Option[String],
+ messageName: Option[String])
+ extends UnaryExpression {
+
+ override def dataType: DataType = BinaryType
+
+ @transient private lazy val protoType = (descFilePath, messageName) match {
+ case (Some(a), Some(b)) => ProtoUtils.buildDescriptor(a, b)
+ case _ => SchemaConverters.toProtoType(child.dataType, child.nullable)
+ }
+
+ @transient private lazy val serializer = new ProtoSerializer(child.dataType, protoType,
+ child.nullable)
+
+ override def nullSafeEval(input: Any): Any = {
+ val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage]
+ dynamicMessage.toByteArray
+ }
+
+ override def prettyName: String = "to_proto"
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val expr = ctx.addReferenceObj("this", this)
+ defineCodeGen(ctx, ev, input =>
+ s"(byte[]) $expr.nullSafeEval($input)")
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): CatalystDataToProto =
+ copy(child = newChild)
+}
+
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala
new file mode 100644
index 0000000000000..42a3f4be890d7
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto
+
+import java.io.ByteArrayInputStream
+
+import scala.util.control.NonFatal
+
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
+ SpecificInternalRow, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode}
+import org.apache.spark.sql.proto.utils.{ProtoOptions, ProtoUtils, SchemaConverters}
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType}
+
+private[proto] case class ProtoDataToCatalyst(child: Expression, descFilePath: String,
+ messageName: String,
+ options: Map[String, String])
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+ override lazy val dataType: DataType = {
+ val dt = SchemaConverters.toSqlType(expectedSchema).dataType
+ parseMode match {
+ // With PermissiveMode, the output Catalyst row might contain columns of null values for
+ // corrupt records, even if some of the columns are not nullable in the user-provided schema.
+ // Therefore we force the schema to be all nullable here.
+ case PermissiveMode => dt.asNullable
+ case _ => dt
+ }
+ }
+
+ override def nullable: Boolean = true
+
+ private lazy val protoOptions = ProtoOptions(options)
+
+ @transient private lazy val descriptor = ProtoUtils.buildDescriptor(descFilePath, messageName)
+
+ @transient private lazy val expectedSchema = protoOptions.schema.getOrElse(descriptor)
+
+ @transient private lazy val deserializer = new ProtoDeserializer(expectedSchema, dataType,
+ protoOptions.datetimeRebaseModeInRead)
+
+ @transient private var result: Any = _
+
+ @transient private lazy val parseMode: ParseMode = {
+ val mode = protoOptions.parseMode
+ if (mode != PermissiveMode && mode != FailFastMode) {
+ throw new AnalysisException(unacceptableModeMessage(mode.name))
+ }
+ mode
+ }
+
+ private def unacceptableModeMessage(name: String): String = {
+ s"from_proto() doesn't support the $name mode. " +
+ s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}."
+ }
+
+ @transient private lazy val nullResultRow: Any = dataType match {
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ for (i <- 0 until st.length) {
+ resultRow.setNullAt(i)
+ }
+ resultRow
+
+ case _ =>
+ null
+ }
+
+ private def handleException(e: Throwable): Any = {
+ parseMode match {
+ case PermissiveMode =>
+ nullResultRow
+ case FailFastMode =>
+ throw new SparkException("Malformed records are detected in record parsing. " +
+ s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " +
+ "result, try setting the option 'mode' as 'PERMISSIVE'.", e)
+ case _ =>
+ throw new AnalysisException(unacceptableModeMessage(parseMode.name))
+ }
+ }
+
+ override def nullSafeEval(input: Any): Any = {
+ val binary = input.asInstanceOf[Array[Byte]]
+ try {
+ result = DynamicMessage.parseFrom(descriptor, new ByteArrayInputStream(binary))
+ val unknownFields = result.asInstanceOf[DynamicMessage].getUnknownFields
+ if (!unknownFields.asMap().isEmpty) {
+ return handleException(new Throwable("UnknownFields encountered"))
+ }
+ val deserialized = deserializer.deserialize(result)
+ assert(deserialized.isDefined,
+ "Proto deserializer cannot return an empty result because filters are not pushed down")
+ deserialized.get
+ } catch {
+ // There could be multiple possible exceptions here, e.g. java.io.IOException,
+ // ProtoRuntimeException, ArrayIndexOutOfBoundsException, etc.
+ // To make it simple, catch all the exceptions here.
+ case NonFatal(e) =>
+ handleException(e)
+ }
+ }
+
+ override def prettyName: String = "from_proto"
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val expr = ctx.addReferenceObj("this", this)
+ nullSafeCodeGen(ctx, ev, eval => {
+ val result = ctx.freshName("result")
+ val dt = CodeGenerator.boxedType(dataType)
+ s"""
+ $dt $result = ($dt) $expr.nullSafeEval($eval);
+ if ($result == null) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $result;
+ }
+ """
+ })
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): ProtoDataToCatalyst =
+ copy(child = newChild)
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala
new file mode 100644
index 0000000000000..bb243b7795802
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto
+
+import com.google.protobuf.{ByteString, DynamicMessage}
+import com.google.protobuf.Descriptors._
+import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+
+import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.proto.utils.ProtoUtils
+import org.apache.spark.sql.proto.utils.ProtoUtils.ProtoMatchedField
+import org.apache.spark.sql.proto.utils.ProtoUtils.toFieldStr
+import org.apache.spark.sql.proto.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType,
+ DateType, Decimal, DoubleType, FloatType, IntegerType, LongType, NullType,
+ ShortType, StringType, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+private[sql] class ProtoDeserializer(
+ rootProtoType: Descriptor,
+ rootCatalystType: DataType,
+ positionalFieldMatch: Boolean,
+ datetimeRebaseSpec: RebaseSpec,
+ filters: StructFilters) {
+
+ def this(
+ rootProtoType: Descriptor,
+ rootCatalystType: DataType,
+ datetimeRebaseMode: String) = {
+ this(
+ rootProtoType,
+ rootCatalystType,
+ positionalFieldMatch = false,
+ RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
+ new NoopFilters)
+ }
+
+ private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInRead(
+ datetimeRebaseSpec.mode, "Proto")
+
+ private val converter: Any => Option[Any] = try {
+ rootCatalystType match {
+ // A shortcut for empty schema.
+ case st: StructType if st.isEmpty =>
+ (_: Any) => Some(InternalRow.empty)
+
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ val fieldUpdater = new RowUpdater(resultRow)
+ val applyFilters = filters.skipRow(resultRow, _)
+ val writer = getRecordWriter(rootProtoType, st, Nil, Nil, applyFilters)
+ (data: Any) => {
+ val record = data.asInstanceOf[DynamicMessage]
+ val skipRow = writer(fieldUpdater, record)
+ if (skipRow) None else Some(resultRow)
+ }
+ }
+ } catch {
+ case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException(
+ s"Cannot convert Proto type ${rootProtoType.getName} " +
+ s"to SQL type ${rootCatalystType.sql}.", ise)
+ }
+
+ def deserialize(data: Any): Option[Any] = converter(data)
+
+ private def newArrayWriter(
+ protoField: FieldDescriptor,
+ protoPath: Seq[String],
+ catalystPath: Seq[String],
+ elementType: DataType,
+ containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+
+
+ val protoElementPath = protoPath :+ "element"
+ val elementWriter = newWriter(protoField, elementType,
+ protoElementPath, catalystPath :+ "element")
+ (updater, ordinal, value) =>
+ val collection = value.asInstanceOf[java.util.Collection[Any]]
+ val result = createArrayData(elementType, collection.size())
+ val elementUpdater = new ArrayDataUpdater(result)
+
+ var i = 0
+ val iterator = collection.iterator()
+ while (iterator.hasNext) {
+ val element = iterator.next()
+ if (element == null) {
+ if (!containsNull) {
+ throw new RuntimeException(
+ s"Array value at path ${toFieldStr(protoElementPath)} is not allowed to be null")
+ } else {
+ elementUpdater.setNullAt(i)
+ }
+ } else {
+ elementWriter(elementUpdater, i, element)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+ }
+
+ /**
+ * Creates a writer to write proto values to Catalyst values at the given ordinal with the given
+ * updater.
+ */
+ private def newWriter(
+ protoType: FieldDescriptor,
+ catalystType: DataType,
+ protoPath: Seq[String],
+ catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
+ val errorPrefix = s"Cannot convert Proto ${toFieldStr(protoPath)} to " +
+ s"SQL ${toFieldStr(catalystPath)} because "
+ val incompatibleMsg = errorPrefix +
+ s"schema is incompatible (protoType = ${protoType} ${protoType.toProto.getLabel} " +
+ s"${protoType.getJavaType} ${protoType.getType}, sqlType = ${catalystType.sql})"
+
+ (protoType.getJavaType, catalystType) match {
+
+ case (null, NullType) => (updater, ordinal, _) =>
+ updater.setNullAt(ordinal)
+
+ // TODO: we can avoid boxing if future version of proto provide primitive accessors.
+ case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
+ updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
+
+ case (BOOLEAN, ArrayType(BooleanType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, BooleanType, containsNull)
+
+ case (INT, IntegerType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, value.asInstanceOf[Int])
+
+ case (INT, ArrayType(IntegerType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, IntegerType, containsNull)
+
+ case (INT, DateType) => (updater, ordinal, value) =>
+ updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int]))
+
+ case (LONG, LongType) => (updater, ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[Long])
+
+ case (LONG, ArrayType(LongType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, LongType, containsNull)
+
+ case (FLOAT, FloatType) => (updater, ordinal, value) =>
+ updater.setFloat(ordinal, value.asInstanceOf[Float])
+
+ case (FLOAT, ArrayType(FloatType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, FloatType, containsNull)
+
+ case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
+ updater.setDouble(ordinal, value.asInstanceOf[Double])
+
+ case (DOUBLE, ArrayType(DoubleType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, DoubleType, containsNull)
+
+ case (STRING, StringType) => (updater, ordinal, value) =>
+ val str = value match {
+ case s: String => UTF8String.fromString(s)
+ }
+ updater.set(ordinal, str)
+
+ case (STRING, ArrayType(StringType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, StringType, containsNull)
+
+ case (BYTE_STRING, BinaryType) => (updater, ordinal, value) =>
+ val byte_array = value match {
+ case s: ByteString => s.toByteArray
+ case _ => throw new Exception("Invalid ByteString format")
+ }
+ updater.set(ordinal, byte_array)
+
+ case (BYTE_STRING, ArrayType(BinaryType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, BinaryType, containsNull)
+
+ case (MESSAGE, st: StructType) =>
+ val writeRecord = getRecordWriter(protoType.getMessageType, st, protoPath,
+ catalystPath, applyFilters = _ => false)
+ (updater, ordinal, value) =>
+ val row = new SpecificInternalRow(st)
+ writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
+ updater.set(ordinal, row)
+
+ case (MESSAGE, ArrayType(st: StructType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, st, containsNull)
+
+ case (ENUM, StringType) => (updater, ordinal, value) =>
+ updater.set(ordinal, UTF8String.fromString(value.toString))
+
+ case (ENUM, ArrayType(StringType, containsNull)) =>
+ newArrayWriter(protoType, protoPath,
+ catalystPath, StringType, containsNull)
+
+ case _ => throw new IncompatibleSchemaException(incompatibleMsg)
+ }
+ }
+
+
+ private def getRecordWriter(
+ protoType: Descriptor,
+ catalystType: StructType,
+ protoPath: Seq[String],
+ catalystPath: Seq[String],
+ applyFilters: Int => Boolean):
+ (CatalystDataUpdater, DynamicMessage) => Boolean = {
+
+ val protoSchemaHelper = new ProtoUtils.ProtoSchemaHelper(
+ protoType, catalystType, protoPath, catalystPath, positionalFieldMatch)
+
+ protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true)
+ // no need to validateNoExtraProtoFields since extra Proto fields are ignored
+
+ val (validFieldIndexes, fieldWriters) = protoSchemaHelper.matchedFields.map {
+ case ProtoMatchedField(catalystField, ordinal, protoField) =>
+ val baseWriter = newWriter(protoField, catalystField.dataType,
+ protoPath :+ protoField.getName, catalystPath :+ catalystField.name)
+ val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
+ if (value == null) {
+ fieldUpdater.setNullAt(ordinal)
+ } else {
+ baseWriter(fieldUpdater, ordinal, value)
+ }
+ }
+ (protoField, fieldWriter)
+ }.toArray.unzip
+
+ (fieldUpdater, record) => {
+ var i = 0
+ var skipRow = false
+ while (i < validFieldIndexes.length && !skipRow) {
+ fieldWriters(i)(fieldUpdater, record.getField(validFieldIndexes(i)))
+ skipRow = applyFilters(i)
+ i += 1
+ }
+ skipRow
+ }
+ }
+
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+
+ def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+
+ override def setDecimal(ordinal: Int, value: Decimal): Unit =
+ row.setDecimal(ordinal, value, value.precision)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+
+ override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
+ }
+
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala
new file mode 100644
index 0000000000000..75fed55f00d6f
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.proto.utils.ProtoUtils
+import org.apache.spark.sql.proto.utils.ProtoUtils.{toFieldStr, ProtoMatchedField}
+import org.apache.spark.sql.proto.utils.SchemaConverters.{IncompatibleSchemaException, UnsupportedProtoValueException}
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize data in catalyst format to data in proto format.
+ */
+private[sql] class ProtoSerializer(
+ rootCatalystType: DataType,
+ rootProtoType: Descriptor,
+ nullable: Boolean,
+ positionalFieldMatch: Boolean) extends Logging {
+
+ def this(rootCatalystType: DataType, rootProtoType: Descriptor, nullable: Boolean) = {
+ this(rootCatalystType, rootProtoType, nullable, positionalFieldMatch = false)
+ }
+
+ def serialize(catalystData: Any): Any = {
+ converter.apply(catalystData)
+ }
+
+ private val converter: Any => Any = {
+ val baseConverter = try {
+ rootCatalystType match {
+ case st: StructType =>
+ newStructConverter(st, rootProtoType, Nil, Nil).asInstanceOf[Any => Any]
+ }
+ } catch {
+ case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException(
+ s"Cannot convert SQL type ${rootCatalystType.sql} to Proto type " +
+ s"${rootProtoType.getName}.", ise)
+ }
+ if (nullable) {
+ (data: Any) =>
+ if (data == null) {
+ null
+ } else {
+ baseConverter.apply(data)
+ }
+ } else {
+ baseConverter
+ }
+ }
+
+ private type Converter = (SpecializedGetters, Int) => Any
+
+
+ private def newConverter(
+ catalystType: DataType,
+ protoType: FieldDescriptor,
+ catalystPath: Seq[String],
+ protoPath: Seq[String]): Converter = {
+ val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
+ s"to Proto ${toFieldStr(protoPath)} because "
+ (catalystType, protoType.getJavaType) match {
+ case (NullType, _) =>
+ (getter, ordinal) => null
+ case (BooleanType, BOOLEAN) =>
+ (getter, ordinal) => getter.getBoolean(ordinal)
+ case (ByteType, INT) =>
+ (getter, ordinal) => getter.getByte(ordinal).toInt
+ case (ShortType, INT) =>
+ (getter, ordinal) => getter.getShort(ordinal).toInt
+ case (IntegerType, INT) =>
+ (getter, ordinal) => {
+ getter.getInt(ordinal)
+ }
+ case (LongType, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+ case (FloatType, FLOAT) =>
+ (getter, ordinal) => getter.getFloat(ordinal)
+ case (DoubleType, DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+ case (StringType, ENUM) =>
+ val enumSymbols: Set[String] = protoType.getEnumType.getValues.asScala.map(
+ e => e.toString).toSet
+ (getter, ordinal) =>
+ val data = getter.getUTF8String(ordinal).toString
+ if (!enumSymbols.contains(data)) {
+ throw new IncompatibleSchemaException(errorPrefix +
+ s""""$data" cannot be written since it's not defined in enum """ +
+ enumSymbols.mkString("\"", "\", \"", "\""))
+ }
+ protoType.getEnumType.findValueByName(data)
+ case (StringType, STRING) =>
+ (getter, ordinal) => {
+ String.valueOf(getter.getUTF8String(ordinal))
+ }
+
+ case (BinaryType, BYTE_STRING) =>
+ (getter, ordinal) => getter.getBinary(ordinal)
+
+ case (DateType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+
+ case (TimestampType, LONG) => protoType.getContainingType match {
+ // For backward compatibility, if the Proto type is Long and it is not logical type
+ // (the `null` case), output the timestamp value as with millisecond precision.
+ case null => (getter, ordinal) =>
+ DateTimeUtils.microsToMillis(getter.getLong(ordinal))
+ case other => throw new IncompatibleSchemaException(errorPrefix +
+ s"SQL type ${TimestampType.sql} cannot be converted to Proto logical type $other")
+ }
+
+ case (TimestampNTZType, LONG) => protoType.getContainingType match {
+ // To keep consistent with TimestampType, if the Proto type is Long and it is not
+ // logical type (the `null` case), output the TimestampNTZ as long value
+ // in millisecond precision.
+ case null => (getter, ordinal) =>
+ DateTimeUtils.microsToMillis(getter.getLong(ordinal))
+ case other => throw new IncompatibleSchemaException(errorPrefix +
+ s"SQL type ${TimestampNTZType.sql} cannot be converted to Proto logical type $other")
+ }
+
+ case (ArrayType(et, containsNull), _) =>
+ val elementConverter = newConverter(
+ et, protoType,
+ catalystPath :+ "element", protoPath :+ "element")
+ (getter, ordinal) => {
+ val arrayData = getter.getArray(ordinal)
+ val len = arrayData.numElements()
+ val result = new Array[Any](len)
+ var i = 0
+ while (i < len) {
+ if (containsNull && arrayData.isNullAt(i)) {
+ result(i) = null
+ } else {
+ result(i) = elementConverter(arrayData, i)
+ }
+ i += 1
+ }
+ // proto writer is expecting a Java Collection, so we convert it into
+ // `ArrayList` backed by the specified array without data copying.
+ java.util.Arrays.asList(result: _*)
+ }
+
+ case (st: StructType, MESSAGE) =>
+ val structConverter = newStructConverter(
+ st, protoType.getMessageType, catalystPath, protoPath)
+ val numFields = st.length
+ (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+
+ case (MapType(kt, vt, valueContainsNull), _) if kt == StringType =>
+ val valueConverter = newConverter(
+ vt, protoType, catalystPath :+ "value", protoPath :+ "value")
+ (getter, ordinal) =>
+ val mapData = getter.getMap(ordinal)
+ val len = mapData.numElements()
+ val result = new java.util.HashMap[String, Any](len)
+ val keyArray = mapData.keyArray()
+ val valueArray = mapData.valueArray()
+ var i = 0
+ while (i < len) {
+ val key = keyArray.getUTF8String(i).toString
+ if (valueContainsNull && valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case (_: YearMonthIntervalType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+
+ case (_: DayTimeIntervalType, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+
+ case _ =>
+ throw new IncompatibleSchemaException(errorPrefix +
+ s"schema is incompatible (sqlType = ${catalystType.sql}, " +
+ s"protoType = ${protoType.getJavaType})")
+ }
+ }
+
+ private def newStructConverter(
+ catalystStruct: StructType,
+ protoStruct: Descriptor,
+ catalystPath: Seq[String],
+ protoPath: Seq[String]): InternalRow => DynamicMessage = {
+
+ val protoSchemaHelper = new ProtoUtils.ProtoSchemaHelper(
+ protoStruct, catalystStruct, protoPath, catalystPath, positionalFieldMatch)
+
+ protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false)
+ protoSchemaHelper.validateNoExtraRequiredProtoFields()
+
+ val (protoIndices, fieldConverters) = protoSchemaHelper.matchedFields.map {
+ case ProtoMatchedField(catalystField, _, protoField) =>
+ val converter = newConverter(catalystField.dataType,
+ protoField,
+ catalystPath :+ catalystField.name, protoPath :+ protoField.getName)
+ (protoField, converter)
+ }.toArray.unzip
+
+ val numFields = catalystStruct.length
+ row: InternalRow =>
+ val result = DynamicMessage.newBuilder(protoStruct)
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ throw new UnsupportedProtoValueException(
+ s"Cannot set ${protoIndices(i).getName} a Null, Proto does not allow Null values")
+ } else {
+ result.setField(protoIndices(i), fieldConverters(i).apply(row, i))
+ }
+ i += 1
+ }
+ result.build()
+ }
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala
new file mode 100644
index 0000000000000..21b6ed77949bd
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.Column
+
+// scalastyle:off: object.name
+object functions {
+// scalastyle:on: object.name
+
+ /**
+ * Converts a binary column of Proto format into its corresponding catalyst value.
+ * The specified schema must match actual schema of the read data, otherwise the behavior
+ * is undefined: it may fail or return arbitrary result.
+ * To deserialize the data with a compatible and evolved schema, the expected Proto schema can be
+ * set via the option protoSchema.
+ *
+ * @param data the binary column.
+ * @param descFilePath the proto schema in Message GeneratedMessageV3 format.
+ * @param messageName the proto MessageName to look for in descriptorFile.
+ * @since 3.4.0
+ */
+ @Experimental
+ def from_proto(data: Column, descFilePath: String, messageName: String,
+ options: java.util.Map[String, String]): Column = {
+ new Column(ProtoDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap))
+ }
+
+ /**
+ * Converts a binary column of Proto format into its corresponding catalyst value.
+ * The specified schema must match actual schema of the read data, otherwise the behavior
+ * is undefined: it may fail or return arbitrary result.
+ * To deserialize the data with a compatible and evolved schema, the expected Proto schema can be
+ * set via the option protoSchema.
+ *
+ * @param data the binary column.
+ * @param descFilePath the proto schema in Message GeneratedMessageV3 format.
+ * @param messageName the proto MessageName to look for in descriptorFile.
+ * @since 3.4.0
+ */
+ @Experimental
+ def from_proto(data: Column, descFilePath: String, messageName: String): Column = {
+ new Column(ProtoDataToCatalyst(data.expr, descFilePath, messageName, Map.empty))
+ }
+
+ /**
+ * Converts a column into binary of proto format.
+ *
+ * @param data the data column.
+ * @since 3.4.0
+ */
+ @Experimental
+ def to_proto(data: Column): Column = {
+ new Column(CatalystDataToProto(data.expr, None, None))
+ }
+
+ /**
+ * Converts a column into binary of proto format.
+ *
+ * @param data the data column.
+ * @param descFilePath the proto schema in Message GeneratedMessageV3 format.
+ * @param messageName the proto MessageName to look for in descriptorFile.
+ * @since 3.4.0
+ */
+ @Experimental
+ def to_proto(data: Column, descFilePath: String, messageName: String): Column = {
+ new Column(CatalystDataToProto(data.expr, Some(descFilePath), Some(messageName)))
+ }
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala
new file mode 100644
index 0000000000000..de88d564062a9
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+package object proto {
+ protected[proto] object ScalaReflectionLock
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala
new file mode 100644
index 0000000000000..962007ecaff80
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto.utils
+
+import java.util
+
+import scala.collection.JavaConverters._
+import scala.util.control.Breaks.{break, breakable}
+
+import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, FileDescriptorSet}
+import com.google.protobuf.Descriptors.{Descriptor, FileDescriptor}
+
+class DynamicSchema {
+ var fileDescSet: FileDescriptorSet = null
+
+ def this(fileDescSet: FileDescriptorSet) = {
+ this
+ this.fileDescSet = fileDescSet;
+ val fileDescMap: util.Map[String, FileDescriptor] = init(fileDescSet)
+ val msgDupes: util.Set[String] = new util.HashSet[String]()
+ for (fileDesc: FileDescriptor <- fileDescMap.values().asScala) {
+ for (msgType: Descriptor <- fileDesc.getMessageTypes().asScala) {
+ addMessageType(msgType, null, msgDupes)
+ }
+ }
+
+ for (msgName: String <- msgDupes.asScala) {
+ messageMsgDescriptorMapShort.remove(msgName)
+ }
+ }
+
+ val messageDescriptorMapFull: util.Map[String, Descriptor] =
+ new util.HashMap[String, Descriptor]()
+ val messageMsgDescriptorMapShort: util.Map[String, Descriptor] =
+ new util.HashMap[String, Descriptor]()
+
+ def newBuilder(): Builder = {
+ new Builder()
+ }
+
+ def addMessageType(msgType: Descriptor, scope: String, msgDupes: util.Set[String]) : Unit = {
+ val msgTypeNameFull: String = msgType.getFullName()
+ val msgTypeNameShort: String = {
+ if (scope == null) {
+ msgType.getName()
+ } else {
+ scope + "." + msgType.getName ()
+ }
+ }
+
+ if (messageDescriptorMapFull.containsKey(msgTypeNameFull)) {
+ throw new IllegalArgumentException("duplicate name: " + msgTypeNameFull)
+ }
+ if (messageMsgDescriptorMapShort.containsKey(msgTypeNameShort)) {
+ msgDupes.add(msgTypeNameShort)
+ }
+ messageDescriptorMapFull.put(msgTypeNameFull, msgType);
+ messageMsgDescriptorMapShort.put(msgTypeNameShort, msgType);
+
+
+ for (nestedType <- msgType.getNestedTypes.asScala) {
+ addMessageType(nestedType, msgTypeNameShort, msgDupes)
+ }
+ }
+
+ def init(fileDescSet: FileDescriptorSet) : util.Map[String, FileDescriptor] = {
+ // check for dupes
+ val allFdProtoNames: util.Set[String] = new util.HashSet[String]()
+ for (fdProto: FileDescriptorProto <- fileDescSet.getFileList().asScala) {
+ if (allFdProtoNames.contains(fdProto.getName())) {
+ throw new IllegalArgumentException("duplicate name: " + fdProto.getName())
+ }
+ allFdProtoNames.add(fdProto.getName())
+ }
+
+ // build FileDescriptors, resolve dependencies (imports) if any
+ val resolvedFileDescMap: util.Map[String, FileDescriptor] =
+ new util.HashMap[String, FileDescriptor]()
+ while (resolvedFileDescMap.size() < fileDescSet.getFileCount()) {
+ for (fdProto : FileDescriptorProto <- fileDescSet.getFileList().asScala) {
+ breakable {
+ if (resolvedFileDescMap.containsKey(fdProto.getName())) {
+ break
+ }
+
+ val dependencyList: util.List[String] = fdProto.getDependencyList();
+ val resolvedFdList: util.List[FileDescriptor] =
+ new util.ArrayList[FileDescriptor]()
+ for (depName: String <- dependencyList.asScala) {
+ if (!allFdProtoNames.contains(depName)) {
+ throw new IllegalArgumentException("cannot resolve import " + depName + " in " +
+ fdProto.getName())
+ }
+ val fd: FileDescriptor = resolvedFileDescMap.get(depName)
+ if (fd != null) resolvedFdList.add(fd)
+ }
+
+ if (resolvedFdList.size() == dependencyList.size()) {
+ val fds = new Array[FileDescriptor](resolvedFdList.size)
+ val fd: FileDescriptor = FileDescriptor.buildFrom(fdProto, resolvedFdList.toArray(fds))
+ resolvedFileDescMap.put(fdProto.getName(), fd)
+ }
+ }
+ }
+ }
+
+ resolvedFileDescMap
+ }
+
+ override def toString() : String = {
+ val msgTypes: util.Set[String] = getMessageTypes()
+ ("types: " + msgTypes + "\n" + fileDescSet)
+ }
+
+ def toByteArray() : Array[Byte] = {
+ fileDescSet.toByteArray()
+ }
+
+ def getMessageTypes(): util.Set[String] = {
+ new util.TreeSet[String](messageDescriptorMapFull.keySet())
+ }
+
+
+ def getMessageDescriptor(messageTypeName: String) : Descriptor = {
+ var messageType: Descriptor = messageMsgDescriptorMapShort.get(messageTypeName)
+ if (messageType == null) {
+ messageType = messageDescriptorMapFull.get(messageTypeName)
+ }
+ messageType
+ }
+
+ class Builder {
+ val messageFileDescProtoBuilder: FileDescriptorProto.Builder = FileDescriptorProto.newBuilder()
+ val messageFileDescSetBuilder: FileDescriptorSet.Builder = FileDescriptorSet.newBuilder()
+
+ def build(): DynamicSchema = {
+ val fileDescSetBuilder: FileDescriptorSet.Builder = FileDescriptorSet.newBuilder()
+ fileDescSetBuilder.addFile(messageFileDescProtoBuilder.build())
+ fileDescSetBuilder.mergeFrom(messageFileDescSetBuilder.build())
+ new DynamicSchema(fileDescSetBuilder.build())
+ }
+
+ def setName(name: String) : Builder = {
+ messageFileDescProtoBuilder.setName(name)
+ this
+ }
+
+ def setPackage(name: String): Builder = {
+ messageFileDescProtoBuilder.setPackage(name)
+ this
+ }
+
+ def addMessageDefinition(msgDef: MessageDefinition) : Builder = {
+ messageFileDescProtoBuilder.addMessageType(msgDef.getMessageType())
+ this
+ }
+ }
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala
new file mode 100644
index 0000000000000..7692b604a878e
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto.utils
+
+import com.google.protobuf.DescriptorProtos.{DescriptorProto, FieldDescriptorProto}
+
+class MessageDefinition(val messageType: DescriptorProto = null) {
+
+ def newBuilder(msgTypeName: String): Builder = {
+ new Builder(msgTypeName)
+ }
+
+ def getMessageType(): DescriptorProto = {
+ messageType
+ }
+
+ class Builder(msgTypeName: String) {
+ var messageTypeBuilder: DescriptorProto.Builder = DescriptorProto.newBuilder()
+ messageTypeBuilder.setName(msgTypeName)
+
+ def addField(label: String, typeName: String, name: String, num: Int): Builder = {
+ addField(label, typeName, name, num, null)
+ }
+
+ def addField(label: String, typeName: String, name: String, num: Int,
+ defaultVal: String): Builder = {
+ val protoLabel: FieldDescriptorProto.Label = protoLabelMap.getOrElse(label, null)
+ if (protoLabel == null) {
+ throw new IllegalArgumentException("Illegal label: " + label)
+ }
+ addField(protoLabel, typeName, name, num, defaultVal)
+ this
+ }
+
+ def addMessageDefinition(msgDef: MessageDefinition): Builder = {
+ messageTypeBuilder.addNestedType(msgDef.getMessageType())
+ this
+ }
+
+ def build(): MessageDefinition = {
+ new MessageDefinition(messageTypeBuilder.build())
+ }
+
+ def addField(label: FieldDescriptorProto.Label, typeName: String, name: String, num: Int,
+ defaultVal: String): DescriptorProto.Builder = {
+ val fieldBuilder: FieldDescriptorProto.Builder = FieldDescriptorProto.newBuilder()
+ fieldBuilder.setLabel(label)
+ val primType: FieldDescriptorProto.Type = protoTypeMap.getOrElse(typeName, null)
+ if (primType != null) {
+ fieldBuilder.setType(primType)
+ } else {
+ fieldBuilder.setTypeName(typeName)
+ }
+
+ fieldBuilder.setName(name).setNumber(num);
+ if (defaultVal != null) fieldBuilder.setDefaultValue(defaultVal);
+ messageTypeBuilder.addField(fieldBuilder.build());
+ }
+ }
+
+
+ private val protoLabelMap: Map[String, FieldDescriptorProto.Label] =
+ Map("optional" -> FieldDescriptorProto.Label.LABEL_OPTIONAL,
+ "required" -> FieldDescriptorProto.Label.LABEL_REQUIRED,
+ "repeated" -> FieldDescriptorProto.Label.LABEL_REPEATED
+ )
+
+ private val protoTypeMap: Map[String, FieldDescriptorProto.Type] =
+ Map("double" -> FieldDescriptorProto.Type.TYPE_DOUBLE,
+ "float" -> FieldDescriptorProto.Type.TYPE_FLOAT,
+ "int32" -> FieldDescriptorProto.Type.TYPE_INT32,
+ "int64" -> FieldDescriptorProto.Type.TYPE_INT64,
+ "uint32" -> FieldDescriptorProto.Type.TYPE_UINT32,
+ "uint64" -> FieldDescriptorProto.Type.TYPE_UINT64,
+ "sint32" -> FieldDescriptorProto.Type.TYPE_SINT32,
+ "sint64" -> FieldDescriptorProto.Type.TYPE_SINT64,
+ "fixed32" -> FieldDescriptorProto.Type.TYPE_FIXED32,
+ "fixed64" -> FieldDescriptorProto.Type.TYPE_FIXED64,
+ "sfixed32" -> FieldDescriptorProto.Type.TYPE_SFIXED32,
+ "sfixed64" -> FieldDescriptorProto.Type.TYPE_SFIXED64,
+ "bool" -> FieldDescriptorProto.Type.TYPE_BOOL,
+ "string" -> FieldDescriptorProto.Type.TYPE_STRING,
+ "bytes" -> FieldDescriptorProto.Type.TYPE_BYTES)
+}
+
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala
new file mode 100644
index 0000000000000..a2d32a0a07fa2
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto.utils
+
+import java.io.FileInputStream
+import java.net.URI
+
+import com.google.protobuf.DescriptorProtos
+import com.google.protobuf.Descriptors.Descriptor
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.FileSourceOptions
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode}
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Options for Proto Reader and Writer stored in case insensitive manner.
+ */
+private[sql] class ProtoOptions(
+ @transient val parameters: CaseInsensitiveMap[String],
+ @transient val conf: Configuration)
+ extends FileSourceOptions(parameters) with Logging {
+
+ def this(parameters: Map[String, String], conf: Configuration) = {
+ this(CaseInsensitiveMap(parameters), conf)
+ }
+
+ /**
+ * Optional schema provided by a user in schema file or in JSON format.
+ *
+ * When reading Proto, this option can be set to an evolved schema, which is compatible but
+ * different with the actual Proto schema. The deserialization schema will be consistent with
+ * the evolved schema. For example, if we set an evolved schema containing one additional
+ * column with a default value, the reading result in Spark will contain the new column too.
+ *
+ * When writing Proto, this option can be set if the expected output Proto schema doesn't match
+ * the schema converted by Spark. For example, the expected schema of one column is of "enum"
+ * type, instead of "string" type in the default converted schema.
+ */
+ val schema: Option[Descriptor] = {
+ parameters.get("protoSchema").map(a => DescriptorProtos.DescriptorProto.parseFrom(
+ new FileInputStream(a)).getDescriptorForType).orElse({
+ val protoUrlSchema = parameters.get("protoSchemaUrl").map(url => {
+ log.debug("loading proto schema from url: " + url)
+ val fs = FileSystem.get(new URI(url), conf)
+ val in = fs.open(new Path(url))
+ try {
+ DescriptorProtos.DescriptorProto.parseFrom(in).getDescriptorForType
+ } finally {
+ in.close()
+ }
+ })
+ protoUrlSchema
+ })
+ }
+
+ val parseMode: ParseMode =
+ parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
+
+ /**
+ * The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads.
+ */
+ val datetimeRebaseModeInRead: String = parameters
+ .get(ProtoOptions.DATETIME_REBASE_MODE)
+ .getOrElse(SQLConf.get.getConf(SQLConf.PROTO_REBASE_MODE_IN_READ))
+}
+
+private[sql] object ProtoOptions {
+ def apply(parameters: Map[String, String]): ProtoOptions = {
+ val hadoopConf = SparkSession
+ .getActiveSession
+ .map(_.sessionState.newHadoopConf())
+ .getOrElse(new Configuration())
+ new ProtoOptions(CaseInsensitiveMap(parameters), hadoopConf)
+ }
+
+ val ignoreExtensionKey = "ignoreExtension"
+
+ // The option controls rebasing of the DATE and TIMESTAMP values between
+ // Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Proto
+ // datasource similarly to the SQL config `spark.sql.proto.datetimeRebaseModeInRead`,
+ // and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`.
+ val DATETIME_REBASE_MODE = "datetimeRebaseMode"
+}
+
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala
new file mode 100644
index 0000000000000..29d3a9b6a33a0
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.proto.utils
+
+import java.io.{BufferedInputStream, FileInputStream, FileNotFoundException, IOException}
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException}
+import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import org.apache.hadoop.fs.FileStatus
+
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.FileSourceOptions
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.proto.utils.ProtoOptions.ignoreExtensionKey
+import org.apache.spark.sql.proto.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+private[sql] object ProtoUtils extends Logging {
+
+ def inferSchema(
+ spark: SparkSession,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ val conf = spark.sessionState.newHadoopConfWithOptions(options)
+ val parsedOptions = new ProtoOptions(options, conf)
+
+ if (parsedOptions.parameters.contains(ignoreExtensionKey)) {
+ logWarning(s"Option $ignoreExtensionKey is deprecated. Please use the " +
+ "general data source option pathGlobFilter for filtering file names.")
+ }
+ // User can specify an optional proto json schema.
+ val protoSchema = parsedOptions.schema
+ .getOrElse {
+ inferProtoSchemaFromFiles(files,
+ new FileSourceOptions(CaseInsensitiveMap(options)).ignoreCorruptFiles)
+ }
+
+ SchemaConverters.toSqlType(protoSchema).dataType match {
+ case t: StructType => Some(t)
+ case _ => throw new RuntimeException(
+ s"""Proto schema cannot be converted to a Spark SQL StructType:
+ |
+ |${protoSchema.toString()}
+ |""".stripMargin)
+ }
+ }
+
+ private def inferProtoSchemaFromFiles(
+ files: Seq[FileStatus],
+ ignoreCorruptFiles: Boolean): Descriptor = {
+ // Schema evolution is not supported yet. Here we only pick first random readable sample file to
+ // figure out the schema of the whole dataset.
+ val protoReader = files.iterator.map { f =>
+ val path = f.getPath
+ if (!path.getName.endsWith(".pb")) {
+ None
+ } else {
+ Utils.tryWithResource {
+ new FileInputStream("saved_model.pb")
+ } { in =>
+ try {
+ Some(DescriptorProtos.DescriptorProto.parseFrom(in).getDescriptorForType)
+ } catch {
+ case e: IOException =>
+ if (ignoreCorruptFiles) {
+ logWarning(s"Skipped the footer in the corrupted file: $path", e)
+ None
+ } else {
+ throw new SparkException(s"Could not read file: $path", e)
+ }
+ }
+ }
+ }
+ }.collectFirst {
+ case Some(reader) => reader
+ }
+
+ protoReader match {
+ case Some(reader) =>
+ reader.getContainingType
+ case None =>
+ throw new FileNotFoundException(
+ "No Proto files found. If files don't have .proto extension, set ignoreExtension to true")
+ }
+ }
+
+ def supportsDataType(dataType: DataType): Boolean = dataType match {
+ case _: AtomicType => true
+
+ case st: StructType => st.forall { f => supportsDataType(f.dataType) }
+
+ case ArrayType(elementType, _) => supportsDataType(elementType)
+
+ case MapType(keyType, valueType, _) =>
+ supportsDataType(keyType) && supportsDataType(valueType)
+
+ case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
+
+ case _: NullType => true
+
+ case _ => false
+ }
+
+ /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Proto field. */
+ private[sql] case class ProtoMatchedField(
+ catalystField: StructField,
+ catalystPosition: Int,
+ protoField: FieldDescriptor)
+
+ /**
+ * Helper class to perform field lookup/matching on Proto schemas.
+ *
+ * This will match `protoSchema` against `catalystSchema`, attempting to find a matching field in
+ * the Proto schema for each field in the Catalyst schema and vice-versa, respecting settings for
+ * case sensitivity. The match results can be accessed using the getter methods.
+ *
+ * @param protoSchema The schema in which to search for fields. Must be of type RECORD.
+ * @param catalystSchema The Catalyst schema to use for matching.
+ * @param protoPath The seq of parent field names leading to `protoSchema`.
+ * @param catalystPath The seq of parent field names leading to `catalystSchema`.
+ * @param positionalFieldMatch If true, perform field matching in a positional fashion
+ * (structural comparison between schemas, ignoring names);
+ * otherwise, perform field matching using field names.
+ */
+ class ProtoSchemaHelper(
+ protoSchema: Descriptor,
+ catalystSchema: StructType,
+ protoPath: Seq[String],
+ catalystPath: Seq[String],
+ positionalFieldMatch: Boolean) {
+ if (protoSchema.getName == null) {
+ throw new IncompatibleSchemaException(
+ s"Attempting to treat ${protoSchema.getName} as a RECORD, " +
+ s"but it was: ${protoSchema.getContainingType}")
+ }
+
+ private[this] val protoFieldArray = protoSchema.getFields.asScala.toArray
+ private[this] val fieldMap = protoSchema.getFields.asScala
+ .groupBy(_.getName.toLowerCase(Locale.ROOT))
+ .mapValues(_.toSeq) // toSeq needed for scala 2.13
+
+ /** The fields which have matching equivalents in both Proto and Catalyst schemas. */
+ val matchedFields: Seq[ProtoMatchedField] = catalystSchema.zipWithIndex.flatMap {
+ case (sqlField, sqlPos) =>
+ getProtoField(sqlField.name, sqlPos).map(ProtoMatchedField(sqlField, sqlPos, _))
+ }
+
+ /**
+ * Validate that there are no Catalyst fields which don't have a matching Proto field, throwing
+ * [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false,
+ * consider nullable Catalyst fields to be eligible to be an extra field; otherwise,
+ * ignore nullable Catalyst fields when checking for extras.
+ */
+ def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit =
+ catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) =>
+ if (getProtoField(sqlField.name, sqlPos).isEmpty &&
+ (!ignoreNullable || !sqlField.nullable)) {
+ if (positionalFieldMatch) {
+ throw new IncompatibleSchemaException("Cannot find field at position " +
+ s"$sqlPos of ${toFieldStr(protoPath)} from Proto schema (using positional matching)")
+ } else {
+ throw new IncompatibleSchemaException(
+ s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Proto schema")
+ }
+ }
+ }
+
+ /**
+ * Validate that there are no Proto fields which don't have a matching Catalyst field, throwing
+ * [[IncompatibleSchemaException]] if such extra fields are found. Only required (non-nullable)
+ * fields are checked; nullable fields are ignored.
+ */
+ def validateNoExtraRequiredProtoFields(): Unit = {
+ val extraFields = protoFieldArray.toSet -- matchedFields.map(_.protoField)
+ extraFields.filterNot(isNullable).foreach { extraField =>
+ if (positionalFieldMatch) {
+ throw new IncompatibleSchemaException(s"Found field '${extraField.getName()}'" +
+ s" at position ${extraField.getIndex} of ${toFieldStr(protoPath)} from Proto schema " +
+ s"but there is no match in the SQL schema at ${toFieldStr(catalystPath)} " +
+ s"(using positional matching)")
+ } else {
+ throw new IncompatibleSchemaException(
+ s"Found ${toFieldStr(protoPath :+ extraField.getName())} in Proto schema " +
+ "but there is no match in the SQL schema")
+ }
+ }
+ }
+
+ /**
+ * Extract a single field from the contained proto schema which has the desired field name,
+ * performing the matching with proper case sensitivity according to SQLConf.resolver.
+ *
+ * @param name The name of the field to search for.
+ * @return `Some(match)` if a matching Proto field is found, otherwise `None`.
+ */
+ private[proto] def getFieldByName(name: String): Option[FieldDescriptor] = {
+
+ // get candidates, ignoring case of field name
+ val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty)
+
+ // search candidates, taking into account case sensitivity settings
+ candidates.filter(f => SQLConf.get.resolver(f.getName(), name)) match {
+ case Seq(protoField) => Some(protoField)
+ case Seq() => None
+ case matches => throw new IncompatibleSchemaException(s"Searching for '$name' in Proto " +
+ s"schema at ${toFieldStr(protoPath)} gave ${matches.size} matches. Candidates: " +
+ matches.map(_.getName()).mkString("[", ", ", "]")
+ )
+ }
+ }
+
+ /** Get the Proto field corresponding to the provided Catalyst field name/position, if any. */
+ def getProtoField(fieldName: String, catalystPos: Int): Option[FieldDescriptor] = {
+ if (positionalFieldMatch) {
+ protoFieldArray.lift(catalystPos)
+ } else {
+ getFieldByName(fieldName)
+ }
+ }
+ }
+
+ def buildDescriptor(protoFilePath: String, messageName: String): Descriptor = {
+ val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(protoFilePath)
+ var result: Descriptors.Descriptor = null;
+
+ for (descriptor <- fileDescriptor.getMessageTypes.asScala) {
+ if (descriptor.getName().equals(messageName)) {
+ result = descriptor
+ }
+ }
+
+ if (null == result) {
+ throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor");
+ }
+ result
+ }
+
+ def parseFileDescriptor(protoFilePath: String): Descriptors.FileDescriptor = {
+ var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
+ try {
+ val dscFile = new BufferedInputStream(new FileInputStream(protoFilePath))
+ fileDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(dscFile)
+ } catch {
+ case ex: InvalidProtocolBufferException =>
+ throw new RuntimeException("Error parsing descriptor byte[] into Descriptor object", ex)
+ case ex: IOException =>
+ throw new RuntimeException("Error reading proto file at path: " + protoFilePath, ex)
+ }
+
+ val descriptorProto: DescriptorProtos.FileDescriptorProto = fileDescriptorSet.getFile(0)
+ try {
+ val fileDescriptor: Descriptors.FileDescriptor = Descriptors.FileDescriptor.buildFrom(
+ descriptorProto, new Array[Descriptors.FileDescriptor](0))
+ if (fileDescriptor.getMessageTypes().isEmpty()) {
+ throw new RuntimeException("No MessageTypes returned, " + fileDescriptor.getName());
+ }
+ fileDescriptor
+ } catch {
+ case e: Descriptors.DescriptorValidationException =>
+ throw new RuntimeException("Error constructing FileDescriptor", e)
+ }
+ }
+
+ /**
+ * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable
+ * string representing the field, like "field 'foo.bar'". If `names` is empty, the string
+ * "top-level record" is returned.
+ */
+ private[proto] def toFieldStr(names: Seq[String]): String = names match {
+ case Seq() => "top-level record"
+ case n => s"field '${n.mkString(".")}'"
+ }
+
+ /** Return true iff `protoField` is nullable, i.e. `UNION` type and has `NULL` as an option. */
+ private[proto] def isNullable(protoField: FieldDescriptor): Boolean =
+ !protoField.isOptional
+
+}
diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala
new file mode 100644
index 0000000000000..f429344af9c69
--- /dev/null
+++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto.utils
+
+import java.util
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.proto.ScalaReflectionLock
+import org.apache.spark.sql.types._
+
+@DeveloperApi
+object SchemaConverters {
+ /**
+ * Internal wrapper for SQL data type and nullability.
+ *
+ * @since 3.4.0
+ */
+ case class SchemaType(dataType: DataType, nullable: Boolean)
+
+ /**
+ * Converts an Proto schema to a corresponding Spark SQL schema.
+ *
+ * @since 3.4.0
+ */
+ def toSqlType(protoSchema: Descriptor): SchemaType = {
+ toSqlTypeHelper(protoSchema)
+ }
+
+ def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized {
+ SchemaType(StructType(descriptor.getFields.asScala.flatMap(structFieldFor).toSeq),
+ nullable = true)
+ }
+
+ def structFieldFor(fd: FieldDescriptor): Option[StructField] = {
+ import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+ val dataType = fd.getJavaType match {
+ case INT => Some(IntegerType)
+ case LONG => Some(LongType)
+ case FLOAT => Some(FloatType)
+ case DOUBLE => Some(DoubleType)
+ case BOOLEAN => Some(BooleanType)
+ case STRING => Some(StringType)
+ case BYTE_STRING => Some(BinaryType)
+ case ENUM => Some(StringType)
+ case MESSAGE =>
+ Option(fd.getMessageType.getFields.asScala.flatMap(structFieldFor).toSeq)
+ .filter(_.nonEmpty)
+ .map(StructType.apply)
+
+ }
+ dataType.map(dt => StructField(
+ fd.getName,
+ if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt,
+ nullable = !fd.isRequired && !fd.isRepeated
+ ))
+ }
+
+ /**
+ * Converts a Spark SQL schema to a corresponding Proto Descriptor
+ *
+ * @since 3.4.0
+ */
+ def toProtoType(catalystType: DataType,
+ nullable: Boolean = false,
+ recordName: String = "topLevelRecord",
+ nameSpace: String = "",
+ indexNum: Int = 0): Descriptor = {
+ val schemaBuilder: DynamicSchema#Builder = new DynamicSchema().newBuilder()
+ schemaBuilder.setName("DynamicSchema.proto")
+ toMessageDefinition(catalystType, recordName, nameSpace, schemaBuilder: DynamicSchema#Builder)
+ schemaBuilder.build().getMessageDescriptor(recordName)
+ }
+
+ def toMessageDefinition(catalystType: DataType, recordName: String,
+ nameSpace: String, schemaBuilder: DynamicSchema#Builder): Unit = {
+ catalystType match {
+ case st: StructType =>
+ val queue = mutable.Queue[ProtoMessage]()
+ val list = new util.ArrayList[ProtoField]()
+ st.foreach { f =>
+ list.add(ProtoField(f.name, f.dataType))
+ }
+ queue += ProtoMessage(recordName, list)
+ while (!queue.isEmpty) {
+ val protoMessage = queue.dequeue()
+ val messageDefinition: MessageDefinition#Builder =
+ new MessageDefinition().newBuilder(protoMessage.messageName)
+ var index = 0
+ protoMessage.fieldList.forEach {
+ protoField =>
+ protoField.catalystType match {
+ case ArrayType(at, containsNull) =>
+ index = index + 1
+ at match {
+ case st: StructType =>
+ messageDefinition.addField("repeated", protoField.name,
+ protoField.name, index)
+ val list = new util.ArrayList[ProtoField]()
+ st.foreach { f =>
+ list.add(ProtoField(f.name, f.dataType))
+ }
+ queue += ProtoMessage(protoField.name, list)
+ case _ =>
+ convertBasicTypes(protoField.catalystType, messageDefinition, "repeated",
+ index, protoField)
+ }
+ case st: StructType =>
+ index = index + 1
+ messageDefinition.addField("optional", protoField.name, protoField.name, index)
+ val list = new util.ArrayList[ProtoField]()
+ st.foreach { f =>
+ list.add(ProtoField(f.name, f.dataType))
+ }
+ queue += ProtoMessage(protoField.name, list)
+ case _ =>
+ index = index + 1
+ convertBasicTypes(protoField.catalystType, messageDefinition, "optional",
+ index, protoField)
+ }
+ }
+ schemaBuilder.addMessageDefinition(messageDefinition.build())
+ }
+ }
+ }
+
+ def convertBasicTypes(catalystType: DataType, messageDefinition: MessageDefinition#Builder,
+ label: String, index: Int, protoField: ProtoField): Unit = {
+ if (sparkToProtoTypeMap.contains(catalystType)) {
+ messageDefinition.addField(label, sparkToProtoTypeMap.get(catalystType).orNull,
+ protoField.name, index)
+ } else {
+ throw new IncompatibleSchemaException(s"Cannot convert SQL type ${catalystType.sql} to " +
+ s"Proto type, try passing proto Descriptor file path to_proto function")
+ }
+ }
+
+ private val sparkToProtoTypeMap = Map[DataType, String](ByteType -> "int32", ShortType -> "int32",
+ IntegerType -> "int32", DateType -> "int32", LongType -> "int64", BinaryType -> "bytes",
+ DoubleType -> "double", FloatType -> "float", TimestampType -> "int64",
+ TimestampNTZType -> "int64", StringType -> "string", BooleanType -> "bool")
+
+ case class ProtoMessage(messageName: String, fieldList: util.ArrayList[ProtoField])
+
+ case class ProtoField(name: String, catalystType: DataType)
+
+ private[proto] class IncompatibleSchemaException(
+ msg: String,
+ ex: Throwable = null) extends Exception(msg, ex)
+
+ private[proto] class UnsupportedProtoTypeException(msg: String) extends Exception(msg)
+
+ private[proto] class UnsupportedProtoValueException(msg: String) extends Exception(msg)
+}
diff --git a/connector/proto/src/test/resources/protobuf/catalyst_types.desc b/connector/proto/src/test/resources/protobuf/catalyst_types.desc
new file mode 100644
index 0000000000000..7a85edcd472ae
--- /dev/null
+++ b/connector/proto/src/test/resources/protobuf/catalyst_types.desc
@@ -0,0 +1,37 @@
+
+”
+@connector/proto/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.proto")
+
+BooleanMsg
+ bool_type (RboolType"+
+
+IntegerMsg
+
+int32_type (R int32Type",
+ DoubleMsg
+double_type (R
+doubleType")
+FloatMsg
+
+float_type (R floatType")
+BytesMsg
+
+bytes_type (R bytesType",
+ StringMsg
+string_type ( R
+stringType".
+Person
+name ( Rname
+age (Rage"n
+Bad
+col_0 (Rcol0
+col_1 (Rcol1
+col_2 ( Rcol2
+col_3 (Rcol3
+col_4 (Rcol4"q
+Actual
+col_0 ( Rcol0
+col_1 (Rcol1
+col_2 (Rcol2
+col_3 (Rcol3
+col_4 (Rcol4BB
CatalystTypesbproto3
\ No newline at end of file
diff --git a/connector/proto/src/test/resources/protobuf/catalyst_types.proto b/connector/proto/src/test/resources/protobuf/catalyst_types.proto
new file mode 100644
index 0000000000000..b0c6aa647535d
--- /dev/null
+++ b/connector/proto/src/test/resources/protobuf/catalyst_types.proto
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/catalyst_types.proto
+// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/catalyst_types.proto
+
+syntax = "proto3";
+
+package org.apache.spark.sql.proto;
+option java_outer_classname = "CatalystTypes";
+
+message BooleanMsg {
+ bool bool_type = 1;
+}
+message IntegerMsg {
+ int32 int32_type = 1;
+}
+message DoubleMsg {
+ double double_type = 1;
+}
+message FloatMsg {
+ float float_type = 1;
+}
+message BytesMsg {
+ bytes bytes_type = 1;
+}
+message StringMsg {
+ string string_type = 1;
+}
+
+message Person {
+ string name = 1;
+ int32 age = 2;
+}
+
+message Bad {
+ bytes col_0 = 1;
+ double col_1 = 2;
+ string col_2 = 3;
+ float col_3 = 4;
+ int64 col_4 = 5;
+}
+
+message Actual {
+ string col_0 = 1;
+ int32 col_1 = 2;
+ float col_2 = 3;
+ bool col_3 = 4;
+ double col_4 = 5;
+}
\ No newline at end of file
diff --git a/connector/proto/src/test/resources/protobuf/proto_functions_suite.desc b/connector/proto/src/test/resources/protobuf/proto_functions_suite.desc
new file mode 100644
index 0000000000000..5fe1dacc6e568
Binary files /dev/null and b/connector/proto/src/test/resources/protobuf/proto_functions_suite.desc differ
diff --git a/connector/proto/src/test/resources/protobuf/proto_functions_suite.proto b/connector/proto/src/test/resources/protobuf/proto_functions_suite.proto
new file mode 100644
index 0000000000000..8ea2f3e05839b
--- /dev/null
+++ b/connector/proto/src/test/resources/protobuf/proto_functions_suite.proto
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// To compile and create test class:
+// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/proto_functions_suite.proto
+// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/proto_functions_suite.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/proto_functions_suite.proto
+
+syntax = "proto3";
+
+package org.apache.spark.sql.proto;
+option java_outer_classname = "SimpleMessageProtos";
+
+message SimpleMessageJavaTypes {
+ int64 id = 1;
+ string string_value = 2;
+ int32 int32_value = 3;
+ int64 int64_value = 4;
+ double double_value = 5;
+ float float_value = 6;
+ bool bool_value = 7;
+ bytes bytes_value = 8;
+}
+
+message SimpleMessage {
+ int64 id = 1;
+ string string_value = 2;
+ int32 int32_value = 3;
+ uint32 uint32_value = 4;
+ sint32 sint32_value = 5;
+ fixed32 fixed32_value = 6;
+ sfixed32 sfixed32_value = 7;
+ int64 int64_value = 8;
+ uint64 uint64_value = 9;
+ sint64 sint64_value = 10;
+ fixed64 fixed64_value = 11;
+ sfixed64 sfixed64_value = 12;
+ double double_value = 13;
+ float float_value = 14;
+ bool bool_value = 15;
+ bytes bytes_value = 16;
+}
+
+message SimpleMessageRepeated {
+ string key = 1;
+ string value = 2;
+ enum NestedEnum {
+ ESTED_NOTHING = 0;
+ NESTED_FIRST = 1;
+ NESTED_SECOND = 2;
+ }
+ repeated string rstring_value = 3;
+ repeated int32 rint32_value = 4;
+ repeated bool rbool_value = 5;
+ repeated int64 rint64_value = 6;
+ repeated float rfloat_value = 7;
+ repeated double rdouble_value = 8;
+ repeated bytes rbytes_value = 9;
+ repeated NestedEnum rnested_enum = 10;
+}
+
+message BasicMessage {
+ int64 id = 1;
+ string string_value = 2;
+ int32 int32_value = 3;
+ int64 int64_value = 4;
+ double double_value = 5;
+ float float_value = 6;
+ bool bool_value = 7;
+ bytes bytes_value = 8;
+}
+
+message RepeatedMessage {
+ repeated BasicMessage basic_message = 1;
+}
+
+message SimpleMessageMap {
+ string key = 1;
+ string value = 2;
+ map string_mapdata = 3;
+ map int32_mapdata = 4;
+ map uint32_mapdata = 5;
+ map sint32_mapdata = 6;
+ map float32_mapdata = 7;
+ map sfixed32_mapdata = 8;
+ map int64_mapdata = 9;
+ map uint64_mapdata = 10;
+ map sint64_mapdata = 11;
+ map fixed64_mapdata = 12;
+ map sfixed64_mapdata = 13;
+ map double_mapdata = 14;
+ map float_mapdata = 15;
+ map bool_mapdata = 16;
+ map bytes_mapdata = 17;
+}
+
+message BasicEnumMessage {
+ enum BasicEnum {
+ NOTHING = 0;
+ FIRST = 1;
+ SECOND = 2;
+ }
+}
+
+message SimpleMessageEnum {
+ string key = 1;
+ string value = 2;
+ enum NestedEnum {
+ ESTED_NOTHING = 0;
+ NESTED_FIRST = 1;
+ NESTED_SECOND = 2;
+ }
+ BasicEnumMessage.BasicEnum basic_enum = 3;
+ NestedEnum nested_enum = 4;
+}
+
+
+message OtherExample {
+ string other = 1;
+}
+
+message IncludedExample {
+ string included = 1;
+ OtherExample other = 2;
+}
+
+message MultipleExample {
+ IncludedExample included_example = 1;
+}
+
diff --git a/connector/proto/src/test/resources/protobuf/proto_serde_suite.desc b/connector/proto/src/test/resources/protobuf/proto_serde_suite.desc
new file mode 100644
index 0000000000000..4a406ab0a6050
--- /dev/null
+++ b/connector/proto/src/test/resources/protobuf/proto_serde_suite.desc
@@ -0,0 +1,27 @@
+
+š
+Cconnector/proto/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.proto"A
+BasicMessage1
+foo (2.org.apache.spark.sql.proto.FooRfoo"
+Foo
+bar (Rbar"'
+MissMatchTypeInRoot
+foo (Rfoo"Q
+FieldMissingInProto:
+foo (2(.org.apache.spark.sql.proto.MissingFieldRfoo"&
+MissingField
+barFoo (RbarFoo"Y
+MissMatchTypeInDeepNested<
+top (2*.org.apache.spark.sql.proto.TypeMissNestedRtop"H
+TypeMissNested6
+foo (2$.org.apache.spark.sql.proto.TypeMissRfoo"
+TypeMiss
+bar (Rbar"\
+FieldMissingInSQLRoot1
+foo (2.org.apache.spark.sql.proto.FooRfoo
+boo (Rboo"L
+FieldMissingInSQLNested1
+foo (2.org.apache.spark.sql.proto.BazRfoo")
+Baz
+bar (Rbar
+baz (RbazBBSimpleMessageProtosbproto3
\ No newline at end of file
diff --git a/connector/proto/src/test/resources/protobuf/proto_serde_suite.proto b/connector/proto/src/test/resources/protobuf/proto_serde_suite.proto
new file mode 100644
index 0000000000000..d535f47b317d1
--- /dev/null
+++ b/connector/proto/src/test/resources/protobuf/proto_serde_suite.proto
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// To compile and create test class:
+// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/proto_serde_suite.proto
+// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/proto_serde_suite.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/proto_serde_suite.proto
+
+syntax = "proto3";
+
+package org.apache.spark.sql.proto;
+option java_outer_classname = "SimpleMessageProtos";
+
+/* Clean Message*/
+message BasicMessage {
+ Foo foo = 1;
+}
+
+message Foo {
+ int32 bar = 1;
+}
+
+/* Field Type missMatch in root Message*/
+message MissMatchTypeInRoot {
+ int64 foo = 1;
+}
+
+/* Field bar missing from proto and Available in SQL*/
+message FieldMissingInProto {
+ MissingField foo = 1;
+}
+
+message MissingField {
+ int64 barFoo = 1;
+}
+
+/* Deep-nested field bar type missMatch Message*/
+message MissMatchTypeInDeepNested {
+ TypeMissNested top = 1;
+}
+
+message TypeMissNested {
+ TypeMiss foo = 1;
+}
+
+message TypeMiss {
+ int64 bar = 1;
+}
+
+/* Field boo missing from SQL root, but available in Proto root*/
+message FieldMissingInSQLRoot {
+ Foo foo = 1;
+ int32 boo = 2;
+}
+
+/* Field baz missing from SQL nested and available in Proto nested*/
+message FieldMissingInSQLNested {
+ Baz foo = 1;
+}
+
+message Baz {
+ int32 bar = 1;
+ int32 baz = 2;
+}
\ No newline at end of file
diff --git a/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala
new file mode 100644
index 0000000000000..a6dca26f7e7ba
--- /dev/null
+++ b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.proto
+
+import com.google.protobuf.{ByteString, DynamicMessage, Message}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters}
+import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.proto.utils.{ProtoOptions, ProtoUtils, SchemaConverters}
+import org.apache.spark.sql.sources.{EqualTo, Not}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class ProtoCatalystDataConversionSuite extends SparkFunSuite
+ with SharedSparkSession
+ with ExpressionEvalHelper {
+
+ private def checkResult(data: Literal, descFilePath: String,
+ messageName: String, expected: Any): Unit = {
+
+ checkEvaluation(
+ ProtoDataToCatalyst(CatalystDataToProto(data, Some(descFilePath),
+ Some(messageName)), descFilePath, messageName, Map.empty),
+ prepareExpectedResult(expected))
+ }
+
+ protected def checkUnsupportedRead(data: Literal, descFilePath: String,
+ actualSchema: String, badSchema: String): Unit = {
+
+ val binary = CatalystDataToProto(data, Some(descFilePath), Some(actualSchema))
+
+ intercept[Exception] {
+ ProtoDataToCatalyst(binary, descFilePath, badSchema,
+ Map("mode" -> "FAILFAST")).eval()
+ }
+
+ val expected = {
+ val protoOptions = ProtoOptions(Map("mode" -> "PERMISSIVE"))
+ val descriptor = ProtoUtils.buildDescriptor(descFilePath, badSchema)
+ val expectedSchema = protoOptions.schema.getOrElse(descriptor)
+ SchemaConverters.toSqlType(expectedSchema).dataType match {
+ case st: StructType => Row.fromSeq((0 until st.length).map {
+ _ => null
+ })
+ case _ => null
+ }
+ }
+
+ checkEvaluation(ProtoDataToCatalyst(binary, descFilePath, badSchema,
+ Map("mode" -> "PERMISSIVE")), expected)
+ }
+
+ protected def prepareExpectedResult(expected: Any): Any = expected match {
+ // Spark byte and short both map to proto int
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult))
+ case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult))
+ case map: MapData =>
+ val keys = new GenericArrayData(
+ map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
+ val values = new GenericArrayData(
+ map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
+ new ArrayBasedMapData(keys, values)
+ case other => other
+ }
+
+ private val testingTypes = Seq(
+ StructType(StructField("bool_type", BooleanType, nullable = false) :: Nil),
+ StructType(StructField("int32_type", IntegerType, nullable = false) :: Nil),
+ StructType(StructField("double_type", DoubleType, nullable = false) :: Nil),
+ StructType(StructField("float_type", FloatType, nullable = false) :: Nil),
+ StructType(StructField("bytes_type", BinaryType, nullable = false) :: Nil),
+ StructType(StructField("string_type", StringType, nullable = false) :: Nil),
+ StructType(StructField("int32_type", ByteType, nullable = false) :: Nil),
+ StructType(StructField("int32_type", ShortType, nullable = false) :: Nil)
+ )
+
+ private val catalystTypesToProtoMessages: Map[DataType, String] = Map(
+ BooleanType -> "BooleanMsg",
+ IntegerType -> "IntegerMsg",
+ DoubleType -> "DoubleMsg",
+ FloatType -> "FloatMsg",
+ BinaryType -> "BytesMsg",
+ StringType -> "StringMsg",
+ ByteType -> "IntegerMsg",
+ ShortType -> "IntegerMsg"
+ )
+
+ testingTypes.foreach { dt =>
+ val seed = scala.util.Random.nextLong()
+ val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ test(s"single $dt with seed $seed") {
+ val rand = new scala.util.Random(seed)
+ val data = RandomDataGenerator.forType(dt, rand = rand).get.apply()
+ val converter = CatalystTypeConverters.createToCatalystConverter(dt)
+ val input = Literal.create(converter(data), dt)
+ checkResult(input, filePath, catalystTypesToProtoMessages(dt.fields(0).dataType),
+ input.eval())
+ }
+ }
+
+ private def checkDeserialization(
+ descFilePath: String,
+ messageName: String,
+ data: Message,
+ expected: Option[Any],
+ filters: StructFilters = new NoopFilters): Unit = {
+
+ val descriptor = ProtoUtils.buildDescriptor(descFilePath, messageName)
+ val dataType = SchemaConverters.toSqlType(descriptor).dataType
+
+ val deserializer = new ProtoDeserializer(
+ descriptor,
+ dataType,
+ false,
+ RebaseSpec(SQLConf.LegacyBehaviorPolicy.CORRECTED),
+ filters)
+
+ val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray)
+ val deserialized = deserializer.deserialize(dynMsg)
+ expected match {
+ case None => assert(deserialized.isEmpty)
+ case Some(d) =>
+ assert(checkResult(d, deserialized.get, dataType, exprNullable = false))
+ }
+ }
+
+ test("Handle unsupported input of message type") {
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val actualSchema = StructType(Seq(
+ StructField("col_0", StringType, nullable = false),
+ StructField("col_1", IntegerType, nullable = false),
+ StructField("col_2", FloatType, nullable = false),
+ StructField("col_3", BooleanType, nullable = false),
+ StructField("col_4", DoubleType, nullable = false)))
+
+ val seed = scala.util.Random.nextLong()
+ withClue(s"create random record with seed $seed") {
+ val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema)
+ val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema)
+ val input = Literal.create(converter(data), actualSchema)
+ checkUnsupportedRead(input, testFileDesc, "Actual", "Bad")
+ }
+ }
+
+ test("filter push-down to proto deserializer") {
+
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val sqlSchema = new StructType()
+ .add("name", "string")
+ .add("age", "int")
+
+ val descriptor = ProtoUtils.buildDescriptor(testFileDesc, "Person")
+ val dynamicMessage = DynamicMessage.newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "Maxim")
+ .setField(descriptor.findFieldByName("age"), 39)
+ .build()
+
+ val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39))
+ checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow)
+ checkDeserialization(
+ testFileDesc,
+ "Person",
+ dynamicMessage,
+ expectedRow,
+ new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema))
+
+ checkDeserialization(
+ testFileDesc,
+ "Person",
+ dynamicMessage,
+ None,
+ new OrderedFilters(Seq(Not(EqualTo("name", "Maxim"))), sqlSchema))
+ }
+
+ test("ProtoDeserializer with binary type") {
+
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace(
+ "file:/", "/")
+ val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53))
+
+ val descriptor = ProtoUtils.buildDescriptor(testFileDesc, "BytesMsg")
+
+ val dynamicMessage = DynamicMessage.newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb))
+ .build()
+
+ val expected = InternalRow(Array[Byte](97, 48, 53))
+ checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected))
+ }
+}
diff --git a/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala
new file mode 100644
index 0000000000000..0e68333039030
--- /dev/null
+++ b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala
@@ -0,0 +1,540 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.proto
+
+import com.google.protobuf.{ByteString, Descriptors, DynamicMessage}
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.functions.{lit, struct}
+import org.apache.spark.sql.proto.utils.ProtoUtils
+import org.apache.spark.sql.test.SharedSparkSession
+
+class ProtoFunctionsSuite extends QueryTest with SharedSparkSession with Serializable {
+
+ import testImplicits._
+
+ val testFileDesc = testFile("protobuf/proto_functions_suite.desc").replace("file:/", "/")
+
+ test("roundtrip in to_proto and from_proto - struct") {
+ val df = spark.range(10).select(
+ struct(
+ $"id",
+ $"id".cast("string").as("string_value"),
+ $"id".cast("int").as("int32_value"),
+ $"id".cast("int").as("uint32_value"),
+ $"id".cast("int").as("sint32_value"),
+ $"id".cast("int").as("fixed32_value"),
+ $"id".cast("int").as("sfixed32_value"),
+ $"id".cast("long").as("int64_value"),
+ $"id".cast("long").as("uint64_value"),
+ $"id".cast("long").as("sint64_value"),
+ $"id".cast("long").as("fixed64_value"),
+ $"id".cast("long").as("sfixed64_value"),
+ $"id".cast("double").as("double_value"),
+ lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"),
+ lit(true).as("bool_value"),
+ lit("0".getBytes).as("bytes_value")
+ ).as("SimpleMessage")
+ )
+ val protoStructDF = df.select(functions.to_proto($"SimpleMessage", testFileDesc,
+ "SimpleMessage").as("proto"))
+ val actualDf = protoStructDF.select(functions.from_proto($"proto", testFileDesc,
+ "SimpleMessage").as("proto.*"))
+ checkAnswer(actualDf, df)
+ }
+
+ test("roundtrip in to_proto(without descriptor params) and from_proto - struct") {
+ val df = spark.range(10).select(
+ struct(
+ $"id",
+ $"id".cast("string").as("string_value"),
+ $"id".cast("int").as("int32_value"),
+ $"id".cast("long").as("int64_value"),
+ $"id".cast("double").as("double_value"),
+ lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"),
+ lit(true).as("bool_value"),
+ lit("0".getBytes).as("bytes_value")
+ ).as("SimpleMessageJavaTypes")
+ )
+ val protoStructDF = df.select(functions.to_proto($"SimpleMessageJavaTypes").as("proto"))
+ val actualDf = protoStructDF.select(functions.from_proto($"proto", testFileDesc,
+ "SimpleMessageJavaTypes").as("proto.*"))
+ checkAnswer(actualDf, df)
+ }
+
+ test("roundtrip in from_proto and to_proto - Repeated") {
+ val descriptor = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated")
+
+ val dynamicMessage = DynamicMessage.newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("key"), "key")
+ .setField(descriptor.findFieldByName("value"), "value")
+ .addRepeatedField(descriptor.findFieldByName("rbool_value"), false)
+ .addRepeatedField(descriptor.findFieldByName("rbool_value"), true)
+ .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654D)
+ .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654D)
+ .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f)
+ .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f)
+ .addRepeatedField(descriptor.findFieldByName("rnested_enum"),
+ descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
+ .addRepeatedField(descriptor.findFieldByName("rnested_enum"),
+ descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "SimpleMessageRepeated").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc,
+ "SimpleMessageRepeated").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "SimpleMessageRepeated").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto - Repeated Message Once") {
+ val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
+ val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer"))
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "RepeatedMessage").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc,
+ "RepeatedMessage").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "RepeatedMessage").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto(without descriptor params) - Repeated Message Once") {
+ val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
+ val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer"))
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "RepeatedMessage").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "RepeatedMessage").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto - Repeated Message Twice") {
+ val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
+ val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage1 = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value1")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer1"))
+ .build()
+ val basicMessage2 = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1112L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value2")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12346)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), false)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer2"))
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "RepeatedMessage").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc,
+ "RepeatedMessage").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "RepeatedMessage").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto(without descriptor params) - Repeated Message Twice") {
+ val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
+ val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage1 = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value1")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer1"))
+ .build()
+ val basicMessage2 = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1112L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value2")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12346)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), false)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer2"))
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "RepeatedMessage").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "RepeatedMessage").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto - Map") {
+ val messageMapDesc = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageMap")
+
+ val mapStr1 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"),
+ "key value1")
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"),
+ "value value2")
+ .build()
+ val mapStr2 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"),
+ "key value2")
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"),
+ "value value2")
+ .build()
+ val mapInt64 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("Int64MapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"),
+ 0x90000000000L)
+ .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"),
+ 0x90000000001L)
+ .build()
+ val mapInt32 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("Int32MapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"),
+ 12345)
+ .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"),
+ 54321)
+ .build()
+ val mapFloat = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("FloatMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"),
+ "float key")
+ .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"),
+ 109202.234F)
+ .build()
+ val mapDouble = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("DoubleMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key")
+ , "double key")
+ .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value")
+ , 109202.12D)
+ .build()
+ val mapBool = DynamicMessage.newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"),
+ true)
+ .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"),
+ false)
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(messageMapDesc)
+ .setField(messageMapDesc.findFieldByName("key"), "key")
+ .setField(messageMapDesc.findFieldByName("value"), "value")
+ .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1)
+ .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2)
+ .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64)
+ .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32)
+ .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat)
+ .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble)
+ .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "SimpleMessageMap").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc,
+ "SimpleMessageMap").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "SimpleMessageMap").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto(without descriptor params) - Map") {
+ val messageMapDesc = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageMap")
+
+ val mapStr1 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"),
+ "key value1")
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"),
+ "value value2")
+ .build()
+ val mapStr2 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"),
+ "key value2")
+ .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"),
+ "value value2")
+ .build()
+ val mapInt64 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("Int64MapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"),
+ 0x90000000000L)
+ .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"),
+ 0x90000000001L)
+ .build()
+ val mapInt32 = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("Int32MapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"),
+ 12345)
+ .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"),
+ 54321)
+ .build()
+ val mapFloat = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("FloatMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"),
+ "float key")
+ .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"),
+ 109202.234F)
+ .build()
+ val mapDouble = DynamicMessage.newBuilder(
+ messageMapDesc.findNestedTypeByName("DoubleMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key")
+ , "double key")
+ .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value")
+ , 109202.12D)
+ .build()
+ val mapBool = DynamicMessage.newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry"))
+ .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"),
+ true)
+ .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"),
+ false)
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(messageMapDesc)
+ .setField(messageMapDesc.findFieldByName("key"), "key")
+ .setField(messageMapDesc.findFieldByName("value"), "value")
+ .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1)
+ .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2)
+ .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64)
+ .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32)
+ .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat)
+ .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble)
+ .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "SimpleMessageMap").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "SimpleMessageMap").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto - Enum") {
+ val messageEnumDesc = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum")
+ val basicEnumDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicEnumMessage")
+
+ val dynamicMessage = DynamicMessage.newBuilder(messageEnumDesc)
+ .setField(messageEnumDesc.findFieldByName("key"), "key")
+ .setField(messageEnumDesc.findFieldByName("value"), "value")
+ .setField(messageEnumDesc.findFieldByName("nested_enum"),
+ messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
+ .setField(messageEnumDesc.findFieldByName("nested_enum"),
+ messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
+ .setField(messageEnumDesc.findFieldByName("basic_enum"),
+ basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("FIRST"))
+ .setField(messageEnumDesc.findFieldByName("basic_enum"),
+ basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("NOTHING"))
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "SimpleMessageEnum").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc,
+ "SimpleMessageEnum").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "SimpleMessageEnum").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto - Multiple Message") {
+ val messageMultiDesc = ProtoUtils.buildDescriptor(testFileDesc, "MultipleExample")
+ val messageIncludeDesc = ProtoUtils.buildDescriptor(testFileDesc, "IncludedExample")
+ val messageOtherDesc = ProtoUtils.buildDescriptor(testFileDesc, "OtherExample")
+
+ val otherMessage = DynamicMessage.newBuilder(messageOtherDesc)
+ .setField(messageOtherDesc.findFieldByName("other"), "other value")
+ .build()
+
+ val includeMessage = DynamicMessage.newBuilder(messageIncludeDesc)
+ .setField(messageIncludeDesc.findFieldByName("included"), "included value")
+ .setField(messageIncludeDesc.findFieldByName("other"), otherMessage)
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(messageMultiDesc)
+ .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "MultipleExample").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc,
+ "MultipleExample").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "MultipleExample").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_proto and to_proto(without descriptor params) - Multiple Message") {
+ val messageMultiDesc = ProtoUtils.buildDescriptor(testFileDesc, "MultipleExample")
+ val messageIncludeDesc = ProtoUtils.buildDescriptor(testFileDesc, "IncludedExample")
+ val messageOtherDesc = ProtoUtils.buildDescriptor(testFileDesc, "OtherExample")
+
+ val otherMessage = DynamicMessage.newBuilder(messageOtherDesc)
+ .setField(messageOtherDesc.findFieldByName("other"), "other value")
+ .build()
+
+ val includeMessage = DynamicMessage.newBuilder(messageIncludeDesc)
+ .setField(messageIncludeDesc.findFieldByName("included"), "included value")
+ .setField(messageIncludeDesc.findFieldByName("other"), otherMessage)
+ .build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(messageMultiDesc)
+ .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc,
+ "MultipleExample").as("value_from"))
+ val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc,
+ "MultipleExample").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in to_proto and from_proto - with null") {
+ val df = spark.range(10).select(
+ struct(
+ $"id",
+ lit(null).cast("string").as("string_value"),
+ $"id".cast("int").as("int32_value"),
+ $"id".cast("int").as("uint32_value"),
+ $"id".cast("int").as("sint32_value"),
+ $"id".cast("int").as("fixed32_value"),
+ $"id".cast("int").as("sfixed32_value"),
+ $"id".cast("long").as("int64_value"),
+ $"id".cast("long").as("uint64_value"),
+ $"id".cast("long").as("sint64_value"),
+ $"id".cast("long").as("fixed64_value"),
+ $"id".cast("long").as("sfixed64_value"),
+ $"id".cast("double").as("double_value"),
+ lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"),
+ lit(true).as("bool_value"),
+ lit("0".getBytes).as("bytes_value")
+ ).as("SimpleMessage")
+ )
+
+ intercept[SparkException] {
+ df.select(functions.to_proto($"SimpleMessage", testFileDesc,
+ "SimpleMessage").as("proto")).collect()
+ }
+ }
+
+ test("from_proto filter to_proto") {
+ val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage = DynamicMessage.newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "slam")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer"))
+ .build()
+
+
+ val df = Seq(basicMessage.toByteArray).toDF("value")
+ val resultFrom = df.select(functions.from_proto($"value", testFileDesc,
+ "BasicMessage") as 'sample)
+ .where("sample.string_value == \"slam\"")
+
+ val resultToFrom = resultFrom.select(functions.to_proto($"sample") as 'value)
+ .select(functions.from_proto($"value", testFileDesc,
+ "BasicMessage") as 'sample)
+ .where("sample.string_value == \"slam\"")
+
+ assert(resultFrom.except(resultToFrom).isEmpty)
+ }
+
+ def buildDescriptor(desc: String): Descriptors.Descriptor = {
+ val descriptor = ProtoUtils.parseFileDescriptor(testFileDesc).getMessageTypes().get(0)
+ descriptor
+ }
+}
diff --git a/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala
new file mode 100644
index 0000000000000..b02c52487b480
--- /dev/null
+++ b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.proto
+
+import com.google.protobuf.Descriptors.Descriptor
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.sql.catalyst.NoopFilters
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.CORRECTED
+import org.apache.spark.sql.proto.utils.ProtoUtils
+import org.apache.spark.sql.proto.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+/**
+ * Tests for [[ProtoSerializer]] and [[ProtoDeserializer]]
+ * with a more specific focus on those classes.
+ */
+class ProtoSerdeSuite extends SharedSparkSession {
+
+ import ProtoSerdeSuite.MatchType._
+ import ProtoSerdeSuite._
+
+ val testFileDesc = testFile("protobuf/proto_serde_suite.desc").replace("file:/", "/")
+
+ test("Test basic conversion") {
+ withFieldMatchType { fieldMatch =>
+ val (top, nest) = fieldMatch match {
+ case BY_NAME => ("foo", "bar")
+ case BY_POSITION => ("NOTfoo", "NOTbar")
+ }
+ val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val dynamicMessageFoo = DynamicMessage.newBuilder(
+ protoFile.getFile.findMessageTypeByName("Foo")).setField(
+ protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"),
+ 10902).build()
+
+ val dynamicMessage = DynamicMessage.newBuilder(protoFile)
+ .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo).build()
+
+ val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
+ val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
+
+ assert(serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage)
+ }
+ }
+
+ test("Fail to convert with field type mismatch") {
+ val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot")
+
+ withFieldMatchType { fieldMatch =>
+ assertFailedConversionMessage(protoFile, Deserializer, fieldMatch,
+ "Cannot convert Proto field 'foo' to SQL field 'foo' because schema is incompatible " +
+ s"(protoType = org.apache.spark.sql.proto.MissMatchTypeInRoot.foo " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})"
+ .stripMargin)
+
+ assertFailedConversionMessage(protoFile, Serializer, fieldMatch,
+ s"Cannot convert SQL field 'foo' to Proto field 'foo' because schema is incompatible " +
+ s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""")
+ }
+ }
+
+ test("Fail to convert with missing nested Proto fields") {
+ val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "FieldMissingInProto")
+
+ val nonnullCatalyst = new StructType()
+ .add("foo", new StructType().add("bar", IntegerType, nullable = false))
+ // Positional matching will work fine with the name change, so add a new field
+ val extraNonnullCatalyst = new StructType().add("foo",
+ new StructType().add("bar", IntegerType).add("baz", IntegerType, nullable = false))
+
+ // deserialize should have no issues when 'bar' is nullable=false but fail when nullable=true
+ Deserializer.create(CATALYST_STRUCT, protoFile, BY_NAME)
+ assertFailedConversionMessage(protoFile, Deserializer, BY_NAME,
+ "Cannot find field 'foo.bar' in Proto schema",
+ nonnullCatalyst)
+ assertFailedConversionMessage(protoFile, Deserializer, BY_POSITION,
+ "Cannot find field at position 1 of field 'foo' from Proto schema (using positional" +
+ " matching)",
+ extraNonnullCatalyst)
+
+ // serialize fails whether or not 'bar' is nullable
+ val byNameMsg = "Cannot find field 'foo.bar' in Proto schema"
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg)
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst)
+ assertFailedConversionMessage(protoFile, Serializer, BY_POSITION,
+ "Cannot find field at position 1 of field 'foo' from Proto schema (using positional" +
+ " matching)",
+ extraNonnullCatalyst)
+ }
+
+ test("Fail to convert with deeply nested field type mismatch") {
+ val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested")
+ val catalyst = new StructType().add("top", CATALYST_STRUCT)
+
+ withFieldMatchType { fieldMatch =>
+ assertFailedConversionMessage(protoFile, Deserializer, fieldMatch,
+ s"Cannot convert Proto field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " +
+ s"is incompatible (protoType = org.apache.spark.sql.proto.TypeMiss.bar " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin,
+ catalyst)
+
+ assertFailedConversionMessage(protoFile, Serializer, fieldMatch,
+ "Cannot convert SQL field 'top.foo.bar' to Proto field 'top.foo.bar' because schema is " +
+ """incompatible (sqlType = INT, protoType = LONG)""",
+ catalyst)
+ }
+ }
+
+ test("Fail to convert with missing Catalyst fields") {
+ val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot")
+
+ // serializing with extra fails if extra field is missing in SQL Schema
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME,
+ "Found field 'boo' in Proto schema but there is no match in the SQL schema")
+ assertFailedConversionMessage(protoFile, Serializer, BY_POSITION,
+ "Found field 'boo' at position 1 of top-level record from Proto schema but there is no " +
+ "match in the SQL schema at top-level record (using positional matching)")
+
+ /* deserializing should work regardless of whether the extra field is missing
+ in SQL Schema or not */
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
+
+
+ val protoNestedFile = ProtoUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested")
+
+ // serializing with extra fails if extra field is missing in SQL Schema
+ assertFailedConversionMessage(protoNestedFile, Serializer, BY_NAME,
+ "Found field 'foo.baz' in Proto schema but there is no match in the SQL schema")
+ assertFailedConversionMessage(protoNestedFile, Serializer, BY_POSITION,
+ s"Found field 'baz' at position 1 of field 'foo' from Proto schema but there is no match " +
+ s"in the SQL schema at field 'foo' (using positional matching)")
+
+ /* deserializing should work regardless of whether the extra field is missing
+ in SQL Schema or not */
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ }
+
+ /**
+ * Attempt to convert `catalystSchema` to `protoSchema` (or vice-versa if `deserialize` is true),
+ * assert that it fails, and assert that the _cause_ of the thrown exception has a message
+ * matching `expectedCauseMessage`.
+ */
+ private def assertFailedConversionMessage(protoSchema: Descriptor,
+ serdeFactory: SerdeFactory[_],
+ fieldMatchType: MatchType,
+ expectedCauseMessage: String,
+ catalystSchema: StructType = CATALYST_STRUCT): Unit = {
+ val e = intercept[IncompatibleSchemaException] {
+ serdeFactory.create(catalystSchema, protoSchema, fieldMatchType)
+ }
+ val expectMsg = serdeFactory match {
+ case Deserializer =>
+ s"Cannot convert Proto type ${protoSchema.getName} to SQL type ${catalystSchema.sql}."
+ case Serializer =>
+ s"Cannot convert SQL type ${catalystSchema.sql} to Proto type ${protoSchema.getName}."
+ }
+
+ assert(e.getMessage === expectMsg)
+ assert(e.getCause.getMessage === expectedCauseMessage)
+ }
+
+ def withFieldMatchType(f: MatchType => Unit): Unit = {
+ MatchType.values.foreach { fieldMatchType =>
+ withClue(s"fieldMatchType == $fieldMatchType") {
+ f(fieldMatchType)
+ }
+ }
+ }
+}
+
+
+object ProtoSerdeSuite {
+
+ private val CATALYST_STRUCT =
+ new StructType().add("foo", new StructType().add("bar", IntegerType))
+
+ /**
+ * Specifier for type of field matching to be used for easy creation of tests that do both
+ * positional and by-name field matching.
+ */
+ private object MatchType extends Enumeration {
+ type MatchType = Value
+ val BY_NAME, BY_POSITION = Value
+
+ def isPositional(fieldMatchType: MatchType): Boolean = fieldMatchType == BY_POSITION
+ }
+
+ import MatchType._
+
+ /**
+ * Specifier for type of serde to be used for easy creation of tests that do both
+ * serialization and deserialization.
+ */
+ private sealed trait SerdeFactory[T] {
+ def create(sqlSchema: StructType, descriptor: Descriptor, fieldMatchType: MatchType): T
+ }
+
+ private object Serializer extends SerdeFactory[ProtoSerializer] {
+ override def create(sql: StructType, descriptor: Descriptor, matchType: MatchType):
+ ProtoSerializer = new ProtoSerializer(sql, descriptor, false, isPositional(matchType))
+ }
+
+ private object Deserializer extends SerdeFactory[ProtoDeserializer] {
+ override def create(sql: StructType, descriptor: Descriptor, matchType: MatchType):
+ ProtoDeserializer = new ProtoDeserializer( descriptor, sql, isPositional(matchType),
+ RebaseSpec(CORRECTED), new NoopFilters)
+ }
+}
diff --git a/pom.xml b/pom.xml
index 5fbd82ad57add..27906a7b9162e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -101,6 +101,7 @@
connector/kafka-0-10-sql
connector/avro
connect
+ connector/proto
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 6ffc1d880c5d1..184e34b9a0c14 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -45,8 +45,8 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro) = Seq(
- "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro"
+ val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, proto) = Seq(
+ "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "proto"
).map(ProjectRef(buildLocation, _))
val streamingProjects@Seq(streaming, streamingKafka010) =
@@ -59,7 +59,7 @@ object BuildCommons {
) = Seq(
"core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
"tags", "sketch", "kvstore"
- ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect)
+ ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) ++ Seq(proto)
val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn,
sparkGangliaLgpl, streamingKinesisAsl,
@@ -390,7 +390,7 @@ object SparkBuild extends PomBuild {
val mimaProjects = allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn,
- unsafe, tags, tokenProviderKafka010, sqlKafka010, connect
+ unsafe, tags, tokenProviderKafka010, sqlKafka010, connect, proto
).contains(x)
}
@@ -433,6 +433,9 @@ object SparkBuild extends PomBuild {
enable(SparkConnect.settings)(connect)
+ /* Connector/proto settings */
+ enable(SparkProto.settings)(proto)
+
// SPARK-14738 - Remove docker tests from main Spark build
// enable(DockerIntegrationTests.settings)(dockerIntegrationTests)
@@ -651,6 +654,7 @@ object SparkConnect {
ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll,
ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll,
ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll,
+ ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connector.proto.protobuf.@1").inAll,
),
(assembly / assemblyMergeStrategy) := {
@@ -662,6 +666,48 @@ object SparkConnect {
)
}
+object SparkProto {
+
+ import BuildCommons.protoVersion
+
+ private val shadePrefix = "org.sparkproject.spark-proto"
+ val shadeJar = taskKey[Unit]("Shade the Jars")
+
+ lazy val settings = Seq(
+ // Setting version for the protobuf compiler. This has to be propagated to every sub-project
+ // even if the project is not using it.
+ PB.protocVersion := BuildCommons.protoVersion,
+
+ // For some reason the resolution from the imported Maven build does not work for some
+ // of these dependendencies that we need to shade later on.
+ libraryDependencies ++= Seq(
+ "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
+ ),
+
+ dependencyOverrides ++= Seq(
+ "com.google.protobuf" % "protobuf-java" % protoVersion
+ ),
+
+ (Compile / PB.targets) := Seq(
+ PB.gens.java -> (Compile / sourceManaged).value,
+ ),
+
+ (assembly / test) := false,
+
+ (assembly / logLevel) := Level.Info,
+
+ (assembly / assemblyShadeRules) := Seq(
+ ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.spark-proto.protobuf.@1").inAll,
+ ),
+
+ (assembly / assemblyMergeStrategy) := {
+ case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard
+ // Drop all proto files that are not needed as artifacts of the build.
+ case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard
+ case _ => MergeStrategy.first
+ },
+ )
+}
object Unsafe {
lazy val settings = Seq(
// This option is needed to suppress warnings from sun.misc.Unsafe usage
@@ -1107,10 +1153,10 @@ object Unidoc {
(ScalaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
- yarn, tags, streamingKafka010, sqlKafka010, connect),
+ yarn, tags, streamingKafka010, sqlKafka010, connect, proto),
(JavaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
- yarn, tags, streamingKafka010, sqlKafka010, connect),
+ yarn, tags, streamingKafka010, sqlKafka010, connect, proto),
(ScalaUnidoc / unidoc / unidocAllClasspaths) := {
ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value)
@@ -1196,6 +1242,7 @@ object CopyDependencies {
// produce the shaded Jar which happens automatically in the case of Maven.
// Later, when the dependencies are copied, we manually copy the shaded Jar only.
val fid = (LocalProject("connect")/assembly).value
+ val fidProto = (LocalProject("proto")/assembly).value
(Compile / dependencyClasspath).value.map(_.data)
.filter { jar => jar.isFile() }
@@ -1208,6 +1255,9 @@ object CopyDependencies {
if (jar.getName.contains("spark-connect") &&
!SbtPomKeys.profiles.value.contains("noshade-connect")) {
Files.copy(fid.toPath, destJar.toPath)
+ } else if (jar.getName.contains("spark-proto") &&
+ !SbtPomKeys.profiles.value.contains("noshade-proto")) {
+ Files.copy(fid.toPath, destJar.toPath)
} else {
Files.copy(jar.toPath(), destJar.toPath())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index bc6700a3b5616..adb9b91b3e5e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3673,6 +3673,22 @@ object SQLConf {
.checkValues(LegacyBehaviorPolicy.values.map(_.toString))
.createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString)
+ val PROTO_REBASE_MODE_IN_READ =
+ buildConf("spark.sql.proto.datetimeRebaseModeInRead")
+ .internal()
+ .doc("When LEGACY, Spark will rebase dates/timestamps from the legacy hybrid (Julian + " +
+ "Gregorian) calendar to Proleptic Gregorian calendar when reading Proto Events. " +
+ "When CORRECTED, Spark will not do rebase and read the dates/timestamps as it is. " +
+ "When EXCEPTION, which is the default, Spark will fail the reading if it sees " +
+ "ancient dates/timestamps that are ambiguous between the two calendars. This config is " +
+ "only effective if the writer info (like Spark, Hive) of the Proto events is unknown.")
+ .version("3.4.0")
+ .withAlternative("spark.sql.legacy.proto.datetimeRebaseModeInRead")
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValues(LegacyBehaviorPolicy.values.map(_.toString))
+ .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString)
+
val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT =
buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds")
.internal()