Skip to content

Commit

Permalink
Issue #367: Solved column order of struct bug (#391)
Browse files Browse the repository at this point in the history
Column order of struct variables doesn't need to be the same as that of BigQuery schema
  • Loading branch information
himanshukohli09 authored May 13, 2021
1 parent 73d1b8e commit 325e298
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
import com.google.cloud.bigquery.connector.common.ArrowUtil;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.apache.arrow.compression.CommonsCompressionFactory;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.slf4j.Logger;
Expand All @@ -44,9 +49,13 @@ public class ArrowBinaryIterator implements Iterator<InternalRow> {
ArrowReaderIterator arrowReaderIterator;
Iterator<InternalRow> currentIterator;
List<String> columnsInOrder;
Map<String, StructField> userProvidedFieldMap;

public ArrowBinaryIterator(
List<String> columnsInOrder, ByteString schema, ByteString rowsInBytes) {
List<String> columnsInOrder,
ByteString schema,
ByteString rowsInBytes,
Optional<StructType> userProvidedSchema) {
BufferAllocator allocator =
ArrowUtil.newRootAllocator(maxAllocation)
.newChildAllocator("ArrowBinaryIterator", 0, maxAllocation);
Expand All @@ -61,6 +70,16 @@ public ArrowBinaryIterator(
arrowReaderIterator = new ArrowReaderIterator(arrowStreamReader);
currentIterator = ImmutableList.<InternalRow>of().iterator();
this.columnsInOrder = columnsInOrder;

List<StructField> userProvidedFieldList =
Arrays
.stream(userProvidedSchema.orElse(new StructType()).fields())
.collect(Collectors.toList());

this.userProvidedFieldMap =
userProvidedFieldList
.stream()
.collect(Collectors.toMap(StructField::name, Function.identity()));
}

@Override
Expand All @@ -84,7 +103,9 @@ private Iterator<InternalRow> toArrowRows(VectorSchemaRoot root, List<String> na
ColumnVector[] columns =
namesInOrder.stream()
.map(name -> root.getVector(name))
.map(vector -> new ArrowSchemaConverter(vector))
.map(
vector ->
new ArrowSchemaConverter(vector, userProvidedFieldMap.get(vector.getName())))
.collect(Collectors.toList())
.toArray(new ColumnVector[0]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import com.google.cloud.bigquery.Schema;
import com.google.protobuf.ByteString;
import java.util.Optional;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -37,6 +39,7 @@ public class AvroBinaryIterator implements Iterator<InternalRow> {
List<String> columnsInOrder;
BinaryDecoder in;
Schema bqSchema;
Optional<StructType> userProvidedSchema;

/**
* An iterator for scanning over rows serialized in Avro format
Expand All @@ -50,11 +53,13 @@ public AvroBinaryIterator(
Schema bqSchema,
List<String> columnsInOrder,
org.apache.avro.Schema schema,
ByteString rowsInBytes) {
ByteString rowsInBytes,
Optional<StructType> userProvidedSchema) {
reader = new GenericDatumReader<GenericRecord>(schema);
this.bqSchema = bqSchema;
this.columnsInOrder = columnsInOrder;
in = new DecoderFactory().binaryDecoder(rowsInBytes.toByteArray(), null);
this.userProvidedSchema = userProvidedSchema;
}

@Override
Expand All @@ -70,7 +75,7 @@ public boolean hasNext() {
public InternalRow next() {
try {
return SchemaConverters.convertToInternalRow(
bqSchema, columnsInOrder, (GenericRecord) reader.read(null, in));
bqSchema, columnsInOrder, (GenericRecord) reader.read(null, in), userProvidedSchema);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,28 @@
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.protobuf.ByteString;
import org.apache.spark.sql.catalyst.InternalRow;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.apache.spark.sql.types.StructType;
import static com.google.common.base.Optional.fromJavaUtil;

public interface ReadRowsResponseToInternalRowIteratorConverter {

static ReadRowsResponseToInternalRowIteratorConverter avro(
final com.google.cloud.bigquery.Schema bqSchema,
final List<String> columnsInOrder,
final String rawAvroSchema) {
return new Avro(bqSchema, columnsInOrder, rawAvroSchema);
final String rawAvroSchema,
final Optional<StructType> userProvidedSchema) {
return new Avro(bqSchema, columnsInOrder, rawAvroSchema, fromJavaUtil(userProvidedSchema));
}

static ReadRowsResponseToInternalRowIteratorConverter arrow(
final List<String> columnsInOrder, final ByteString arrowSchema) {
return new Arrow(columnsInOrder, arrowSchema);
final List<String> columnsInOrder,
final ByteString arrowSchema,
final Optional<StructType> userProvidedSchema) {
return new Arrow(columnsInOrder, arrowSchema, fromJavaUtil(userProvidedSchema));
}

Iterator<InternalRow> convert(ReadRowsResponse response);
Expand All @@ -45,11 +50,17 @@ class Avro implements ReadRowsResponseToInternalRowIteratorConverter, Serializab
private final com.google.cloud.bigquery.Schema bqSchema;
private final List<String> columnsInOrder;
private final String rawAvroSchema;
private final com.google.common.base.Optional<StructType> userProvidedSchema;

public Avro(Schema bqSchema, List<String> columnsInOrder, String rawAvroSchema) {
public Avro(
Schema bqSchema,
List<String> columnsInOrder,
String rawAvroSchema,
com.google.common.base.Optional<StructType> userProvidedSchema) {
this.bqSchema = bqSchema;
this.columnsInOrder = columnsInOrder;
this.rawAvroSchema = rawAvroSchema;
this.userProvidedSchema = userProvidedSchema;
}

@Override
Expand All @@ -58,24 +69,33 @@ public Iterator<InternalRow> convert(ReadRowsResponse response) {
bqSchema,
columnsInOrder,
new org.apache.avro.Schema.Parser().parse(rawAvroSchema),
response.getAvroRows().getSerializedBinaryRows());
response.getAvroRows().getSerializedBinaryRows(),
userProvidedSchema.toJavaUtil());
}
}

class Arrow implements ReadRowsResponseToInternalRowIteratorConverter, Serializable {

private final List<String> columnsInOrder;
private final ByteString arrowSchema;
private final com.google.common.base.Optional<StructType> userProvidedSchema;

public Arrow(List<String> columnsInOrder, ByteString arrowSchema) {
public Arrow(
List<String> columnsInOrder,
ByteString arrowSchema,
com.google.common.base.Optional<StructType> userProvidedSchema) {
this.columnsInOrder = columnsInOrder;
this.arrowSchema = arrowSchema;
this.userProvidedSchema = userProvidedSchema;
}

@Override
public Iterator<InternalRow> convert(ReadRowsResponse response) {
return new ArrowBinaryIterator(
columnsInOrder, arrowSchema, response.getArrowRecordBatch().getSerializedRecordBatch());
columnsInOrder,
arrowSchema,
response.getArrowRecordBatch().getSerializedRecordBatch(),
userProvidedSchema.toJavaUtil());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.cloud.bigquery.TimePartitioning;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.function.Function;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.spark.sql.catalyst.InternalRow;
Expand Down Expand Up @@ -91,11 +92,19 @@ public static Schema getSchemaWithPseudoColumns(TableInfo tableInfo) {
}

public static InternalRow convertToInternalRow(
Schema schema, List<String> namesInOrder, GenericRecord record) {
return convertAll(schema.getFields(), record, namesInOrder);
Schema schema,
List<String> namesInOrder,
GenericRecord record,
Optional<StructType> userProvidedSchema) {
List<StructField> userProvidedFieldList =
Arrays
.stream(userProvidedSchema.orElse(new StructType()).fields())
.collect(Collectors.toList());

return convertAll(schema.getFields(), record, namesInOrder, userProvidedFieldList);
}

static Object convert(Field field, Object value) {
static Object convert(Field field, Object value, StructField userProvidedField) {
if (value == null) {
return null;
}
Expand All @@ -113,53 +122,82 @@ static Object convert(Field field, Object value) {
.build();

List<Object> valueList = (List<Object>) value;

return new GenericArrayData(
valueList.stream().map(v -> convert(nestedField, v)).collect(Collectors.toList()));
valueList
.stream()
.map(v -> convert(nestedField, v, getStructFieldForRepeatedMode(userProvidedField)))
.collect(Collectors.toList()));
}

Object datum = convertByBigQueryType(field, value);
Object datum = convertByBigQueryType(field, value, userProvidedField);
Optional<Object> customDatum =
getCustomDataType(field).map(dt -> ((UserDefinedType) dt).deserialize(datum));
return customDatum.orElse(datum);
}

static Object convertByBigQueryType(Field field, Object value) {
if (LegacySQLTypeName.INTEGER.equals(field.getType())
|| LegacySQLTypeName.FLOAT.equals(field.getType())
|| LegacySQLTypeName.BOOLEAN.equals(field.getType())
|| LegacySQLTypeName.DATE.equals(field.getType())
|| LegacySQLTypeName.TIME.equals(field.getType())
|| LegacySQLTypeName.TIMESTAMP.equals(field.getType())) {
private static StructField getStructFieldForRepeatedMode(StructField field) {
StructField nestedField = null;

if (field != null) {
ArrayType arrayType = ((ArrayType) field.dataType());
nestedField =
new StructField(
field.name(),
arrayType.elementType(),
arrayType.containsNull(),
Metadata.empty()); // safe to pass empty metadata as it is not used anywhere
}
return nestedField;
}

static Object convertByBigQueryType(Field bqField, Object value, StructField userProvidedField) {
if (LegacySQLTypeName.INTEGER.equals(bqField.getType())
|| LegacySQLTypeName.FLOAT.equals(bqField.getType())
|| LegacySQLTypeName.BOOLEAN.equals(bqField.getType())
|| LegacySQLTypeName.DATE.equals(bqField.getType())
|| LegacySQLTypeName.TIME.equals(bqField.getType())
|| LegacySQLTypeName.TIMESTAMP.equals(bqField.getType())) {
return value;
}

if (LegacySQLTypeName.STRING.equals(field.getType())
|| LegacySQLTypeName.DATETIME.equals(field.getType())
|| LegacySQLTypeName.GEOGRAPHY.equals(field.getType())) {
if (LegacySQLTypeName.STRING.equals(bqField.getType())
|| LegacySQLTypeName.DATETIME.equals(bqField.getType())
|| LegacySQLTypeName.GEOGRAPHY.equals(bqField.getType())) {
return UTF8String.fromBytes(((Utf8) value).getBytes());
}

if (LegacySQLTypeName.BYTES.equals(field.getType())) {
if (LegacySQLTypeName.BYTES.equals(bqField.getType())) {
return getBytes((ByteBuffer) value);
}

if (LegacySQLTypeName.NUMERIC.equals(field.getType())) {
if (LegacySQLTypeName.NUMERIC.equals(bqField.getType())) {
byte[] bytes = getBytes((ByteBuffer) value);
BigDecimal b = new BigDecimal(new BigInteger(bytes), BQ_NUMERIC_SCALE);
Decimal d = Decimal.apply(b, BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE);

return d;
}

if (LegacySQLTypeName.RECORD.equals(field.getType())) {
return convertAll(
field.getSubFields(),
(GenericRecord) value,
field.getSubFields().stream().map(f -> f.getName()).collect(Collectors.toList()));
if (LegacySQLTypeName.RECORD.equals(bqField.getType())) {
List<String> namesInOrder = null;
List<StructField> structList = null;

if (userProvidedField != null) {
structList =
Arrays
.stream(((StructType)userProvidedField.dataType()).fields())
.collect(Collectors.toList());

namesInOrder = structList.stream().map(StructField::name).collect(Collectors.toList());
} else {
namesInOrder =
bqField.getSubFields().stream().map(Field::getName).collect(Collectors.toList());
}

return convertAll(bqField.getSubFields(), (GenericRecord) value, namesInOrder, structList);
}

throw new IllegalStateException("Unexpected type: " + field.getType());
throw new IllegalStateException("Unexpected type: " + bqField.getType());
}

private static byte[] getBytes(ByteBuffer buf) {
Expand All @@ -171,13 +209,28 @@ private static byte[] getBytes(ByteBuffer buf) {

// Schema is not recursive so add helper for sequence of fields
static GenericInternalRow convertAll(
FieldList fieldList, GenericRecord record, List<String> namesInOrder) {

FieldList fieldList,
GenericRecord record,
List<String> namesInOrder,
List<StructField> userProvidedFieldList) {
Map<String, Object> fieldMap = new HashMap<>();

Map<String, StructField> userProvidedFieldMap =
userProvidedFieldList == null
? new HashMap<>()
: userProvidedFieldList
.stream()
.collect(Collectors.toMap(StructField::name, Function.identity()));

fieldList.stream()
.forEach(
field -> fieldMap.put(field.getName(), convert(field, record.get(field.getName()))));
field ->
fieldMap.put(
field.getName(),
convert(
field,
record.get(field.getName()),
userProvidedFieldMap.get(field.getName()))));

Object[] values = new Object[namesInOrder.size()];
for (int i = 0; i < namesInOrder.size(); i++) {
Expand Down
Loading

0 comments on commit 325e298

Please sign in to comment.