From a5c886619dd1573e96bbba058db099b47f0c147c Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 30 Jun 2021 14:21:18 -0700 Subject: [PATCH] [SPARK-34859][SQL] Handle column index when using vectorized Parquet reader ### What changes were proposed in this pull request? Make the current vectorized Parquet reader to work with column index introduced in Parquet 1.11. In particular, this PR makes the following changes: 1. in `ParquetReadState`, track row ranges returned via `PageReadStore.getRowIndexes` as well as the first row index for each page via `DataPage.getFirstRowIndex`. 1. introduced a new API `ParquetVectorUpdater.skipValues` which skips a batch of values from a Parquet value reader. As part of the process also renamed existing `updateBatch` to `readValues`, and `update` to `readValue` to keep the method names consistent. 1. in correspondence as above, also introduced new API `VectorizedValuesReader.skipXXX` for different data types, as well as the implementations. These are useful when the reader knows that the given batch of values can be skipped, for instance, due to the batch is not covered in the row ranges generated by column index filtering. 2. changed `VectorizedRleValuesReader` to handle column index filtering. This is done by comparing the range that is going to be read next within the current RLE/PACKED block (let's call this block range), against the current row range. There are three cases: * if the block range is before the current row range, skip all the values in the block range * if the block range is after the current row range, advance the row range and repeat the steps * if the block range overlaps with the current row range, only read the values within the overlapping area and skip the rest. ### Why are the changes needed? [Parquet Column Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) is a new feature in Parquet 1.11 which allows very efficient filtering on page level (some benchmark numbers can be found [here](https://blog.cloudera.com/speeding-up-select-queries-with-parquet-page-indexes/)), especially when data is sorted. The feature is largely implemented in parquet-mr (via classes such as `ColumnIndex` and `ColumnIndexFilter`). In Spark, the non-vectorized Parquet reader can automatically benefit from the feature after upgrading to Parquet 1.11.x, without any code change. However, the same is not true for vectorized Parquet reader since Spark chose to implement its own logic such as reading Parquet pages, handling definition levels, reading values into columnar batches, etc. Previously, [SPARK-26345](https://issues.apache.org/jira/browse/SPARK-26345) / (#31393) updated Spark to only scan pages filtered by column index from parquet-mr side. This is done by calling `ParquetFileReader.readNextFilteredRowGroup` and `ParquetFileReader.getFilteredRecordCount` API. The implementation, however, only work for a few limited cases: in the scenario where there are multiple columns and their type width are different (e.g., `int` and `bigint`), it could return incorrect result. For this issue, please see SPARK-34859 for a detailed description. In order to fix the above, Spark needs to leverage the API `PageReadStore.getRowIndexes` and `DataPage.getFirstRowIndex`. The former returns the indexes of all rows (note the difference between rows and values: for flat schema there is no difference between the two, but for nested schema they're different) after filtering within a Parquet row group. The latter returns the first row index within a single data page. With the combination of the two, one is able to know which rows/values should be filtered while scanning a Parquet page. ### Does this PR introduce _any_ user-facing change? Yes. Now the vectorized Parquet reader should work correctly with column index. ### How was this patch tested? Borrowed tests from #31998 and added a few more tests. Closes #32753 from sunchao/SPARK-34859. Lead-authored-by: Chao Sun Co-authored-by: Li Xian Signed-off-by: Dongjoon Hyun --- .../datasources/parquet/ParquetReadState.java | 120 ++++++++- .../parquet/ParquetVectorUpdater.java | 12 +- .../parquet/ParquetVectorUpdaterFactory.java | 216 ++++++++++++---- .../parquet/VectorizedColumnReader.java | 28 +- .../VectorizedParquetRecordReader.java | 1 + .../parquet/VectorizedPlainValuesReader.java | 51 ++++ .../parquet/VectorizedRleValuesReader.java | 239 +++++++++++++----- .../parquet/VectorizedValuesReader.java | 13 + .../parquet/ParquetColumnIndexSuite.scala | 126 +++++++++ .../datasources/parquet/ParquetIOSuite.scala | 72 +++++- 10 files changed, 746 insertions(+), 132 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java index 28dcc44b28cad..b26088753465e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java @@ -17,13 +17,38 @@ package org.apache.spark.sql.execution.datasources.parquet; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.PrimitiveIterator; + /** * Helper class to store intermediate state while reading a Parquet column chunk. */ final class ParquetReadState { - /** Maximum definition level */ + /** A special row range used when there is no row indexes (hence all rows must be included) */ + private static final RowRange MAX_ROW_RANGE = new RowRange(Long.MIN_VALUE, Long.MAX_VALUE); + + /** + * A special row range used when the row indexes are present AND all the row ranges have been + * processed. This serves as a sentinel at the end indicating that all rows come after the last + * row range should be skipped. + */ + private static final RowRange END_ROW_RANGE = new RowRange(Long.MAX_VALUE, Long.MIN_VALUE); + + /** Iterator over all row ranges, only not-null if column index is present */ + private final Iterator rowRanges; + + /** The current row range */ + private RowRange currentRange; + + /** Maximum definition level for the Parquet column */ final int maxDefinitionLevel; + /** The current index over all rows within the column chunk. This is used to check if the + * current row should be skipped by comparing against the row ranges. */ + long rowId; + /** The offset in the current batch to put the next value */ int offset; @@ -33,31 +58,108 @@ final class ParquetReadState { /** The remaining number of values to read in the current batch */ int valuesToReadInBatch; - ParquetReadState(int maxDefinitionLevel) { + ParquetReadState(int maxDefinitionLevel, PrimitiveIterator.OfLong rowIndexes) { this.maxDefinitionLevel = maxDefinitionLevel; + this.rowRanges = constructRanges(rowIndexes); + nextRange(); } /** - * Called at the beginning of reading a new batch. + * Construct a list of row ranges from the given `rowIndexes`. For example, suppose the + * `rowIndexes` are `[0, 1, 2, 4, 5, 7, 8, 9]`, it will be converted into 3 row ranges: + * `[0-2], [4-5], [7-9]`. */ - void resetForBatch(int batchSize) { + private Iterator constructRanges(PrimitiveIterator.OfLong rowIndexes) { + if (rowIndexes == null) { + return null; + } + + List rowRanges = new ArrayList<>(); + long currentStart = Long.MIN_VALUE; + long previous = Long.MIN_VALUE; + + while (rowIndexes.hasNext()) { + long idx = rowIndexes.nextLong(); + if (currentStart == Long.MIN_VALUE) { + currentStart = idx; + } else if (previous + 1 != idx) { + RowRange range = new RowRange(currentStart, previous); + rowRanges.add(range); + currentStart = idx; + } + previous = idx; + } + + if (previous != Long.MIN_VALUE) { + rowRanges.add(new RowRange(currentStart, previous)); + } + + return rowRanges.iterator(); + } + + /** + * Must be called at the beginning of reading a new batch. + */ + void resetForNewBatch(int batchSize) { this.offset = 0; this.valuesToReadInBatch = batchSize; } /** - * Called at the beginning of reading a new page. + * Must be called at the beginning of reading a new page. */ - void resetForPage(int totalValuesInPage) { + void resetForNewPage(int totalValuesInPage, long pageFirstRowIndex) { this.valuesToReadInPage = totalValuesInPage; + this.rowId = pageFirstRowIndex; } /** - * Advance the current offset to the new values. + * Returns the start index of the current row range. */ - void advanceOffset(int newOffset) { + long currentRangeStart() { + return currentRange.start; + } + + /** + * Returns the end index of the current row range. + */ + long currentRangeEnd() { + return currentRange.end; + } + + /** + * Advance the current offset and rowId to the new values. + */ + void advanceOffsetAndRowId(int newOffset, long newRowId) { valuesToReadInBatch -= (newOffset - offset); - valuesToReadInPage -= (newOffset - offset); + valuesToReadInPage -= (newRowId - rowId); offset = newOffset; + rowId = newRowId; + } + + /** + * Advance to the next range. + */ + void nextRange() { + if (rowRanges == null) { + currentRange = MAX_ROW_RANGE; + } else if (!rowRanges.hasNext()) { + currentRange = END_ROW_RANGE; + } else { + currentRange = rowRanges.next(); + } + } + + /** + * Helper struct to represent a range of row indexes `[start, end]`. + */ + private static class RowRange { + final long start; + final long end; + + RowRange(long start, long end) { + this.start = start; + this.end = end; + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java index b91d507a38786..9bb852987e656 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java @@ -30,12 +30,20 @@ public interface ParquetVectorUpdater { * @param values destination values vector * @param valuesReader reader to read values from */ - void updateBatch( + void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader); + /** + * Skip a batch of `total` values from `valuesReader`. + * + * @param total total number of values to skip + * @param valuesReader reader to skip values from + */ + void skipValues(int total, VectorizedValuesReader valuesReader); + /** * Read a single value from `valuesReader` into `values`, at `offset`. * @@ -43,7 +51,7 @@ void updateBatch( * @param values destination value vector * @param valuesReader reader to read values from */ - void update(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader); + void readValue(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader); /** * Process a batch of `total` values starting from `offset` in `values`, whose null slots diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 62e34fe549f04..2282dc798463b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -185,7 +185,7 @@ boolean isUnsignedIntTypeMatched(int bitWidth) { private static class BooleanUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -194,7 +194,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipBooleans(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -213,7 +218,7 @@ public void decodeSingleDictionaryId( private static class IntegerUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -222,7 +227,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -241,7 +251,7 @@ public void decodeSingleDictionaryId( private static class UnsignedIntegerUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -250,7 +260,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -270,7 +285,7 @@ public void decodeSingleDictionaryId( private static class ByteUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -279,7 +294,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipBytes(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -298,7 +318,7 @@ public void decodeSingleDictionaryId( private static class ShortUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -307,7 +327,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipShorts(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -332,7 +357,7 @@ private static class IntegerWithRebaseUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -341,7 +366,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -362,7 +392,7 @@ public void decodeSingleDictionaryId( private static class LongUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -371,7 +401,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -390,7 +425,7 @@ public void decodeSingleDictionaryId( private static class DowncastLongUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -401,7 +436,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -420,7 +460,7 @@ public void decodeSingleDictionaryId( private static class UnsignedLongUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -429,7 +469,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -457,7 +502,7 @@ private static class LongWithRebaseUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -466,7 +511,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -487,18 +537,23 @@ public void decodeSingleDictionaryId( private static class LongAsMicrosUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -524,18 +579,23 @@ private static class LongAsMicrosRebaseUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -557,7 +617,7 @@ public void decodeSingleDictionaryId( private static class FloatUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -566,7 +626,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFloats(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -585,7 +650,7 @@ public void decodeSingleDictionaryId( private static class DoubleUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -594,7 +659,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipDoubles(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -613,7 +683,7 @@ public void decodeSingleDictionaryId( private static class BinaryUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -622,7 +692,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipBinary(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -642,18 +717,23 @@ public void decodeSingleDictionaryId( private static class BinaryToSQLTimestampUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -681,18 +761,23 @@ private static class BinaryToSQLTimestampConvertTzUpdater implements ParquetVect } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -723,18 +808,23 @@ private static class BinaryToSQLTimestampRebaseUpdater implements ParquetVectorU } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -767,18 +857,23 @@ private static class BinaryToSQLTimestampConvertTzRebaseUpdater implements Parqu } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -811,18 +906,23 @@ private static class FixedLenByteArrayUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, arrayLen); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -848,18 +948,23 @@ private static class FixedLenByteArrayAsIntUpdater implements ParquetVectorUpdat } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, arrayLen); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -886,18 +991,23 @@ private static class FixedLenByteArrayAsLongUpdater implements ParquetVectorUpda } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, arrayLen); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c61ee460880a8..92dea08102dfc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.ZoneId; +import java.util.PrimitiveIterator; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesInput; @@ -74,6 +75,12 @@ public class VectorizedColumnReader { */ private final ParquetReadState readState; + /** + * The index for the first row in the current page, among all rows across all pages in the + * column chunk for this reader. If there is no column index, the value is 0. + */ + private long pageFirstRowIndex; + private final PageReader pageReader; private final ColumnDescriptor descriptor; private final LogicalTypeAnnotation logicalTypeAnnotation; @@ -83,12 +90,13 @@ public VectorizedColumnReader( ColumnDescriptor descriptor, LogicalTypeAnnotation logicalTypeAnnotation, PageReader pageReader, + PrimitiveIterator.OfLong rowIndexes, ZoneId convertTz, String datetimeRebaseMode, String int96RebaseMode) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; - this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel()); + this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel(), rowIndexes); this.logicalTypeAnnotation = logicalTypeAnnotation; this.updaterFactory = new ParquetVectorUpdaterFactory( logicalTypeAnnotation, convertTz, datetimeRebaseMode, int96RebaseMode); @@ -151,18 +159,19 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // page. dictionaryIds = column.reserveDictionaryIds(total); } - readState.resetForBatch(total); + readState.resetForNewBatch(total); while (readState.valuesToReadInBatch > 0) { - // Compute the number of values we want to read in this page. if (readState.valuesToReadInPage == 0) { int pageValueCount = readPage(); - readState.resetForPage(pageValueCount); + readState.resetForNewPage(pageValueCount, pageFirstRowIndex); } PrimitiveType.PrimitiveTypeName typeName = descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Save starting offset in case we need to decode dictionary IDs. int startOffset = readState.offset; + // Save starting row index so we can check if we need to eagerly decode dict ids later + long startRowId = readState.rowId; // Read and decode dictionary ids. defColumn.readIntegers(readState, dictionaryIds, column, @@ -170,10 +179,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. - if (column.hasDictionary() || (startOffset == 0 && isLazyDecodingSupported(typeName))) { + if (column.hasDictionary() || (startRowId == pageFirstRowIndex && + isLazyDecodingSupported(typeName))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. - // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some - // non-dictionary encoded values have already been added). + // We can't do this if startRowId is not the first row index in the page AND the column + // doesn't have a dictionary (i.e. some non-dictionary encoded values have already been + // added). PrimitiveType primitiveType = descriptor.getPrimitiveType(); // We need to make sure that we initialize the right type for the dictionary otherwise @@ -213,6 +224,8 @@ void readBatch(int total, WritableColumnVector column) throws IOException { private int readPage() { DataPage page = pageReader.readPage(); + this.pageFirstRowIndex = page.getFirstRowIndex().orElse(0L); + return page.accept(new DataPage.Visitor() { @Override public Integer visit(DataPageV1 dataPageV1) { @@ -268,7 +281,6 @@ private void initDataReader( } private int readPageV1(DataPageV1 page) throws IOException { - // Initialize the decoders. if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { throw new UnsupportedOperationException("Unsupported encoding: " + page.getDlEncoding()); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 32455278c4fb7..9f7836ae4818d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -334,6 +334,7 @@ private void checkEndOfRowGroup() throws IOException { columns.get(i), types.get(i).getLogicalTypeAnnotation(), pages.getPageReader(columns.get(i)), + pages.getRowIndexes().orElse(null), convertTz, datetimeRebaseMode, int96RebaseMode); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 6a0038dbdc44c..39591be3b4be4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -61,6 +61,14 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + @Override + public final void skipBooleans(int total) { + // TODO: properly vectorize this + for (int i = 0; i < total; i++) { + readBoolean(); + } + } + private ByteBuffer getBuffer(int length) { try { return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); @@ -84,6 +92,11 @@ public final void readIntegers(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipIntegers(int total) { + in.skip(total * 4L); + } + @Override public final void readUnsignedIntegers(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 4; @@ -140,6 +153,11 @@ public final void readLongs(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipLongs(int total) { + in.skip(total * 8L); + } + @Override public final void readUnsignedLongs(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 8; @@ -197,6 +215,11 @@ public final void readFloats(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipFloats(int total) { + in.skip(total * 4L); + } + @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 8; @@ -212,6 +235,11 @@ public final void readDoubles(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipDoubles(int total) { + in.skip(total * 8L); + } + @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. @@ -226,6 +254,11 @@ public final void readBytes(int total, WritableColumnVector c, int rowId) { } } + @Override + public final void skipBytes(int total) { + in.skip(total * 4L); + } + @Override public final void readShorts(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 4; @@ -236,6 +269,11 @@ public final void readShorts(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipShorts(int total) { + in.skip(total * 4L); + } + @Override public final boolean readBoolean() { // TODO: vectorize decoding and keep boolean[] instead of currentByte @@ -300,6 +338,14 @@ public final void readBinary(int total, WritableColumnVector v, int rowId) { } } + @Override + public void skipBinary(int total) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + in.skip(len); + } + } + @Override public final Binary readBinary(int len) { ByteBuffer buffer = getBuffer(len); @@ -312,4 +358,9 @@ public final Binary readBinary(int len) { return Binary.fromConstantByteArray(bytes); } } + + @Override + public void skipFixedLenByteArray(int total, int len) { + in.skip(total * (long) len); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 538b69877e2d5..03bda0fedbd29 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -156,18 +156,12 @@ public int readInteger() { } /** - * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader - * reads the definition levels and then will read from `data` for the non-null values. - * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only - * necessary for readIntegers because we also use it to decode dictionaryIds and want to make - * sure it always has a value in range. - * - * This is a batched version of this logic: - * if (this.readInt() == level) { - * c[rowId] = data.readInteger(); - * } else { - * c[rowId] = null; - * } + * Reads a batch of values into vector `values`, using `valueReader`. The related states such + * as row index, offset, number of values left in the batch and page, etc, are tracked by + * `state`. The type-specific `updater` is used to update or skip values. + *

