diff --git a/data/src/test/java/org/apache/iceberg/data/RandomGenericData.java b/data/src/test/java/org/apache/iceberg/data/RandomGenericData.java index 19179c0b1de4..7dd44b30f093 100644 --- a/data/src/test/java/org/apache/iceberg/data/RandomGenericData.java +++ b/data/src/test/java/org/apache/iceberg/data/RandomGenericData.java @@ -25,8 +25,10 @@ import java.time.LocalTime; import java.time.OffsetDateTime; import java.time.ZoneOffset; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Random; import java.util.Set; import java.util.UUID; @@ -46,13 +48,37 @@ public class RandomGenericData { private RandomGenericData() {} public static List generate(Schema schema, int numRecords, long seed) { - RandomRecordGenerator generator = new RandomRecordGenerator(seed); - List records = Lists.newArrayListWithExpectedSize(numRecords); - for (int i = 0; i < numRecords; i += 1) { - records.add((Record) TypeUtil.visit(schema, generator)); - } + return Lists.newArrayList(generateIcebergGenerics(schema, numRecords, () -> new RandomRecordGenerator(seed))); + } - return records; + public static Iterable generateFallbackRecords(Schema schema, int numRecords, long seed, long numDictRows) { + return generateIcebergGenerics(schema, numRecords, () -> new FallbackGenerator(seed, numDictRows)); + } + + public static Iterable generateDictionaryEncodableRecords(Schema schema, int numRecords, long seed) { + return generateIcebergGenerics(schema, numRecords, () -> new DictionaryEncodedGenerator(seed)); + } + + private static Iterable generateIcebergGenerics(Schema schema, int numRecords, + Supplier> supplier) { + return () -> new Iterator() { + private final RandomDataGenerator generator = supplier.get(); + private int count = 0; + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public Record next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + ++count; + return (Record) TypeUtil.visit(schema, generator); + } + }; } private static class RandomRecordGenerator extends RandomDataGenerator { @@ -78,6 +104,46 @@ public Record struct(Types.StructType struct, Iterable fieldResults) { } } + private static class DictionaryEncodedGenerator extends RandomRecordGenerator { + DictionaryEncodedGenerator(long seed) { + super(seed); + } + + @Override + protected int getMaxEntries() { + // Here we limited the max entries in LIST or MAP to be 3, because we have the mechanism to duplicate + // the keys in RandomDataGenerator#map while the dictionary encoder will generate a string with + // limited values("0","1","2"). It's impossible for us to request the generator to generate more than 3 keys, + // otherwise we will get in a infinite loop in RandomDataGenerator#map. + return 3; + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random random) { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, random); + } + } + + private static class FallbackGenerator extends RandomRecordGenerator { + private final long dictionaryEncodedRows; + private long rowCount = 0; + + FallbackGenerator(long seed, long numDictionaryEncoded) { + super(seed); + this.dictionaryEncodedRows = numDictionaryEncoded; + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + this.rowCount += 1; + if (rowCount > dictionaryEncodedRows) { + return RandomUtil.generatePrimitive(primitive, rand); + } else { + return RandomUtil.generateDictionaryEncodablePrimitive(primitive, rand); + } + } + } + public abstract static class RandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { private final Random random; private static final int MAX_ENTRIES = 20; diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java index ccea5d6529c5..3012544cba83 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java +++ b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java @@ -19,64 +19,760 @@ package org.apache.iceberg.flink.data; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.time.Instant; +import java.time.ZoneOffset; import java.util.List; -import org.apache.flink.types.Row; +import java.util.Map; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; import org.apache.iceberg.Schema; -import org.apache.iceberg.data.parquet.BaseParquetReaders; import org.apache.iceberg.parquet.ParquetValueReader; import org.apache.iceberg.parquet.ParquetValueReaders; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; -public class FlinkParquetReaders extends BaseParquetReaders { +class FlinkParquetReaders { + private FlinkParquetReaders() { + } - private static final FlinkParquetReaders INSTANCE = new FlinkParquetReaders(); + public static ParquetValueReader buildReader(Schema expectedSchema, MessageType fileSchema) { + return buildReader(expectedSchema, fileSchema, ImmutableMap.of()); + } - private FlinkParquetReaders() { + @SuppressWarnings("unchecked") + public static ParquetValueReader buildReader(Schema expectedSchema, + MessageType fileSchema, + Map idToConstant) { + return (ParquetValueReader) TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new ReadBuilder(fileSchema, idToConstant) + ); + } + + private static class ReadBuilder extends TypeWithSchemaVisitor> { + private final MessageType type; + private final Map idToConstant; + + ReadBuilder(MessageType type, Map idToConstant) { + this.type = type; + this.idToConstant = idToConstant; + } + + @Override + public ParquetValueReader message(Types.StructType expected, MessageType message, + List> fieldReaders) { + return struct(expected, message.asGroupType(), fieldReaders); + } + + @Override + public ParquetValueReader struct(Types.StructType expected, GroupType struct, + List> fieldReaders) { + // match the expected struct's order + Map> readersById = Maps.newHashMap(); + Map typesById = Maps.newHashMap(); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type.getMaxDefinitionLevel(path(fieldType.getName())) - 1; + if (fieldType.getId() != null) { + int id = fieldType.getId().intValue(); + readersById.put(id, ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + typesById.put(id, fieldType); + } + } + + List expectedFields = expected != null ? + expected.fields() : ImmutableList.of(); + List> reorderedFields = Lists.newArrayListWithExpectedSize( + expectedFields.size()); + List types = Lists.newArrayListWithExpectedSize(expectedFields.size()); + for (Types.NestedField field : expectedFields) { + int id = field.fieldId(); + if (idToConstant.containsKey(id)) { + // containsKey is used because the constant may be null + reorderedFields.add(ParquetValueReaders.constant(idToConstant.get(id))); + types.add(null); + } else { + ParquetValueReader reader = readersById.get(id); + if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); + } + } + } + + return new RowDataReader(types, reorderedFields); + } + + @Override + public ParquetValueReader list(Types.ListType expectedList, GroupType array, + ParquetValueReader elementReader) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type elementType = repeated.getType(0); + int elementD = type.getMaxDefinitionLevel(path(elementType.getName())) - 1; + + return new ArrayReader<>(repeatedD, repeatedR, ParquetValueReaders.option(elementType, elementD, elementReader)); + } + + @Override + public ParquetValueReader map(Types.MapType expectedMap, GroupType map, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type keyType = repeatedKeyValue.getType(0); + int keyD = type.getMaxDefinitionLevel(path(keyType.getName())) - 1; + Type valueType = repeatedKeyValue.getType(1); + int valueD = type.getMaxDefinitionLevel(path(valueType.getName())) - 1; + + return new MapReader<>(repeatedD, repeatedR, + ParquetValueReaders.option(keyType, keyD, keyReader), + ParquetValueReaders.option(valueType, valueD, valueReader)); + } + + @Override + @SuppressWarnings("CyclomaticComplexity") + public ParquetValueReader primitive(org.apache.iceberg.types.Type.PrimitiveType expected, + PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return new StringReader(desc); + case INT_8: + case INT_16: + case INT_32: + if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { + return new ParquetValueReaders.IntAsLongReader(desc); + } else { + return new ParquetValueReaders.UnboxedReader<>(desc); + } + case TIME_MICROS: + return new LossyMicrosToMillisTimeReader(desc); + case TIME_MILLIS: + return new MillisTimeReader(desc); + case DATE: + case INT_64: + return new ParquetValueReaders.UnboxedReader<>(desc); + case TIMESTAMP_MICROS: + if (((Types.TimestampType) expected).shouldAdjustToUTC()) { + return new MicrosToTimestampTzReader(desc); + } else { + return new MicrosToTimestampReader(desc); + } + case TIMESTAMP_MILLIS: + if (((Types.TimestampType) expected).shouldAdjustToUTC()) { + return new MillisToTimestampTzReader(desc); + } else { + return new MillisToTimestampReader(desc); + } + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return new BinaryDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + case INT64: + return new LongDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + case INT32: + return new IntegerDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return new ParquetValueReaders.ByteArrayReader(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return new ParquetValueReaders.ByteArrayReader(desc); + case INT32: + if (expected != null && expected.typeId() == org.apache.iceberg.types.Type.TypeID.LONG) { + return new ParquetValueReaders.IntAsLongReader(desc); + } else { + return new ParquetValueReaders.UnboxedReader<>(desc); + } + case FLOAT: + if (expected != null && expected.typeId() == org.apache.iceberg.types.Type.TypeID.DOUBLE) { + return new ParquetValueReaders.FloatAsDoubleReader(desc); + } else { + return new ParquetValueReaders.UnboxedReader<>(desc); + } + case BOOLEAN: + case INT64: + case DOUBLE: + return new ParquetValueReaders.UnboxedReader<>(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + private static class BinaryDecimalReader extends ParquetValueReaders.PrimitiveReader { + private final int precision; + private final int scale; + + BinaryDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public DecimalData read(DecimalData ignored) { + Binary binary = column.nextBinary(); + BigDecimal bigDecimal = new BigDecimal(new BigInteger(binary.getBytes()), scale); + // TODO: need a unit test to write-read-validate decimal via FlinkParquetWrite/Reader + return DecimalData.fromBigDecimal(bigDecimal, precision, scale); + } + } + + private static class IntegerDecimalReader extends ParquetValueReaders.PrimitiveReader { + private final int precision; + private final int scale; + + IntegerDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public DecimalData read(DecimalData ignored) { + return DecimalData.fromUnscaledLong(column.nextInteger(), precision, scale); + } + } + + private static class LongDecimalReader extends ParquetValueReaders.PrimitiveReader { + private final int precision; + private final int scale; + + LongDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public DecimalData read(DecimalData ignored) { + return DecimalData.fromUnscaledLong(column.nextLong(), precision, scale); + } + } + + private static class MicrosToTimestampTzReader extends ParquetValueReaders.UnboxedReader { + MicrosToTimestampTzReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public TimestampData read(TimestampData ignored) { + long value = readLong(); + return TimestampData.fromLocalDateTime(Instant.ofEpochSecond(Math.floorDiv(value, 1000_000), + Math.floorMod(value, 1000_000) * 1000) + .atOffset(ZoneOffset.UTC) + .toLocalDateTime()); + } + + @Override + public long readLong() { + return column.nextLong(); + } + } + + private static class MicrosToTimestampReader extends ParquetValueReaders.UnboxedReader { + MicrosToTimestampReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public TimestampData read(TimestampData ignored) { + long value = readLong(); + return TimestampData.fromInstant(Instant.ofEpochSecond(Math.floorDiv(value, 1000_000), + Math.floorMod(value, 1000_000) * 1000)); + } + + @Override + public long readLong() { + return column.nextLong(); + } + } + + private static class MillisToTimestampReader extends ParquetValueReaders.UnboxedReader { + MillisToTimestampReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public TimestampData read(TimestampData ignored) { + long millis = readLong(); + return TimestampData.fromEpochMillis(millis); + } + + @Override + public long readLong() { + return column.nextLong(); + } + } + + private static class MillisToTimestampTzReader extends ParquetValueReaders.UnboxedReader { + MillisToTimestampTzReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public TimestampData read(TimestampData ignored) { + long millis = readLong(); + return TimestampData.fromLocalDateTime(Instant.ofEpochMilli(millis) + .atOffset(ZoneOffset.UTC) + .toLocalDateTime()); + } + + @Override + public long readLong() { + return column.nextLong(); + } + } + + private static class StringReader extends ParquetValueReaders.PrimitiveReader { + StringReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public StringData read(StringData ignored) { + Binary binary = column.nextBinary(); + ByteBuffer buffer = binary.toByteBuffer(); + if (buffer.hasArray()) { + return StringData.fromBytes( + buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + } else { + return StringData.fromBytes(binary.getBytes()); + } + } + } + + private static class LossyMicrosToMillisTimeReader extends ParquetValueReaders.PrimitiveReader { + LossyMicrosToMillisTimeReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Integer read(Integer reuse) { + // Discard microseconds since Flink uses millisecond unit for TIME type. + return (int) Math.floorDiv(column.nextLong(), 1000); + } } - public static ParquetValueReader buildReader(Schema expectedSchema, MessageType fileSchema) { - return INSTANCE.createReader(expectedSchema, fileSchema); + private static class MillisTimeReader extends ParquetValueReaders.PrimitiveReader { + MillisTimeReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Integer read(Integer reuse) { + return (int) column.nextLong(); + } } - @Override - protected ParquetValueReader createStructReader(List types, - List> fieldReaders, - Types.StructType structType) { - return new RowReader(types, fieldReaders, structType); + private static class ArrayReader extends ParquetValueReaders.RepeatedReader { + private int readPos = 0; + private int writePos = 0; + + ArrayReader(int definitionLevel, int repetitionLevel, ParquetValueReader reader) { + super(definitionLevel, repetitionLevel, reader); + } + + @Override + protected ReusableArrayData newListData(ArrayData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableArrayData) { + return (ReusableArrayData) reuse; + } else { + return new ReusableArrayData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected E getElement(ReusableArrayData list) { + E value = null; + if (readPos < list.capacity()) { + value = (E) list.values[readPos]; + } + + readPos += 1; + + return value; + } + + @Override + protected void addElement(ReusableArrayData reused, E element) { + if (writePos >= reused.capacity()) { + reused.grow(); + } + + reused.values[writePos] = element; + + writePos += 1; + } + + @Override + protected ArrayData buildList(ReusableArrayData list) { + list.setNumElements(writePos); + return list; + } } - private static class RowReader extends ParquetValueReaders.StructReader { - private final Types.StructType structType; + private static class MapReader extends + ParquetValueReaders.RepeatedKeyValueReader { + private int readPos = 0; + private int writePos = 0; + + private final ParquetValueReaders.ReusableEntry entry = new ParquetValueReaders.ReusableEntry<>(); + private final ParquetValueReaders.ReusableEntry nullEntry = new ParquetValueReaders.ReusableEntry<>(); + + MapReader(int definitionLevel, int repetitionLevel, + ParquetValueReader keyReader, ParquetValueReader valueReader) { + super(definitionLevel, repetitionLevel, keyReader, valueReader); + } + + @Override + protected ReusableMapData newMapData(MapData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableMapData) { + return (ReusableMapData) reuse; + } else { + return new ReusableMapData(); + } + } - RowReader(List types, List> readers, Types.StructType struct) { + @Override + @SuppressWarnings("unchecked") + protected Map.Entry getPair(ReusableMapData map) { + Map.Entry kv = nullEntry; + if (readPos < map.capacity()) { + entry.set((K) map.keys.values[readPos], (V) map.values.values[readPos]); + kv = entry; + } + + readPos += 1; + + return kv; + } + + @Override + protected void addPair(ReusableMapData map, K key, V value) { + if (writePos >= map.capacity()) { + map.grow(); + } + + map.keys.values[writePos] = key; + map.values.values[writePos] = value; + + writePos += 1; + } + + @Override + protected MapData buildMap(ReusableMapData map) { + map.setNumElements(writePos); + return map; + } + } + + private static class RowDataReader extends ParquetValueReaders.StructReader { + private final int numFields; + + RowDataReader(List types, List> readers) { super(types, readers); - this.structType = struct; + this.numFields = readers.size(); } @Override - protected Row newStructData(Row reuse) { - if (reuse != null) { - return reuse; + protected GenericRowData newStructData(RowData reuse) { + if (reuse instanceof GenericRowData) { + return (GenericRowData) reuse; } else { - return new Row(structType.fields().size()); + return new GenericRowData(numFields); } } @Override - protected Object getField(Row row, int pos) { - return row.getField(pos); + protected Object getField(GenericRowData intermediate, int pos) { + return intermediate.getField(pos); + } + + @Override + protected RowData buildStruct(GenericRowData struct) { + return struct; + } + + @Override + protected void set(GenericRowData row, int pos, Object value) { + row.setField(pos, value); + } + + @Override + protected void setNull(GenericRowData row, int pos) { + row.setField(pos, null); + } + + @Override + protected void setBoolean(GenericRowData row, int pos, boolean value) { + row.setField(pos, value); + } + + @Override + protected void setInteger(GenericRowData row, int pos, int value) { + row.setField(pos, value); + } + + @Override + protected void setLong(GenericRowData row, int pos, long value) { + row.setField(pos, value); } @Override - protected Row buildStruct(Row row) { - return row; + protected void setFloat(GenericRowData row, int pos, float value) { + row.setField(pos, value); } @Override - protected void set(Row row, int pos, Object value) { + protected void setDouble(GenericRowData row, int pos, double value) { row.setField(pos, value); } } + + private static class ReusableMapData implements MapData { + private final ReusableArrayData keys; + private final ReusableArrayData values; + + private int numElements; + + private ReusableMapData() { + this.keys = new ReusableArrayData(); + this.values = new ReusableArrayData(); + } + + private void grow() { + keys.grow(); + values.grow(); + } + + private int capacity() { + return keys.capacity(); + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + keys.setNumElements(numElements); + values.setNumElements(numElements); + } + + @Override + public int size() { + return numElements; + } + + @Override + public ReusableArrayData keyArray() { + return keys; + } + + @Override + public ReusableArrayData valueArray() { + return values; + } + } + + private static class ReusableArrayData implements ArrayData { + private static final Object[] EMPTY = new Object[0]; + + private Object[] values = EMPTY; + private int numElements = 0; + + private void grow() { + if (values.length == 0) { + this.values = new Object[20]; + } else { + Object[] old = values; + this.values = new Object[old.length << 1]; + // copy the old array in case it has values that can be reused + System.arraycopy(old, 0, values, 0, old.length); + } + } + + private int capacity() { + return values.length; + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + } + + @Override + public int size() { + return numElements; + } + + @Override + public boolean isNullAt(int ordinal) { + return null == values[ordinal]; + } + + @Override + public boolean getBoolean(int ordinal) { + return (boolean) values[ordinal]; + } + + @Override + public byte getByte(int ordinal) { + return (byte) values[ordinal]; + } + + @Override + public short getShort(int ordinal) { + return (short) values[ordinal]; + } + + @Override + public int getInt(int ordinal) { + return (int) values[ordinal]; + } + + @Override + public long getLong(int ordinal) { + return (long) values[ordinal]; + } + + @Override + public float getFloat(int ordinal) { + return (float) values[ordinal]; + } + + @Override + public double getDouble(int ordinal) { + return (double) values[ordinal]; + } + + @Override + public StringData getString(int pos) { + return (StringData) values[pos]; + } + + @Override + public DecimalData getDecimal(int pos, int precision, int scale) { + return (DecimalData) values[pos]; + } + + @Override + public TimestampData getTimestamp(int pos, int precision) { + return (TimestampData) values[pos]; + } + + @SuppressWarnings("unchecked") + @Override + public RawValueData getRawValue(int pos) { + return (RawValueData) values[pos]; + } + + @Override + public byte[] getBinary(int ordinal) { + return (byte[]) values[ordinal]; + } + + @Override + public ArrayData getArray(int ordinal) { + return (ArrayData) values[ordinal]; + } + + @Override + public MapData getMap(int ordinal) { + return (MapData) values[ordinal]; + } + + @Override + public RowData getRow(int pos, int numFields) { + return (RowData) values[pos]; + } + + @Override + public boolean[] toBooleanArray() { + return ArrayUtils.toPrimitive((Boolean[]) values); + } + + @Override + public byte[] toByteArray() { + return ArrayUtils.toPrimitive((Byte[]) values); + } + + @Override + public short[] toShortArray() { + return ArrayUtils.toPrimitive((Short[]) values); + } + + @Override + public int[] toIntArray() { + return ArrayUtils.toPrimitive((Integer[]) values); + } + + @Override + public long[] toLongArray() { + return ArrayUtils.toPrimitive((Long[]) values); + } + + @Override + public float[] toFloatArray() { + return ArrayUtils.toPrimitive((Float[]) values); + } + + @Override + public double[] toDoubleArray() { + return ArrayUtils.toPrimitive((Double[]) values); + } + } } diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java b/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java index 843006197fdf..b1e14c6c0fc5 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java +++ b/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java @@ -22,16 +22,13 @@ import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; -import java.util.Random; import java.util.function.Supplier; import org.apache.flink.types.Row; import org.apache.iceberg.Schema; import org.apache.iceberg.data.RandomGenericData; import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; -import org.apache.iceberg.util.RandomUtil; import static org.apache.iceberg.types.Types.NestedField.optional; import static org.apache.iceberg.types.Types.NestedField.required; @@ -92,16 +89,7 @@ public static Iterable generate(Schema schema, int numRecords, long seed) { return generateData(schema, numRecords, () -> new RandomRowGenerator(seed)); } - public static Iterable generateFallbackData(Schema schema, int numRecords, long seed, long numDictRows) { - return generateData(schema, numRecords, () -> new FallbackGenerator(seed, numDictRows)); - } - - public static Iterable generateDictionaryEncodableData(Schema schema, int numRecords, long seed) { - return generateData(schema, numRecords, () -> new DictionaryEncodedGenerator(seed)); - } - private static class RandomRowGenerator extends RandomGenericData.RandomDataGenerator { - RandomRowGenerator(long seed) { super(seed); } @@ -123,44 +111,4 @@ public Row struct(Types.StructType struct, Iterable fieldResults) { return row; } } - - private static class DictionaryEncodedGenerator extends RandomRowGenerator { - DictionaryEncodedGenerator(long seed) { - super(seed); - } - - @Override - protected int getMaxEntries() { - // Here we limited the max entries in LIST or MAP to be 3, because we have the mechanism to duplicate - // the keys in RandomDataGenerator#map while the dictionary encoder will generate a string with - // limited values("0","1","2"). It's impossible for us to request the generator to generate more than 3 keys, - // otherwise we will get in a infinite loop in RandomDataGenerator#map. - return 3; - } - - @Override - protected Object randomValue(Type.PrimitiveType primitive, Random random) { - return RandomUtil.generateDictionaryEncodablePrimitive(primitive, random); - } - } - - private static class FallbackGenerator extends RandomRowGenerator { - private final long dictionaryEncodedRows; - private long rowCount = 0; - - FallbackGenerator(long seed, long numDictionaryEncoded) { - super(seed); - this.dictionaryEncodedRows = numDictionaryEncoded; - } - - @Override - protected Object randomValue(Type.PrimitiveType primitive, Random rand) { - this.rowCount += 1; - if (rowCount > dictionaryEncodedRows) { - return RandomUtil.generatePrimitive(primitive, rand); - } else { - return RandomUtil.generateDictionaryEncodablePrimitive(primitive, rand); - } - } - } } diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReader.java similarity index 54% rename from flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java rename to flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReader.java index 41ea960b72c2..8a8a3d7aaf66 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReader.java @@ -22,64 +22,58 @@ import java.io.File; import java.io.IOException; import java.util.Iterator; -import org.apache.flink.types.Row; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.logical.LogicalType; import org.apache.iceberg.Files; import org.apache.iceberg.Schema; +import org.apache.iceberg.data.DataTest; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.parquet.GenericParquetWriter; +import org.apache.iceberg.flink.FlinkSchemaUtil; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; import org.junit.Assert; import org.junit.Rule; -import org.junit.Test; import org.junit.rules.TemporaryFolder; -import static org.apache.iceberg.flink.data.RandomData.COMPLEX_SCHEMA; - -public class TestFlinkParquetReaderWriter { - private static final int NUM_RECORDS = 20_000; +public class TestFlinkParquetReader extends DataTest { + private static final int NUM_RECORDS = 100; @Rule public TemporaryFolder temp = new TemporaryFolder(); - private void testCorrectness(Schema schema, int numRecords, Iterable iterable) throws IOException { + private void writeAndValidate(Iterable iterable, Schema schema) throws IOException { File testFile = temp.newFile(); Assert.assertTrue("Delete should succeed", testFile.delete()); - try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) .schema(schema) - .createWriterFunc(FlinkParquetWriters::buildWriter) + .createWriterFunc(GenericParquetWriter::buildWriter) .build()) { writer.addAll(iterable); } - try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) .project(schema) .createReaderFunc(type -> FlinkParquetReaders.buildReader(schema, type)) .build()) { - Iterator expected = iterable.iterator(); - Iterator rows = reader.iterator(); - for (int i = 0; i < numRecords; i += 1) { + Iterator expected = iterable.iterator(); + Iterator rows = reader.iterator(); + LogicalType rowType = FlinkSchemaUtil.convert(schema); + for (int i = 0; i < NUM_RECORDS; i += 1) { Assert.assertTrue("Should have expected number of rows", rows.hasNext()); - Assert.assertEquals(expected.next(), rows.next()); + TestHelpers.assertRowData(schema.asStruct(), rowType, expected.next(), rows.next()); } Assert.assertFalse("Should not have extra rows", rows.hasNext()); } } - @Test - public void testNormalRowData() throws IOException { - testCorrectness(COMPLEX_SCHEMA, NUM_RECORDS, RandomData.generate(COMPLEX_SCHEMA, NUM_RECORDS, 19981)); - } - - @Test - public void testDictionaryEncodedData() throws IOException { - testCorrectness(COMPLEX_SCHEMA, NUM_RECORDS, - RandomData.generateDictionaryEncodableData(COMPLEX_SCHEMA, NUM_RECORDS, 21124)); - } - - @Test - public void testFallbackData() throws IOException { - testCorrectness(COMPLEX_SCHEMA, NUM_RECORDS, - RandomData.generateFallbackData(COMPLEX_SCHEMA, NUM_RECORDS, 21124, NUM_RECORDS / 20)); + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(RandomGenericData.generate(schema, NUM_RECORDS, 19981), schema); + writeAndValidate(RandomGenericData.generateDictionaryEncodableRecords(schema, NUM_RECORDS, 21124), schema); + writeAndValidate(RandomGenericData.generateFallbackRecords(schema, NUM_RECORDS, 21124, NUM_RECORDS / 20), schema); } } diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java new file mode 100644 index 000000000000..be427ce868b8 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Supplier; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.TimestampData; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DateTimeUtil; +import org.junit.Assert; + +public class TestHelpers { + private TestHelpers() {} + + public static void assertRowData(Types.StructType structType, LogicalType rowType, Record expectedRecord, + RowData actualRowData) { + if (expectedRecord == null && actualRowData == null) { + return; + } + + Assert.assertTrue("expected Record and actual RowData should be both null or not null", + expectedRecord != null && actualRowData != null); + + List types = Lists.newArrayList(); + for (Types.NestedField field : structType.fields()) { + types.add(field.type()); + } + + for (int i = 0; i < types.size(); i += 1) { + Object expected = expectedRecord.get(i); + LogicalType logicalType = ((RowType) rowType).getTypeAt(i); + + final int fieldPos = i; + assertEquals(types.get(i), logicalType, expected, + () -> RowData.createFieldGetter(logicalType, fieldPos).getFieldOrNull(actualRowData)); + } + } + + private static void assertEquals(Type type, LogicalType logicalType, Object expected, Supplier supplier) { + Object actual = supplier.get(); + + if (expected == null && actual == null) { + return; + } + + Assert.assertTrue("expected and actual should be both null or not null", + expected != null && actual != null); + + switch (type.typeId()) { + case BOOLEAN: + Assert.assertEquals("boolean value should be equal", expected, actual); + break; + case INTEGER: + Assert.assertEquals("int value should be equal", expected, actual); + break; + case LONG: + Assert.assertEquals("long value should be equal", expected, actual); + break; + case FLOAT: + Assert.assertEquals("float value should be equal", expected, actual); + break; + case DOUBLE: + Assert.assertEquals("double value should be equal", expected, actual); + break; + case STRING: + Assert.assertTrue("Should expect a CharSequence", expected instanceof CharSequence); + Assert.assertEquals("string should be equal", String.valueOf(expected), actual.toString()); + break; + case DATE: + Assert.assertTrue("Should expect a Date", expected instanceof LocalDate); + LocalDate date = DateTimeUtil.dateFromDays((int) actual); + Assert.assertEquals("date should be equal", expected, date); + break; + case TIME: + Assert.assertTrue("Should expect a LocalTime", expected instanceof LocalTime); + int milliseconds = (int) (((LocalTime) expected).toNanoOfDay() / 1000_000); + Assert.assertEquals("time millis should be equal", milliseconds, actual); + break; + case TIMESTAMP: + if (((Types.TimestampType) type).shouldAdjustToUTC()) { + Assert.assertTrue("Should expect a OffsetDataTime", expected instanceof OffsetDateTime); + OffsetDateTime ts = (OffsetDateTime) expected; + Assert.assertEquals("OffsetDataTime should be equal", ts.toLocalDateTime(), + ((TimestampData) actual).toLocalDateTime()); + } else { + Assert.assertTrue("Should expect a LocalDataTime", expected instanceof LocalDateTime); + LocalDateTime ts = (LocalDateTime) expected; + Assert.assertEquals("LocalDataTime should be equal", ts, + ((TimestampData) actual).toLocalDateTime()); + } + break; + case BINARY: + Assert.assertTrue("Should expect a ByteBuffer", expected instanceof ByteBuffer); + Assert.assertEquals("binary should be equal", expected, ByteBuffer.wrap((byte[]) actual)); + break; + case DECIMAL: + Assert.assertTrue("Should expect a BigDecimal", expected instanceof BigDecimal); + BigDecimal bd = (BigDecimal) expected; + Assert.assertEquals("decimal value should be equal", bd, + ((DecimalData) actual).toBigDecimal()); + break; + case LIST: + Assert.assertTrue("Should expect a Collection", expected instanceof Collection); + Collection expectedArrayData = (Collection) expected; + ArrayData actualArrayData = (ArrayData) actual; + LogicalType elementType = ((ArrayType) logicalType).getElementType(); + Assert.assertEquals("array length should be equal", expectedArrayData.size(), actualArrayData.size()); + assertArrayValues(type.asListType().elementType(), elementType, expectedArrayData, actualArrayData); + break; + case MAP: + Assert.assertTrue("Should expect a Map", expected instanceof Map); + assertMapValues(type.asMapType(), logicalType, (Map) expected, (MapData) actual); + break; + case STRUCT: + Assert.assertTrue("Should expect a Record", expected instanceof Record); + assertRowData(type.asStructType(), logicalType, (Record) expected, (RowData) actual); + break; + case UUID: + Assert.assertTrue("Should expect a UUID", expected instanceof UUID); + Assert.assertEquals("UUID should be equal", expected.toString(), + UUID.nameUUIDFromBytes((byte[]) actual).toString()); + break; + case FIXED: + Assert.assertTrue("Should expect byte[]", expected instanceof byte[]); + Assert.assertArrayEquals("binary should be equal", (byte[]) expected, (byte[]) actual); + break; + default: + throw new IllegalArgumentException("Not a supported type: " + type); + } + } + + private static void assertArrayValues(Type type, LogicalType logicalType, Collection expectedArray, + ArrayData actualArray) { + List expectedElements = Lists.newArrayList(expectedArray); + for (int i = 0; i < expectedArray.size(); i += 1) { + if (expectedElements.get(i) == null) { + Assert.assertTrue(actualArray.isNullAt(i)); + continue; + } + + Object expected = expectedElements.get(i); + + final int pos = i; + assertEquals(type, logicalType, expected, + () -> ArrayData.createElementGetter(logicalType).getElementOrNull(actualArray, pos)); + } + } + + private static void assertMapValues(Types.MapType mapType, LogicalType type, Map expected, MapData actual) { + Assert.assertEquals("map size should be equal", expected.size(), actual.size()); + + ArrayData actualKeyArrayData = actual.keyArray(); + ArrayData actualValueArrayData = actual.valueArray(); + LogicalType actualKeyType = ((MapType) type).getKeyType(); + LogicalType actualValueType = ((MapType) type).getValueType(); + Type keyType = mapType.keyType(); + Type valueType = mapType.valueType(); + + ArrayData.ElementGetter keyGetter = ArrayData.createElementGetter(actualKeyType); + ArrayData.ElementGetter valueGetter = ArrayData.createElementGetter(actualValueType); + + for (Map.Entry entry : expected.entrySet()) { + Object matchedActualKey = null; + int matchedKeyIndex = 0; + for (int i = 0; i < actual.size(); i += 1) { + try { + Object key = keyGetter.getElementOrNull(actualKeyArrayData, i); + assertEquals(keyType, actualKeyType, entry.getKey(), () -> key); + matchedActualKey = key; + matchedKeyIndex = i; + break; + } catch (AssertionError e) { + // not found + } + } + Assert.assertNotNull("Should have a matching key", matchedActualKey); + final int valueIndex = matchedKeyIndex; + assertEquals(valueType, actualValueType, entry.getValue(), + () -> valueGetter.getElementOrNull(actualValueArrayData, valueIndex)); + } + } +} diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java index 55bf925e0581..e0968566c33f 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java @@ -309,6 +309,17 @@ public ByteBuffer read(ByteBuffer reuse) { } } + public static class ByteArrayReader extends ParquetValueReaders.PrimitiveReader { + public ByteArrayReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public byte[] read(byte[] ignored) { + return column.nextBinary().getBytes(); + } + } + private static class OptionReader implements ParquetValueReader { private final int definitionLevel; private final ParquetValueReader reader; diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java index 51ddc9432bc3..84fce5881cdc 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java @@ -250,7 +250,7 @@ public ParquetValueReader primitive(org.apache.iceberg.types.Type.PrimitiveTy "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); } case BSON: - return new BytesReader(desc); + return new ParquetValueReaders.ByteArrayReader(desc); default: throw new UnsupportedOperationException( "Unsupported logical type: " + primitive.getOriginalType()); @@ -260,7 +260,7 @@ public ParquetValueReader primitive(org.apache.iceberg.types.Type.PrimitiveTy switch (primitive.getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: case BINARY: - return new BytesReader(desc); + return new ParquetValueReaders.ByteArrayReader(desc); case INT32: if (expected != null && expected.typeId() == TypeID.LONG) { return new IntAsLongReader(desc); @@ -368,17 +368,6 @@ public UTF8String read(UTF8String ignored) { } } - private static class BytesReader extends PrimitiveReader { - BytesReader(ColumnDescriptor desc) { - super(desc); - } - - @Override - public byte[] read(byte[] ignored) { - return column.nextBinary().getBytes(); - } - } - private static class ArrayReader extends RepeatedReader { private int readPos = 0; private int writePos = 0;