diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 958e9477f9d8..50db6348923c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -167,14 +167,13 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { case (ARRAY, ArrayType(elementType, containsNull)) => val elementWriter = newWriter(avroType.getElementType, elementType, path) (updater, ordinal, value) => - val array = value.asInstanceOf[GenericData.Array[Any]] + val array = value.asInstanceOf[java.util.Collection[Any]] val len = array.size() val result = createArrayData(elementType, len) val elementUpdater = new ArrayDataUpdater(result) var i = 0 - while (i < len) { - val element = array.get(i) + for (element <- array.asScala) { if (element == null) { if (!containsNull) { throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 4b39e711aa28..c8a1f670bda9 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.avro +import java.util +import java.util.Collections + import org.apache.avro.Schema +import org.apache.avro.generic.{GenericData, GenericRecordBuilder} +import org.apache.avro.message.{BinaryMessageDecoder, BinaryMessageEncoder} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{RandomDataGenerator, Row} @@ -127,6 +132,26 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite } } + test("array of nested schema with seed") { + val seed = scala.util.Random.nextLong() + val rand = new scala.util.Random(seed) + val schema = StructType( + StructField("a", + ArrayType( + RandomDataGenerator.randomNestedSchema(rand, 10, testingTypes), + containsNull = false), + nullable = false + ) :: Nil + ) + + withClue(s"Schema: $schema\nseed: $seed") { + val data = RandomDataGenerator.randomRow(rand, schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val input = Literal.create(converter(data), schema) + roundTripTest(input) + } + } + test("read int as string") { val data = Literal(1) val avroTypeJson = @@ -246,4 +271,46 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite }.getMessage assert(message == "Cannot convert Catalyst type StringType to Avro type \"long\".") } + + test("avro array can be generic java collection") { + val jsonFormatSchema = + """ + |{ "type": "record", + | "name": "record", + | "fields" : [{ + | "name": "array", + | "type": { + | "type": "array", + | "items": ["null", "int"] + | } + | }] + |} + """.stripMargin + val avroSchema = new Schema.Parser().parse(jsonFormatSchema) + val dataType = SchemaConverters.toSqlType(avroSchema).dataType + val deserializer = new AvroDeserializer(avroSchema, dataType) + + def checkDeserialization(data: GenericData.Record, expected: Any): Unit = { + assert(checkResult( + expected, + deserializer.deserialize(data), + dataType, exprNullable = false + )) + } + + def validateDeserialization(array: java.util.Collection[Integer]): Unit = { + val data = new GenericRecordBuilder(avroSchema) + .set("array", array) + .build() + val expected = InternalRow(new GenericArrayData(new util.ArrayList[Any](array))) + checkDeserialization(data, expected) + + val reEncoded = new BinaryMessageDecoder[GenericData.Record](new GenericData(), avroSchema) + .decode(new BinaryMessageEncoder(new GenericData(), avroSchema).encode(data)) + checkDeserialization(reEncoded, expected) + } + + validateDeserialization(Collections.emptySet()) + validateDeserialization(util.Arrays.asList(1, null, 3)) + } }