From 6af374953301ff959f1c1d3e4e0d5812a6e6a6e3 Mon Sep 17 00:00:00 2001 From: SandishKumarHN Date: Sun, 25 Sep 2022 16:08:22 -0700 Subject: [PATCH 1/3] connector/proto: from_proto and to_proto support --- .github/workflows/build_and_test.yml | 2 +- connector/proto/pom.xml | 119 ++++ .../spark/sql/proto/CatalystDataToProto.scala | 58 ++ .../spark/sql/proto/ProtoDataToCatalyst.scala | 144 +++++ .../spark/sql/proto/ProtoDeserializer.scala | 347 +++++++++++ .../spark/sql/proto/ProtoSerializer.scala | 241 ++++++++ .../apache/spark/sql/proto/functions.scala | 86 +++ .../org/apache/spark/sql/proto/package.scala | 21 + .../spark/sql/proto/utils/DynamicSchema.scala | 172 ++++++ .../sql/proto/utils/MessageDefinition.scala | 99 ++++ .../spark/sql/proto/utils/ProtoOptions.scala | 102 ++++ .../spark/sql/proto/utils/ProtoUtils.scala | 300 ++++++++++ .../sql/proto/utils/SchemaConverters.scala | 172 ++++++ .../resources/protobuf/catalyst_types.desc | 37 ++ .../resources/protobuf/catalyst_types.proto | 63 ++ .../protobuf/proto_functions_suite.desc | Bin 0 -> 5060 bytes .../protobuf/proto_functions_suite.proto | 142 +++++ .../resources/protobuf/proto_serde_suite.desc | 27 + .../protobuf/proto_serde_suite.proto | 76 +++ .../ProtoCatalystDataConversionSuite.scala | 213 +++++++ .../spark/sql/proto/ProtoFunctionsSuite.scala | 540 ++++++++++++++++++ .../spark/sql/proto/ProtoSerdeSuite.scala | 229 ++++++++ pom.xml | 1 + project/SparkBuild.scala | 58 +- .../apache/spark/sql/internal/SQLConf.scala | 16 + 25 files changed, 3260 insertions(+), 5 deletions(-) create mode 100644 connector/proto/pom.xml create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala create mode 100644 connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala create mode 100644 connector/proto/src/test/resources/protobuf/catalyst_types.desc create mode 100644 connector/proto/src/test/resources/protobuf/catalyst_types.proto create mode 100644 connector/proto/src/test/resources/protobuf/proto_functions_suite.desc create mode 100644 connector/proto/src/test/resources/protobuf/proto_functions_suite.proto create mode 100644 connector/proto/src/test/resources/protobuf/proto_serde_suite.desc create mode 100644 connector/proto/src/test/resources/protobuf/proto_serde_suite.proto create mode 100644 connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala create mode 100644 connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala create mode 100644 connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index b0847187dffdd..d53133e09b33a 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -146,7 +146,7 @@ jobs: - >- core, unsafe, kvstore, avro, network-common, network-shuffle, repl, launcher, - examples, sketch, graphx + examples, sketch, graphx, proto - >- catalyst, hive-thriftserver - >- diff --git a/connector/proto/pom.xml b/connector/proto/pom.xml new file mode 100644 index 0000000000000..12cdae3b12202 --- /dev/null +++ b/connector/proto/pom.xml @@ -0,0 +1,119 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.12 + 3.4.0-SNAPSHOT + ../../pom.xml + + + spark-proto_2.12 + + proto + 3.21.1 + + jar + Spark Proto + https://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + org.tukaani + xz + + + com.google.protobuf + protobuf-java + ${protobuf.version} + compile + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + com.google.protobuf:* + + + + + com.google.protobuf + ${spark.shade.packageName}.spark-proto.protobuf + + com.google.protobuf.** + + + + + + + + diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala new file mode 100644 index 0000000000000..1846d360db13d --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/CatalystDataToProto.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto + +import com.google.protobuf.DynamicMessage + +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.proto.utils.{ProtoUtils, SchemaConverters} +import org.apache.spark.sql.types.{BinaryType, DataType} + +private[proto] case class CatalystDataToProto( + child: Expression, + descFilePath: Option[String], + messageName: Option[String]) + extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val protoType = (descFilePath, messageName) match { + case (Some(a), Some(b)) => ProtoUtils.buildDescriptor(a, b) + case _ => SchemaConverters.toProtoType(child.dataType, child.nullable) + } + + @transient private lazy val serializer = new ProtoSerializer(child.dataType, protoType, + child.nullable) + + override def nullSafeEval(input: Any): Any = { + val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage] + dynamicMessage.toByteArray + } + + override def prettyName: String = "to_proto" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => + s"(byte[]) $expr.nullSafeEval($input)") + } + + override protected def withNewChildInternal(newChild: Expression): CatalystDataToProto = + copy(child = newChild) +} + diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala new file mode 100644 index 0000000000000..42a3f4be890d7 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDataToCatalyst.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto + +import java.io.ByteArrayInputStream + +import scala.util.control.NonFatal + +import com.google.protobuf.DynamicMessage + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, + SpecificInternalRow, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.proto.utils.{ProtoOptions, ProtoUtils, SchemaConverters} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType} + +private[proto] case class ProtoDataToCatalyst(child: Expression, descFilePath: String, + messageName: String, + options: Map[String, String]) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = { + val dt = SchemaConverters.toSqlType(expectedSchema).dataType + parseMode match { + // With PermissiveMode, the output Catalyst row might contain columns of null values for + // corrupt records, even if some of the columns are not nullable in the user-provided schema. + // Therefore we force the schema to be all nullable here. + case PermissiveMode => dt.asNullable + case _ => dt + } + } + + override def nullable: Boolean = true + + private lazy val protoOptions = ProtoOptions(options) + + @transient private lazy val descriptor = ProtoUtils.buildDescriptor(descFilePath, messageName) + + @transient private lazy val expectedSchema = protoOptions.schema.getOrElse(descriptor) + + @transient private lazy val deserializer = new ProtoDeserializer(expectedSchema, dataType, + protoOptions.datetimeRebaseModeInRead) + + @transient private var result: Any = _ + + @transient private lazy val parseMode: ParseMode = { + val mode = protoOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(unacceptableModeMessage(mode.name)) + } + mode + } + + private def unacceptableModeMessage(name: String): String = { + s"from_proto() doesn't support the $name mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}." + } + + @transient private lazy val nullResultRow: Any = dataType match { + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + for (i <- 0 until st.length) { + resultRow.setNullAt(i) + } + resultRow + + case _ => + null + } + + private def handleException(e: Throwable): Any = { + parseMode match { + case PermissiveMode => + nullResultRow + case FailFastMode => + throw new SparkException("Malformed records are detected in record parsing. " + + s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) + case _ => + throw new AnalysisException(unacceptableModeMessage(parseMode.name)) + } + } + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + try { + result = DynamicMessage.parseFrom(descriptor, new ByteArrayInputStream(binary)) + val unknownFields = result.asInstanceOf[DynamicMessage].getUnknownFields + if (!unknownFields.asMap().isEmpty) { + return handleException(new Throwable("UnknownFields encountered")) + } + val deserialized = deserializer.deserialize(result) + assert(deserialized.isDefined, + "Proto deserializer cannot return an empty result because filters are not pushed down") + deserialized.get + } catch { + // There could be multiple possible exceptions here, e.g. java.io.IOException, + // ProtoRuntimeException, ArrayIndexOutOfBoundsException, etc. + // To make it simple, catch all the exceptions here. + case NonFatal(e) => + handleException(e) + } + } + + override def prettyName: String = "from_proto" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + nullSafeCodeGen(ctx, ev, eval => { + val result = ctx.freshName("result") + val dt = CodeGenerator.boxedType(dataType) + s""" + $dt $result = ($dt) $expr.nullSafeEval($eval); + if ($result == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = $result; + } + """ + }) + } + + override protected def withNewChildInternal(newChild: Expression): ProtoDataToCatalyst = + copy(child = newChild) +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala new file mode 100644 index 0000000000000..bb243b7795802 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto + +import com.google.protobuf.{ByteString, DynamicMessage} +import com.google.protobuf.Descriptors._ +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.proto.utils.ProtoUtils +import org.apache.spark.sql.proto.utils.ProtoUtils.ProtoMatchedField +import org.apache.spark.sql.proto.utils.ProtoUtils.toFieldStr +import org.apache.spark.sql.proto.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, + DateType, Decimal, DoubleType, FloatType, IntegerType, LongType, NullType, + ShortType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +private[sql] class ProtoDeserializer( + rootProtoType: Descriptor, + rootCatalystType: DataType, + positionalFieldMatch: Boolean, + datetimeRebaseSpec: RebaseSpec, + filters: StructFilters) { + + def this( + rootProtoType: Descriptor, + rootCatalystType: DataType, + datetimeRebaseMode: String) = { + this( + rootProtoType, + rootCatalystType, + positionalFieldMatch = false, + RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), + new NoopFilters) + } + + private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInRead( + datetimeRebaseSpec.mode, "Proto") + + private val converter: Any => Option[Any] = try { + rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (_: Any) => Some(InternalRow.empty) + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val applyFilters = filters.skipRow(resultRow, _) + val writer = getRecordWriter(rootProtoType, st, Nil, Nil, applyFilters) + (data: Any) => { + val record = data.asInstanceOf[DynamicMessage] + val skipRow = writer(fieldUpdater, record) + if (skipRow) None else Some(resultRow) + } + } + } catch { + case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException( + s"Cannot convert Proto type ${rootProtoType.getName} " + + s"to SQL type ${rootCatalystType.sql}.", ise) + } + + def deserialize(data: Any): Option[Any] = converter(data) + + private def newArrayWriter( + protoField: FieldDescriptor, + protoPath: Seq[String], + catalystPath: Seq[String], + elementType: DataType, + containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = { + + + val protoElementPath = protoPath :+ "element" + val elementWriter = newWriter(protoField, elementType, + protoElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iterator = collection.iterator() + while (iterator.hasNext) { + val element = iterator.next() + if (element == null) { + if (!containsNull) { + throw new RuntimeException( + s"Array value at path ${toFieldStr(protoElementPath)} is not allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + } + + /** + * Creates a writer to write proto values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + protoType: FieldDescriptor, + catalystType: DataType, + protoPath: Seq[String], + catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = { + val errorPrefix = s"Cannot convert Proto ${toFieldStr(protoPath)} to " + + s"SQL ${toFieldStr(catalystPath)} because " + val incompatibleMsg = errorPrefix + + s"schema is incompatible (protoType = ${protoType} ${protoType.toProto.getLabel} " + + s"${protoType.getJavaType} ${protoType.getType}, sqlType = ${catalystType.sql})" + + (protoType.getJavaType, catalystType) match { + + case (null, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of proto provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (BOOLEAN, ArrayType(BooleanType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, BooleanType, containsNull) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, ArrayType(IntegerType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, IntegerType, containsNull) + + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, ArrayType(LongType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, LongType, containsNull) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (FLOAT, ArrayType(FloatType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, FloatType, containsNull) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (DOUBLE, ArrayType(DoubleType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, DoubleType, containsNull) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + } + updater.set(ordinal, str) + + case (STRING, ArrayType(StringType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, StringType, containsNull) + + case (BYTE_STRING, BinaryType) => (updater, ordinal, value) => + val byte_array = value match { + case s: ByteString => s.toByteArray + case _ => throw new Exception("Invalid ByteString format") + } + updater.set(ordinal, byte_array) + + case (BYTE_STRING, ArrayType(BinaryType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, BinaryType, containsNull) + + case (MESSAGE, st: StructType) => + val writeRecord = getRecordWriter(protoType.getMessageType, st, protoPath, + catalystPath, applyFilters = _ => false) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) + updater.set(ordinal, row) + + case (MESSAGE, ArrayType(st: StructType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, st, containsNull) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (ENUM, ArrayType(StringType, containsNull)) => + newArrayWriter(protoType, protoPath, + catalystPath, StringType, containsNull) + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + + + private def getRecordWriter( + protoType: Descriptor, + catalystType: StructType, + protoPath: Seq[String], + catalystPath: Seq[String], + applyFilters: Int => Boolean): + (CatalystDataUpdater, DynamicMessage) => Boolean = { + + val protoSchemaHelper = new ProtoUtils.ProtoSchemaHelper( + protoType, catalystType, protoPath, catalystPath, positionalFieldMatch) + + protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + // no need to validateNoExtraProtoFields since extra Proto fields are ignored + + val (validFieldIndexes, fieldWriters) = protoSchemaHelper.matchedFields.map { + case ProtoMatchedField(catalystField, ordinal, protoField) => + val baseWriter = newWriter(protoField, catalystField.dataType, + protoPath :+ protoField.getName, catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + (protoField, fieldWriter) + }.toArray.unzip + + (fieldUpdater, record) => { + var i = 0 + var skipRow = false + while (i < validFieldIndexes.length && !skipRow) { + fieldWriters(i)(fieldUpdater, record.getField(validFieldIndexes(i))) + skipRow = applyFilters(i) + i += 1 + } + skipRow + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } + +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala new file mode 100644 index 0000000000000..75fed55f00d6f --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoSerializer.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto + +import scala.collection.JavaConverters._ + +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ +import com.google.protobuf.DynamicMessage + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.proto.utils.ProtoUtils +import org.apache.spark.sql.proto.utils.ProtoUtils.{toFieldStr, ProtoMatchedField} +import org.apache.spark.sql.proto.utils.SchemaConverters.{IncompatibleSchemaException, UnsupportedProtoValueException} +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in proto format. + */ +private[sql] class ProtoSerializer( + rootCatalystType: DataType, + rootProtoType: Descriptor, + nullable: Boolean, + positionalFieldMatch: Boolean) extends Logging { + + def this(rootCatalystType: DataType, rootProtoType: Descriptor, nullable: Boolean) = { + this(rootCatalystType, rootProtoType, nullable, positionalFieldMatch = false) + } + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val baseConverter = try { + rootCatalystType match { + case st: StructType => + newStructConverter(st, rootProtoType, Nil, Nil).asInstanceOf[Any => Any] + } + } catch { + case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException( + s"Cannot convert SQL type ${rootCatalystType.sql} to Proto type " + + s"${rootProtoType.getName}.", ise) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + + private def newConverter( + catalystType: DataType, + protoType: FieldDescriptor, + catalystPath: Seq[String], + protoPath: Seq[String]): Converter = { + val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " + + s"to Proto ${toFieldStr(protoPath)} because " + (catalystType, protoType.getJavaType) match { + case (NullType, _) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => { + getter.getInt(ordinal) + } + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (StringType, ENUM) => + val enumSymbols: Set[String] = protoType.getEnumType.getValues.asScala.map( + e => e.toString).toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException(errorPrefix + + s""""$data" cannot be written since it's not defined in enum """ + + enumSymbols.mkString("\"", "\", \"", "\"")) + } + protoType.getEnumType.findValueByName(data) + case (StringType, STRING) => + (getter, ordinal) => { + String.valueOf(getter.getUTF8String(ordinal)) + } + + case (BinaryType, BYTE_STRING) => + (getter, ordinal) => getter.getBinary(ordinal) + + case (DateType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (TimestampType, LONG) => protoType.getContainingType match { + // For backward compatibility, if the Proto type is Long and it is not logical type + // (the `null` case), output the timestamp value as with millisecond precision. + case null => (getter, ordinal) => + DateTimeUtils.microsToMillis(getter.getLong(ordinal)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"SQL type ${TimestampType.sql} cannot be converted to Proto logical type $other") + } + + case (TimestampNTZType, LONG) => protoType.getContainingType match { + // To keep consistent with TimestampType, if the Proto type is Long and it is not + // logical type (the `null` case), output the TimestampNTZ as long value + // in millisecond precision. + case null => (getter, ordinal) => + DateTimeUtils.microsToMillis(getter.getLong(ordinal)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"SQL type ${TimestampNTZType.sql} cannot be converted to Proto logical type $other") + } + + case (ArrayType(et, containsNull), _) => + val elementConverter = newConverter( + et, protoType, + catalystPath :+ "element", protoPath :+ "element") + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // proto writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (st: StructType, MESSAGE) => + val structConverter = newStructConverter( + st, protoType.getMessageType, catalystPath, protoPath) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case (MapType(kt, vt, valueContainsNull), _) if kt == StringType => + val valueConverter = newConverter( + vt, protoType, catalystPath :+ "value", protoPath :+ "value") + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val result = new java.util.HashMap[String, Any](len) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val key = keyArray.getUTF8String(i).toString + if (valueContainsNull && valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case (_: YearMonthIntervalType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (_: DayTimeIntervalType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + + case _ => + throw new IncompatibleSchemaException(errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, " + + s"protoType = ${protoType.getJavaType})") + } + } + + private def newStructConverter( + catalystStruct: StructType, + protoStruct: Descriptor, + catalystPath: Seq[String], + protoPath: Seq[String]): InternalRow => DynamicMessage = { + + val protoSchemaHelper = new ProtoUtils.ProtoSchemaHelper( + protoStruct, catalystStruct, protoPath, catalystPath, positionalFieldMatch) + + protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + protoSchemaHelper.validateNoExtraRequiredProtoFields() + + val (protoIndices, fieldConverters) = protoSchemaHelper.matchedFields.map { + case ProtoMatchedField(catalystField, _, protoField) => + val converter = newConverter(catalystField.dataType, + protoField, + catalystPath :+ catalystField.name, protoPath :+ protoField.getName) + (protoField, converter) + }.toArray.unzip + + val numFields = catalystStruct.length + row: InternalRow => + val result = DynamicMessage.newBuilder(protoStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + throw new UnsupportedProtoValueException( + s"Cannot set ${protoIndices(i).getName} a Null, Proto does not allow Null values") + } else { + result.setField(protoIndices(i), fieldConverters(i).apply(row, i)) + } + i += 1 + } + result.build() + } +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala new file mode 100644 index 0000000000000..21b6ed77949bd --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/functions.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column + +// scalastyle:off: object.name +object functions { +// scalastyle:on: object.name + + /** + * Converts a binary column of Proto format into its corresponding catalyst value. + * The specified schema must match actual schema of the read data, otherwise the behavior + * is undefined: it may fail or return arbitrary result. + * To deserialize the data with a compatible and evolved schema, the expected Proto schema can be + * set via the option protoSchema. + * + * @param data the binary column. + * @param descFilePath the proto schema in Message GeneratedMessageV3 format. + * @param messageName the proto MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def from_proto(data: Column, descFilePath: String, messageName: String, + options: java.util.Map[String, String]): Column = { + new Column(ProtoDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap)) + } + + /** + * Converts a binary column of Proto format into its corresponding catalyst value. + * The specified schema must match actual schema of the read data, otherwise the behavior + * is undefined: it may fail or return arbitrary result. + * To deserialize the data with a compatible and evolved schema, the expected Proto schema can be + * set via the option protoSchema. + * + * @param data the binary column. + * @param descFilePath the proto schema in Message GeneratedMessageV3 format. + * @param messageName the proto MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def from_proto(data: Column, descFilePath: String, messageName: String): Column = { + new Column(ProtoDataToCatalyst(data.expr, descFilePath, messageName, Map.empty)) + } + + /** + * Converts a column into binary of proto format. + * + * @param data the data column. + * @since 3.4.0 + */ + @Experimental + def to_proto(data: Column): Column = { + new Column(CatalystDataToProto(data.expr, None, None)) + } + + /** + * Converts a column into binary of proto format. + * + * @param data the data column. + * @param descFilePath the proto schema in Message GeneratedMessageV3 format. + * @param messageName the proto MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def to_proto(data: Column, descFilePath: String, messageName: String): Column = { + new Column(CatalystDataToProto(data.expr, Some(descFilePath), Some(messageName))) + } +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala new file mode 100644 index 0000000000000..de88d564062a9 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/package.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +package object proto { + protected[proto] object ScalaReflectionLock +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala new file mode 100644 index 0000000000000..962007ecaff80 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/DynamicSchema.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto.utils + +import java.util + +import scala.collection.JavaConverters._ +import scala.util.control.Breaks.{break, breakable} + +import com.google.protobuf.DescriptorProtos.{FileDescriptorProto, FileDescriptorSet} +import com.google.protobuf.Descriptors.{Descriptor, FileDescriptor} + +class DynamicSchema { + var fileDescSet: FileDescriptorSet = null + + def this(fileDescSet: FileDescriptorSet) = { + this + this.fileDescSet = fileDescSet; + val fileDescMap: util.Map[String, FileDescriptor] = init(fileDescSet) + val msgDupes: util.Set[String] = new util.HashSet[String]() + for (fileDesc: FileDescriptor <- fileDescMap.values().asScala) { + for (msgType: Descriptor <- fileDesc.getMessageTypes().asScala) { + addMessageType(msgType, null, msgDupes) + } + } + + for (msgName: String <- msgDupes.asScala) { + messageMsgDescriptorMapShort.remove(msgName) + } + } + + val messageDescriptorMapFull: util.Map[String, Descriptor] = + new util.HashMap[String, Descriptor]() + val messageMsgDescriptorMapShort: util.Map[String, Descriptor] = + new util.HashMap[String, Descriptor]() + + def newBuilder(): Builder = { + new Builder() + } + + def addMessageType(msgType: Descriptor, scope: String, msgDupes: util.Set[String]) : Unit = { + val msgTypeNameFull: String = msgType.getFullName() + val msgTypeNameShort: String = { + if (scope == null) { + msgType.getName() + } else { + scope + "." + msgType.getName () + } + } + + if (messageDescriptorMapFull.containsKey(msgTypeNameFull)) { + throw new IllegalArgumentException("duplicate name: " + msgTypeNameFull) + } + if (messageMsgDescriptorMapShort.containsKey(msgTypeNameShort)) { + msgDupes.add(msgTypeNameShort) + } + messageDescriptorMapFull.put(msgTypeNameFull, msgType); + messageMsgDescriptorMapShort.put(msgTypeNameShort, msgType); + + + for (nestedType <- msgType.getNestedTypes.asScala) { + addMessageType(nestedType, msgTypeNameShort, msgDupes) + } + } + + def init(fileDescSet: FileDescriptorSet) : util.Map[String, FileDescriptor] = { + // check for dupes + val allFdProtoNames: util.Set[String] = new util.HashSet[String]() + for (fdProto: FileDescriptorProto <- fileDescSet.getFileList().asScala) { + if (allFdProtoNames.contains(fdProto.getName())) { + throw new IllegalArgumentException("duplicate name: " + fdProto.getName()) + } + allFdProtoNames.add(fdProto.getName()) + } + + // build FileDescriptors, resolve dependencies (imports) if any + val resolvedFileDescMap: util.Map[String, FileDescriptor] = + new util.HashMap[String, FileDescriptor]() + while (resolvedFileDescMap.size() < fileDescSet.getFileCount()) { + for (fdProto : FileDescriptorProto <- fileDescSet.getFileList().asScala) { + breakable { + if (resolvedFileDescMap.containsKey(fdProto.getName())) { + break + } + + val dependencyList: util.List[String] = fdProto.getDependencyList(); + val resolvedFdList: util.List[FileDescriptor] = + new util.ArrayList[FileDescriptor]() + for (depName: String <- dependencyList.asScala) { + if (!allFdProtoNames.contains(depName)) { + throw new IllegalArgumentException("cannot resolve import " + depName + " in " + + fdProto.getName()) + } + val fd: FileDescriptor = resolvedFileDescMap.get(depName) + if (fd != null) resolvedFdList.add(fd) + } + + if (resolvedFdList.size() == dependencyList.size()) { + val fds = new Array[FileDescriptor](resolvedFdList.size) + val fd: FileDescriptor = FileDescriptor.buildFrom(fdProto, resolvedFdList.toArray(fds)) + resolvedFileDescMap.put(fdProto.getName(), fd) + } + } + } + } + + resolvedFileDescMap + } + + override def toString() : String = { + val msgTypes: util.Set[String] = getMessageTypes() + ("types: " + msgTypes + "\n" + fileDescSet) + } + + def toByteArray() : Array[Byte] = { + fileDescSet.toByteArray() + } + + def getMessageTypes(): util.Set[String] = { + new util.TreeSet[String](messageDescriptorMapFull.keySet()) + } + + + def getMessageDescriptor(messageTypeName: String) : Descriptor = { + var messageType: Descriptor = messageMsgDescriptorMapShort.get(messageTypeName) + if (messageType == null) { + messageType = messageDescriptorMapFull.get(messageTypeName) + } + messageType + } + + class Builder { + val messageFileDescProtoBuilder: FileDescriptorProto.Builder = FileDescriptorProto.newBuilder() + val messageFileDescSetBuilder: FileDescriptorSet.Builder = FileDescriptorSet.newBuilder() + + def build(): DynamicSchema = { + val fileDescSetBuilder: FileDescriptorSet.Builder = FileDescriptorSet.newBuilder() + fileDescSetBuilder.addFile(messageFileDescProtoBuilder.build()) + fileDescSetBuilder.mergeFrom(messageFileDescSetBuilder.build()) + new DynamicSchema(fileDescSetBuilder.build()) + } + + def setName(name: String) : Builder = { + messageFileDescProtoBuilder.setName(name) + this + } + + def setPackage(name: String): Builder = { + messageFileDescProtoBuilder.setPackage(name) + this + } + + def addMessageDefinition(msgDef: MessageDefinition) : Builder = { + messageFileDescProtoBuilder.addMessageType(msgDef.getMessageType()) + this + } + } +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala new file mode 100644 index 0000000000000..7692b604a878e --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/MessageDefinition.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto.utils + +import com.google.protobuf.DescriptorProtos.{DescriptorProto, FieldDescriptorProto} + +class MessageDefinition(val messageType: DescriptorProto = null) { + + def newBuilder(msgTypeName: String): Builder = { + new Builder(msgTypeName) + } + + def getMessageType(): DescriptorProto = { + messageType + } + + class Builder(msgTypeName: String) { + var messageTypeBuilder: DescriptorProto.Builder = DescriptorProto.newBuilder() + messageTypeBuilder.setName(msgTypeName) + + def addField(label: String, typeName: String, name: String, num: Int): Builder = { + addField(label, typeName, name, num, null) + } + + def addField(label: String, typeName: String, name: String, num: Int, + defaultVal: String): Builder = { + val protoLabel: FieldDescriptorProto.Label = protoLabelMap.getOrElse(label, null) + if (protoLabel == null) { + throw new IllegalArgumentException("Illegal label: " + label) + } + addField(protoLabel, typeName, name, num, defaultVal) + this + } + + def addMessageDefinition(msgDef: MessageDefinition): Builder = { + messageTypeBuilder.addNestedType(msgDef.getMessageType()) + this + } + + def build(): MessageDefinition = { + new MessageDefinition(messageTypeBuilder.build()) + } + + def addField(label: FieldDescriptorProto.Label, typeName: String, name: String, num: Int, + defaultVal: String): DescriptorProto.Builder = { + val fieldBuilder: FieldDescriptorProto.Builder = FieldDescriptorProto.newBuilder() + fieldBuilder.setLabel(label) + val primType: FieldDescriptorProto.Type = protoTypeMap.getOrElse(typeName, null) + if (primType != null) { + fieldBuilder.setType(primType) + } else { + fieldBuilder.setTypeName(typeName) + } + + fieldBuilder.setName(name).setNumber(num); + if (defaultVal != null) fieldBuilder.setDefaultValue(defaultVal); + messageTypeBuilder.addField(fieldBuilder.build()); + } + } + + + private val protoLabelMap: Map[String, FieldDescriptorProto.Label] = + Map("optional" -> FieldDescriptorProto.Label.LABEL_OPTIONAL, + "required" -> FieldDescriptorProto.Label.LABEL_REQUIRED, + "repeated" -> FieldDescriptorProto.Label.LABEL_REPEATED + ) + + private val protoTypeMap: Map[String, FieldDescriptorProto.Type] = + Map("double" -> FieldDescriptorProto.Type.TYPE_DOUBLE, + "float" -> FieldDescriptorProto.Type.TYPE_FLOAT, + "int32" -> FieldDescriptorProto.Type.TYPE_INT32, + "int64" -> FieldDescriptorProto.Type.TYPE_INT64, + "uint32" -> FieldDescriptorProto.Type.TYPE_UINT32, + "uint64" -> FieldDescriptorProto.Type.TYPE_UINT64, + "sint32" -> FieldDescriptorProto.Type.TYPE_SINT32, + "sint64" -> FieldDescriptorProto.Type.TYPE_SINT64, + "fixed32" -> FieldDescriptorProto.Type.TYPE_FIXED32, + "fixed64" -> FieldDescriptorProto.Type.TYPE_FIXED64, + "sfixed32" -> FieldDescriptorProto.Type.TYPE_SFIXED32, + "sfixed64" -> FieldDescriptorProto.Type.TYPE_SFIXED64, + "bool" -> FieldDescriptorProto.Type.TYPE_BOOL, + "string" -> FieldDescriptorProto.Type.TYPE_STRING, + "bytes" -> FieldDescriptorProto.Type.TYPE_BYTES) +} + diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala new file mode 100644 index 0000000000000..a2d32a0a07fa2 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoOptions.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto.utils + +import java.io.FileInputStream +import java.net.URI + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors.Descriptor +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} +import org.apache.spark.sql.internal.SQLConf + +/** + * Options for Proto Reader and Writer stored in case insensitive manner. + */ +private[sql] class ProtoOptions( + @transient val parameters: CaseInsensitiveMap[String], + @transient val conf: Configuration) + extends FileSourceOptions(parameters) with Logging { + + def this(parameters: Map[String, String], conf: Configuration) = { + this(CaseInsensitiveMap(parameters), conf) + } + + /** + * Optional schema provided by a user in schema file or in JSON format. + * + * When reading Proto, this option can be set to an evolved schema, which is compatible but + * different with the actual Proto schema. The deserialization schema will be consistent with + * the evolved schema. For example, if we set an evolved schema containing one additional + * column with a default value, the reading result in Spark will contain the new column too. + * + * When writing Proto, this option can be set if the expected output Proto schema doesn't match + * the schema converted by Spark. For example, the expected schema of one column is of "enum" + * type, instead of "string" type in the default converted schema. + */ + val schema: Option[Descriptor] = { + parameters.get("protoSchema").map(a => DescriptorProtos.DescriptorProto.parseFrom( + new FileInputStream(a)).getDescriptorForType).orElse({ + val protoUrlSchema = parameters.get("protoSchemaUrl").map(url => { + log.debug("loading proto schema from url: " + url) + val fs = FileSystem.get(new URI(url), conf) + val in = fs.open(new Path(url)) + try { + DescriptorProtos.DescriptorProto.parseFrom(in).getDescriptorForType + } finally { + in.close() + } + }) + protoUrlSchema + }) + } + + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) + + /** + * The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads. + */ + val datetimeRebaseModeInRead: String = parameters + .get(ProtoOptions.DATETIME_REBASE_MODE) + .getOrElse(SQLConf.get.getConf(SQLConf.PROTO_REBASE_MODE_IN_READ)) +} + +private[sql] object ProtoOptions { + def apply(parameters: Map[String, String]): ProtoOptions = { + val hadoopConf = SparkSession + .getActiveSession + .map(_.sessionState.newHadoopConf()) + .getOrElse(new Configuration()) + new ProtoOptions(CaseInsensitiveMap(parameters), hadoopConf) + } + + val ignoreExtensionKey = "ignoreExtension" + + // The option controls rebasing of the DATE and TIMESTAMP values between + // Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Proto + // datasource similarly to the SQL config `spark.sql.proto.datetimeRebaseModeInRead`, + // and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`. + val DATETIME_REBASE_MODE = "datetimeRebaseMode" +} + diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala new file mode 100644 index 0000000000000..29d3a9b6a33a0 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/ProtoUtils.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.proto.utils + +import java.io.{BufferedInputStream, FileInputStream, FileNotFoundException, IOException} +import java.util.Locale + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException} +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.proto.utils.ProtoOptions.ignoreExtensionKey +import org.apache.spark.sql.proto.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +private[sql] object ProtoUtils extends Logging { + + def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sessionState.newHadoopConfWithOptions(options) + val parsedOptions = new ProtoOptions(options, conf) + + if (parsedOptions.parameters.contains(ignoreExtensionKey)) { + logWarning(s"Option $ignoreExtensionKey is deprecated. Please use the " + + "general data source option pathGlobFilter for filtering file names.") + } + // User can specify an optional proto json schema. + val protoSchema = parsedOptions.schema + .getOrElse { + inferProtoSchemaFromFiles(files, + new FileSourceOptions(CaseInsensitiveMap(options)).ignoreCorruptFiles) + } + + SchemaConverters.toSqlType(protoSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Proto schema cannot be converted to a Spark SQL StructType: + | + |${protoSchema.toString()} + |""".stripMargin) + } + } + + private def inferProtoSchemaFromFiles( + files: Seq[FileStatus], + ignoreCorruptFiles: Boolean): Descriptor = { + // Schema evolution is not supported yet. Here we only pick first random readable sample file to + // figure out the schema of the whole dataset. + val protoReader = files.iterator.map { f => + val path = f.getPath + if (!path.getName.endsWith(".pb")) { + None + } else { + Utils.tryWithResource { + new FileInputStream("saved_model.pb") + } { in => + try { + Some(DescriptorProtos.DescriptorProto.parseFrom(in).getDescriptorForType) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read file: $path", e) + } + } + } + } + }.collectFirst { + case Some(reader) => reader + } + + protoReader match { + case Some(reader) => + reader.getContainingType + case None => + throw new FileNotFoundException( + "No Proto files found. If files don't have .proto extension, set ignoreExtension to true") + } + } + + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _: NullType => true + + case _ => false + } + + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Proto field. */ + private[sql] case class ProtoMatchedField( + catalystField: StructField, + catalystPosition: Int, + protoField: FieldDescriptor) + + /** + * Helper class to perform field lookup/matching on Proto schemas. + * + * This will match `protoSchema` against `catalystSchema`, attempting to find a matching field in + * the Proto schema for each field in the Catalyst schema and vice-versa, respecting settings for + * case sensitivity. The match results can be accessed using the getter methods. + * + * @param protoSchema The schema in which to search for fields. Must be of type RECORD. + * @param catalystSchema The Catalyst schema to use for matching. + * @param protoPath The seq of parent field names leading to `protoSchema`. + * @param catalystPath The seq of parent field names leading to `catalystSchema`. + * @param positionalFieldMatch If true, perform field matching in a positional fashion + * (structural comparison between schemas, ignoring names); + * otherwise, perform field matching using field names. + */ + class ProtoSchemaHelper( + protoSchema: Descriptor, + catalystSchema: StructType, + protoPath: Seq[String], + catalystPath: Seq[String], + positionalFieldMatch: Boolean) { + if (protoSchema.getName == null) { + throw new IncompatibleSchemaException( + s"Attempting to treat ${protoSchema.getName} as a RECORD, " + + s"but it was: ${protoSchema.getContainingType}") + } + + private[this] val protoFieldArray = protoSchema.getFields.asScala.toArray + private[this] val fieldMap = protoSchema.getFields.asScala + .groupBy(_.getName.toLowerCase(Locale.ROOT)) + .mapValues(_.toSeq) // toSeq needed for scala 2.13 + + /** The fields which have matching equivalents in both Proto and Catalyst schemas. */ + val matchedFields: Seq[ProtoMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getProtoField(sqlField.name, sqlPos).map(ProtoMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Proto field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false, + * consider nullable Catalyst fields to be eligible to be an extra field; otherwise, + * ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) => + if (getProtoField(sqlField.name, sqlPos).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + if (positionalFieldMatch) { + throw new IncompatibleSchemaException("Cannot find field at position " + + s"$sqlPos of ${toFieldStr(protoPath)} from Proto schema (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Proto schema") + } + } + } + + /** + * Validate that there are no Proto fields which don't have a matching Catalyst field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. Only required (non-nullable) + * fields are checked; nullable fields are ignored. + */ + def validateNoExtraRequiredProtoFields(): Unit = { + val extraFields = protoFieldArray.toSet -- matchedFields.map(_.protoField) + extraFields.filterNot(isNullable).foreach { extraField => + if (positionalFieldMatch) { + throw new IncompatibleSchemaException(s"Found field '${extraField.getName()}'" + + s" at position ${extraField.getIndex} of ${toFieldStr(protoPath)} from Proto schema " + + s"but there is no match in the SQL schema at ${toFieldStr(catalystPath)} " + + s"(using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(protoPath :+ extraField.getName())} in Proto schema " + + "but there is no match in the SQL schema") + } + } + } + + /** + * Extract a single field from the contained proto schema which has the desired field name, + * performing the matching with proper case sensitivity according to SQLConf.resolver. + * + * @param name The name of the field to search for. + * @return `Some(match)` if a matching Proto field is found, otherwise `None`. + */ + private[proto] def getFieldByName(name: String): Option[FieldDescriptor] = { + + // get candidates, ignoring case of field name + val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty) + + // search candidates, taking into account case sensitivity settings + candidates.filter(f => SQLConf.get.resolver(f.getName(), name)) match { + case Seq(protoField) => Some(protoField) + case Seq() => None + case matches => throw new IncompatibleSchemaException(s"Searching for '$name' in Proto " + + s"schema at ${toFieldStr(protoPath)} gave ${matches.size} matches. Candidates: " + + matches.map(_.getName()).mkString("[", ", ", "]") + ) + } + } + + /** Get the Proto field corresponding to the provided Catalyst field name/position, if any. */ + def getProtoField(fieldName: String, catalystPos: Int): Option[FieldDescriptor] = { + if (positionalFieldMatch) { + protoFieldArray.lift(catalystPos) + } else { + getFieldByName(fieldName) + } + } + } + + def buildDescriptor(protoFilePath: String, messageName: String): Descriptor = { + val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(protoFilePath) + var result: Descriptors.Descriptor = null; + + for (descriptor <- fileDescriptor.getMessageTypes.asScala) { + if (descriptor.getName().equals(messageName)) { + result = descriptor + } + } + + if (null == result) { + throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor"); + } + result + } + + def parseFileDescriptor(protoFilePath: String): Descriptors.FileDescriptor = { + var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null + try { + val dscFile = new BufferedInputStream(new FileInputStream(protoFilePath)) + fileDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(dscFile) + } catch { + case ex: InvalidProtocolBufferException => + throw new RuntimeException("Error parsing descriptor byte[] into Descriptor object", ex) + case ex: IOException => + throw new RuntimeException("Error reading proto file at path: " + protoFilePath, ex) + } + + val descriptorProto: DescriptorProtos.FileDescriptorProto = fileDescriptorSet.getFile(0) + try { + val fileDescriptor: Descriptors.FileDescriptor = Descriptors.FileDescriptor.buildFrom( + descriptorProto, new Array[Descriptors.FileDescriptor](0)) + if (fileDescriptor.getMessageTypes().isEmpty()) { + throw new RuntimeException("No MessageTypes returned, " + fileDescriptor.getName()); + } + fileDescriptor + } catch { + case e: Descriptors.DescriptorValidationException => + throw new RuntimeException("Error constructing FileDescriptor", e) + } + } + + /** + * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable + * string representing the field, like "field 'foo.bar'". If `names` is empty, the string + * "top-level record" is returned. + */ + private[proto] def toFieldStr(names: Seq[String]): String = names match { + case Seq() => "top-level record" + case n => s"field '${n.mkString(".")}'" + } + + /** Return true iff `protoField` is nullable, i.e. `UNION` type and has `NULL` as an option. */ + private[proto] def isNullable(protoField: FieldDescriptor): Boolean = + !protoField.isOptional + +} diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala new file mode 100644 index 0000000000000..f429344af9c69 --- /dev/null +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/utils/SchemaConverters.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto.utils + +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.proto.ScalaReflectionLock +import org.apache.spark.sql.types._ + +@DeveloperApi +object SchemaConverters { + /** + * Internal wrapper for SQL data type and nullability. + * + * @since 3.4.0 + */ + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * Converts an Proto schema to a corresponding Spark SQL schema. + * + * @since 3.4.0 + */ + def toSqlType(protoSchema: Descriptor): SchemaType = { + toSqlTypeHelper(protoSchema) + } + + def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized { + SchemaType(StructType(descriptor.getFields.asScala.flatMap(structFieldFor).toSeq), + nullable = true) + } + + def structFieldFor(fd: FieldDescriptor): Option[StructField] = { + import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + val dataType = fd.getJavaType match { + case INT => Some(IntegerType) + case LONG => Some(LongType) + case FLOAT => Some(FloatType) + case DOUBLE => Some(DoubleType) + case BOOLEAN => Some(BooleanType) + case STRING => Some(StringType) + case BYTE_STRING => Some(BinaryType) + case ENUM => Some(StringType) + case MESSAGE => + Option(fd.getMessageType.getFields.asScala.flatMap(structFieldFor).toSeq) + .filter(_.nonEmpty) + .map(StructType.apply) + + } + dataType.map(dt => StructField( + fd.getName, + if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt, + nullable = !fd.isRequired && !fd.isRepeated + )) + } + + /** + * Converts a Spark SQL schema to a corresponding Proto Descriptor + * + * @since 3.4.0 + */ + def toProtoType(catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + nameSpace: String = "", + indexNum: Int = 0): Descriptor = { + val schemaBuilder: DynamicSchema#Builder = new DynamicSchema().newBuilder() + schemaBuilder.setName("DynamicSchema.proto") + toMessageDefinition(catalystType, recordName, nameSpace, schemaBuilder: DynamicSchema#Builder) + schemaBuilder.build().getMessageDescriptor(recordName) + } + + def toMessageDefinition(catalystType: DataType, recordName: String, + nameSpace: String, schemaBuilder: DynamicSchema#Builder): Unit = { + catalystType match { + case st: StructType => + val queue = mutable.Queue[ProtoMessage]() + val list = new util.ArrayList[ProtoField]() + st.foreach { f => + list.add(ProtoField(f.name, f.dataType)) + } + queue += ProtoMessage(recordName, list) + while (!queue.isEmpty) { + val protoMessage = queue.dequeue() + val messageDefinition: MessageDefinition#Builder = + new MessageDefinition().newBuilder(protoMessage.messageName) + var index = 0 + protoMessage.fieldList.forEach { + protoField => + protoField.catalystType match { + case ArrayType(at, containsNull) => + index = index + 1 + at match { + case st: StructType => + messageDefinition.addField("repeated", protoField.name, + protoField.name, index) + val list = new util.ArrayList[ProtoField]() + st.foreach { f => + list.add(ProtoField(f.name, f.dataType)) + } + queue += ProtoMessage(protoField.name, list) + case _ => + convertBasicTypes(protoField.catalystType, messageDefinition, "repeated", + index, protoField) + } + case st: StructType => + index = index + 1 + messageDefinition.addField("optional", protoField.name, protoField.name, index) + val list = new util.ArrayList[ProtoField]() + st.foreach { f => + list.add(ProtoField(f.name, f.dataType)) + } + queue += ProtoMessage(protoField.name, list) + case _ => + index = index + 1 + convertBasicTypes(protoField.catalystType, messageDefinition, "optional", + index, protoField) + } + } + schemaBuilder.addMessageDefinition(messageDefinition.build()) + } + } + } + + def convertBasicTypes(catalystType: DataType, messageDefinition: MessageDefinition#Builder, + label: String, index: Int, protoField: ProtoField): Unit = { + if (sparkToProtoTypeMap.contains(catalystType)) { + messageDefinition.addField(label, sparkToProtoTypeMap.get(catalystType).orNull, + protoField.name, index) + } else { + throw new IncompatibleSchemaException(s"Cannot convert SQL type ${catalystType.sql} to " + + s"Proto type, try passing proto Descriptor file path to_proto function") + } + } + + private val sparkToProtoTypeMap = Map[DataType, String](ByteType -> "int32", ShortType -> "int32", + IntegerType -> "int32", DateType -> "int32", LongType -> "int64", BinaryType -> "bytes", + DoubleType -> "double", FloatType -> "float", TimestampType -> "int64", + TimestampNTZType -> "int64", StringType -> "string", BooleanType -> "bool") + + case class ProtoMessage(messageName: String, fieldList: util.ArrayList[ProtoField]) + + case class ProtoField(name: String, catalystType: DataType) + + private[proto] class IncompatibleSchemaException( + msg: String, + ex: Throwable = null) extends Exception(msg, ex) + + private[proto] class UnsupportedProtoTypeException(msg: String) extends Exception(msg) + + private[proto] class UnsupportedProtoValueException(msg: String) extends Exception(msg) +} diff --git a/connector/proto/src/test/resources/protobuf/catalyst_types.desc b/connector/proto/src/test/resources/protobuf/catalyst_types.desc new file mode 100644 index 0000000000000..7a85edcd472ae --- /dev/null +++ b/connector/proto/src/test/resources/protobuf/catalyst_types.desc @@ -0,0 +1,37 @@ + +” +@connector/proto/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.proto") + +BooleanMsg + bool_type (RboolType"+ + +IntegerMsg + +int32_type (R int32Type", + DoubleMsg + double_type (R +doubleType") +FloatMsg + +float_type (R floatType") +BytesMsg + +bytes_type ( R bytesType", + StringMsg + string_type ( R +stringType". +Person +name ( Rname +age (Rage"n +Bad +col_0 ( Rcol0 +col_1 (Rcol1 +col_2 ( Rcol2 +col_3 (Rcol3 +col_4 (Rcol4"q +Actual +col_0 ( Rcol0 +col_1 (Rcol1 +col_2 (Rcol2 +col_3 (Rcol3 +col_4 (Rcol4BB CatalystTypesbproto3 \ No newline at end of file diff --git a/connector/proto/src/test/resources/protobuf/catalyst_types.proto b/connector/proto/src/test/resources/protobuf/catalyst_types.proto new file mode 100644 index 0000000000000..b0c6aa647535d --- /dev/null +++ b/connector/proto/src/test/resources/protobuf/catalyst_types.proto @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/catalyst_types.proto +// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/catalyst_types.proto + +syntax = "proto3"; + +package org.apache.spark.sql.proto; +option java_outer_classname = "CatalystTypes"; + +message BooleanMsg { + bool bool_type = 1; +} +message IntegerMsg { + int32 int32_type = 1; +} +message DoubleMsg { + double double_type = 1; +} +message FloatMsg { + float float_type = 1; +} +message BytesMsg { + bytes bytes_type = 1; +} +message StringMsg { + string string_type = 1; +} + +message Person { + string name = 1; + int32 age = 2; +} + +message Bad { + bytes col_0 = 1; + double col_1 = 2; + string col_2 = 3; + float col_3 = 4; + int64 col_4 = 5; +} + +message Actual { + string col_0 = 1; + int32 col_1 = 2; + float col_2 = 3; + bool col_3 = 4; + double col_4 = 5; +} \ No newline at end of file diff --git a/connector/proto/src/test/resources/protobuf/proto_functions_suite.desc b/connector/proto/src/test/resources/protobuf/proto_functions_suite.desc new file mode 100644 index 0000000000000000000000000000000000000000..5fe1dacc6e568646d4b5a3bc3a3af1108fe26ed4 GIT binary patch literal 5060 zcmd5=+iu%N5UoUAhonR?PGJ}*nyM{opnwc0h7lC4lPFeF7qC-MPJlisltep1BvB

