diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 06cd9ea2d242..84e9e0ee0e2f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -61,6 +61,8 @@ import org.apache.parquet.hadoop.util.ConfigurationUtil; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Types; +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter; +import org.apache.spark.sql.execution.datasources.parquet.ParquetStruct; import org.apache.spark.TaskContext; import org.apache.spark.TaskContext$; import org.apache.spark.sql.types.StructType; @@ -81,6 +83,7 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader= endOfPageValueCount) { - if (valuesRead >= totalValueCount) { - // How do we get here? Throw end of stream exception? - return false; - } - readPage(); - } - ++valuesRead; - // TODO: Don't read for flat schemas - //repetitionLevel = repetitionLevelColumn.nextInt(); - return definitionLevelColumn.nextInt() == maxDefLevel; - } + boolean resetNestedRecord = true; /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, ColumnVector column) throws IOException { + public void readBatch(int total, ColumnVector column) throws IOException { + asComplexColElement = column.getParentColumn() != null; + boolean isRepeatedColumn = maxRepLevel > 0; int rowId = 0; + int repeatedRowId = 0; + int remaining = total; + + // The number of values to read. + int num = 0; + + // Stores row ids and offsets during constructing nested records. + int[] rowIds = new int[maxRepLevel + 2]; + int[] offsets = new int[maxRepLevel + 2]; + + // Keeps repetition levels and corresponding repetition counts. + int[] repetitions = new int[maxRepLevel + 2]; + ColumnVector dictionaryIds = null; if (dictionary != null) { // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to @@ -143,18 +162,60 @@ void readBatch(int total, ColumnVector column) throws IOException { // page. dictionaryIds = column.reserveDictionaryIds(total); } - while (total > 0) { + + while (true) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); + + // Stop condition: + // If we are going to read data in repeated column, the stop condition is that we + // read `total` repeated columns. Eg., if we want to read 5 records of an array of int column. + // we can't just read 5 integers. Instead, we have to read the integers until 5 arrays are put + // into this array column. + if (isRepeatedColumn) { + if (repeatedRowId == total) break; + } else { + if (remaining == 0) break; + } + + // Reaching the end of current page. if (leftInPage == 0) { - readPage(); + boolean pageExists = readPage(); + if (!pageExists) { + if (!resetNestedRecord) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); + resetNestedRecord = true; + repeatedRowId = rowIds[1]; + if (repeatedRowId == total) break; + } + // Should not reach here. + throw new IOException("Failed to read page. No page exists anymore!"); + } leftInPage = (int) (endOfPageValueCount - valuesRead); } - int num = Math.min(total, leftInPage); + + // Determine the number of values to read for this column in the current page. + if (asComplexColElement) { + // Using repetition and definition level encodings to construct nested/repeated records. + // When constructing nested/repeated records, we returns the number of values to read in + // this page for this column. + num = constructComplexRecords(column, repetitions, rowIds, offsets, leftInPage, total); + repeatedRowId = rowIds[1]; + } else { + // If this column is not a repeated/nested column, just read minimum of remaining values + // and all values left in the current page. + num = Math.min(remaining, leftInPage); + } + if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. + if (asComplexColElement) { + int dictionaryCapacity = Math.max(remaining, rowId + num); + dictionaryIds = column.reserveDictionaryIds(dictionaryCapacity); + } defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || @@ -171,10 +232,13 @@ void readBatch(int total, ColumnVector column) throws IOException { } else { if (column.hasDictionary() && rowId != 0) { // This batch already has dictionary encoded values but this new page is not. The batch - // does not support a mix of dictionary and not so we will decode the dictionary. + // does not support a mix of dictionary and not, so we will decode the dictionary. decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); + if (asComplexColElement) { + column.reserve(Math.max(remaining, rowId + num)); + } switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -204,10 +268,9 @@ void readBatch(int total, ColumnVector column) throws IOException { throw new IOException("Unsupported type: " + descriptor.getType()); } } - valuesRead += num; rowId += num; - total -= num; + remaining -= num; } } @@ -425,30 +488,35 @@ private void readFixedLenByteArrayBatch(int rowId, int num, } } - private void readPage() throws IOException { + private boolean readPage() throws IOException { DataPage page = pageReader.readPage(); - // TODO: Why is this a visitor? - page.accept(new DataPage.Visitor() { - @Override - public Void visit(DataPageV1 dataPageV1) { - try { - readPageV1(dataPageV1); - return null; - } catch (IOException e) { - throw new RuntimeException(e); + if (page == null) { + return false; + } else { + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } } - } - @Override - public Void visit(DataPageV2 dataPageV2) { - try { - readPageV2(dataPageV2); - return null; - } catch (IOException e) { - throw new RuntimeException(e); + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } } - } - }); + }); + return true; + } } private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { @@ -482,6 +550,292 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } } + /** + * Inserts records into parent columns of a column. These parent columns are repeated columns. As + * the real data are read into the column, we only need to insert array into its repeated columns. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param offsets The beginning offsets in columns which we use to construct nested records. + * @param repetitions Mapping between repetition levels and their corresponding counts. + * @param total The total number of rows to construct. + * @param repLevel The current repetition level. + */ + private void insertRepeatedArray( + ColumnVector column, + int[] rowIds, + int[] offsets, + int[] repetitions, + int total, + int repLevel) throws IOException { + ColumnVector parentRepeatedColumn = column; + int curRepLevel = maxRepLevel; + while (true) { + parentRepeatedColumn = parentRepeatedColumn.getNearestParentArrayColumn(); + if (parentRepeatedColumn != null) { + int parentColRepLevel = parentRepeatedColumn.getRepLevel(); + // The current repetition level means the beginning level of the current value. Thus, + // we only need to insert array into the parent columns whose repetition levels are + // equal to or more than the given repetition level. + if (parentColRepLevel >= repLevel) { + parentRepeatedColumn.reserve(rowIds[curRepLevel] + 1); + parentRepeatedColumn.putArray(rowIds[curRepLevel], + offsets[curRepLevel], repetitions[curRepLevel]); + + offsets[curRepLevel] += repetitions[curRepLevel]; + repetitions[curRepLevel] = 0; + rowIds[curRepLevel]++; + + // Increase the repetition count for parent repetition level as we add a new record. + if (curRepLevel > 1) { + repetitions[curRepLevel - 1]++; + } + + // In vectorization, the most outside repeated element is at the repetition 1. + if (curRepLevel == 1 && rowIds[curRepLevel] == total) { + return; + } + curRepLevel--; + } else { + break; + } + } else { + break; + } + } + } + + /** + * Finds the outside element of an inner element which is defined as Catalyst DataType, + * with the specified definition level. + * @param column The column as the beginning level for looking up the inner element. + * @param defLevel The specified definition level. + * @return the column which is the outside group element of the inner element. + */ + private ColumnVector findInnerElementWithDefLevel(ColumnVector column, int defLevel) { + while (true) { + if (column == null) { + return null; + } + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == defLevel) { + ColumnVector outside = parent.getParentColumn(); + if (outside == null || outside.getDefLevel() < defLevel) { + return column; + } + } + column = parent; + } + } + + /** + * Finds the outside element of the inner element which is not defined as Catalyst DataType, + * with the specified definition level. + * @param column The column as the beginning level for looking up the inner element. + * @param defLevel The specified definition level. + * @return the column which is the outside group element of the inner element. + */ + private ColumnVector findHiddenInnerElementWithDefLevel(ColumnVector column, int defLevel) { + while (true) { + if (column == null) { + return null; + } + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() <= defLevel) { + ColumnVector outside = parent.getParentColumn(); + if (outside == null || outside.getDefLevel() < defLevel) { + return column; + } + } + column = parent; + } + } + + /** + * Checks if the given column is a legacy array in Parquet schema. + * @param column The column we want to check if it is legacy array. + * @return whether the given column is a legacy array in Parquet schema. + */ + private boolean isLegacyArray(ColumnVector column) { + ColumnVector parent = column.getNearestParentArrayColumn(); + if (parent == null) { + return false; + } else if (parent.getRepLevel() <= maxRepLevel && parent.getDefLevel() < maxDefLevel) { + return true; + } + return false; + } + + /** + * Inserts a null record at specified column. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param repetitions Mapping between repetition levels and their corresponding counts. + */ + private void insertNullRecord( + ColumnVector column, + int[] rowIds, + int[] repetitions) { + int repLevel = column.getRepLevel(); + + if (repLevel == 0) { + repLevel = 1; + } + + rowIds[repLevel] += repetitions[repLevel]; + repetitions[repLevel] = 0; + + column.reserve(rowIds[repLevel] + 1); + column.putNull(rowIds[repLevel]); + rowIds[repLevel]++; + } + + /** + * Returns the array of repetition level values. + */ + private int[] getRepetitionLevels() throws IOException { + int[] repetitions = new int[this.pageValueCount]; + for (int i = 0; i < this.pageValueCount; i++) { + repetitions[i] = this.repetitionLevelColumn.nextInt(); + } + return repetitions; + } + + /** + * Iterates the values of definition and repetition levels for the values read in the page, + * and constructs complex records accordingly. + * @param column The ColumnVector which the data in the page are read into. + @ @param repetitions Mapping between repetition levels and their counts. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param offsets The beginning offsets in columns which we use to construct nested records. + * @param leftInPage The number of values can be read in the current page. + * @param total The total number of rows to construct. + * @return the number of values needed to read in the current page. + */ + private int constructComplexRecords( + ColumnVector column, + int[] repetitions, + int[] rowIds, + int[] offsets, + int leftInPage, + int total) throws IOException { + for (int i = 0; i < leftInPage; i++) { + int repLevel = repetitionLevelColumn.nextInt(); + int defLevel = definitionLevelColumn.nextInt(); + + // If there are previous values and counts needed to be consider. + if (!resetNestedRecord) { + // When a new record begins at lower repetition level, + // we insert array into repeated column. + if (repLevel < maxRepLevel) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + } + } + resetNestedRecord = false; + + // When definition level is less than max definition level, + // there is a null value. + if (defLevel < maxDefLevel) { + int offset = offsets[maxRepLevel]; + + // The null value is defined at the root level. + // Insert a null record. + if (repLevel == 0 && defLevel == 0) { + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == maxDefLevel + && parent.getRepLevel() == maxRepLevel) { + // A repeated element at root level. + // E.g., The repeatedPrimitive at the following schema. + // Going to insert an empty record. + // messageType: message spark_schema { + // optional int32 optionalPrimitive; + // required int32 requiredPrimitive; + // + // repeated int32 repeatedPrimitive; + // + // optional group optionalMessage { + // optional int32 someId; + // } + // required group requiredMessage { + // optional int32 someId; + // } + // repeated group repeatedMessage { + // optional int32 someId; + // } + // } + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + } else { + // Obtain most outside column. + ColumnVector topColumn = column.getParentColumn(); + while (topColumn.getParentColumn() != null) { + topColumn = topColumn.getParentColumn(); + } + + insertNullRecord(topColumn, rowIds, repetitions); + } + // Move to next offset in max repetition level as we processed the current value. + offsets[maxRepLevel]++; + resetNestedRecord = true; + } else if (isLegacyArray(column) && + column.getNearestParentArrayColumn().getDefLevel() == defLevel) { + // For a legacy array, if a null is defined at the repeated group column, it actually + // means an element with null value. + + repetitions[maxRepLevel]++; + } else if (!column.getParentColumn().isArray() && + column.getParentColumn().getDefLevel() == defLevel) { + // A null element defined in the wrapping non-repeated group. + rowIds[1]++; + } else { + // An empty element defined in outside group. + // E.g., the element in the following schema. + // messageType: message spark_schema { + // required int32 index; + // optional group col { + // optional float f1; + // optional group f2 (LIST) { + // repeated group list { + // optional boolean element; + // } + // } + // } + // } + ColumnVector parent = findInnerElementWithDefLevel(column, defLevel); + if (parent != null) { + // Found the group with the same definition level. + // Insert a null record at definition level. + // E.g, R=0, D=1 for above schema. + insertNullRecord(parent, rowIds, repetitions); + offsets[maxRepLevel]++; + resetNestedRecord = true; + } else { + // Found the group with lower definition level. + // Insert an empty record. + // E.g, R=0, D=2 for above schema. + parent = findHiddenInnerElementWithDefLevel(column, defLevel); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + offsets[maxRepLevel]++; + resetNestedRecord = true; + } + } + } else { + // Determine the repetition level of non-null values. + // A new record begins with non-null value. + if (maxRepLevel == 0) { + // A required record at root level. + repetitions[1]++; + insertRepeatedArray(column, rowIds, offsets, repetitions, total, maxRepLevel - 1); + } else { + // Repeated values. We increase repetition count. + repetitions[maxRepLevel]++; + } + } + // If we have constructed `total` records, return the number of values to read. + if (rowIds[1] == total) return i + 1; + } + // All `leftInPage` values in the current page are needed to read. + return leftInPage; + } + private void readPageV1(DataPageV1 page) throws IOException { this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); @@ -495,12 +849,20 @@ private void readPageV1(DataPageV1 page) throws IOException { this.defColumn = new VectorizedRleValuesReader(bitWidth); dlReader = this.defColumn; this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); - this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { byte[] bytes = page.getBytes().toByteArray(); rlReader.initFromPage(pageValueCount, bytes, 0); int next = rlReader.getNextOffset(); dlReader.initFromPage(pageValueCount, bytes, next); + + if (asComplexColElement) { + ValuesReader dlReaderCopy; + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + dlReaderCopy = this.defColumnCopy; + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); + dlReaderCopy.initFromPage(pageValueCount, bytes, next); + } + next = dlReader.getNextOffset(); initDataReader(page.getValueEncoding(), bytes, next); } catch (IOException e) { @@ -515,9 +877,16 @@ private void readPageV2(DataPageV2 page) throws IOException { int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); - this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); this.defColumn.initFromBuffer( this.pageValueCount, page.getDefinitionLevels().toByteArray()); + + if (asComplexColElement) { + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); + this.defColumnCopy.initFromBuffer( + this.pageValueCount, page.getDefinitionLevels().toByteArray()); + } + try { initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); } catch (IOException e) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f229..89b36df103cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -178,7 +179,10 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } } + // Allocate ColumnVectors in ColumnarBatch columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); + columnarBatch.setParquetSchema(this.parquetSchema); + columnarBatch.initColumnVectors(); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -188,12 +192,30 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } // Initialize missing columns with nulls. - for (int i = 0; i < missingColumns.length; i++) { - if (missingColumns[i]) { - columnarBatch.column(i).putNulls(0, columnarBatch.capacity()); - columnarBatch.column(i).setIsConstant(); + int missingColumnIdx = 0; + int partitionIdxBase = missingColumns.length; + if (partitionColumns != null) { + partitionIdxBase = sparkSchema.fields().length; + } + for (int i = 0; i < columnarBatch.numFields(); i++) { + if (i < partitionIdxBase) { + missingColumnIdx = initColumnWithNulls(columnarBatch.column(i), missingColumnIdx); + } + } + } + + private int initColumnWithNulls(ColumnVector column, int missingColumnIdx) { + if (column.isComplex()) { + for (int j = 0; j < column.getChildColumnNums(); j++) { + missingColumnIdx = initColumnWithNulls(column.getChildColumn(j), missingColumnIdx); + } + } else { + if (missingColumns[missingColumnIdx++]) { + column.putNulls(0, columnarBatch.capacity()); + column.setIsConstant(); } } + return missingColumnIdx; } public void initBatch() { @@ -225,10 +247,10 @@ public boolean nextBatch() throws IOException { checkEndOfRowGroup(); int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); - for (int i = 0; i < columnReaders.length; ++i) { - if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnarBatch.column(i)); + for (int i = 0; i < columnarBatch.numFields(); i++) { + readBatchOnColumnVector(columnarBatch.column(i), num); } + rowsReturned += num; columnarBatch.setNumRows(num); numBatched = num; @@ -236,17 +258,23 @@ public boolean nextBatch() throws IOException { return true; } + private void readBatchOnColumnVector(ColumnVector column, int num) throws IOException { + if (column.hasColumnReader()) { + column.readBatch(num); + } else { + for (int j = 0; j < column.getChildColumnNums(); j++) { + readBatchOnColumnVector(column.getChildColumn(j), num); + } + } + } + private void initializeInternal() throws IOException, UnsupportedOperationException { /** * Check that the requested schema is supported. */ - missingColumns = new boolean[requestedSchema.getFieldCount()]; - for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { - Type t = requestedSchema.getFields().get(i); - if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { - throw new UnsupportedOperationException("Complex types not supported."); - } - + missingColumns = new boolean[requestedSchema.getColumns().size()]; + // For loop on each physical columns. + for (int i = 0; i < requestedSchema.getColumns().size(); ++i) { String[] colPath = requestedSchema.getPaths().get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); @@ -265,6 +293,25 @@ private void initializeInternal() throws IOException, UnsupportedOperationExcept } } + private int setupColumnReader( + ColumnVector column, + VectorizedColumnReader[] columnReaders, + int readerIdx) { + if (column.isComplex()) { + column.setColumnReader(null); + for (int j = 0; j < column.getChildColumnNums(); j++) { + readerIdx = setupColumnReader(column.getChildColumn(j), columnReaders, readerIdx); + } + } else { + if (!missingColumns[readerIdx]) { + column.setColumnReader(columnReaders[readerIdx++]); + } else { + readerIdx++; + } + } + return readerIdx; + } + private void checkEndOfRowGroup() throws IOException { if (rowsReturned != totalCountLoadedSoFar) return; PageReadStore pages = reader.readNextRowGroup(); @@ -272,13 +319,25 @@ private void checkEndOfRowGroup() throws IOException { throw new IOException("expecting more rows but reached last block. Read " + rowsReturned + " out of " + totalRowCount); } + // Return physical columns stored in Parquet file. Not logical fields. + // For example, a nested StructType field in requestedSchema might have many columns. + // A column is always a primitive type column. List columns = requestedSchema.getColumns(); columnReaders = new VectorizedColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); } + + // Associate ColumnReaders to ColumnVectors in ColumnarBatch. + int readerIdx = 0; + int partitionIdx = sparkSchema.fields().length; + for (int i = 0; i < columnarBatch.numFields(); i++) { + if (i >= partitionIdx) break; + readerIdx = setupColumnReader(columnarBatch.column(i), columnReaders, readerIdx); + } totalCountLoadedSoFar += pages.getRowCount(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 25a565d32638..f7a2e28b5c63 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -63,6 +63,7 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int this.maxSteps = maxSteps; numBuckets = (int) (capacity / loadFactor); batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity); + batch.initColumnVectors(); buckets = new int[numBuckets]; Arrays.fill(buckets, -1); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a7cb3b11f687..cfca27290cd7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; @@ -27,6 +28,12 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.execution.datasources.parquet.ParquetArray; +import org.apache.spark.sql.execution.datasources.parquet.ParquetField; +import org.apache.spark.sql.execution.datasources.parquet.ParquetMap; +import org.apache.spark.sql.execution.datasources.parquet.ParquetStruct; +import org.apache.spark.sql.execution.datasources.parquet.RepetitionDefinitionInfo; +import org.apache.spark.sql.execution.datasources.parquet.VectorizedColumnReader; import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -179,9 +186,7 @@ public Object[] array() { public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } @Override - public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); - } + public boolean getBoolean(int ordinal) { return data.getBoolean(offset + ordinal); } @Override public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } @@ -198,9 +203,7 @@ public short getShort(int ordinal) { public long getLong(int ordinal) { return data.getLong(offset + ordinal); } @Override - public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); - } + public float getFloat(int ordinal) { return data.getFloat(offset + ordinal); } @Override public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } @@ -312,6 +315,11 @@ private void throwUnsupportedException(int newCapacity, int requiredCapacity, Th */ protected abstract void reserveInternal(int capacity); + /** + * Ensures that there is enough storage to store null information. + */ + protected abstract void reserveNulls(int capacity); + /** * Returns the number of nulls in this column. */ @@ -516,7 +524,8 @@ private void throwUnsupportedException(int newCapacity, int requiredCapacity, Th public abstract double getDouble(int rowId); /** - * Puts a byte array that already exists in this column. + * Puts an array that already exists in this column. + * This method only updates array length and offset data in this column. */ public abstract void putArray(int rowId, int offset, int length); @@ -862,6 +871,28 @@ public final int appendStruct(boolean isNull) { */ public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + /** + * Returns the number of childColumns. + */ + public final int getChildColumnNums() { + if (childColumns == null) { + return 0; + } else { + return childColumns.length; + } + } + + /** + * Returns whether this ColumnVector represents complex types such as Array, Map, Struct. + */ + public final boolean isComplex() { + if (type instanceof ArrayType || type instanceof StructType || type instanceof MapType) { + return true; + } else { + return false; + } + } + /** * Returns the elements appended. */ @@ -877,6 +908,16 @@ public final int appendStruct(boolean isNull) { */ public final void setIsConstant() { isConstant = true; } + /** + * Returns definition level for this column. This value is valid only if isComplex() return true. + */ + public final int getDefLevel() { return defLevel; } + + /** + * Returns repetition level for this column. This value is valid only if isComplex() return true. + */ + public final int getRepLevel() { return repLevel; } + /** * Maximum number of rows that can be stored in this column. */ @@ -910,6 +951,16 @@ public final int appendStruct(boolean isNull) { */ protected boolean isConstant; + /** + * Max definition level of this column. This value is valid only if this is a nested column. + */ + protected int defLevel; + + /** + * Max repetition level of this column. This value is valid only if this is a nested column. + */ + protected int repLevel; + /** * Default size of each array length value. This grows as necessary. */ @@ -947,6 +998,113 @@ public final int appendStruct(boolean isNull) { */ protected ColumnVector dictionaryIds; + /** + * Represents a field in Parquet schema used to capture schema structure and metadata + * such as repetition and definition levels. + */ + protected ParquetField parquetField; + + /** + * Associated VectorizedColumnReader which is used to load data into this ColumnVector. + * If this is a complex type such as array or struct, the VectorizedColumnReader will be + * null. + */ + protected VectorizedColumnReader columnReader; + + /** + * The parent ColumnVector of this column. If this column is not an element of nested column, + * then this is null. + */ + protected ColumnVector parentColumn; + + /** + * Sets the ParquetField for this column. + */ + public void setParquetField(ParquetField field) { this.parquetField = field; } + + /** + * Gets the repetition and definition metadata object from Parquet schema. + */ + private RepetitionDefinitionInfo getRepetitionDefinitionInfo(DataType type) { + if (this.parquetField != null) { + if (type instanceof StructType) { + ParquetStruct struct = (ParquetStruct)this.parquetField; + return struct.metadata(); + } else if (type instanceof ArrayType) { + ParquetArray array = (ParquetArray)this.parquetField; + return array.metadata(); + } else if (type instanceof MapType) { + ParquetMap map = (ParquetMap)this.parquetField; + return map.metadata(); + } + } + return null; + } + + /** + * Sets the columnReader for this column. + */ + public void setColumnReader(VectorizedColumnReader columnReader) { + this.columnReader = columnReader; + } + + /** + * Sets the parent column for this column. + */ + public void setParentColumn(ColumnVector column) { + this.parentColumn = column; + } + + /** + * Returns the parent column for this column. + */ + public ColumnVector getParentColumn() { + return this.parentColumn; + } + + /** + * The flag shows if the nearest parent column is initialized. + */ + private boolean isNearestParentArrayColumnInited = false; + + /** + * The nearest parent column which is an Array column. + */ + private ColumnVector nearestParentArrayColumn; + + /** + * Returns the nearest parent column which is an Array column. + */ + public ColumnVector getNearestParentArrayColumn() { + if (!isNearestParentArrayColumnInited) { + nearestParentArrayColumn = this.parentColumn; + while (nearestParentArrayColumn != null && !nearestParentArrayColumn.isArray()) { + nearestParentArrayColumn = nearestParentArrayColumn.parentColumn; + } + isNearestParentArrayColumnInited = true; + } + return nearestParentArrayColumn; + } + + /** + * Returns if this ColumnVector has initialized VectorizedColumnReader. + */ + public boolean hasColumnReader() { + return this.columnReader != null; + } + + /** + * Reads `total` values from associated columnReader into this column. + */ + public void readBatch(int total) throws IOException { + if (this.columnReader != null) { + this.columnReader.readBatch(total, this); + } else { + throw new RuntimeException("The reader of this ColumnVector is not initialized yet. " + + "Failed to call readBatch()."); + } + } + /** * Update the dictionary. */ @@ -970,6 +1128,7 @@ public ColumnVector reserveDictionaryIds(int capacity) { dictionaryIds.reset(); dictionaryIds.reserve(capacity); } + reserveNulls(capacity); return dictionaryIds; } @@ -980,6 +1139,33 @@ public ColumnVector getDictionaryIds() { return dictionaryIds; } + public void initRepetitionAndDefinitionLevels() { + if (type instanceof ArrayType) { + DataType childType; + ArrayType arrayType = (ArrayType)type; + RepetitionDefinitionInfo metadata = getRepetitionDefinitionInfo(type); + if (metadata != null) { + this.defLevel = metadata.definition(); + this.repLevel = metadata.repetition(); + } + ParquetArray parquetArray = (ParquetArray)this.parquetField; + this.childColumns[0].setParquetField(parquetArray.element()); + this.childColumns[0].initRepetitionAndDefinitionLevels(); + } else if (type instanceof StructType) { + RepetitionDefinitionInfo metadata = getRepetitionDefinitionInfo(type); + if (metadata != null) { + this.defLevel = metadata.definition(); + this.repLevel = metadata.repetition(); + } + StructType st = (StructType)type; + for (int i = 0; i < childColumns.length; ++i) { + ParquetStruct parquetStruct = (ParquetStruct)this.parquetField; + this.childColumns[i].setParquetField(parquetStruct.fields()[i]); + this.childColumns[i].initRepetitionAndDefinitionLevels(); + } + } + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. @@ -993,6 +1179,7 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { + ArrayType arrayType = (ArrayType)type; childType = ((ArrayType)type).elementType(); } else { childType = DataTypes.ByteType; @@ -1000,20 +1187,24 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { } this.childColumns = new ColumnVector[1]; this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); + this.childColumns[0].setParentColumn(this); this.resultArray = new Array(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { StructType st = (StructType)type; this.childColumns = new ColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + this.childColumns[i] = + ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + this.childColumns[i].setParentColumn(this); } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new ColumnVector[2]; - this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[0] = + ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2fa476b9cfb7..b62bf2db6e5a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -198,6 +198,7 @@ private static void appendValue(ColumnVector dst, DataType t, Row src, int field public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode); + batch.initColumnVectors(); int n = 0; while (row.hasNext()) { Row r = row.next(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index f3afa8f938f8..850092a10cae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.execution.datasources.parquet.ParquetField; +import org.apache.spark.sql.execution.datasources.parquet.ParquetStruct; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -52,6 +54,10 @@ public final class ColumnarBatch { private final int capacity; private int numRows; private final ColumnVector[] columns; + private MemoryMode memMode; + + // This captures Parquet schema structure and metadata such as definition and repetition levels. + private ParquetStruct parquetStruct; // True if the row is filtered. private final boolean[] filteredRows; @@ -77,6 +83,10 @@ public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int return new ColumnarBatch(schema, maxRows, memMode); } + public int numFields() { return columns.length; } + + public void setParquetSchema(ParquetStruct struct) { this.parquetStruct = struct; } + /** * Called to close all the columns in this batch. It is not valid to access the data after * calling this. This must be called at the end to clean up memory allocations. @@ -461,18 +471,27 @@ public void filterNullsInColumn(int ordinal) { nullFilteredColumns.add(ordinal); } + public void initColumnVectors() { + for (int i = 0; i < this.schema.fields().length; ++i) { + StructField field = this.schema.fields()[i]; + this.columns[i] = ColumnVector.allocate(this.capacity, field.dataType(), this.memMode); + if (this.parquetStruct != null) { + if (i < this.parquetStruct.fields().length) { + ParquetField parquetField = this.parquetStruct.fields()[i]; + this.columns[i].setParquetField(parquetField); + this.columns[i].initRepetitionAndDefinitionLevels(); + } + } + } + } + private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { this.schema = schema; this.capacity = maxRows; + this.memMode = memMode; this.columns = new ColumnVector[schema.size()]; this.nullFilteredColumns = new HashSet<>(); this.filteredRows = new boolean[maxRows]; - - for (int i = 0; i < schema.fields().length; ++i) { - StructField field = schema.fields()[i]; - columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode); - } - this.row = new Row(this); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 12fa109cec82..b8948cae1ae8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -456,8 +456,13 @@ protected void reserveInternal(int newCapacity) { } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + reserveNulls(newCapacity); capacity = newCapacity; } + + @Override + protected void reserveNulls(int capacity) { + this.nulls = Platform.reallocateMemory(nulls, elementsAppended, capacity); + Platform.setMemory(nulls + elementsAppended, (byte)0, capacity - elementsAppended); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 9b410bacff5d..5ee9389aeb2f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -410,53 +410,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, this.arrayLengths.length); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, this.arrayOffsets.length); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, byteData.length); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, byteData.length); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, shortData.length); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, intData.length); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, longData.length); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, floatData.length); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, doubleData.length); doubleData = newData; } } else if (resultStruct != null) { @@ -465,10 +465,17 @@ protected void reserveInternal(int newCapacity) { throw new RuntimeException("Unhandled " + type); } - byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); - nulls = newNulls; + reserveNulls(newCapacity); capacity = newCapacity; } + + @Override + protected void reserveNulls(int capacity) { + if (nulls == null || nulls.length < capacity) { + byte[] newNulls = new byte[capacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, nulls.length); + nulls = newNulls; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7418df90b824..163074bcee47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -89,9 +89,11 @@ class VectorizedHashMapGenerator( | public $generatedClassName() { | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | batch.initColumnVectors(); | // TODO: Possibly generate this projection in HashAggregate directly | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | aggregateBufferBatch.initColumnVectors(); | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length})); | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 7794f31331a8..02d9c6d076d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -259,7 +259,7 @@ class ParquetFileFormat val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && - schema.forall(_.dataType.isInstanceOf[AtomicType]) + !schema.existsRecursively(f => f.isInstanceOf[MapType] || f.isInstanceOf[UserDefinedType[_]]) } override def isSplitable( @@ -335,7 +335,9 @@ class ParquetFileFormat val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && - resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + !resultSchema.existsRecursively(f => f.isInstanceOf[MapType] || + f.isInstanceOf[UserDefinedType[_]]) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c81a65f4973e..12a3682d8529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -94,6 +94,86 @@ private[parquet] class ParquetSchemaConverter( StructType(fields) } + /** + * Obtains schema structure and metadata (repetition and definition information) for + * Parquet [[MessageType]] `parquetSchema`. + */ + def getParquetStruct(parquetSchema: MessageType): ParquetStruct = { + getParquetStruct(parquetSchema.asGroupType(), parquetSchema, Seq.empty[String]) + } + + private def getParquetStruct( + parquetSchema: GroupType, + messageType: MessageType, + path: Seq[String]): ParquetStruct = { + val fields = parquetSchema.getFields.asScala.map { field => + field.getRepetition match { + case OPTIONAL => + getParquetField(field, messageType, path) + + case REQUIRED => + getParquetField(field, messageType, path) + + case REPEATED => + val curPath = path ++ Seq(field.getName) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) + val repLevel = messageType.getMaxRepetitionLevel(curPath: _*) + + val inner = getParquetField(field, messageType, path) + ParquetArray(inner, RepetitionDefinitionInfo(repLevel, defLevel)) + } + } + + if (path.isEmpty) { + ParquetStruct(fields.toArray, RepetitionDefinitionInfo(0, 0)) + } else { + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + ParquetStruct(fields.toArray, RepetitionDefinitionInfo(repLevel, defLevel)) + } + } + + private def getParquetField( + parquetType: Type, + messageType: MessageType, + path: Seq[String]): ParquetField = parquetType match { + case t: PrimitiveType => new ParquetField() + case t: GroupType => + val curPath = path ++ Seq(t.getName()) + getParquetGroupField(t.asGroupType(), messageType, curPath) + } + + private def getParquetGroupField( + field: GroupType, + messageType: MessageType, + path: Seq[String]): ParquetField = { + Option(field.getOriginalType).fold(getParquetStruct(field, messageType, path): ParquetField) { + case LIST => + ParquetSchemaConverter.checkConversionRequirement( + field.getFieldCount == 1, s"Invalid list type $field") + + val repeatedType = field.getType(0) + ParquetSchemaConverter.checkConversionRequirement( + repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + if (isElementType(repeatedType, field.getName)) { + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val inner = getParquetField(repeatedType, messageType, path) + ParquetArray(inner, RepetitionDefinitionInfo(repLevel, defLevel)) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val curPath = path ++ Seq(repeatedType.getName) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val inner = getParquetField(elementType, messageType, curPath) + ParquetArray(inner, RepetitionDefinitionInfo(repLevel, defLevel)) + } + + case _ => + throw new AnalysisException(s"Unrecognized Parquet type: $field") + } + } + /** * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/RepetitionDefinitionInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/RepetitionDefinitionInfo.scala new file mode 100644 index 000000000000..43888113dff7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/RepetitionDefinitionInfo.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +/** + * This is the wrapper class of repetition and definition level information used in Parquet + * encoding for complext types. + */ +private[sql] case class RepetitionDefinitionInfo(repetition: Int, definition: Int) + +/** + * The following classes are defined to capture the schema structure for Parquet schema. + * We don't care the actual types but only use these to have the structure and metadata such as + * repetition and definition levels. + */ +private[sql] class ParquetField + +private[sql] case class ParquetStruct( + fields: Array[ParquetField], + metadata: RepetitionDefinitionInfo) extends ParquetField + +private[sql] case class ParquetArray( + element: ParquetField, + metadata: RepetitionDefinitionInfo) extends ParquetField + +private[sql] case class ParquetMap( + keyElement: ParquetField, + valueElement: ParquetField, + metadata: RepetitionDefinitionInfo) extends ParquetField diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca87..26380f19835f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -477,6 +477,7 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("string", BinaryType) val batch = ColumnarBatch.allocate(schema, memMode) + batch.initColumnVectors(); assert(batch.numCols() == 4) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0)