diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java index 8c79fd13e55b..6b3a72027e01 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java @@ -119,7 +119,7 @@ public static boolean isTimestamptz(Schema schema) { return false; } - static boolean isOptionSchema(Schema schema) { + public static boolean isOptionSchema(Schema schema) { if (schema.getType() == UNION && schema.getTypes().size() == 2) { if (schema.getTypes().get(0).getType() == Schema.Type.NULL) { return true; @@ -162,7 +162,7 @@ static Schema fromOptions(List options) { } } - static boolean isKeyValueSchema(Schema schema) { + public static boolean isKeyValueSchema(Schema schema) { return schema.getType() == RECORD && schema.getFields().size() == 2; } diff --git a/core/src/main/java/org/apache/iceberg/avro/ValueWriters.java b/core/src/main/java/org/apache/iceberg/avro/ValueWriters.java index 02a40908cc5e..d41f73a17aa1 100644 --- a/core/src/main/java/org/apache/iceberg/avro/ValueWriters.java +++ b/core/src/main/java/org/apache/iceberg/avro/ValueWriters.java @@ -48,6 +48,14 @@ public static ValueWriter booleans() { return BooleanWriter.INSTANCE; } + public static ValueWriter tinyints() { + return ByteToIntegerWriter.INSTANCE; + } + + public static ValueWriter shorts() { + return ShortToIntegerWriter.INSTANCE; + } + public static ValueWriter ints() { return IntegerWriter.INSTANCE; } @@ -142,6 +150,30 @@ public void write(Boolean bool, Encoder encoder) throws IOException { } } + private static class ByteToIntegerWriter implements ValueWriter { + private static final ByteToIntegerWriter INSTANCE = new ByteToIntegerWriter(); + + private ByteToIntegerWriter() { + } + + @Override + public void write(Byte b, Encoder encoder) throws IOException { + encoder.writeInt(b.intValue()); + } + } + + private static class ShortToIntegerWriter implements ValueWriter { + private static final ShortToIntegerWriter INSTANCE = new ShortToIntegerWriter(); + + private ShortToIntegerWriter() { + } + + @Override + public void write(Short s, Encoder encoder) throws IOException { + encoder.writeInt(s.intValue()); + } + } + private static class IntegerWriter implements ValueWriter { private static final IntegerWriter INSTANCE = new IntegerWriter(); diff --git a/data/src/main/java/org/apache/iceberg/data/parquet/GenericParquetWriter.java b/data/src/main/java/org/apache/iceberg/data/parquet/GenericParquetWriter.java index d20db91af8c4..a675a557580d 100644 --- a/data/src/main/java/org/apache/iceberg/data/parquet/GenericParquetWriter.java +++ b/data/src/main/java/org/apache/iceberg/data/parquet/GenericParquetWriter.java @@ -126,8 +126,9 @@ public ParquetValueWriter primitive(PrimitiveType primitive) { case INT_8: case INT_16: case INT_32: + return ParquetValueWriters.ints(desc); case INT_64: - return ParquetValueWriters.unboxed(desc); + return ParquetValueWriters.longs(desc); case DATE: return new DateWriter(desc); case TIME_MICROS: @@ -162,11 +163,15 @@ public ParquetValueWriter primitive(PrimitiveType primitive) { case BINARY: return ParquetValueWriters.byteBuffers(desc); case BOOLEAN: + return ParquetValueWriters.booleans(desc); case INT32: + return ParquetValueWriters.ints(desc); case INT64: + return ParquetValueWriters.longs(desc); case FLOAT: + return ParquetValueWriters.floats(desc); case DOUBLE: - return ParquetValueWriters.unboxed(desc); + return ParquetValueWriters.doubles(desc); default: throw new UnsupportedOperationException("Unsupported type: " + primitive); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetAvroWriter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetAvroWriter.java index a900669cd09a..c5cd97774565 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetAvroWriter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetAvroWriter.java @@ -118,10 +118,11 @@ public ParquetValueWriter primitive(PrimitiveType primitive) { case INT_8: case INT_16: case INT_32: + return ParquetValueWriters.ints(desc); case INT_64: case TIME_MICROS: case TIMESTAMP_MICROS: - return ParquetValueWriters.unboxed(desc); + return ParquetValueWriters.longs(desc); case DECIMAL: DecimalMetadata decimal = primitive.getDecimalMetadata(); switch (primitive.getPrimitiveTypeName()) { @@ -153,11 +154,15 @@ public ParquetValueWriter primitive(PrimitiveType primitive) { case BINARY: return ParquetValueWriters.byteBuffers(desc); case BOOLEAN: + return ParquetValueWriters.booleans(desc); case INT32: + return ParquetValueWriters.ints(desc); case INT64: + return ParquetValueWriters.longs(desc); case FLOAT: + return ParquetValueWriters.floats(desc); case DOUBLE: - return ParquetValueWriters.unboxed(desc); + return ParquetValueWriters.doubles(desc); default: throw new UnsupportedOperationException("Unsupported type: " + primitive); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueWriters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueWriters.java index 18f1410e7ed9..8de2e7201d15 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueWriters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueWriters.java @@ -49,7 +49,31 @@ public static ParquetValueWriter option(Type type, return writer; } - public static UnboxedWriter unboxed(ColumnDescriptor desc) { + public static UnboxedWriter booleans(ColumnDescriptor desc) { + return new UnboxedWriter<>(desc); + } + + public static UnboxedWriter tinyints(ColumnDescriptor desc) { + return new ByteWriter(desc); + } + + public static UnboxedWriter shorts(ColumnDescriptor desc) { + return new ShortWriter(desc); + } + + public static UnboxedWriter ints(ColumnDescriptor desc) { + return new UnboxedWriter<>(desc); + } + + public static UnboxedWriter longs(ColumnDescriptor desc) { + return new UnboxedWriter<>(desc); + } + + public static UnboxedWriter floats(ColumnDescriptor desc) { + return new UnboxedWriter<>(desc); + } + + public static UnboxedWriter doubles(ColumnDescriptor desc) { return new UnboxedWriter<>(desc); } @@ -138,6 +162,28 @@ public void writeDouble(int repetitionLevel, double value) { } } + private static class ByteWriter extends UnboxedWriter { + private ByteWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, Byte value) { + writeInteger(repetitionLevel, value.intValue()); + } + } + + private static class ShortWriter extends UnboxedWriter { + private ShortWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, Short value) { + writeInteger(repetitionLevel, value.intValue()); + } + } + private static class IntegerDecimalWriter extends PrimitiveWriter { private final int precision; private final int scale; diff --git a/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java b/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java index 8cb5b07020e5..e8b62d97ea78 100644 --- a/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java +++ b/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersFlatDataBenchmark.java @@ -95,7 +95,7 @@ public void tearDownBenchmark() { @Threads(1) public void writeUsingIcebergWriter() throws IOException { try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) - .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SCHEMA, msgType)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) .schema(SCHEMA) .build()) { diff --git a/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java b/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java index dd395f519916..73d9a5f5140d 100644 --- a/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java +++ b/spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetWritersNestedDataBenchmark.java @@ -94,7 +94,7 @@ public void tearDownBenchmark() { @Threads(1) public void writeUsingIcebergWriter() throws IOException { try (FileAppender writer = Parquet.write(Files.localOutput(dataFile)) - .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SCHEMA, msgType)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(SCHEMA), msgType)) .schema(SCHEMA) .build()) { diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java b/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..7f01e10d2969 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/AvroWithSparkSchemaVisitor.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import java.util.Deque; +import java.util.List; +import org.apache.avro.Schema; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.avro.LogicalMap; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public abstract class AvroWithSparkSchemaVisitor { + public static T visit(StructType struct, Schema schema, AvroWithSparkSchemaVisitor visitor) { + return visitRecord(struct, schema, visitor); + } + + public static T visit(DataType type, Schema schema, AvroWithSparkSchemaVisitor visitor) { + switch (schema.getType()) { + case RECORD: + Preconditions.checkArgument(type instanceof StructType, "Invalid struct: %s is not a struct", type); + return visitRecord((StructType) type, schema, visitor); + + case UNION: + return visitUnion(type, schema, visitor); + + case ARRAY: + return visitArray(type, schema, visitor); + + case MAP: + Preconditions.checkArgument(type instanceof MapType, "Invalid map: %s is not a map", type); + MapType map = (MapType) type; + Preconditions.checkArgument(map.keyType() instanceof StringType, + "Invalid map: %s is not a string", map.keyType()); + return visitor.map(map, schema, visit(map.valueType(), schema.getValueType(), visitor)); + + default: + return visitor.primitive(type, schema); + } + } + + private static T visitRecord(StructType struct, Schema record, AvroWithSparkSchemaVisitor visitor) { + // check to make sure this hasn't been visited before + String name = record.getFullName(); + Preconditions.checkState(!visitor.recordLevels.contains(name), + "Cannot process recursive Avro record %s", name); + StructField[] sFields = struct.fields(); + List fields = record.getFields(); + Preconditions.checkArgument(sFields.length == fields.size(), + "Structs do not match: %s != %s", struct, record); + + visitor.recordLevels.push(name); + + List names = Lists.newArrayListWithExpectedSize(fields.size()); + List results = Lists.newArrayListWithExpectedSize(fields.size()); + for (int i = 0; i < sFields.length; i += 1) { + StructField sField = sFields[i]; + Schema.Field field = fields.get(i); + Preconditions.checkArgument(AvroSchemaUtil.makeCompatibleName(sField.name()).equals(field.name()), + "Structs do not match: field %s != %s", sField.name(), field.name()); + results.add(visit(sField.dataType(), field.schema(), visitor)); + } + + visitor.recordLevels.pop(); + + return visitor.record(struct, record, names, results); + } + + private static T visitUnion(DataType type, Schema union, AvroWithSparkSchemaVisitor visitor) { + List types = union.getTypes(); + Preconditions.checkArgument(AvroSchemaUtil.isOptionSchema(union), + "Cannot visit non-option union: %s", union); + List options = Lists.newArrayListWithExpectedSize(types.size()); + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit(DataTypes.NullType, branch, visitor)); + } else { + options.add(visit(type, branch, visitor)); + } + } + return visitor.union(type, union, options); + } + + private static T visitArray(DataType type, Schema array, AvroWithSparkSchemaVisitor visitor) { + if (array.getLogicalType() instanceof LogicalMap || type instanceof MapType) { + Preconditions.checkState( + AvroSchemaUtil.isKeyValueSchema(array.getElementType()), + "Cannot visit invalid logical map type: %s", array); + Preconditions.checkArgument(type instanceof MapType, "Invalid map: %s is not a map", type); + MapType map = (MapType) type; + List keyValueFields = array.getElementType().getFields(); + return visitor.map(map, array, + visit(map.keyType(), keyValueFields.get(0).schema(), visitor), + visit(map.valueType(), keyValueFields.get(1).schema(), visitor)); + + } else { + Preconditions.checkArgument(type instanceof ArrayType, "Invalid array: %s is not an array", type); + ArrayType list = (ArrayType) type; + return visitor.array(list, array, visit(list.elementType(), array.getElementType(), visitor)); + } + } + + private Deque recordLevels = Lists.newLinkedList(); + + public T record(StructType struct, Schema record, List names, List fields) { + return null; + } + + public T union(DataType type, Schema union, List options) { + return null; + } + + public T array(ArrayType sArray, Schema array, T element) { + return null; + } + + public T map(MapType sMap, Schema map, T key, T value) { + return null; + } + + public T map(MapType sMap, Schema map, T value) { + return null; + } + + public T primitive(DataType type, Schema primitive) { + return null; + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java b/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java new file mode 100644 index 000000000000..111356dbc479 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/ParquetWithSparkSchemaVisitor.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import java.util.Deque; +import java.util.List; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Type.Repetition; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Visitor for traversing a Parquet type with a companion Spark type. + * + * @param the Java class returned by the visitor + */ +public class ParquetWithSparkSchemaVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(DataType sType, Type type, ParquetWithSparkSchemaVisitor visitor) { + Preconditions.checkArgument(sType != null, "Invalid DataType: null"); + if (type instanceof MessageType) { + Preconditions.checkArgument(sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.message(struct, (MessageType) type, visitFields(struct, type.asGroupType(), visitor)); + + } else if (type.isPrimitive()) { + return visitor.primitive(sType, type.asPrimitiveType()); + + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument(!group.isRepetition(Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + Preconditions.checkArgument(repeatedElement.isRepetition(Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + Preconditions.checkArgument(repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: %s", group); + + Preconditions.checkArgument(sType instanceof ArrayType, "Invalid list: %s is not an array", sType); + ArrayType array = (ArrayType) sType; + StructField element = new StructField( + "element", array.elementType(), array.containsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(element, repeatedElement.getType(0), visitor); + } + + return visitor.list(array, group, elementResult); + + } finally { + visitor.fieldNames.pop(); + } + + case MAP: + Preconditions.checkArgument(!group.isRepetition(Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument(repeatedKeyValue.isRepetition(Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Preconditions.checkArgument(sType instanceof MapType, "Invalid map: %s is not a map", sType); + MapType map = (MapType) sType; + StructField keyField = new StructField("key", map.keyType(), false, Metadata.empty()); + StructField valueField = new StructField( + "value", map.valueType(), map.valueContainsNull(), Metadata.empty()); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + + Preconditions.checkArgument(sType instanceof StructType, "Invalid struct: %s is not a struct", sType); + StructType struct = (StructType) sType; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitField(StructField sField, Type field, ParquetWithSparkSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(sField.dataType(), field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields(StructType struct, GroupType group, + ParquetWithSparkSchemaVisitor visitor) { + StructField[] sFields = struct.fields(); + Preconditions.checkArgument(sFields.length == group.getFieldCount(), + "Structs do not match: %s and %s", struct, group); + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (int i = 0; i < sFields.length; i += 1) { + Type field = group.getFields().get(i); + StructField sField = sFields[i]; + Preconditions.checkArgument(field.getName().equals(AvroSchemaUtil.makeCompatibleName(sField.name())), + "Structs do not match: field %s != %s", field.getName(), sField.name()); + results.add(visitField(sField, field, visitor)); + } + + return results; + } + + public T message(StructType sStruct, MessageType message, List fields) { + return null; + } + + public T struct(StructType sStruct, GroupType struct, List fields) { + return null; + } + + public T list(ArrayType sArray, GroupType array, T element) { + return null; + } + + public T map(MapType sMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(DataType sPrimitive, PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java index a24c205fb7ac..bc457be4b852 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroWriter.java @@ -20,36 +20,39 @@ package org.apache.iceberg.spark.data; import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import java.io.IOException; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; import org.apache.avro.io.DatumWriter; import org.apache.avro.io.Encoder; -import org.apache.iceberg.avro.AvroSchemaUtil; -import org.apache.iceberg.avro.AvroSchemaVisitor; import org.apache.iceberg.avro.ValueWriter; import org.apache.iceberg.avro.ValueWriters; -import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.types.Type; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; public class SparkAvroWriter implements DatumWriter { - private final org.apache.iceberg.Schema schema; + private final StructType dsSchema; private ValueWriter writer = null; - public SparkAvroWriter(org.apache.iceberg.Schema schema) { - this.schema = schema; + public SparkAvroWriter(StructType dsSchema) { + this.dsSchema = dsSchema; } @Override @SuppressWarnings("unchecked") public void setSchema(Schema schema) { - this.writer = (ValueWriter) AvroSchemaVisitor.visit(schema, new WriteBuilder(this.schema)); + this.writer = (ValueWriter) AvroWithSparkSchemaVisitor + .visit(dsSchema, schema, new WriteBuilder()); } @Override @@ -57,26 +60,15 @@ public void write(InternalRow datum, Encoder out) throws IOException { writer.write(datum, out); } - private static class WriteBuilder extends AvroSchemaVisitor> { - private final org.apache.iceberg.Schema schema; - - private WriteBuilder(org.apache.iceberg.Schema schema) { - this.schema = schema; - } - + private static class WriteBuilder extends AvroWithSparkSchemaVisitor> { @Override - public ValueWriter record(Schema record, List names, List> fields) { - List types = Lists.newArrayList(); - for (Schema.Field field : record.getFields()) { - int fieldId = AvroSchemaUtil.getFieldId(field); - Type fieldType = schema.findType(fieldId); - types.add(SparkSchemaUtil.convert(fieldType)); - } + public ValueWriter record(StructType struct, Schema record, List names, List> fields) { + List types = Stream.of(struct.fields()).map(StructField::dataType).collect(Collectors.toList()); return SparkValueWriters.struct(fields, types); } @Override - public ValueWriter union(Schema union, List> options) { + public ValueWriter union(DataType type, Schema union, List> options) { Preconditions.checkArgument(options.contains(ValueWriters.nulls()), "Cannot create writer for non-option union: %s", union); Preconditions.checkArgument(options.size() == 2, @@ -89,33 +81,22 @@ public ValueWriter union(Schema union, List> options) { } @Override - public ValueWriter array(Schema array, ValueWriter elementWriter) { - LogicalType logical = array.getLogicalType(); - if (logical != null && "map".equals(logical.getName())) { - int keyFieldId = AvroSchemaUtil.getFieldId(array.getElementType().getField("key")); - Type keyType = schema.findType(keyFieldId); - int valueFieldId = AvroSchemaUtil.getFieldId(array.getElementType().getField("value")); - Type valueType = schema.findType(valueFieldId); - ValueWriter[] writers = ((SparkValueWriters.StructWriter) elementWriter).writers(); - return SparkValueWriters.arrayMap( - writers[0], SparkSchemaUtil.convert(keyType), writers[1], SparkSchemaUtil.convert(valueType)); - } + public ValueWriter array(ArrayType sArray, Schema array, ValueWriter elementWriter) { + return SparkValueWriters.array(elementWriter, sArray.elementType()); + } - Type elementType = schema.findType(AvroSchemaUtil.getElementId(array)); - return SparkValueWriters.array(elementWriter, SparkSchemaUtil.convert(elementType)); + @Override + public ValueWriter map(MapType sMap, Schema map, ValueWriter valueReader) { + return SparkValueWriters.map(SparkValueWriters.strings(), sMap.keyType(), valueReader, sMap.valueType()); } @Override - public ValueWriter map(Schema map, ValueWriter valueReader) { - Type keyType = schema.findType(AvroSchemaUtil.getKeyId(map)); - Type valueType = schema.findType(AvroSchemaUtil.getValueId(map)); - ValueWriter writer = SparkValueWriters.strings(); - return SparkValueWriters.map( - writer, SparkSchemaUtil.convert(keyType), valueReader, SparkSchemaUtil.convert(valueType)); + public ValueWriter map(MapType sMap, Schema map, ValueWriter keyWriter, ValueWriter valueWriter) { + return SparkValueWriters.arrayMap(keyWriter, sMap.keyType(), valueWriter, sMap.valueType()); } @Override - public ValueWriter primitive(Schema primitive) { + public ValueWriter primitive(DataType type, Schema primitive) { LogicalType logicalType = primitive.getLogicalType(); if (logicalType != null) { switch (logicalType.getName()) { @@ -145,6 +126,11 @@ public ValueWriter primitive(Schema primitive) { case BOOLEAN: return ValueWriters.booleans(); case INT: + if (type instanceof ByteType) { + return ValueWriters.tinyints(); + } else if (type instanceof ShortType) { + return ValueWriters.shorts(); + } return ValueWriters.ints(); case LONG: return ValueWriters.longs(); diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java index 52ebd823335a..99c957c5277a 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetWriters.java @@ -26,15 +26,12 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; -import org.apache.iceberg.Schema; -import org.apache.iceberg.parquet.ParquetTypeVisitor; import org.apache.iceberg.parquet.ParquetValueReaders.ReusableEntry; import org.apache.iceberg.parquet.ParquetValueWriter; import org.apache.iceberg.parquet.ParquetValueWriters; import org.apache.iceberg.parquet.ParquetValueWriters.PrimitiveWriter; import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedKeyValueWriter; import org.apache.iceberg.parquet.ParquetValueWriters.RepeatedWriter; -import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.types.TypeUtil; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.io.api.Binary; @@ -46,8 +43,14 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; public class SparkParquetWriters { @@ -55,84 +58,73 @@ private SparkParquetWriters() { } @SuppressWarnings("unchecked") - public static ParquetValueWriter buildWriter(Schema schema, MessageType type) { - return (ParquetValueWriter) ParquetTypeVisitor.visit(type, new WriteBuilder(schema, type)); + public static ParquetValueWriter buildWriter(StructType dfSchema, MessageType type) { + return (ParquetValueWriter) ParquetWithSparkSchemaVisitor.visit(dfSchema, type, new WriteBuilder(type)); } - private static class WriteBuilder extends ParquetTypeVisitor> { - private final Schema schema; + private static class WriteBuilder extends ParquetWithSparkSchemaVisitor> { private final MessageType type; - WriteBuilder(Schema schema, MessageType type) { - this.schema = schema; + WriteBuilder(MessageType type) { this.type = type; } @Override - public ParquetValueWriter message(MessageType message, + public ParquetValueWriter message(StructType sStruct, MessageType message, List> fieldWriters) { - return struct(message.asGroupType(), fieldWriters); + return struct(sStruct, message.asGroupType(), fieldWriters); } @Override - public ParquetValueWriter struct(GroupType struct, + public ParquetValueWriter struct(StructType sStruct, GroupType struct, List> fieldWriters) { List fields = struct.getFields(); + StructField[] sparkFields = sStruct.fields(); List> writers = Lists.newArrayListWithExpectedSize(fieldWriters.size()); List sparkTypes = Lists.newArrayList(); for (int i = 0; i < fields.size(); i += 1) { - Type fieldType = struct.getType(i); - int fieldD = type.getMaxDefinitionLevel(path(fieldType.getName())); - writers.add(ParquetValueWriters.option(fieldType, fieldD, fieldWriters.get(i))); - sparkTypes.add(SparkSchemaUtil.convert(schema.findType(fieldType.getId().intValue()))); + writers.add(newOption(struct.getType(i), fieldWriters.get(i))); + sparkTypes.add(sparkFields[i].dataType()); } return new InternalRowWriter(writers, sparkTypes); } @Override - public ParquetValueWriter list(GroupType array, ParquetValueWriter elementWriter) { + public ParquetValueWriter list(ArrayType sArray, GroupType array, ParquetValueWriter elementWriter) { GroupType repeated = array.getFields().get(0).asGroupType(); String[] repeatedPath = currentPath(); int repeatedD = type.getMaxDefinitionLevel(repeatedPath); int repeatedR = type.getMaxRepetitionLevel(repeatedPath); - org.apache.parquet.schema.Type elementType = repeated.getType(0); - int elementD = type.getMaxDefinitionLevel(path(elementType.getName())); - - DataType elementSparkType = SparkSchemaUtil.convert(schema.findType(elementType.getId().intValue())); - return new ArrayDataWriter<>(repeatedD, repeatedR, - ParquetValueWriters.option(elementType, elementD, elementWriter), - elementSparkType); + newOption(repeated.getType(0), elementWriter), + sArray.elementType()); } @Override - public ParquetValueWriter map(GroupType map, - ParquetValueWriter keyWriter, - ParquetValueWriter valueWriter) { + public ParquetValueWriter map(MapType sMap, GroupType map, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter) { GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); String[] repeatedPath = currentPath(); int repeatedD = type.getMaxDefinitionLevel(repeatedPath); int repeatedR = type.getMaxRepetitionLevel(repeatedPath); - org.apache.parquet.schema.Type keyType = repeatedKeyValue.getType(0); - int keyD = type.getMaxDefinitionLevel(path(keyType.getName())); - DataType keySparkType = SparkSchemaUtil.convert(schema.findType(keyType.getId().intValue())); - org.apache.parquet.schema.Type valueType = repeatedKeyValue.getType(1); - int valueD = type.getMaxDefinitionLevel(path(valueType.getName())); - DataType valueSparkType = SparkSchemaUtil.convert(schema.findType(valueType.getId().intValue())); - return new MapDataWriter<>(repeatedD, repeatedR, - ParquetValueWriters.option(keyType, keyD, keyWriter), - ParquetValueWriters.option(valueType, valueD, valueWriter), - keySparkType, valueSparkType); + newOption(repeatedKeyValue.getType(0), keyWriter), + newOption(repeatedKeyValue.getType(1), valueWriter), + sMap.keyType(), sMap.valueType()); + } + + private ParquetValueWriter newOption(org.apache.parquet.schema.Type fieldType, ParquetValueWriter writer) { + int maxD = type.getMaxDefinitionLevel(path(fieldType.getName())); + return ParquetValueWriters.option(fieldType, maxD, writer); } @Override - public ParquetValueWriter primitive(PrimitiveType primitive) { + public ParquetValueWriter primitive(DataType sType, PrimitiveType primitive) { ColumnDescriptor desc = type.getColumnDescription(currentPath()); if (primitive.getOriginalType() != null) { @@ -145,10 +137,11 @@ public ParquetValueWriter primitive(PrimitiveType primitive) { case INT_8: case INT_16: case INT_32: + return ints(sType, desc); case INT_64: case TIME_MICROS: case TIMESTAMP_MICROS: - return ParquetValueWriters.unboxed(desc); + return ParquetValueWriters.longs(desc); case DECIMAL: DecimalMetadata decimal = primitive.getDecimalMetadata(); switch (primitive.getPrimitiveTypeName()) { @@ -176,17 +169,30 @@ public ParquetValueWriter primitive(PrimitiveType primitive) { case BINARY: return byteArrays(desc); case BOOLEAN: + return ParquetValueWriters.booleans(desc); case INT32: + return ints(sType, desc); case INT64: + return ParquetValueWriters.longs(desc); case FLOAT: + return ParquetValueWriters.floats(desc); case DOUBLE: - return ParquetValueWriters.unboxed(desc); + return ParquetValueWriters.doubles(desc); default: throw new UnsupportedOperationException("Unsupported type: " + primitive); } } } + private static PrimitiveWriter ints(DataType type, ColumnDescriptor desc) { + if (type instanceof ByteType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof ShortType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + private static PrimitiveWriter utf8Strings(ColumnDescriptor desc) { return new UTF8StringWriter(desc); } diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java b/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java index 18c5aeaf6e16..aa7ce09f4fe1 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/IcebergSource.java @@ -100,8 +100,8 @@ public Optional createWriter(String jobId, StructType dsStruct "Save mode %s is not supported", mode); Configuration conf = new Configuration(lazyBaseConf()); Table table = getTableAndResolveHadoopConfiguration(options, conf); - Schema dsSchema = SparkSchemaUtil.convert(table.schema(), dsStruct); - validateWriteSchema(table.schema(), dsSchema, checkNullability(options), checkOrdering(options)); + Schema writeSchema = SparkSchemaUtil.convert(table.schema(), dsStruct); + validateWriteSchema(table.schema(), writeSchema, checkNullability(options), checkOrdering(options)); validatePartitionTransforms(table.spec()); String appId = lazySparkSession().sparkContext().applicationId(); String wapId = lazySparkSession().conf().get("spark.wap.id", null); @@ -110,7 +110,8 @@ public Optional createWriter(String jobId, StructType dsStruct Broadcast io = lazySparkContext().broadcast(fileIO(table)); Broadcast encryptionManager = lazySparkContext().broadcast(table.encryption()); - return Optional.of(new Writer(table, io, encryptionManager, options, replacePartitions, appId, wapId, dsSchema)); + return Optional.of(new Writer( + table, io, encryptionManager, options, replacePartitions, appId, wapId, writeSchema, dsStruct)); } @Override @@ -121,8 +122,8 @@ public StreamWriter createStreamWriter(String runId, StructType dsStruct, "Output mode %s is not supported", mode); Configuration conf = new Configuration(lazyBaseConf()); Table table = getTableAndResolveHadoopConfiguration(options, conf); - Schema dsSchema = SparkSchemaUtil.convert(table.schema(), dsStruct); - validateWriteSchema(table.schema(), dsSchema, checkNullability(options), checkOrdering(options)); + Schema writeSchema = SparkSchemaUtil.convert(table.schema(), dsStruct); + validateWriteSchema(table.schema(), writeSchema, checkNullability(options), checkOrdering(options)); validatePartitionTransforms(table.spec()); // Spark 2.4.x passes runId to createStreamWriter instead of real queryId, // so we fetch it directly from sparkContext to make writes idempotent @@ -132,7 +133,7 @@ public StreamWriter createStreamWriter(String runId, StructType dsStruct, Broadcast io = lazySparkContext().broadcast(fileIO(table)); Broadcast encryptionManager = lazySparkContext().broadcast(table.encryption()); - return new StreamingWriter(table, io, encryptionManager, options, queryId, mode, appId, dsSchema); + return new StreamingWriter(table, io, encryptionManager, options, queryId, mode, appId, writeSchema, dsStruct); } protected Table findTable(DataSourceOptions options, Configuration conf) { diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/StreamingWriter.java b/spark/src/main/java/org/apache/iceberg/spark/source/StreamingWriter.java index 9a3fd633c328..f3a1a40b32c3 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/StreamingWriter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/StreamingWriter.java @@ -35,6 +35,7 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; +import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,8 +50,8 @@ public class StreamingWriter extends Writer implements StreamWriter { StreamingWriter(Table table, Broadcast io, Broadcast encryptionManager, DataSourceOptions options, String queryId, OutputMode mode, String applicationId, - Schema dsSchema) { - super(table, io, encryptionManager, options, false, applicationId, dsSchema); + Schema writeSchema, StructType dsSchema) { + super(table, io, encryptionManager, options, false, applicationId, writeSchema, dsSchema); this.queryId = queryId; this.mode = mode; } diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/Writer.java b/spark/src/main/java/org/apache/iceberg/spark/source/Writer.java index 88f8a240ce1a..8049dc27df28 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/Writer.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/Writer.java @@ -66,6 +66,7 @@ import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.DataWriterFactory; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; +import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -94,16 +95,18 @@ class Writer implements DataSourceWriter { private final String applicationId; private final String wapId; private final long targetFileSize; - private final Schema dsSchema; + private final Schema writeSchema; + private final StructType dsSchema; Writer(Table table, Broadcast io, Broadcast encryptionManager, - DataSourceOptions options, boolean replacePartitions, String applicationId, Schema dsSchema) { - this(table, io, encryptionManager, options, replacePartitions, applicationId, null, dsSchema); + DataSourceOptions options, boolean replacePartitions, String applicationId, Schema writeSchema, + StructType dsSchema) { + this(table, io, encryptionManager, options, replacePartitions, applicationId, null, writeSchema, dsSchema); } Writer(Table table, Broadcast io, Broadcast encryptionManager, DataSourceOptions options, boolean replacePartitions, String applicationId, String wapId, - Schema dsSchema) { + Schema writeSchema, StructType dsSchema) { this.table = table; this.format = getFileFormat(table.properties(), options); this.io = io; @@ -111,6 +114,7 @@ class Writer implements DataSourceWriter { this.replacePartitions = replacePartitions; this.applicationId = applicationId; this.wapId = wapId; + this.writeSchema = writeSchema; this.dsSchema = dsSchema; long tableTargetFileSize = PropertyUtil.propertyAsLong( @@ -134,7 +138,7 @@ private boolean isWapTable() { public DataWriterFactory createWriterFactory() { return new WriterFactory( table.spec(), format, table.locationProvider(), table.properties(), io, encryptionManager, targetFileSize, - dsSchema); + writeSchema, dsSchema); } @Override @@ -260,12 +264,13 @@ private static class WriterFactory implements DataWriterFactory { private final Broadcast io; private final Broadcast encryptionManager; private final long targetFileSize; - private final Schema dsSchema; + private final Schema writeSchema; + private final StructType dsSchema; WriterFactory(PartitionSpec spec, FileFormat format, LocationProvider locations, Map properties, Broadcast io, Broadcast encryptionManager, long targetFileSize, - Schema dsSchema) { + Schema writeSchema, StructType dsSchema) { this.spec = spec; this.format = format; this.locations = locations; @@ -273,6 +278,7 @@ private static class WriterFactory implements DataWriterFactory { this.io = io; this.encryptionManager = encryptionManager; this.targetFileSize = targetFileSize; + this.writeSchema = writeSchema; this.dsSchema = dsSchema; } @@ -284,7 +290,8 @@ public DataWriter createDataWriter(int partitionId, long taskId, lo if (spec.fields().isEmpty()) { return new UnpartitionedWriter(spec, format, appenderFactory, fileFactory, io.value(), targetFileSize); } else { - return new PartitionedWriter(spec, format, appenderFactory, fileFactory, io.value(), targetFileSize, dsSchema); + return new PartitionedWriter( + spec, format, appenderFactory, fileFactory, io.value(), targetFileSize, writeSchema); } } @@ -299,7 +306,7 @@ public FileAppender newAppender(OutputFile file, FileFormat fileFor .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(dsSchema, msgType)) .setAll(properties) .metricsConfig(metricsConfig) - .schema(dsSchema) + .schema(writeSchema) .overwrite() .build(); @@ -307,7 +314,7 @@ public FileAppender newAppender(OutputFile file, FileFormat fileFor return Avro.write(file) .createWriterFunc(ignored -> new SparkAvroWriter(dsSchema)) .setAll(properties) - .schema(dsSchema) + .schema(writeSchema) .overwrite() .build(); @@ -315,7 +322,7 @@ public FileAppender newAppender(OutputFile file, FileFormat fileFor return ORC.write(file) .createWriterFunc(SparkOrcWriter::new) .setAll(properties) - .schema(dsSchema) + .schema(writeSchema) .overwrite() .build(); diff --git a/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java b/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java index 64c9d519f16b..c67f7d91f49f 100644 --- a/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java +++ b/spark/src/test/java/org/apache/iceberg/TestDataFileSerialization.java @@ -38,6 +38,7 @@ import java.util.UUID; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.data.RandomData; import org.apache.iceberg.spark.data.SparkParquetWriters; import org.apache.iceberg.types.Types; @@ -164,7 +165,7 @@ public void testParquetWriterSplitOffsets() throws IOException { FileAppender writer = Parquet.write(Files.localOutput(parquetFile)) .schema(DATE_SCHEMA) - .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(DATE_SCHEMA, msgType)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(DATE_SCHEMA), msgType)) .build(); try { writer.addAll(records); diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java index 00f95f382a15..4ff784448e80 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java @@ -27,6 +27,7 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.types.Types; import org.apache.spark.sql.catalyst.InternalRow; import org.junit.Assert; @@ -78,7 +79,7 @@ public void testCorrectness() throws IOException { try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) .schema(COMPLEX_SCHEMA) - .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(COMPLEX_SCHEMA, msgType)) + .createWriterFunc(msgType -> SparkParquetWriters.buildWriter(SparkSchemaUtil.convert(COMPLEX_SCHEMA), msgType)) .build()) { writer.addAll(records); }