A!UL>Y+r-k#K~S1#Ni?u2C-NeEY8J`{`q1yOvd3X?#GL9614bp?oJq8wEVe0d>gdlxgWi2 z#lKF(TJ!fjeA*jN=ab+hh-3dE_{qQY&#vY{?A9PZ9&P6ysQe%AO^Sg6C_f#c31^O5L-S zRR@Nj1PP46Pmdf9b59>%6?c#FV7V3Zz}uC<{@Ja#_Ka z)5S*6`gExwVqLn_aeZC7IJllpmwybnpGp@mm8c?l;Y4aEg`}X93-L&KYe-jr~jDAA7wsC&xh2((68Z zd3xmJo4@73>iF?^C_Cgw=?0>6>Vx@zHpVwlSGlcXaRT;$#il*Bv4a zrf2cTC0+B+EpSq2eB#ePxQx!R&4~+Y>d!}h;%~D$GTV>WIkT+R>TwFxWenX}5?x`_ zp*D92wkSj;!%ER?A8sUb%t|?>VsTPwa&a^AdQd8?oHyOfE6OQ1BWQ_^yknFOYYLm{ z^fGh=tF$sX2m2`B#F_?Z&)=4$B73 zsY;02T$T-NY#|%6xX^~#AUINit4U&MbagfmEt_lIB^#EuC9^?WSpenLr~+8KLjaHX z2%OrMEojxQFza~M>#VDDnaz3-FQ$}=#YvUO8?g8|F_rDq;Tk_|R74$|yHgS-qf`pG z4KXAriN+na(P7hkO2gu$(zNpg+|{Q3p|{F+`A4~($Iv)l?$?x{m@(ZIJu%&J0(4=u z6;lQ~gTY>6SZ%VDp^?GRpbSsp9zXDO_oNi%42m+P=)hK1Q=}aA430YG5V6!1!`lNT zFfs@XN+4r-GeaSRp|HfT+M=irZU%!(8CWb=dk>{3Wl)qTg^uOT8*wr?9Lj;ASZx*5 zi>%j<9A&`MSm^?{5#%YsAvnt=PM@!2Y@tG*M+a8wml9;{z#;_AXXpT%K-t8v#>hRW zYMX09`Fi57LQ#$S03|)+|CNVEY7%5@x4BW3Zn&@Tg99Ie1AGdH&K$nR9IWe-+BxZP zOB?V*h1!h!3z+ovh`kfk+K<-pu55w68o_`5X7m+UFO#=H)P3)>FG%d|He~R*!AJpjpfZ*wiPitzLq)>AFWvGcLDfAehoi;GVv>x}Atm=393|Hegs)uI hY=65hnoC}Lbwwh1JNG;HQW5`|eWr^CCH~p%;lGR(1+V}B literal 0 HcmV?d00001 diff --git a/connector/proto/src/test/resources/protobuf/proto_functions_suite.proto b/connector/proto/src/test/resources/protobuf/proto_functions_suite.proto new file mode 100644 index 0000000000000..8ea2f3e05839b --- /dev/null +++ b/connector/proto/src/test/resources/protobuf/proto_functions_suite.proto @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// To compile and create test class: +// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/proto_functions_suite.proto +// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/proto_functions_suite.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/proto_functions_suite.proto + +syntax = "proto3"; + +package org.apache.spark.sql.proto; +option java_outer_classname = "SimpleMessageProtos"; + +message SimpleMessageJavaTypes { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + bytes bytes_value = 8; +} + +message SimpleMessage { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + uint32 uint32_value = 4; + sint32 sint32_value = 5; + fixed32 fixed32_value = 6; + sfixed32 sfixed32_value = 7; + int64 int64_value = 8; + uint64 uint64_value = 9; + sint64 sint64_value = 10; + fixed64 fixed64_value = 11; + sfixed64 sfixed64_value = 12; + double double_value = 13; + float float_value = 14; + bool bool_value = 15; + bytes bytes_value = 16; +} + +message SimpleMessageRepeated { + string key = 1; + string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + repeated string rstring_value = 3; + repeated int32 rint32_value = 4; + repeated bool rbool_value = 5; + repeated int64 rint64_value = 6; + repeated float rfloat_value = 7; + repeated double rdouble_value = 8; + repeated bytes rbytes_value = 9; + repeated NestedEnum rnested_enum = 10; +} + +message BasicMessage { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + bytes bytes_value = 8; +} + +message RepeatedMessage { + repeated BasicMessage basic_message = 1; +} + +message SimpleMessageMap { + string key = 1; + string value = 2; + map string_mapdata = 3; + map int32_mapdata = 4; + map uint32_mapdata = 5; + map sint32_mapdata = 6; + map float32_mapdata = 7; + map sfixed32_mapdata = 8; + map int64_mapdata = 9; + map uint64_mapdata = 10; + map sint64_mapdata = 11; + map fixed64_mapdata = 12; + map sfixed64_mapdata = 13; + map double_mapdata = 14; + map float_mapdata = 15; + map bool_mapdata = 16; + map bytes_mapdata = 17; +} + +message BasicEnumMessage { + enum BasicEnum { + NOTHING = 0; + FIRST = 1; + SECOND = 2; + } +} + +message SimpleMessageEnum { + string key = 1; + string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + BasicEnumMessage.BasicEnum basic_enum = 3; + NestedEnum nested_enum = 4; +} + + +message OtherExample { + string other = 1; +} + +message IncludedExample { + string included = 1; + OtherExample other = 2; +} + +message MultipleExample { + IncludedExample included_example = 1; +} + diff --git a/connector/proto/src/test/resources/protobuf/proto_serde_suite.desc b/connector/proto/src/test/resources/protobuf/proto_serde_suite.desc new file mode 100644 index 0000000000000..4a406ab0a6050 --- /dev/null +++ b/connector/proto/src/test/resources/protobuf/proto_serde_suite.desc @@ -0,0 +1,27 @@ + +š +Cconnector/proto/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.proto"A + BasicMessage1 +foo ( 2.org.apache.spark.sql.proto.FooRfoo" +Foo +bar (Rbar"' +MissMatchTypeInRoot +foo (Rfoo"Q +FieldMissingInProto: +foo ( 2(.org.apache.spark.sql.proto.MissingFieldRfoo"& + MissingField +barFoo (RbarFoo"Y +MissMatchTypeInDeepNested< +top ( 2*.org.apache.spark.sql.proto.TypeMissNestedRtop"H +TypeMissNested6 +foo ( 2$.org.apache.spark.sql.proto.TypeMissRfoo" +TypeMiss +bar (Rbar"\ +FieldMissingInSQLRoot1 +foo ( 2.org.apache.spark.sql.proto.FooRfoo +boo (Rboo"L +FieldMissingInSQLNested1 +foo ( 2.org.apache.spark.sql.proto.BazRfoo") +Baz +bar (Rbar +baz (RbazBBSimpleMessageProtosbproto3 \ No newline at end of file diff --git a/connector/proto/src/test/resources/protobuf/proto_serde_suite.proto b/connector/proto/src/test/resources/protobuf/proto_serde_suite.proto new file mode 100644 index 0000000000000..d535f47b317d1 --- /dev/null +++ b/connector/proto/src/test/resources/protobuf/proto_serde_suite.proto @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// To compile and create test class: +// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/proto_serde_suite.proto +// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/proto_serde_suite.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/proto_serde_suite.proto + +syntax = "proto3"; + +package org.apache.spark.sql.proto; +option java_outer_classname = "SimpleMessageProtos"; + +/* Clean Message*/ +message BasicMessage { + Foo foo = 1; +} + +message Foo { + int32 bar = 1; +} + +/* Field Type missMatch in root Message*/ +message MissMatchTypeInRoot { + int64 foo = 1; +} + +/* Field bar missing from proto and Available in SQL*/ +message FieldMissingInProto { + MissingField foo = 1; +} + +message MissingField { + int64 barFoo = 1; +} + +/* Deep-nested field bar type missMatch Message*/ +message MissMatchTypeInDeepNested { + TypeMissNested top = 1; +} + +message TypeMissNested { + TypeMiss foo = 1; +} + +message TypeMiss { + int64 bar = 1; +} + +/* Field boo missing from SQL root, but available in Proto root*/ +message FieldMissingInSQLRoot { + Foo foo = 1; + int32 boo = 2; +} + +/* Field baz missing from SQL nested and available in Proto nested*/ +message FieldMissingInSQLNested { + Baz foo = 1; +} + +message Baz { + int32 bar = 1; + int32 baz = 2; +} \ No newline at end of file diff --git a/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..a6dca26f7e7ba --- /dev/null +++ b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoCatalystDataConversionSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.proto + +import com.google.protobuf.{ByteString, DynamicMessage, Message} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.proto.utils.{ProtoOptions, ProtoUtils, SchemaConverters} +import org.apache.spark.sql.sources.{EqualTo, Not} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ProtoCatalystDataConversionSuite extends SparkFunSuite + with SharedSparkSession + with ExpressionEvalHelper { + + private def checkResult(data: Literal, descFilePath: String, + messageName: String, expected: Any): Unit = { + + checkEvaluation( + ProtoDataToCatalyst(CatalystDataToProto(data, Some(descFilePath), + Some(messageName)), descFilePath, messageName, Map.empty), + prepareExpectedResult(expected)) + } + + protected def checkUnsupportedRead(data: Literal, descFilePath: String, + actualSchema: String, badSchema: String): Unit = { + + val binary = CatalystDataToProto(data, Some(descFilePath), Some(actualSchema)) + + intercept[Exception] { + ProtoDataToCatalyst(binary, descFilePath, badSchema, + Map("mode" -> "FAILFAST")).eval() + } + + val expected = { + val protoOptions = ProtoOptions(Map("mode" -> "PERMISSIVE")) + val descriptor = ProtoUtils.buildDescriptor(descFilePath, badSchema) + val expectedSchema = protoOptions.schema.getOrElse(descriptor) + SchemaConverters.toSqlType(expectedSchema).dataType match { + case st: StructType => Row.fromSeq((0 until st.length).map { + _ => null + }) + case _ => null + } + } + + checkEvaluation(ProtoDataToCatalyst(binary, descFilePath, badSchema, + Map("mode" -> "PERMISSIVE")), expected) + } + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark byte and short both map to proto int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + private val testingTypes = Seq( + StructType(StructField("bool_type", BooleanType, nullable = false) :: Nil), + StructType(StructField("int32_type", IntegerType, nullable = false) :: Nil), + StructType(StructField("double_type", DoubleType, nullable = false) :: Nil), + StructType(StructField("float_type", FloatType, nullable = false) :: Nil), + StructType(StructField("bytes_type", BinaryType, nullable = false) :: Nil), + StructType(StructField("string_type", StringType, nullable = false) :: Nil), + StructType(StructField("int32_type", ByteType, nullable = false) :: Nil), + StructType(StructField("int32_type", ShortType, nullable = false) :: Nil) + ) + + private val catalystTypesToProtoMessages: Map[DataType, String] = Map( + BooleanType -> "BooleanMsg", + IntegerType -> "IntegerMsg", + DoubleType -> "DoubleMsg", + FloatType -> "FloatMsg", + BinaryType -> "BytesMsg", + StringType -> "StringMsg", + ByteType -> "IntegerMsg", + ShortType -> "IntegerMsg" + ) + + testingTypes.foreach { dt => + val seed = scala.util.Random.nextLong() + val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + checkResult(input, filePath, catalystTypesToProtoMessages(dt.fields(0).dataType), + input.eval()) + } + } + + private def checkDeserialization( + descFilePath: String, + messageName: String, + data: Message, + expected: Option[Any], + filters: StructFilters = new NoopFilters): Unit = { + + val descriptor = ProtoUtils.buildDescriptor(descFilePath, messageName) + val dataType = SchemaConverters.toSqlType(descriptor).dataType + + val deserializer = new ProtoDeserializer( + descriptor, + dataType, + false, + RebaseSpec(SQLConf.LegacyBehaviorPolicy.CORRECTED), + filters) + + val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray) + val deserialized = deserializer.deserialize(dynMsg) + expected match { + case None => assert(deserialized.isEmpty) + case Some(d) => + assert(checkResult(d, deserialized.get, dataType, exprNullable = false)) + } + } + + test("Handle unsupported input of message type") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val actualSchema = StructType(Seq( + StructField("col_0", StringType, nullable = false), + StructField("col_1", IntegerType, nullable = false), + StructField("col_2", FloatType, nullable = false), + StructField("col_3", BooleanType, nullable = false), + StructField("col_4", DoubleType, nullable = false))) + + val seed = scala.util.Random.nextLong() + withClue(s"create random record with seed $seed") { + val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema) + val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema) + val input = Literal.create(converter(data), actualSchema) + checkUnsupportedRead(input, testFileDesc, "Actual", "Bad") + } + } + + test("filter push-down to proto deserializer") { + + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val sqlSchema = new StructType() + .add("name", "string") + .add("age", "int") + + val descriptor = ProtoUtils.buildDescriptor(testFileDesc, "Person") + val dynamicMessage = DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("name"), "Maxim") + .setField(descriptor.findFieldByName("age"), 39) + .build() + + val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39)) + checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow) + checkDeserialization( + testFileDesc, + "Person", + dynamicMessage, + expectedRow, + new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema)) + + checkDeserialization( + testFileDesc, + "Person", + dynamicMessage, + None, + new OrderedFilters(Seq(Not(EqualTo("name", "Maxim"))), sqlSchema)) + } + + test("ProtoDeserializer with binary type") { + + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace( + "file:/", "/") + val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53)) + + val descriptor = ProtoUtils.buildDescriptor(testFileDesc, "BytesMsg") + + val dynamicMessage = DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb)) + .build() + + val expected = InternalRow(Array[Byte](97, 48, 53)) + checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected)) + } +} diff --git a/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala new file mode 100644 index 0000000000000..0e68333039030 --- /dev/null +++ b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoFunctionsSuite.scala @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.proto + +import com.google.protobuf.{ByteString, Descriptors, DynamicMessage} + +import org.apache.spark.SparkException +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.{lit, struct} +import org.apache.spark.sql.proto.utils.ProtoUtils +import org.apache.spark.sql.test.SharedSparkSession + +class ProtoFunctionsSuite extends QueryTest with SharedSparkSession with Serializable { + + import testImplicits._ + + val testFileDesc = testFile("protobuf/proto_functions_suite.desc").replace("file:/", "/") + + test("roundtrip in to_proto and from_proto - struct") { + val df = spark.range(10).select( + struct( + $"id", + $"id".cast("string").as("string_value"), + $"id".cast("int").as("int32_value"), + $"id".cast("int").as("uint32_value"), + $"id".cast("int").as("sint32_value"), + $"id".cast("int").as("fixed32_value"), + $"id".cast("int").as("sfixed32_value"), + $"id".cast("long").as("int64_value"), + $"id".cast("long").as("uint64_value"), + $"id".cast("long").as("sint64_value"), + $"id".cast("long").as("fixed64_value"), + $"id".cast("long").as("sfixed64_value"), + $"id".cast("double").as("double_value"), + lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), + lit(true).as("bool_value"), + lit("0".getBytes).as("bytes_value") + ).as("SimpleMessage") + ) + val protoStructDF = df.select(functions.to_proto($"SimpleMessage", testFileDesc, + "SimpleMessage").as("proto")) + val actualDf = protoStructDF.select(functions.from_proto($"proto", testFileDesc, + "SimpleMessage").as("proto.*")) + checkAnswer(actualDf, df) + } + + test("roundtrip in to_proto(without descriptor params) and from_proto - struct") { + val df = spark.range(10).select( + struct( + $"id", + $"id".cast("string").as("string_value"), + $"id".cast("int").as("int32_value"), + $"id".cast("long").as("int64_value"), + $"id".cast("double").as("double_value"), + lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), + lit(true).as("bool_value"), + lit("0".getBytes).as("bytes_value") + ).as("SimpleMessageJavaTypes") + ) + val protoStructDF = df.select(functions.to_proto($"SimpleMessageJavaTypes").as("proto")) + val actualDf = protoStructDF.select(functions.from_proto($"proto", testFileDesc, + "SimpleMessageJavaTypes").as("proto.*")) + checkAnswer(actualDf, df) + } + + test("roundtrip in from_proto and to_proto - Repeated") { + val descriptor = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated") + + val dynamicMessage = DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("key"), "key") + .setField(descriptor.findFieldByName("value"), "value") + .addRepeatedField(descriptor.findFieldByName("rbool_value"), false) + .addRepeatedField(descriptor.findFieldByName("rbool_value"), true) + .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654D) + .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654D) + .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f) + .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f) + .addRepeatedField(descriptor.findFieldByName("rnested_enum"), + descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + .addRepeatedField(descriptor.findFieldByName("rnested_enum"), + descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "SimpleMessageRepeated").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc, + "SimpleMessageRepeated").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "SimpleMessageRepeated").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto - Repeated Message Once") { + val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc, + "RepeatedMessage").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto(without descriptor params) - Repeated Message Once") { + val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto - Repeated Message Twice") { + val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage1 = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value1") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer1")) + .build() + val basicMessage2 = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1112L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value2") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12346) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), false) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer2")) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc, + "RepeatedMessage").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto(without descriptor params) - Repeated Message Twice") { + val repeatedMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage1 = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value1") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer1")) + .build() + val basicMessage2 = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1112L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value2") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12346) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), false) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer2")) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto - Map") { + val messageMapDesc = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageMap") + + val mapStr1 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "key value1") + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value value2") + .build() + val mapStr2 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "key value2") + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value value2") + .build() + val mapInt64 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("Int64MapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"), + 0x90000000000L) + .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"), + 0x90000000001L) + .build() + val mapInt32 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("Int32MapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"), + 12345) + .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"), + 54321) + .build() + val mapFloat = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("FloatMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"), + "float key") + .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"), + 109202.234F) + .build() + val mapDouble = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("DoubleMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key") + , "double key") + .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value") + , 109202.12D) + .build() + val mapBool = DynamicMessage.newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"), + true) + .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"), + false) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(messageMapDesc) + .setField(messageMapDesc.findFieldByName("key"), "key") + .setField(messageMapDesc.findFieldByName("value"), "value") + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1) + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2) + .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64) + .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32) + .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat) + .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble) + .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "SimpleMessageMap").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc, + "SimpleMessageMap").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "SimpleMessageMap").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto(without descriptor params) - Map") { + val messageMapDesc = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageMap") + + val mapStr1 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "key value1") + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value value2") + .build() + val mapStr2 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "key value2") + .setField(messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value value2") + .build() + val mapInt64 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("Int64MapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"), + 0x90000000000L) + .setField(messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"), + 0x90000000001L) + .build() + val mapInt32 = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("Int32MapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"), + 12345) + .setField(messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"), + 54321) + .build() + val mapFloat = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("FloatMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"), + "float key") + .setField(messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"), + 109202.234F) + .build() + val mapDouble = DynamicMessage.newBuilder( + messageMapDesc.findNestedTypeByName("DoubleMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key") + , "double key") + .setField(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value") + , 109202.12D) + .build() + val mapBool = DynamicMessage.newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry")) + .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"), + true) + .setField(messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"), + false) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(messageMapDesc) + .setField(messageMapDesc.findFieldByName("key"), "key") + .setField(messageMapDesc.findFieldByName("value"), "value") + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1) + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2) + .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64) + .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32) + .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat) + .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble) + .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "SimpleMessageMap").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "SimpleMessageMap").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto - Enum") { + val messageEnumDesc = ProtoUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum") + val basicEnumDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicEnumMessage") + + val dynamicMessage = DynamicMessage.newBuilder(messageEnumDesc) + .setField(messageEnumDesc.findFieldByName("key"), "key") + .setField(messageEnumDesc.findFieldByName("value"), "value") + .setField(messageEnumDesc.findFieldByName("nested_enum"), + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + .setField(messageEnumDesc.findFieldByName("nested_enum"), + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + .setField(messageEnumDesc.findFieldByName("basic_enum"), + basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("FIRST")) + .setField(messageEnumDesc.findFieldByName("basic_enum"), + basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("NOTHING")) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "SimpleMessageEnum").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc, + "SimpleMessageEnum").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "SimpleMessageEnum").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto - Multiple Message") { + val messageMultiDesc = ProtoUtils.buildDescriptor(testFileDesc, "MultipleExample") + val messageIncludeDesc = ProtoUtils.buildDescriptor(testFileDesc, "IncludedExample") + val messageOtherDesc = ProtoUtils.buildDescriptor(testFileDesc, "OtherExample") + + val otherMessage = DynamicMessage.newBuilder(messageOtherDesc) + .setField(messageOtherDesc.findFieldByName("other"), "other value") + .build() + + val includeMessage = DynamicMessage.newBuilder(messageIncludeDesc) + .setField(messageIncludeDesc.findFieldByName("included"), "included value") + .setField(messageIncludeDesc.findFieldByName("other"), otherMessage) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(messageMultiDesc) + .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "MultipleExample").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", testFileDesc, + "MultipleExample").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "MultipleExample").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_proto and to_proto(without descriptor params) - Multiple Message") { + val messageMultiDesc = ProtoUtils.buildDescriptor(testFileDesc, "MultipleExample") + val messageIncludeDesc = ProtoUtils.buildDescriptor(testFileDesc, "IncludedExample") + val messageOtherDesc = ProtoUtils.buildDescriptor(testFileDesc, "OtherExample") + + val otherMessage = DynamicMessage.newBuilder(messageOtherDesc) + .setField(messageOtherDesc.findFieldByName("other"), "other value") + .build() + + val includeMessage = DynamicMessage.newBuilder(messageIncludeDesc) + .setField(messageIncludeDesc.findFieldByName("included"), "included value") + .setField(messageIncludeDesc.findFieldByName("other"), otherMessage) + .build() + + val dynamicMessage = DynamicMessage.newBuilder(messageMultiDesc) + .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select(functions.from_proto($"value", testFileDesc, + "MultipleExample").as("value_from")) + val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from").as("value_to")) + val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", testFileDesc, + "MultipleExample").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in to_proto and from_proto - with null") { + val df = spark.range(10).select( + struct( + $"id", + lit(null).cast("string").as("string_value"), + $"id".cast("int").as("int32_value"), + $"id".cast("int").as("uint32_value"), + $"id".cast("int").as("sint32_value"), + $"id".cast("int").as("fixed32_value"), + $"id".cast("int").as("sfixed32_value"), + $"id".cast("long").as("int64_value"), + $"id".cast("long").as("uint64_value"), + $"id".cast("long").as("sint64_value"), + $"id".cast("long").as("fixed64_value"), + $"id".cast("long").as("sfixed64_value"), + $"id".cast("double").as("double_value"), + lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), + lit(true).as("bool_value"), + lit("0".getBytes).as("bytes_value") + ).as("SimpleMessage") + ) + + intercept[SparkException] { + df.select(functions.to_proto($"SimpleMessage", testFileDesc, + "SimpleMessage").as("proto")).collect() + } + } + + test("from_proto filter to_proto") { + val basicMessageDesc = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage.newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "slam") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0D) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField(basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + + val df = Seq(basicMessage.toByteArray).toDF("value") + val resultFrom = df.select(functions.from_proto($"value", testFileDesc, + "BasicMessage") as 'sample) + .where("sample.string_value == \"slam\"") + + val resultToFrom = resultFrom.select(functions.to_proto($"sample") as 'value) + .select(functions.from_proto($"value", testFileDesc, + "BasicMessage") as 'sample) + .where("sample.string_value == \"slam\"") + + assert(resultFrom.except(resultToFrom).isEmpty) + } + + def buildDescriptor(desc: String): Descriptors.Descriptor = { + val descriptor = ProtoUtils.parseFileDescriptor(testFileDesc).getMessageTypes().get(0) + descriptor + } +} diff --git a/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala new file mode 100644 index 0000000000000..b02c52487b480 --- /dev/null +++ b/connector/proto/src/test/scala/org/apache/spark/sql/proto/ProtoSerdeSuite.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.proto + +import com.google.protobuf.Descriptors.Descriptor +import com.google.protobuf.DynamicMessage + +import org.apache.spark.sql.catalyst.NoopFilters +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.CORRECTED +import org.apache.spark.sql.proto.utils.ProtoUtils +import org.apache.spark.sql.proto.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructType} + +/** + * Tests for [[ProtoSerializer]] and [[ProtoDeserializer]] + * with a more specific focus on those classes. + */ +class ProtoSerdeSuite extends SharedSparkSession { + + import ProtoSerdeSuite.MatchType._ + import ProtoSerdeSuite._ + + val testFileDesc = testFile("protobuf/proto_serde_suite.desc").replace("file:/", "/") + + test("Test basic conversion") { + withFieldMatchType { fieldMatch => + val (top, nest) = fieldMatch match { + case BY_NAME => ("foo", "bar") + case BY_POSITION => ("NOTfoo", "NOTbar") + } + val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val dynamicMessageFoo = DynamicMessage.newBuilder( + protoFile.getFile.findMessageTypeByName("Foo")).setField( + protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), + 10902).build() + + val dynamicMessage = DynamicMessage.newBuilder(protoFile) + .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo).build() + + val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + + assert(serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage) + } + } + + test("Fail to convert with field type mismatch") { + val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot") + + withFieldMatchType { fieldMatch => + assertFailedConversionMessage(protoFile, Deserializer, fieldMatch, + "Cannot convert Proto field 'foo' to SQL field 'foo' because schema is incompatible " + + s"(protoType = org.apache.spark.sql.proto.MissMatchTypeInRoot.foo " + + s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})" + .stripMargin) + + assertFailedConversionMessage(protoFile, Serializer, fieldMatch, + s"Cannot convert SQL field 'foo' to Proto field 'foo' because schema is incompatible " + + s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""") + } + } + + test("Fail to convert with missing nested Proto fields") { + val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "FieldMissingInProto") + + val nonnullCatalyst = new StructType() + .add("foo", new StructType().add("bar", IntegerType, nullable = false)) + // Positional matching will work fine with the name change, so add a new field + val extraNonnullCatalyst = new StructType().add("foo", + new StructType().add("bar", IntegerType).add("baz", IntegerType, nullable = false)) + + // deserialize should have no issues when 'bar' is nullable=false but fail when nullable=true + Deserializer.create(CATALYST_STRUCT, protoFile, BY_NAME) + assertFailedConversionMessage(protoFile, Deserializer, BY_NAME, + "Cannot find field 'foo.bar' in Proto schema", + nonnullCatalyst) + assertFailedConversionMessage(protoFile, Deserializer, BY_POSITION, + "Cannot find field at position 1 of field 'foo' from Proto schema (using positional" + + " matching)", + extraNonnullCatalyst) + + // serialize fails whether or not 'bar' is nullable + val byNameMsg = "Cannot find field 'foo.bar' in Proto schema" + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg) + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) + assertFailedConversionMessage(protoFile, Serializer, BY_POSITION, + "Cannot find field at position 1 of field 'foo' from Proto schema (using positional" + + " matching)", + extraNonnullCatalyst) + } + + test("Fail to convert with deeply nested field type mismatch") { + val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested") + val catalyst = new StructType().add("top", CATALYST_STRUCT) + + withFieldMatchType { fieldMatch => + assertFailedConversionMessage(protoFile, Deserializer, fieldMatch, + s"Cannot convert Proto field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " + + s"is incompatible (protoType = org.apache.spark.sql.proto.TypeMiss.bar " + + s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin, + catalyst) + + assertFailedConversionMessage(protoFile, Serializer, fieldMatch, + "Cannot convert SQL field 'top.foo.bar' to Proto field 'top.foo.bar' because schema is " + + """incompatible (sqlType = INT, protoType = LONG)""", + catalyst) + } + } + + test("Fail to convert with missing Catalyst fields") { + val protoFile = ProtoUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, + "Found field 'boo' in Proto schema but there is no match in the SQL schema") + assertFailedConversionMessage(protoFile, Serializer, BY_POSITION, + "Found field 'boo' at position 1 of top-level record from Proto schema but there is no " + + "match in the SQL schema at top-level record (using positional matching)") + + /* deserializing should work regardless of whether the extra field is missing + in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + + + val protoNestedFile = ProtoUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested") + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage(protoNestedFile, Serializer, BY_NAME, + "Found field 'foo.baz' in Proto schema but there is no match in the SQL schema") + assertFailedConversionMessage(protoNestedFile, Serializer, BY_POSITION, + s"Found field 'baz' at position 1 of field 'foo' from Proto schema but there is no match " + + s"in the SQL schema at field 'foo' (using positional matching)") + + /* deserializing should work regardless of whether the extra field is missing + in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + } + + /** + * Attempt to convert `catalystSchema` to `protoSchema` (or vice-versa if `deserialize` is true), + * assert that it fails, and assert that the _cause_ of the thrown exception has a message + * matching `expectedCauseMessage`. + */ + private def assertFailedConversionMessage(protoSchema: Descriptor, + serdeFactory: SerdeFactory[_], + fieldMatchType: MatchType, + expectedCauseMessage: String, + catalystSchema: StructType = CATALYST_STRUCT): Unit = { + val e = intercept[IncompatibleSchemaException] { + serdeFactory.create(catalystSchema, protoSchema, fieldMatchType) + } + val expectMsg = serdeFactory match { + case Deserializer => + s"Cannot convert Proto type ${protoSchema.getName} to SQL type ${catalystSchema.sql}." + case Serializer => + s"Cannot convert SQL type ${catalystSchema.sql} to Proto type ${protoSchema.getName}." + } + + assert(e.getMessage === expectMsg) + assert(e.getCause.getMessage === expectedCauseMessage) + } + + def withFieldMatchType(f: MatchType => Unit): Unit = { + MatchType.values.foreach { fieldMatchType => + withClue(s"fieldMatchType == $fieldMatchType") { + f(fieldMatchType) + } + } + } +} + + +object ProtoSerdeSuite { + + private val CATALYST_STRUCT = + new StructType().add("foo", new StructType().add("bar", IntegerType)) + + /** + * Specifier for type of field matching to be used for easy creation of tests that do both + * positional and by-name field matching. + */ + private object MatchType extends Enumeration { + type MatchType = Value + val BY_NAME, BY_POSITION = Value + + def isPositional(fieldMatchType: MatchType): Boolean = fieldMatchType == BY_POSITION + } + + import MatchType._ + + /** + * Specifier for type of serde to be used for easy creation of tests that do both + * serialization and deserialization. + */ + private sealed trait SerdeFactory[T] { + def create(sqlSchema: StructType, descriptor: Descriptor, fieldMatchType: MatchType): T + } + + private object Serializer extends SerdeFactory[ProtoSerializer] { + override def create(sql: StructType, descriptor: Descriptor, matchType: MatchType): + ProtoSerializer = new ProtoSerializer(sql, descriptor, false, isPositional(matchType)) + } + + private object Deserializer extends SerdeFactory[ProtoDeserializer] { + override def create(sql: StructType, descriptor: Descriptor, matchType: MatchType): + ProtoDeserializer = new ProtoDeserializer( descriptor, sql, isPositional(matchType), + RebaseSpec(CORRECTED), new NoopFilters) + } +} diff --git a/pom.xml b/pom.xml index 5fbd82ad57add..27906a7b9162e 100644 --- a/pom.xml +++ b/pom.xml @@ -101,6 +101,7 @@ connector/kafka-0-10-sql connector/avro connect + connector/proto diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6ffc1d880c5d1..49f44474a2748 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, proto) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "proto" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -433,6 +433,9 @@ object SparkBuild extends PomBuild { enable(SparkConnect.settings)(connect) + /* Connector/proto settings */ + enable(SparkProto.settings)(proto) + // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) @@ -651,6 +654,7 @@ object SparkConnect { ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll, ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll, ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll, + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connector.proto.protobuf.@1").inAll, ), (assembly / assemblyMergeStrategy) := { @@ -662,6 +666,48 @@ object SparkConnect { ) } +object SparkProto { + + import BuildCommons.protoVersion + + private val shadePrefix = "org.sparkproject.spark-proto" + val shadeJar = taskKey[Unit]("Shade the Jars") + + lazy val settings = Seq( + // Setting version for the protobuf compiler. This has to be propagated to every sub-project + // even if the project is not using it. + PB.protocVersion := BuildCommons.protoVersion, + + // For some reason the resolution from the imported Maven build does not work for some + // of these dependendencies that we need to shade later on. + libraryDependencies ++= Seq( + "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" + ), + + dependencyOverrides ++= Seq( + "com.google.protobuf" % "protobuf-java" % protoVersion + ), + + (Compile / PB.targets) := Seq( + PB.gens.java -> (Compile / sourceManaged).value, + ), + + (assembly / test) := false, + + (assembly / logLevel) := Level.Info, + + (assembly / assemblyShadeRules) := Seq( + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.spark-proto.protobuf.@1").inAll, + ), + + (assembly / assemblyMergeStrategy) := { + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard + // Drop all proto files that are not needed as artifacts of the build. + case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard + case _ => MergeStrategy.first + }, + ) +} object Unsafe { lazy val settings = Seq( // This option is needed to suppress warnings from sun.misc.Unsafe usage @@ -1107,10 +1153,10 @@ object Unidoc { (ScalaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connect), + yarn, tags, streamingKafka010, sqlKafka010, connect, proto), (JavaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connect), + yarn, tags, streamingKafka010, sqlKafka010, connect, proto), (ScalaUnidoc / unidoc / unidocAllClasspaths) := { ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value) @@ -1196,6 +1242,7 @@ object CopyDependencies { // produce the shaded Jar which happens automatically in the case of Maven. // Later, when the dependencies are copied, we manually copy the shaded Jar only. val fid = (LocalProject("connect")/assembly).value + val fidProto = (LocalProject("proto")/assembly).value (Compile / dependencyClasspath).value.map(_.data) .filter { jar => jar.isFile() } @@ -1208,6 +1255,9 @@ object CopyDependencies { if (jar.getName.contains("spark-connect") && !SbtPomKeys.profiles.value.contains("noshade-connect")) { Files.copy(fid.toPath, destJar.toPath) + } else if (jar.getName.contains("spark-proto") && + !SbtPomKeys.profiles.value.contains("noshade-proto")) { + Files.copy(fid.toPath, destJar.toPath) } else { Files.copy(jar.toPath(), destJar.toPath()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bc6700a3b5616..adb9b91b3e5e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3673,6 +3673,22 @@ object SQLConf { .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + val PROTO_REBASE_MODE_IN_READ = + buildConf("spark.sql.proto.datetimeRebaseModeInRead") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from the legacy hybrid (Julian + " + + "Gregorian) calendar to Proleptic Gregorian calendar when reading Proto Events. " + + "When CORRECTED, Spark will not do rebase and read the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the reading if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars. This config is " + + "only effective if the writer info (like Spark, Hive) of the Proto events is unknown.") + .version("3.4.0") + .withAlternative("spark.sql.legacy.proto.datetimeRebaseModeInRead") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT = buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds") .internal() From dd6a8a8cc695977e19493d6b8d66c7b5c7cc7c49 Mon Sep 17 00:00:00 2001 From: SandishKumarHN Date: Thu, 29 Sep 2022 01:16:36 -0700 Subject: [PATCH 2/3] CI issues --- project/SparkBuild.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 49f44474a2748..184e34b9a0c14 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -59,7 +59,7 @@ object BuildCommons { ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "tags", "sketch", "kvstore" - ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) ++ Seq(proto) val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, sparkGangliaLgpl, streamingKinesisAsl, @@ -390,7 +390,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, tokenProviderKafka010, sqlKafka010, connect + unsafe, tags, tokenProviderKafka010, sqlKafka010, connect, proto ).contains(x) } From b597f36e11cb104add4b267be2c62e8e6d522336 Mon Sep 17 00:00:00 2001 From: SandishKumarHN Date: Thu, 29 Sep 2022 08:56:43 -0700 Subject: [PATCH 3/3] CI issues --- .github/workflows/build_and_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index d53133e09b33a..b0847187dffdd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -146,7 +146,7 @@ jobs: - >- core, unsafe, kvstore, avro, network-common, network-shuffle, repl, launcher, - examples, sketch, graphx, proto + examples, sketch, graphx - >- catalyst, hive-thriftserver - >-