+ * This reader reads the definition levels and then will read from `valueReader` for the + * non-null values. If the value is null, `values` will be populated with null value. */ public void readBatch( ParquetReadState state, @@ -175,36 +169,68 @@ public void readBatch( VectorizedValuesReader valueReader, ParquetVectorUpdater updater) throws IOException { int offset = state.offset; - int left = Math.min(state.valuesToReadInBatch, state.valuesToReadInPage); + long rowId = state.rowId; + int leftInBatch = state.valuesToReadInBatch; + int leftInPage = state.valuesToReadInPage; - while (left > 0) { + while (leftInBatch > 0 && leftInPage > 0) { if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - - switch (mode) { - case RLE: - if (currentValue == state.maxDefinitionLevel) { - updater.updateBatch(n, offset, values, valueReader); - } else { - values.putNulls(offset, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { - updater.update(offset + i, values, valueReader); + int n = Math.min(leftInBatch, Math.min(leftInPage, this.currentCount)); + + long rangeStart = state.currentRangeStart(); + long rangeEnd = state.currentRangeEnd(); + + if (rowId + n < rangeStart) { + updater.skipValues(n, valueReader); + advance(n); + rowId += n; + leftInPage -= n; + } else if (rowId > rangeEnd) { + state.nextRange(); + } else { + // the range [rowId, rowId + n) overlaps with the current row range in state + long start = Math.max(rangeStart, rowId); + long end = Math.min(rangeEnd, rowId + n - 1); + + // skip the part [rowId, start) + int toSkip = (int) (start - rowId); + if (toSkip > 0) { + updater.skipValues(toSkip, valueReader); + advance(toSkip); + rowId += toSkip; + leftInPage -= toSkip; + } + + // read the part [start, end] + n = (int) (end - start + 1); + + switch (mode) { + case RLE: + if (currentValue == state.maxDefinitionLevel) { + updater.readValues(n, offset, values, valueReader); } else { - values.putNull(offset + i); + values.putNulls(offset, n); } - } - break; + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { + updater.readValue(offset + i, values, valueReader); + } else { + values.putNull(offset + i); + } + } + break; + } + offset += n; + leftInBatch -= n; + rowId += n; + leftInPage -= n; + currentCount -= n; } - offset += n; - left -= n; - currentCount -= n; } - state.advanceOffset(offset); + state.advanceOffsetAndRowId(offset, rowId); } /** @@ -217,36 +243,68 @@ public void readIntegers( WritableColumnVector nulls, VectorizedValuesReader data) throws IOException { int offset = state.offset; - int left = Math.min(state.valuesToReadInBatch, state.valuesToReadInPage); + long rowId = state.rowId; + int leftInBatch = state.valuesToReadInBatch; + int leftInPage = state.valuesToReadInPage; - while (left > 0) { + while (leftInBatch > 0 && leftInPage > 0) { if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - - switch (mode) { - case RLE: - if (currentValue == state.maxDefinitionLevel) { - data.readIntegers(n, values, offset); - } else { - nulls.putNulls(offset, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { - values.putInt(offset + i, data.readInteger()); + int n = Math.min(leftInBatch, Math.min(leftInPage, this.currentCount)); + + long rangeStart = state.currentRangeStart(); + long rangeEnd = state.currentRangeEnd(); + + if (rowId + n < rangeStart) { + data.skipIntegers(n); + advance(n); + rowId += n; + leftInPage -= n; + } else if (rowId > rangeEnd) { + state.nextRange(); + } else { + // the range [rowId, rowId + n) overlaps with the current row range in state + long start = Math.max(rangeStart, rowId); + long end = Math.min(rangeEnd, rowId + n - 1); + + // skip the part [rowId, start) + int toSkip = (int) (start - rowId); + if (toSkip > 0) { + data.skipIntegers(toSkip); + advance(toSkip); + rowId += toSkip; + leftInPage -= toSkip; + } + + // read the part [start, end] + n = (int) (end - start + 1); + + switch (mode) { + case RLE: + if (currentValue == state.maxDefinitionLevel) { + data.readIntegers(n, values, offset); } else { - nulls.putNull(offset + i); + nulls.putNulls(offset, n); } - } - break; + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { + values.putInt(offset + i, data.readInteger()); + } else { + nulls.putNull(offset + i); + } + } + break; + } + rowId += n; + leftInPage -= n; + offset += n; + leftInBatch -= n; + currentCount -= n; } - offset += n; - left -= n; - currentCount -= n; } - state.advanceOffset(offset); + state.advanceOffsetAndRowId(offset, rowId); } @@ -346,6 +404,71 @@ public Binary readBinary(int len) { throw new UnsupportedOperationException("only readInts is valid."); } + @Override + public void skipIntegers(int total) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + advance(n); + left -= n; + } + } + + @Override + public void skipBooleans(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipBytes(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipShorts(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipLongs(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipFloats(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipDoubles(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipBinary(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipFixedLenByteArray(int total, int len) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + /** + * Advance and skip the next `n` values in the current block. `n` MUST be <= `currentCount`. + */ + private void advance(int n) { + switch (mode) { + case RLE: + break; + case PACKED: + currentBufferIdx += n; + break; + } + currentCount -= n; + } + /** * Reads the next varint encoded int. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index a2d663fd8c8b6..fc4eac94d1c46 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -50,4 +50,17 @@ public interface VectorizedValuesReader { void readFloats(int total, WritableColumnVector c, int rowId); void readDoubles(int total, WritableColumnVector c, int rowId); void readBinary(int total, WritableColumnVector c, int rowId); + + /* + * Skips `total` values + */ + void skipBooleans(int total); + void skipBytes(int total); + void skipShorts(int total); + void skipIntegers(int total); + void skipLongs(int total); + void skipFloats(int total); + void skipDoubles(int total); + void skipBinary(int total); + void skipFixedLenByteArray(int total, int len); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala new file mode 100644 index 0000000000000..f10b7013185b3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.test.SharedSparkSession + +class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSparkSession { + import testImplicits._ + + /** + * create parquet file with two columns and unaligned pages + * pages will be of the following layout + * col_1 500 500 500 500 + * |---------|---------|---------|---------| + * |-------|-----|-----|---|---|---|---|---| + * col_2 400 300 200 200 200 200 200 200 + */ + def checkUnalignedPages(actions: (DataFrame => DataFrame)*): Unit = { + withTempPath(file => { + val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i / 100).toInt)) + ds.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + + actions.foreach { action => + checkAnswer(action(parquetDf), action(ds.toDF())) + } + }) + } + + test("reading from unaligned pages - test filters") { + checkUnalignedPages( + // single value filter + df => df.filter("_1 = 500"), + df => df.filter("_1 = 500 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500"), + // range filter + df => df.filter("_1 >= 500 and _1 < 1000"), + df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 < 1600)") + ) + } + + test("test reading unaligned pages - test all types") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id as short) as _3", + "cast(id as int) as _4", + "cast(id as float) as _5", + "cast(id as double) as _6", + "cast(id as decimal(20,0)) as _7", + "cast(cast(1618161925000 + id * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + id as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500 " + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } + + test("test reading unaligned pages - test all types (dict encode)") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id % 10 as byte) as _2", + "cast(id % 10 as short) as _3", + "cast(id % 10 as int) as _4", + "cast(id % 10 as float) as _5", + "cast(id % 10 as double) as _6", + "cast(id % 10 as decimal(20,0)) as _7", + "cast(id % 2 as boolean) as _8", + "cast(cast(1618161925000 + (id % 10) * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + (id % 10) as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500" + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index bc4234f01b5fe..a330b82de2d0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -368,7 +368,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession private def createParquetWriter( schema: MessageType, path: Path, - dictionaryEnabled: Boolean = false): ParquetWriter[Group] = { + dictionaryEnabled: Boolean = false, + pageSize: Int = 1024, + dictionaryPageSize: Int = 1024): ParquetWriter[Group] = { val hadoopConf = spark.sessionState.newHadoopConf() ExampleParquetWriter @@ -378,11 +380,77 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession .withWriterVersion(PARQUET_1_0) .withCompressionCodec(GZIP) .withRowGroupSize(1024 * 1024) - .withPageSize(1024) + .withPageSize(pageSize) + .withDictionaryPageSize(dictionaryPageSize) .withConf(hadoopConf) .build() } + test("SPARK-34859: test multiple pages with different sizes and nulls") { + def makeRawParquetFile( + path: Path, + dictionaryEnabled: Boolean, + n: Int, + pageSize: Int): Seq[Option[Int]] = { + val schemaStr = + """ + |message root { + | optional boolean _1; + | optional int32 _2; + | optional int64 _3; + | optional float _4; + | optional double _5; + |} + """.stripMargin + + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter(schema, path, + dictionaryEnabled = dictionaryEnabled, pageSize = pageSize, dictionaryPageSize = pageSize) + + val rand = scala.util.Random + val expected = (0 until n).map { i => + if (rand.nextBoolean()) { + None + } else { + Some(i) + } + } + expected.foreach { opt => + val record = new SimpleGroup(schema) + opt match { + case Some(i) => + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + case _ => + } + writer.write(record) + } + + writer.close() + expected + } + + Seq(true, false).foreach { dictionaryEnabled => + Seq(64, 128, 89).foreach { pageSize => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = makeRawParquetFile(path, dictionaryEnabled, 1000, pageSize) + readParquetFile(path.toString) { df => + checkAnswer(df, expected.map { + case None => + Row(null, null, null, null, null) + case Some(i) => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) + } + } + } + } + } + test("read raw Parquet file") { def makeRawParquetFile(path: Path): Unit = { val schemaStr =