From 325e2989d2fd5dfa3617e084c8113690bdd4bc51 Mon Sep 17 00:00:00 2001 From: Himanshu Kohli <40983327+himanshukohli09@users.noreply.github.com> Date: Thu, 13 May 2021 21:05:12 +0530 Subject: [PATCH] Issue #367: Solved column order of struct bug (#391) Column order of struct variables doesn't need to be the same as that of BigQuery schema --- .../spark/bigquery/ArrowBinaryIterator.java | 27 ++++- .../spark/bigquery/AvroBinaryIterator.java | 9 +- ...esponseToInternalRowIteratorConverter.java | 38 +++++-- .../spark/bigquery/SchemaConverters.java | 107 +++++++++++++----- .../v2/ArrowColumnBatchPartitionReader.java | 23 +++- .../bigquery/v2/ArrowInputPartition.java | 15 ++- .../bigquery/v2/BigQueryDataSourceReader.java | 17 ++- .../spark/bigquery/direct/BigQueryRDD.scala | 23 ++-- .../v2/BigQueryInputPartitionReaderTest.java | 4 +- .../spark/bigquery/SchemaIteratorSuite.scala | 8 +- .../it/SparkBigQueryEndToEndReadITSuite.scala | 26 ++++- .../spark/bigquery/it/TestConstants.scala | 41 +++++++ .../spark/bigquery/ArrowSchemaConverter.java | 53 +++++++-- 13 files changed, 322 insertions(+), 69 deletions(-) diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java index 4ee2f2f30..4a3b8a36f 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java @@ -18,13 +18,18 @@ import com.google.cloud.bigquery.connector.common.ArrowUtil; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; +import java.util.Arrays; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; import org.apache.arrow.compression.CommonsCompressionFactory; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.slf4j.Logger; @@ -44,9 +49,13 @@ public class ArrowBinaryIterator implements Iterator { ArrowReaderIterator arrowReaderIterator; Iterator currentIterator; List columnsInOrder; + Map userProvidedFieldMap; public ArrowBinaryIterator( - List columnsInOrder, ByteString schema, ByteString rowsInBytes) { + List columnsInOrder, + ByteString schema, + ByteString rowsInBytes, + Optional userProvidedSchema) { BufferAllocator allocator = ArrowUtil.newRootAllocator(maxAllocation) .newChildAllocator("ArrowBinaryIterator", 0, maxAllocation); @@ -61,6 +70,16 @@ public ArrowBinaryIterator( arrowReaderIterator = new ArrowReaderIterator(arrowStreamReader); currentIterator = ImmutableList.of().iterator(); this.columnsInOrder = columnsInOrder; + + List userProvidedFieldList = + Arrays + .stream(userProvidedSchema.orElse(new StructType()).fields()) + .collect(Collectors.toList()); + + this.userProvidedFieldMap = + userProvidedFieldList + .stream() + .collect(Collectors.toMap(StructField::name, Function.identity())); } @Override @@ -84,7 +103,9 @@ private Iterator toArrowRows(VectorSchemaRoot root, List na ColumnVector[] columns = namesInOrder.stream() .map(name -> root.getVector(name)) - .map(vector -> new ArrowSchemaConverter(vector)) + .map( + vector -> + new ArrowSchemaConverter(vector, userProvidedFieldMap.get(vector.getName()))) .collect(Collectors.toList()) .toArray(new ColumnVector[0]); diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java b/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java index 5ae447941..de3a48181 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java @@ -17,11 +17,13 @@ import com.google.cloud.bigquery.Schema; import com.google.protobuf.ByteString; +import java.util.Optional; import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericRecord; import org.apache.avro.io.BinaryDecoder; import org.apache.avro.io.DecoderFactory; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,6 +39,7 @@ public class AvroBinaryIterator implements Iterator { List columnsInOrder; BinaryDecoder in; Schema bqSchema; + Optional userProvidedSchema; /** * An iterator for scanning over rows serialized in Avro format @@ -50,11 +53,13 @@ public AvroBinaryIterator( Schema bqSchema, List columnsInOrder, org.apache.avro.Schema schema, - ByteString rowsInBytes) { + ByteString rowsInBytes, + Optional userProvidedSchema) { reader = new GenericDatumReader(schema); this.bqSchema = bqSchema; this.columnsInOrder = columnsInOrder; in = new DecoderFactory().binaryDecoder(rowsInBytes.toByteArray(), null); + this.userProvidedSchema = userProvidedSchema; } @Override @@ -70,7 +75,7 @@ public boolean hasNext() { public InternalRow next() { try { return SchemaConverters.convertToInternalRow( - bqSchema, columnsInOrder, (GenericRecord) reader.read(null, in)); + bqSchema, columnsInOrder, (GenericRecord) reader.read(null, in), userProvidedSchema); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java index 20e62106e..fc776818e 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java @@ -19,23 +19,28 @@ import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; import com.google.protobuf.ByteString; import org.apache.spark.sql.catalyst.InternalRow; - import java.io.Serializable; import java.util.Iterator; import java.util.List; +import java.util.Optional; +import org.apache.spark.sql.types.StructType; +import static com.google.common.base.Optional.fromJavaUtil; public interface ReadRowsResponseToInternalRowIteratorConverter { static ReadRowsResponseToInternalRowIteratorConverter avro( final com.google.cloud.bigquery.Schema bqSchema, final List columnsInOrder, - final String rawAvroSchema) { - return new Avro(bqSchema, columnsInOrder, rawAvroSchema); + final String rawAvroSchema, + final Optional userProvidedSchema) { + return new Avro(bqSchema, columnsInOrder, rawAvroSchema, fromJavaUtil(userProvidedSchema)); } static ReadRowsResponseToInternalRowIteratorConverter arrow( - final List columnsInOrder, final ByteString arrowSchema) { - return new Arrow(columnsInOrder, arrowSchema); + final List columnsInOrder, + final ByteString arrowSchema, + final Optional userProvidedSchema) { + return new Arrow(columnsInOrder, arrowSchema, fromJavaUtil(userProvidedSchema)); } Iterator convert(ReadRowsResponse response); @@ -45,11 +50,17 @@ class Avro implements ReadRowsResponseToInternalRowIteratorConverter, Serializab private final com.google.cloud.bigquery.Schema bqSchema; private final List columnsInOrder; private final String rawAvroSchema; + private final com.google.common.base.Optional userProvidedSchema; - public Avro(Schema bqSchema, List columnsInOrder, String rawAvroSchema) { + public Avro( + Schema bqSchema, + List columnsInOrder, + String rawAvroSchema, + com.google.common.base.Optional userProvidedSchema) { this.bqSchema = bqSchema; this.columnsInOrder = columnsInOrder; this.rawAvroSchema = rawAvroSchema; + this.userProvidedSchema = userProvidedSchema; } @Override @@ -58,7 +69,8 @@ public Iterator convert(ReadRowsResponse response) { bqSchema, columnsInOrder, new org.apache.avro.Schema.Parser().parse(rawAvroSchema), - response.getAvroRows().getSerializedBinaryRows()); + response.getAvroRows().getSerializedBinaryRows(), + userProvidedSchema.toJavaUtil()); } } @@ -66,16 +78,24 @@ class Arrow implements ReadRowsResponseToInternalRowIteratorConverter, Serializa private final List columnsInOrder; private final ByteString arrowSchema; + private final com.google.common.base.Optional userProvidedSchema; - public Arrow(List columnsInOrder, ByteString arrowSchema) { + public Arrow( + List columnsInOrder, + ByteString arrowSchema, + com.google.common.base.Optional userProvidedSchema) { this.columnsInOrder = columnsInOrder; this.arrowSchema = arrowSchema; + this.userProvidedSchema = userProvidedSchema; } @Override public Iterator convert(ReadRowsResponse response) { return new ArrowBinaryIterator( - columnsInOrder, arrowSchema, response.getArrowRecordBatch().getSerializedRecordBatch()); + columnsInOrder, + arrowSchema, + response.getArrowRecordBatch().getSerializedRecordBatch(), + userProvidedSchema.toJavaUtil()); } } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java index 4e76ea933..f40907f3b 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java @@ -25,6 +25,7 @@ import com.google.cloud.bigquery.TimePartitioning; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import java.util.function.Function; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; import org.apache.spark.sql.catalyst.InternalRow; @@ -91,11 +92,19 @@ public static Schema getSchemaWithPseudoColumns(TableInfo tableInfo) { } public static InternalRow convertToInternalRow( - Schema schema, List namesInOrder, GenericRecord record) { - return convertAll(schema.getFields(), record, namesInOrder); + Schema schema, + List namesInOrder, + GenericRecord record, + Optional userProvidedSchema) { + List userProvidedFieldList = + Arrays + .stream(userProvidedSchema.orElse(new StructType()).fields()) + .collect(Collectors.toList()); + + return convertAll(schema.getFields(), record, namesInOrder, userProvidedFieldList); } - static Object convert(Field field, Object value) { + static Object convert(Field field, Object value, StructField userProvidedField) { if (value == null) { return null; } @@ -113,38 +122,55 @@ static Object convert(Field field, Object value) { .build(); List valueList = (List) value; - return new GenericArrayData( - valueList.stream().map(v -> convert(nestedField, v)).collect(Collectors.toList())); + valueList + .stream() + .map(v -> convert(nestedField, v, getStructFieldForRepeatedMode(userProvidedField))) + .collect(Collectors.toList())); } - Object datum = convertByBigQueryType(field, value); + Object datum = convertByBigQueryType(field, value, userProvidedField); Optional customDatum = getCustomDataType(field).map(dt -> ((UserDefinedType) dt).deserialize(datum)); return customDatum.orElse(datum); } - static Object convertByBigQueryType(Field field, Object value) { - if (LegacySQLTypeName.INTEGER.equals(field.getType()) - || LegacySQLTypeName.FLOAT.equals(field.getType()) - || LegacySQLTypeName.BOOLEAN.equals(field.getType()) - || LegacySQLTypeName.DATE.equals(field.getType()) - || LegacySQLTypeName.TIME.equals(field.getType()) - || LegacySQLTypeName.TIMESTAMP.equals(field.getType())) { + private static StructField getStructFieldForRepeatedMode(StructField field) { + StructField nestedField = null; + + if (field != null) { + ArrayType arrayType = ((ArrayType) field.dataType()); + nestedField = + new StructField( + field.name(), + arrayType.elementType(), + arrayType.containsNull(), + Metadata.empty()); // safe to pass empty metadata as it is not used anywhere + } + return nestedField; + } + + static Object convertByBigQueryType(Field bqField, Object value, StructField userProvidedField) { + if (LegacySQLTypeName.INTEGER.equals(bqField.getType()) + || LegacySQLTypeName.FLOAT.equals(bqField.getType()) + || LegacySQLTypeName.BOOLEAN.equals(bqField.getType()) + || LegacySQLTypeName.DATE.equals(bqField.getType()) + || LegacySQLTypeName.TIME.equals(bqField.getType()) + || LegacySQLTypeName.TIMESTAMP.equals(bqField.getType())) { return value; } - if (LegacySQLTypeName.STRING.equals(field.getType()) - || LegacySQLTypeName.DATETIME.equals(field.getType()) - || LegacySQLTypeName.GEOGRAPHY.equals(field.getType())) { + if (LegacySQLTypeName.STRING.equals(bqField.getType()) + || LegacySQLTypeName.DATETIME.equals(bqField.getType()) + || LegacySQLTypeName.GEOGRAPHY.equals(bqField.getType())) { return UTF8String.fromBytes(((Utf8) value).getBytes()); } - if (LegacySQLTypeName.BYTES.equals(field.getType())) { + if (LegacySQLTypeName.BYTES.equals(bqField.getType())) { return getBytes((ByteBuffer) value); } - if (LegacySQLTypeName.NUMERIC.equals(field.getType())) { + if (LegacySQLTypeName.NUMERIC.equals(bqField.getType())) { byte[] bytes = getBytes((ByteBuffer) value); BigDecimal b = new BigDecimal(new BigInteger(bytes), BQ_NUMERIC_SCALE); Decimal d = Decimal.apply(b, BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); @@ -152,14 +178,26 @@ static Object convertByBigQueryType(Field field, Object value) { return d; } - if (LegacySQLTypeName.RECORD.equals(field.getType())) { - return convertAll( - field.getSubFields(), - (GenericRecord) value, - field.getSubFields().stream().map(f -> f.getName()).collect(Collectors.toList())); + if (LegacySQLTypeName.RECORD.equals(bqField.getType())) { + List namesInOrder = null; + List structList = null; + + if (userProvidedField != null) { + structList = + Arrays + .stream(((StructType)userProvidedField.dataType()).fields()) + .collect(Collectors.toList()); + + namesInOrder = structList.stream().map(StructField::name).collect(Collectors.toList()); + } else { + namesInOrder = + bqField.getSubFields().stream().map(Field::getName).collect(Collectors.toList()); + } + + return convertAll(bqField.getSubFields(), (GenericRecord) value, namesInOrder, structList); } - throw new IllegalStateException("Unexpected type: " + field.getType()); + throw new IllegalStateException("Unexpected type: " + bqField.getType()); } private static byte[] getBytes(ByteBuffer buf) { @@ -171,13 +209,28 @@ private static byte[] getBytes(ByteBuffer buf) { // Schema is not recursive so add helper for sequence of fields static GenericInternalRow convertAll( - FieldList fieldList, GenericRecord record, List namesInOrder) { - + FieldList fieldList, + GenericRecord record, + List namesInOrder, + List userProvidedFieldList) { Map fieldMap = new HashMap<>(); + Map userProvidedFieldMap = + userProvidedFieldList == null + ? new HashMap<>() + : userProvidedFieldList + .stream() + .collect(Collectors.toMap(StructField::name, Function.identity())); + fieldList.stream() .forEach( - field -> fieldMap.put(field.getName(), convert(field, record.get(field.getName())))); + field -> + fieldMap.put( + field.getName(), + convert( + field, + record.get(field.getName()), + userProvidedFieldMap.get(field.getName())))); Object[] values = new Object[namesInOrder.size()]; for (int i = 0; i < namesInOrder.size(); i++) { diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java index 17178bf4a..5f5c72db9 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java @@ -24,15 +24,20 @@ import java.io.IOException; import java.io.InputStream; import java.io.SequenceInputStream; +import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.stream.Collectors; import org.apache.arrow.compression.CommonsCompressionFactory; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; @@ -47,6 +52,7 @@ class ArrowColumnBatchPartitionColumnBatchReader implements InputPartitionReader private ColumnarBatch currentBatch; private final BigQueryStorageReadRowsTracer tracer; private boolean closed = false; + private final Map userProvidedFieldMap; static class ReadRowsResponseInputStreamEnumeration implements java.util.Enumeration { @@ -95,7 +101,8 @@ void loadNextResponse() { ByteString schema, ReadRowsHelper readRowsHelper, List namesInOrder, - BigQueryStorageReadRowsTracer tracer) { + BigQueryStorageReadRowsTracer tracer, + Optional userProvidedSchema) { this.allocator = ArrowUtil.newRootAllocator(maxAllocation) .newChildAllocator("ArrowBinaryIterator", 0, maxAllocation); @@ -103,6 +110,14 @@ void loadNextResponse() { this.namesInOrder = namesInOrder; this.tracer = tracer; + List userProvidedFieldList = + Arrays + .stream(userProvidedSchema.orElse(new StructType()).fields()) + .collect(Collectors.toList()); + + this.userProvidedFieldMap = + userProvidedFieldList.stream().collect(Collectors.toMap(StructField::name, field -> field)); + InputStream batchStream = new SequenceInputStream( new ReadRowsResponseInputStreamEnumeration(readRowsResponses, tracer)); @@ -132,7 +147,9 @@ public boolean next() throws IOException { ColumnVector[] columns = namesInOrder.stream() .map(root::getVector) - .map(ArrowSchemaConverter::new) + .map( + vector -> + new ArrowSchemaConverter(vector, userProvidedFieldMap.get(vector.getName()))) .toArray(ColumnVector[]::new); currentBatch = new ColumnarBatch(columns); diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java index 5842fe247..5f7e31b06 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java @@ -24,9 +24,12 @@ import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; +import java.util.Optional; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; +import static com.google.common.base.Optional.fromJavaUtil; import java.util.Iterator; @@ -38,6 +41,7 @@ public class ArrowInputPartition implements InputPartition { private final int maxReadRowsRetries; private final ImmutableList selectedFields; private final ByteString serializedArrowSchema; + private final com.google.common.base.Optional userProvidedSchema; public ArrowInputPartition( BigQueryReadClientFactory bigQueryReadClientFactory, @@ -45,7 +49,8 @@ public ArrowInputPartition( String name, int maxReadRowsRetries, ImmutableList selectedFields, - ReadSessionResponse readSessionResponse) { + ReadSessionResponse readSessionResponse, + Optional userProvidedSchema) { this.bigQueryReadClientFactory = bigQueryReadClientFactory; this.streamName = name; this.maxReadRowsRetries = maxReadRowsRetries; @@ -53,6 +58,7 @@ public ArrowInputPartition( this.serializedArrowSchema = readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema(); this.tracerFactory = tracerFactory; + this.userProvidedSchema = fromJavaUtil(userProvidedSchema); } @Override @@ -65,6 +71,11 @@ public InputPartitionReader createPartitionReader() { tracer.startStream(); Iterator readRowsResponses = readRowsHelper.readRows(); return new ArrowColumnBatchPartitionColumnBatchReader( - readRowsResponses, serializedArrowSchema, readRowsHelper, selectedFields, tracer); + readRowsResponses, + serializedArrowSchema, + readRowsHelper, + selectedFields, + tracer, + userProvidedSchema.toJavaUtil()); } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java index 78fa46120..6759a5c27 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java @@ -92,6 +92,7 @@ public OptionalLong numRows() { private final ReadSessionCreator readSessionCreator; private final Optional globalFilter; private Optional schema; + private Optional userProvidedSchema; private Filter[] pushedFilters = new Filter[] {}; private Map fields; @@ -116,8 +117,10 @@ public BigQueryDataSourceReader( SchemaConverters.toSpark(SchemaConverters.getSchemaWithPseudoColumns(table)); if (schema.isPresent()) { this.schema = schema; + this.userProvidedSchema = schema; } else { this.schema = Optional.of(convertedSchema); + this.userProvidedSchema = Optional.empty(); } // We want to keep the key order this.fields = new LinkedHashMap<>(); @@ -164,7 +167,7 @@ public List> planInputPartitions() { bigQueryReadClientFactory, stream.getName(), readSessionCreatorConfig.getMaxReadRowsRetries(), - createConverter(selectedFields, readSessionResponse))) + createConverter(selectedFields, readSessionResponse, userProvidedSchema))) .collect(Collectors.toList()); } @@ -206,7 +209,8 @@ public List> planBatchInputPartitions() { stream.getName(), readSessionCreatorConfig.getMaxReadRowsRetries(), partitionSelectedFields, - readSessionResponse)) + readSessionResponse, + userProvidedSchema)) .collect(Collectors.toList()); } @@ -215,7 +219,9 @@ private boolean isEmptySchema() { } private ReadRowsResponseToInternalRowIteratorConverter createConverter( - ImmutableList selectedFields, ReadSessionResponse readSessionResponse) { + ImmutableList selectedFields, + ReadSessionResponse readSessionResponse, + Optional userProvidedSchema) { ReadRowsResponseToInternalRowIteratorConverter converter; DataFormat format = readSessionCreatorConfig.getReadDataFormat(); if (format == DataFormat.AVRO) { @@ -236,7 +242,10 @@ private ReadRowsResponseToInternalRowIteratorConverter createConverter( .collect(Collectors.toList())); } return ReadRowsResponseToInternalRowIteratorConverter.avro( - schema, selectedFields, readSessionResponse.getReadSession().getAvroSchema().getSchema()); + schema, + selectedFields, + readSessionResponse.getReadSession().getAvroSchema().getSchema(), + userProvidedSchema); } throw new IllegalArgumentException( "No known converted for " + readSessionCreatorConfig.getReadDataFormat()); diff --git a/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/BigQueryRDD.scala b/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/BigQueryRDD.scala index 4c11ca400..e58896555 100644 --- a/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/BigQueryRDD.scala +++ b/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/BigQueryRDD.scala @@ -15,9 +15,11 @@ */ package com.google.cloud.spark.bigquery.direct +import java.util.Optional + import com.google.api.gax.rpc.ServerStreamingCallable import com.google.cloud.bigquery.connector.common.BigQueryUtil -import com.google.cloud.bigquery.storage.v1.{BigQueryReadClient, DataFormat, ReadRowsRequest, ReadRowsResponse, ReadSession, ReadStream} +import com.google.cloud.bigquery.storage.v1.{BigQueryReadClient, DataFormat, ReadRowsRequest, ReadRowsResponse, ReadSession} import com.google.cloud.bigquery.{BigQuery, Schema} import com.google.cloud.spark.bigquery.{ArrowBinaryIterator, AvroBinaryIterator, SparkBigQueryConfig} import com.google.protobuf.ByteString @@ -26,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import scala.collection.JavaConverters._ @@ -59,12 +62,14 @@ class BigQueryRDD(sc: SparkContext, AvroConverter(bqSchema, columnsInOrder, session.getAvroSchema.getSchema, - readRowResponses).getIterator() + readRowResponses, + options.getSchema).getIterator() } else { ArrowConverter(columnsInOrder, session.getArrowSchema.getSerializedSchema, - readRowResponses).getIterator() + readRowResponses, + options.getSchema).getIterator() } new InterruptibleIterator(context, it) @@ -82,13 +87,15 @@ class BigQueryRDD(sc: SparkContext, */ case class ArrowConverter(columnsInOrder: Seq[String], rawArrowSchema : ByteString, - rowResponseIterator : Iterator[ReadRowsResponse]) + rowResponseIterator : Iterator[ReadRowsResponse], + userProvidedSchema: Optional[StructType]) { def getIterator(): Iterator[InternalRow] = { rowResponseIterator.flatMap(readRowResponse => new ArrowBinaryIterator(columnsInOrder.asJava, rawArrowSchema, - readRowResponse.getArrowRecordBatch.getSerializedRecordBatch).asScala); + readRowResponse.getArrowRecordBatch.getSerializedRecordBatch, + userProvidedSchema).asScala); } } @@ -103,7 +110,8 @@ case class ArrowConverter(columnsInOrder: Seq[String], case class AvroConverter (bqSchema: Schema, columnsInOrder: Seq[String], rawAvroSchema: String, - rowResponseIterator : Iterator[ReadRowsResponse]) + rowResponseIterator : Iterator[ReadRowsResponse], + userProvidedSchema: Optional[StructType]) { @transient private lazy val avroSchema = new AvroSchema.Parser().parse(rawAvroSchema) @@ -116,7 +124,8 @@ case class AvroConverter (bqSchema: Schema, bqSchema, columnsInOrder.asJava, avroSchema, - response.getAvroRows.getSerializedBinaryRows).asScala + response.getAvroRows.getSerializedBinaryRows, + userProvidedSchema).asScala } case class BigQueryPartition(stream: String, index: Int) extends Partition diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java index 4aa9481e5..6c1c757e5 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java @@ -21,6 +21,7 @@ import com.google.cloud.spark.bigquery.ReadRowsResponseToInternalRowIteratorConverter; import com.google.common.collect.ImmutableList; import com.google.protobuf.TextFormat; +import java.util.Optional; import org.apache.log4j.Logger; import org.apache.spark.sql.catalyst.InternalRow; import org.junit.Test; @@ -100,7 +101,8 @@ public void testReadAvro() throws Exception { ReadRowsResponseToInternalRowIteratorConverter.avro( ALL_TYPES_TABLE_BIGQUERY_SCHEMA, ALL_TYPES_TABLE_FIELDS, - ALL_TYPES_TABLE_AVRO_RAW_SCHEMA); + ALL_TYPES_TABLE_AVRO_RAW_SCHEMA, + Optional.empty()); BigQueryInputPartitionReader reader = new BigQueryInputPartitionReader(readRowsResponses, converter, null); diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/SchemaIteratorSuite.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/SchemaIteratorSuite.scala index 2116db960..60e8aaaf7 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/SchemaIteratorSuite.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/SchemaIteratorSuite.scala @@ -15,6 +15,7 @@ */ package com.google.cloud.spark.bigquery +import java.util.Optional import com.google.cloud.bigquery.Field.Mode import com.google.cloud.bigquery.LegacySQLTypeName.{BOOLEAN, BYTES, DATE, DATETIME, FLOAT, INTEGER, NUMERIC, RECORD, STRING, TIME, TIMESTAMP} import com.google.cloud.bigquery.{Field, Schema} @@ -76,14 +77,15 @@ class SchemaIteratorSuite extends FunSuite { val arrowBinaryIterator = new ArrowBinaryIterator(columnsInOrder.asJava, arrowSchema, - arrowByteString).asScala + arrowByteString, + Optional.empty()).asScala if (arrowBinaryIterator.hasNext) { arrowSparkRow = arrowBinaryIterator.next(); } val avroBinaryIterator = new AvroBinaryIterator(bqSchema, - columnsInOrder.asJava, avroSchema, avroByteString) + columnsInOrder.asJava, avroSchema, avroByteString, Optional.empty()) if (avroBinaryIterator.hasNext) { avroSparkRow = avroBinaryIterator.next() @@ -156,7 +158,7 @@ class SchemaIteratorSuite extends FunSuite { val arrowBinaryIterator = new ArrowBinaryIterator( - columnsInOrder.asJava, arrowSchema, arrowByteString).asScala + columnsInOrder.asJava, arrowSchema, arrowByteString, Optional.empty()).asScala while (arrowBinaryIterator.hasNext) { val arrowSparkRow = arrowBinaryIterator.next() diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala index a46d74292..da764bda4 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndReadITSuite.scala @@ -20,7 +20,7 @@ import com.google.cloud.spark.bigquery.TestUtils import com.google.cloud.spark.bigquery.direct.DirectBigQueryRelation import com.google.cloud.spark.bigquery.it.TestConstants._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} +import org.apache.spark.sql.{DataFrame, Encoders, SparkSession} import org.scalatest.concurrent.TimeLimits import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.time.SpanSugar._ @@ -83,6 +83,7 @@ class SparkBigQueryEndToEndReadITSuite extends FunSuite private val LARGE_TABLE_FIELD = "is_male" private val LARGE_TABLE_NUM_ROWS = 33271914L private val NON_EXISTENT_TABLE = "non-existent.non-existent.non-existent" + private val STRUCT_COLUMN_ORDER_TEST_TABLE_NAME = "struct_column_order" private val ALL_TYPES_TABLE_NAME = "all_types" private val ALL_TYPES_VIEW_NAME = "all_types_view" private var spark: SparkSession = _ @@ -143,6 +144,10 @@ class SparkBigQueryEndToEndReadITSuite extends FunSuite IntegrationTestUtils.runQuery( TestConstants.ALL_TYPES_TABLE_QUERY_TEMPLATE.format(s"$testDataset.$ALL_TYPES_TABLE_NAME")) IntegrationTestUtils.createView(testDataset, ALL_TYPES_TABLE_NAME, ALL_TYPES_VIEW_NAME) + IntegrationTestUtils.runQuery( + TestConstants + .STRUCT_COLUMN_ORDER_TEST_TABLE_QUERY_TEMPLATE + .format(s"$testDataset.$STRUCT_COLUMN_ORDER_TEST_TABLE_NAME")) } test("test filters") { @@ -357,6 +362,25 @@ class SparkBigQueryEndToEndReadITSuite extends FunSuite newBehaviourWords should equal(oldBehaviourWords) } + + test("column order of struct. DataSource %s. Data Format %s" + .format(dataSourceFormat, dataFormat)) { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + val schema = Encoders.product[ColumnOrderTestClass].schema + + val dataset = spark.read + .schema(schema) + .option("dataset", testDataset) + .option("table", STRUCT_COLUMN_ORDER_TEST_TABLE_NAME) + .format(dataSourceFormat) + .option("readDataFormat", dataFormat) + .load() + .as[ColumnOrderTestClass] + + val row = Seq(dataset.head())(0) + assert(row == STRUCT_COLUMN_ORDER_TEST_TABLE_COLS) + } } def getViewDataFrame(dataSourceFormat: String, dataFormat: String): DataFrame = diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/TestConstants.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/TestConstants.scala index 52c2a8d07..3b0e394b9 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/TestConstants.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/TestConstants.scala @@ -108,4 +108,45 @@ object TestConstants { array(lit(1), lit(2), lit(3)), array(struct(lit(1))) ) + + val STRUCT_COLUMN_ORDER_TEST_TABLE_QUERY_TEMPLATE = + """ + |create table %s ( + |str string, + |nums struct < + | num1 int64, + | num2 int64, + | num3 int64, + | string_struct_arr array > + | > + |) + |as + |select + |"outer_string" as str, + |struct( + | 1 as num1, + | 2 as num2, + | 3 as num3, + | [ + | struct("0:str1" as str1, "0:str2" as str2, "0:str3" as str3), + | struct("1:str1" as str1, "1:str2" as str2, "1:str3" as str3) + | ] as string_struct_arr + |) as nums + """.stripMargin + + val STRUCT_COLUMN_ORDER_TEST_TABLE_COLS = + ColumnOrderTestClass( + NumStruct( + 3, + 2, + 1, + List( + StringStruct("0:str3", "0:str1", "0:str2"), + StringStruct("1:str3", "1:str1", "1:str2") + )), + "outer_string") + + case class StringStruct(str3: String, str1: String, str2: String) + case class NumStruct(num3: Long, num2: Long, num1: Long, string_struct_arr: List[StringStruct]) + case class ColumnOrderTestClass(nums: NumStruct, str: String) } diff --git a/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java b/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java index 9baeea389..81238e9c5 100644 --- a/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java +++ b/connector/third_party/apache-spark/src/main/java/com/google/cloud/spark/bigquery/ArrowSchemaConverter.java @@ -15,6 +15,10 @@ */ package com.google.cloud.spark.bigquery; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import org.apache.arrow.memory.ArrowBuf; import java.time.LocalDateTime; import java.time.ZoneOffset; @@ -188,7 +192,7 @@ private static DataType fromArrowField(Field field) } - public ArrowSchemaConverter(ValueVector vector) { + public ArrowSchemaConverter(ValueVector vector, StructField userProvidedField) { super(fromArrowField(vector.getField())); @@ -214,14 +218,36 @@ public ArrowSchemaConverter(ValueVector vector) { accessor = new ArrowSchemaConverter.TimestampMicroTZVectorAccessor((TimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; - accessor = new ArrowSchemaConverter.ArrayAccessor(listVector); + accessor = new ArrowSchemaConverter.ArrayAccessor(listVector, userProvidedField); } else if (vector instanceof StructVector) { StructVector structVector = (StructVector) vector; accessor = new ArrowSchemaConverter.StructAccessor(structVector); - childColumns = new ArrowSchemaConverter[structVector.size()]; - for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowSchemaConverter(structVector.getVectorById(i)); + if(userProvidedField != null) { + List structList = + Arrays + .stream(((StructType)userProvidedField.dataType()).fields()) + .collect(Collectors.toList()); + + childColumns = new ArrowSchemaConverter[structList.size()]; + + Map valueVectorMap = + structVector + .getChildrenFromFields() + .stream() + .collect(Collectors.toMap(ValueVector::getName, valueVector -> valueVector)); + + for (int i = 0; i < childColumns.length; ++i) { + StructField structField = structList.get(i); + childColumns[i] = + new ArrowSchemaConverter(valueVectorMap.get(structField.name()), structField); + } + + } else { + childColumns = new ArrowSchemaConverter[structVector.size()]; + for (int i = 0; i < childColumns.length; ++i) { + childColumns[i] = new ArrowSchemaConverter(structVector.getVectorById(i), null); + } } } else { throw new UnsupportedOperationException(); @@ -510,10 +536,23 @@ private static class ArrayAccessor extends ArrowSchemaConverter.ArrowVectorAcces private final ListVector accessor; private final ArrowSchemaConverter arrayData; - ArrayAccessor(ListVector vector) { + ArrayAccessor(ListVector vector, StructField userProvidedField) { super(vector); this.accessor = vector; - this.arrayData = new ArrowSchemaConverter(vector.getDataVector()); + StructField structField = null; + + // this is to support Array of StructType/StructVector + if(userProvidedField != null) { + ArrayType arrayType = ((ArrayType)userProvidedField.dataType()); + structField = + new StructField( + vector.getDataVector().getName(), + arrayType.elementType(), + arrayType.containsNull(), + Metadata.empty());// safe to pass empty metadata as it is not used anywhere + } + + this.arrayData = new ArrowSchemaConverter(vector.getDataVector(), structField); } @Override