Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public static boolean isTimestamptz(Schema schema) {
return false;
}

static boolean isOptionSchema(Schema schema) {
public static boolean isOptionSchema(Schema schema) {
if (schema.getType() == UNION && schema.getTypes().size() == 2) {
if (schema.getTypes().get(0).getType() == Schema.Type.NULL) {
return true;
Expand Down Expand Up @@ -162,7 +162,7 @@ static Schema fromOptions(List<Schema> options) {
}
}

static boolean isKeyValueSchema(Schema schema) {
public static boolean isKeyValueSchema(Schema schema) {
return schema.getType() == RECORD && schema.getFields().size() == 2;
}

Expand Down
32 changes: 32 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/ValueWriters.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ public static ValueWriter<Boolean> booleans() {
return BooleanWriter.INSTANCE;
}

public static ValueWriter<Byte> tinyints() {
return ByteToIntegerWriter.INSTANCE;
}

public static ValueWriter<Short> shorts() {
return ShortToIntegerWriter.INSTANCE;
}

public static ValueWriter<Integer> ints() {
return IntegerWriter.INSTANCE;
}
Expand Down Expand Up @@ -142,6 +150,30 @@ public void write(Boolean bool, Encoder encoder) throws IOException {
}
}

private static class ByteToIntegerWriter implements ValueWriter<Byte> {
private static final ByteToIntegerWriter INSTANCE = new ByteToIntegerWriter();

private ByteToIntegerWriter() {
}

@Override
public void write(Byte b, Encoder encoder) throws IOException {
encoder.writeInt(b.intValue());
}
}

private static class ShortToIntegerWriter implements ValueWriter<Short> {
private static final ShortToIntegerWriter INSTANCE = new ShortToIntegerWriter();

private ShortToIntegerWriter() {
}

@Override
public void write(Short s, Encoder encoder) throws IOException {
encoder.writeInt(s.intValue());
}
}

private static class IntegerWriter implements ValueWriter<Integer> {
private static final IntegerWriter INSTANCE = new IntegerWriter();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ public ParquetValueWriter<?> primitive(PrimitiveType primitive) {
case INT_8:
case INT_16:
case INT_32:
return ParquetValueWriters.ints(desc);
case INT_64:
return ParquetValueWriters.unboxed(desc);
return ParquetValueWriters.longs(desc);
case DATE:
return new DateWriter(desc);
case TIME_MICROS:
Expand Down Expand Up @@ -162,11 +163,15 @@ public ParquetValueWriter<?> primitive(PrimitiveType primitive) {
case BINARY:
return ParquetValueWriters.byteBuffers(desc);
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
return ParquetValueWriters.ints(desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
return ParquetValueWriters.floats(desc);
case DOUBLE:
return ParquetValueWriters.unboxed(desc);
return ParquetValueWriters.doubles(desc);
default:
throw new UnsupportedOperationException("Unsupported type: " + primitive);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ public ParquetValueWriter<?> primitive(PrimitiveType primitive) {
case INT_8:
case INT_16:
case INT_32:
return ParquetValueWriters.ints(desc);
case INT_64:
case TIME_MICROS:
case TIMESTAMP_MICROS:
return ParquetValueWriters.unboxed(desc);
return ParquetValueWriters.longs(desc);
case DECIMAL:
DecimalMetadata decimal = primitive.getDecimalMetadata();
switch (primitive.getPrimitiveTypeName()) {
Expand Down Expand Up @@ -153,11 +154,15 @@ public ParquetValueWriter<?> primitive(PrimitiveType primitive) {
case BINARY:
return ParquetValueWriters.byteBuffers(desc);
case BOOLEAN:
return ParquetValueWriters.booleans(desc);
case INT32:
return ParquetValueWriters.ints(desc);
case INT64:
return ParquetValueWriters.longs(desc);
case FLOAT:
return ParquetValueWriters.floats(desc);
case DOUBLE:
return ParquetValueWriters.unboxed(desc);
return ParquetValueWriters.doubles(desc);
default:
throw new UnsupportedOperationException("Unsupported type: " + primitive);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,31 @@ public static <T> ParquetValueWriter<T> option(Type type,
return writer;
}

public static <T> UnboxedWriter<T> unboxed(ColumnDescriptor desc) {
public static UnboxedWriter<Boolean> booleans(ColumnDescriptor desc) {
return new UnboxedWriter<>(desc);
}

public static UnboxedWriter<Byte> tinyints(ColumnDescriptor desc) {
return new ByteWriter(desc);
}

public static UnboxedWriter<Short> shorts(ColumnDescriptor desc) {
return new ShortWriter(desc);
}

public static UnboxedWriter<Integer> ints(ColumnDescriptor desc) {
return new UnboxedWriter<>(desc);
}

public static UnboxedWriter<Long> longs(ColumnDescriptor desc) {
return new UnboxedWriter<>(desc);
}

public static UnboxedWriter<Float> floats(ColumnDescriptor desc) {
return new UnboxedWriter<>(desc);
}

public static UnboxedWriter<Double> doubles(ColumnDescriptor desc) {
return new UnboxedWriter<>(desc);
}

Expand Down Expand Up @@ -138,6 +162,28 @@ public void writeDouble(int repetitionLevel, double value) {
}
}

private static class ByteWriter extends UnboxedWriter<Byte> {
private ByteWriter(ColumnDescriptor desc) {
super(desc);
}

@Override
public void write(int repetitionLevel, Byte value) {
writeInteger(repetitionLevel, value.intValue());
}
}

private static class ShortWriter extends UnboxedWriter<Short> {
private ShortWriter(ColumnDescriptor desc) {
super(desc);
}

@Override
public void write(int repetitionLevel, Short value) {
writeInteger(repetitionLevel, value.intValue());
}
}

private static class IntegerDecimalWriter extends PrimitiveWriter<BigDecimal> {
private final int precision;
private final int scale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void tearDownBenchmark() {
@Threads(1)
public void writeUsingIcebergWriter() throws IOException {
try (FileAppender<InternalRow> writer = Parquet.write(Files.localOutput(dataFile))
.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SCHEMA, msgType))
.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType))
.schema(SCHEMA)
.build()) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public void tearDownBenchmark() {
@Threads(1)
public void writeUsingIcebergWriter() throws IOException {
try (FileAppender<InternalRow> writer = Parquet.write(Files.localOutput(dataFile))
.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SCHEMA, msgType))
.createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType))
.schema(SCHEMA)
.build()) {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iceberg.spark.data;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.Deque;
import java.util.List;
import org.apache.avro.Schema;
import org.apache.iceberg.avro.AvroSchemaUtil;
import org.apache.iceberg.avro.LogicalMap;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public abstract class AvroWithSparkSchemaVisitor<T> {
public static <T> T visit(StructType struct, Schema schema, AvroWithSparkSchemaVisitor<T> visitor) {
return visitRecord(struct, schema, visitor);
}

public static <T> T visit(DataType type, Schema schema, AvroWithSparkSchemaVisitor<T> visitor) {
switch (schema.getType()) {
case RECORD:
Preconditions.checkArgument(type instanceof StructType, "Invalid struct: %s is not a struct", type);
return visitRecord((StructType) type, schema, visitor);

case UNION:
return visitUnion(type, schema, visitor);

case ARRAY:
return visitArray(type, schema, visitor);

case MAP:
Preconditions.checkArgument(type instanceof MapType, "Invalid map: %s is not a map", type);
MapType map = (MapType) type;
Preconditions.checkArgument(map.keyType() instanceof StringType,
"Invalid map: %s is not a string", map.keyType());
return visitor.map(map, schema, visit(map.valueType(), schema.getValueType(), visitor));

default:
return visitor.primitive(type, schema);
}
}

private static <T> T visitRecord(StructType struct, Schema record, AvroWithSparkSchemaVisitor<T> visitor) {
// check to make sure this hasn't been visited before
String name = record.getFullName();
Preconditions.checkState(!visitor.recordLevels.contains(name),
"Cannot process recursive Avro record %s", name);
StructField[] sFields = struct.fields();
List<Schema.Field> fields = record.getFields();
Preconditions.checkArgument(sFields.length == fields.size(),
"Structs do not match: %s != %s", struct, record);

visitor.recordLevels.push(name);

List<String> names = Lists.newArrayListWithExpectedSize(fields.size());
List<T> results = Lists.newArrayListWithExpectedSize(fields.size());
for (int i = 0; i < sFields.length; i += 1) {
StructField sField = sFields[i];
Schema.Field field = fields.get(i);
Preconditions.checkArgument(AvroSchemaUtil.makeCompatibleName(sField.name()).equals(field.name()),
"Structs do not match: field %s != %s", sField.name(), field.name());
results.add(visit(sField.dataType(), field.schema(), visitor));
}

visitor.recordLevels.pop();

return visitor.record(struct, record, names, results);
}

private static <T> T visitUnion(DataType type, Schema union, AvroWithSparkSchemaVisitor<T> visitor) {
List<Schema> types = union.getTypes();
Preconditions.checkArgument(AvroSchemaUtil.isOptionSchema(union),
"Cannot visit non-option union: %s", union);
List<T> options = Lists.newArrayListWithExpectedSize(types.size());
for (Schema branch : types) {
if (branch.getType() == Schema.Type.NULL) {
options.add(visit(DataTypes.NullType, branch, visitor));
} else {
options.add(visit(type, branch, visitor));
}
}
return visitor.union(type, union, options);
}

private static <T> T visitArray(DataType type, Schema array, AvroWithSparkSchemaVisitor<T> visitor) {
if (array.getLogicalType() instanceof LogicalMap || type instanceof MapType) {
Preconditions.checkState(
AvroSchemaUtil.isKeyValueSchema(array.getElementType()),
"Cannot visit invalid logical map type: %s", array);
Preconditions.checkArgument(type instanceof MapType, "Invalid map: %s is not a map", type);
MapType map = (MapType) type;
List<Schema.Field> keyValueFields = array.getElementType().getFields();
return visitor.map(map, array,
visit(map.keyType(), keyValueFields.get(0).schema(), visitor),
visit(map.valueType(), keyValueFields.get(1).schema(), visitor));

} else {
Preconditions.checkArgument(type instanceof ArrayType, "Invalid array: %s is not an array", type);
ArrayType list = (ArrayType) type;
return visitor.array(list, array, visit(list.elementType(), array.getElementType(), visitor));
}
}

private Deque<String> recordLevels = Lists.newLinkedList();

public T record(StructType struct, Schema record, List<String> names, List<T> fields) {
return null;
}

public T union(DataType type, Schema union, List<T> options) {
return null;
}

public T array(ArrayType sArray, Schema array, T element) {
return null;
}

public T map(MapType sMap, Schema map, T key, T value) {
return null;
}

public T map(MapType sMap, Schema map, T value) {
return null;
}

public T primitive(DataType type, Schema primitive) {
return null;
}
}
Loading