diff --git a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala index a1545721fa11..545ada0f82d7 100644 --- a/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala +++ b/connector/proto/src/main/scala/org/apache/spark/sql/proto/ProtoDeserializer.scala @@ -78,6 +78,42 @@ private[sql] class ProtoDeserializer( def deserialize(data: Any): Option[Any] = converter(data) + private def newArrayWriter( + protoField: FieldDescriptor, + catalystType: DataType, + protoPath: Seq[String], + catalystPath: Seq[String], + elementType: DataType, + containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = { + + + val protoElementPath = protoPath :+ "element" + val elementWriter = newWriter(protoField, elementType, + protoElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iter = collection.iterator() + while (iter.hasNext) { + val element = iter.next() + if (element == null) { + if (!containsNull) { + throw new RuntimeException( + s"Array value at path ${toFieldStr(protoElementPath)} is not allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + } /** * Creates a writer to write proto values to Catalyst values at the given ordinal with the given * updater. @@ -102,32 +138,50 @@ private[sql] class ProtoDeserializer( case (BOOLEAN, BooleanType) => (updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + case (BOOLEAN, ArrayType(BooleanType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, BooleanType, containsNull) + case (INT, IntegerType) => (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) + case (INT, ArrayType(IntegerType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, IntegerType, containsNull) + case (INT, DateType) => (updater, ordinal, value) => updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) + case (LONG, ArrayType(LongType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, LongType, containsNull) + case (FLOAT, FloatType) => (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float]) + case (FLOAT, ArrayType(FloatType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, FloatType, containsNull) + case (DOUBLE, DoubleType) => (updater, ordinal, value) => updater.setDouble(ordinal, value.asInstanceOf[Double]) + case (DOUBLE, ArrayType(DoubleType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, DoubleType, containsNull) + case (STRING, StringType) => (updater, ordinal, value) => val str = value match { case s: String => UTF8String.fromString(s) } updater.set(ordinal, str) - case (STRING, ArrayType(StringType, containsNull)) => (updater, ordinal, value) => - val str = value match { - case s: String => UTF8String.fromString(s) - } - updater.set(ordinal, str) + case (STRING, ArrayType(StringType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, StringType, containsNull) case (BYTE_STRING, BinaryType) => (updater, ordinal, value) => val byte_array = value match { @@ -136,51 +190,30 @@ private[sql] class ProtoDeserializer( } updater.set(ordinal, byte_array) - case (BYTE_STRING, ArrayType(BinaryType, containsNull)) => (updater, ordinal, value) => - val byte_array = value match { - case s: ByteString => s.toByteArray - case _ => throw new Exception("Invalid ByteString format") - } - updater.set(ordinal, byte_array) - - case (DOUBLE, ArrayType(DoubleType, containsNull)) => (updater, ordinal, value) => - updater.setDouble(ordinal, value.asInstanceOf[Double]) - - case (FLOAT, ArrayType(FloatType, containsNull)) => (updater, ordinal, value) => - updater.setFloat(ordinal, value.asInstanceOf[Float]) - - case (LONG, ArrayType(LongType, containsNull)) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - - case (INT, ArrayType(DateType, containsNull)) => (updater, ordinal, value) => - updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) - - case (INT, ArrayType(IntegerType, containsNull)) => (updater, ordinal, value) => - updater.setInt(ordinal, value.asInstanceOf[Int]) - - case (BOOLEAN, ArrayType(BooleanType, containsNull)) => (updater, ordinal, value) => - updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + case (BYTE_STRING, ArrayType(BinaryType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, BinaryType, containsNull) case (MESSAGE, st: StructType) => - val writeRecord = getRecordWriter(protoType.getMessageType, st, protoPath, catalystPath, applyFilters = _ => false) + val writeRecord = getRecordWriter(protoType.getMessageType, st, protoPath, + catalystPath, applyFilters = _ => false) (updater, ordinal, value) => val row = new SpecificInternalRow(st) writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) updater.set(ordinal, row) case (MESSAGE, ArrayType(st: StructType, containsNull)) => - val writeRecord = getRecordWriter(protoType.getMessageType, st, protoPath, catalystPath, applyFilters = _ => false) - (updater, ordinal, value) => - val row = new SpecificInternalRow(st) - writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) - updater.set(ordinal, row) + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, st, containsNull) case (ENUM, StringType) => (updater, ordinal, value) => updater.set(ordinal, UTF8String.fromString(value.toString)) - case (ENUM, ArrayType(StringType, containsNull)) => (updater, ordinal, value) => - updater.set(ordinal, UTF8String.fromString(value.toString)) + case (ENUM, ArrayType(StringType, containsNull)) => + newArrayWriter(protoType, catalystType, protoPath, + catalystPath, StringType, containsNull) + // TBD: Do we need this here ? case (INT, _: YearMonthIntervalType) => (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) @@ -207,45 +240,16 @@ private[sql] class ProtoDeserializer( val (validFieldIndexes, fieldWriters) = protoSchemaHelper.matchedFields.map { case ProtoMatchedField(catalystField, ordinal, protoField) => - if(protoField.isRepeated) { - val protoElementPath = protoPath :+ "element" - val elementWriter = newWriter(protoField, catalystField.dataType, protoPath :+ protoField.getName, - catalystPath :+ catalystField.name) - val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { - val collection = value.asInstanceOf[java.util.List[Object]] - val result = createArrayData(catalystField.dataType, collection.size()) - val elementUpdater = new ArrayDataUpdater(result) - var i = 0 - val iter = collection.iterator() - while (iter.hasNext) { - val element = iter.next() - if (element == null) { - if (value != null) { - throw new RuntimeException( - s"Array value at path ${toFieldStr(protoElementPath)} is not allowed to be null") - } else { - elementUpdater.setNullAt(i) - } - } else { - elementWriter(elementUpdater, i, element) - } - i += 1 - } - fieldUpdater.set(ordinal, result) - } - (protoField, fieldWriter) - } else { - val baseWriter = newWriter(protoField, catalystField.dataType, - protoPath :+ protoField.getName, catalystPath :+ catalystField.name) - val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { - if (value == null) { - fieldUpdater.setNullAt(ordinal) - } else { - baseWriter(fieldUpdater, ordinal, value) - } + val baseWriter = newWriter(protoField, catalystField.dataType, + protoPath :+ protoField.getName, catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) } - (protoField, fieldWriter) } + (protoField, fieldWriter) }.toArray.unzip (fieldUpdater, record) => {