From 3dcbbb807673ad4946cde393cfcd46d45233d031 Mon Sep 17 00:00:00 2001 From: Venki Korukanti Date: Wed, 10 Apr 2024 14:53:56 -0700 Subject: [PATCH] [Kernel] Push predicate on partition values to checkpoint reader in state reconstruction (#2872) ## Description Converts the partition predicate into a filter on `add.partitionValues_parsed.`. This predicate is pushed to the Parquet reader when reading the checkpoint files during the state reconstruction. This helps prune reading checkpoint files that can't possibly have any scan files satisfying the given partition predicate. This can be extended in future to even support pushdown of predicate on data columns as well. ## How was this patch tested? Unittests --- .../internal/InternalScanFileUtils.java | 10 +++ .../io/delta/kernel/internal/ScanImpl.java | 36 +++++++---- .../internal/replay/ActionsIterator.java | 9 ++- .../kernel/internal/replay/LogReplay.java | 31 ++++++---- .../kernel/internal/util/PartitionUtils.java | 46 ++++++++++++++ .../internal/util/PartitionUtilsSuite.scala | 62 +++++++++++++------ 6 files changed, 148 insertions(+), 46 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java index b66348374f1..d6b9aaa55fd 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java @@ -176,4 +176,14 @@ public static DeletionVectorDescriptor getDeletionVectorDescriptorFromRow(Row sc Row addFile = getAddFileEntry(scanFile); return DeletionVectorDescriptor.fromRow(addFile.getStruct(ADD_FILE_DV_ORDINAL)); } + + /** + * Get a references column for given partition column name in partitionValues_parsed column in + * scan file row. + * @param partitionColName Partition column name + * @return {@link Column} reference + */ + public static Column getPartitionValuesParsedRefInAddFile(String partitionColName) { + return new Column(new String[]{"add", "partitionValues_parsed", partitionColName}); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index f07fa90f1eb..229b7cb8aeb 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.*; +import java.util.function.Supplier; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; @@ -37,6 +38,7 @@ import io.delta.kernel.internal.skipping.DataSkippingUtils; import io.delta.kernel.internal.util.*; import static io.delta.kernel.internal.skipping.StatsSchemaHelper.getStatsSchema; +import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnCheckpointFileSchema; import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnScanFileSchema; /** @@ -56,6 +58,7 @@ public class ScanImpl implements Scan { private final LogReplay logReplay; private final Path dataPath; private final Optional> partitionAndDataFilters; + private final Supplier> partitionColToStructFieldMap; private boolean accessedScanFiles; public ScanImpl( @@ -73,6 +76,15 @@ public ScanImpl( this.logReplay = logReplay; this.partitionAndDataFilters = splitFilters(filter); this.dataPath = dataPath; + this.partitionColToStructFieldMap = () -> { + Set partitionColNames = metadata.getPartitionColNames(); + return metadata.getSchema().fields().stream() + .filter(field -> partitionColNames.contains( + field.getName().toLowerCase(Locale.ROOT))) + .collect(toMap( + field -> field.getName().toLowerCase(Locale.ROOT), + identity())); + }; } /** @@ -92,8 +104,15 @@ public CloseableIterator getScanFiles(TableClient tableCl boolean shouldReadStats = dataSkippingFilter.isPresent(); // Get active AddFiles via log replay - CloseableIterator scanFileIter = logReplay - .getAddFilesAsColumnarBatches(shouldReadStats); + // If there is a partition predicate, construct a predicate to prune checkpoint files + // while constructing the table state. + CloseableIterator scanFileIter = + logReplay.getAddFilesAsColumnarBatches( + shouldReadStats, + getPartitionsFilters().map(predicate -> + rewritePartitionPredicateOnCheckpointFileSchema( + predicate, + partitionColToStructFieldMap.get()))); // Apply partition pruning scanFileIter = applyPartitionPruning(tableClient, scanFileIter); @@ -177,18 +196,9 @@ private CloseableIterator applyPartitionPruning( return scanFileIter; } - Set partitionColNames = metadata.getPartitionColNames(); - Map partitionColNameToStructFieldMap = - metadata.getSchema().fields().stream() - .filter(field -> - partitionColNames.contains(field.getName().toLowerCase(Locale.ROOT))) - .collect(toMap( - field -> field.getName().toLowerCase(Locale.ROOT), - identity())); - Predicate predicateOnScanFileBatch = rewritePartitionPredicateOnScanFileSchema( - partitionPredicate.get(), - partitionColNameToStructFieldMap); + partitionPredicate.get(), + partitionColToStructFieldMap.get()); return new CloseableIterator() { PredicateEvaluator predicateEvaluator = null; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java index 9340cbc3a2a..a4bba9aa86e 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/ActionsIterator.java @@ -22,6 +22,7 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.FileStatus; @@ -45,6 +46,8 @@ class ActionsIterator implements CloseableIterator { private final TableClient tableClient; + private final Optional checkpointPredicate; + /** * Linked list of iterator files (commit files and/or checkpoint files) * {@link LinkedList} to allow removing the head of the list and also to peek at the head @@ -69,8 +72,10 @@ class ActionsIterator implements CloseableIterator { ActionsIterator( TableClient tableClient, List files, - StructType readSchema) { + StructType readSchema, + Optional checkpointPredicate) { this.tableClient = tableClient; + this.checkpointPredicate = checkpointPredicate; this.filesList = new LinkedList<>(); this.filesList.addAll(files); this.readSchema = readSchema; @@ -191,7 +196,7 @@ private CloseableIterator getNextActionsIter() { tableClient.getParquetHandler().readParquetFiles( checkpointFilesIter, readSchema, - Optional.empty()); + checkpointPredicate); return combine(dataIter, true /* isFromCheckpoint */, fileVersion); } else { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java index b49396a43dd..21ae51e0f52 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java @@ -24,6 +24,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.FilteredColumnarBatch; +import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.StringType; import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; @@ -151,12 +152,14 @@ public Optional getLatestTransactionIdentifier(String applicationId) { * */ public CloseableIterator getAddFilesAsColumnarBatches( - boolean shouldReadStats) { + boolean shouldReadStats, + Optional checkpointPredicate) { final CloseableIterator addRemoveIter = - new ActionsIterator( - tableClient, - logSegment.allLogFilesReversed(), - getAddRemoveReadSchema(shouldReadStats)); + new ActionsIterator( + tableClient, + logSegment.allLogFilesReversed(), + getAddRemoveReadSchema(shouldReadStats), + checkpointPredicate); return new ActiveAddFilesIterator(tableClient, addRemoveIter, dataPath); } @@ -188,10 +191,11 @@ private Tuple2 loadTableProtocolAndMetadata( Metadata metadata = null; try (CloseableIterator reverseIter = - new ActionsIterator( - tableClient, - logSegment.allLogFilesReversed(), - PROTOCOL_METADATA_READ_SCHEMA)) { + new ActionsIterator( + tableClient, + logSegment.allLogFilesReversed(), + PROTOCOL_METADATA_READ_SCHEMA, + Optional.empty())) { while (reverseIter.hasNext()) { final ActionWrapper nextElem = reverseIter.next(); final long version = nextElem.getVersion(); @@ -271,10 +275,11 @@ private Tuple2 loadTableProtocolAndMetadata( private Optional loadLatestTransactionVersion(String applicationId) { try (CloseableIterator reverseIter = - new ActionsIterator( - tableClient, - logSegment.allLogFilesReversed(), - SET_TRANSACTION_READ_SCHEMA)) { + new ActionsIterator( + tableClient, + logSegment.allLogFilesReversed(), + SET_TRANSACTION_READ_SCHEMA, + Optional.empty())) { while (reverseIter.hasNext()) { final ColumnarBatch columnarBatch = reverseIter.next().getColumnarBatch(); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java index ad4b78fc0b1..cff57e734f5 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java @@ -133,6 +133,52 @@ public static Tuple2 splitMetadataAndDataPredicates( } } + /** + * Rewrite the given predicate on partition columns on `partitionValues_parsed` in checkpoint + * schema. The rewritten predicate can be pushed to the Parquet reader when reading the + * checkpoint files. + * + * @param predicate Predicate on partition columns. + * @param partitionColNameToField Map of partition column name (in lower case) to its + * {@link StructField}. + * @return Rewritten {@link Predicate} on `partitionValues_parsed` in `add`. + */ + public static Predicate rewritePartitionPredicateOnCheckpointFileSchema( + Predicate predicate, + Map partitionColNameToField) { + return new Predicate( + predicate.getName(), + predicate.getChildren().stream() + .map(child -> + rewriteColRefOnPartitionValuesParsed( + child, partitionColNameToField)) + .collect(Collectors.toList())); + } + + private static Expression rewriteColRefOnPartitionValuesParsed( + Expression expression, + Map partitionColMetadata) { + if (expression instanceof Column) { + Column column = (Column) expression; + String partColName = column.getNames()[0]; + StructField partColField = + partitionColMetadata.get(partColName.toLowerCase(Locale.ROOT)); + if (partColField == null) { + throw new IllegalArgumentException(partColName + " is not present in metadata"); + } + + String partColPhysicalName = ColumnMapping.getPhysicalName(partColField); + + return InternalScanFileUtils.getPartitionValuesParsedRefInAddFile(partColPhysicalName); + } else if (expression instanceof Predicate) { + return rewritePartitionPredicateOnCheckpointFileSchema( + (Predicate) expression, + partitionColMetadata); + } + + return expression; + } + /** * Utility method to rewrite the partition predicate referring to the table schema as predicate * referring to the {@code partitionValues} in scan files read from Delta log. The scan file diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala index ceb91ddc45b..1811c52ebc0 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala @@ -15,16 +15,15 @@ */ package io.delta.kernel.internal.util -import java.util - -import scala.collection.JavaConverters._ - -import io.delta.kernel.expressions._ import io.delta.kernel.expressions.Literal._ -import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates} +import io.delta.kernel.expressions._ +import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnCheckpointFileSchema, rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates} import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite +import java.util +import scala.collection.JavaConverters._ + class PartitionUtilsSuite extends AnyFunSuite { // Table schema // Data columns: data1: int, data2: string, date3: struct(data31: boolean, data32: long) @@ -154,33 +153,60 @@ class PartitionUtilsSuite extends AnyFunSuite { } } - // Map entry format: (given predicate -> expected rewritten predicate) + // Map entry format: (given predicate -> \ + // (exp predicate for partition pruning, exp predicate for checkpoint reader pushdown)) val rewriteTestCases = Map( // single predicate on a partition column predicate("=", col("part2"), ofTimestamp(12)) -> - "(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)", + ( + // exp predicate for partition pruning + "(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)", + + // exp predicate for checkpoint reader pushdown + "(column(`add`.`partitionValues_parsed`.`part2`) = 12)" + ), // multiple predicates on partition columns joined with AND predicate("AND", predicate("=", col("part1"), ofInt(12)), predicate(">=", col("part3"), ofString("sss"))) -> - """((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND - |(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))""" - .stripMargin.replaceAll("\n", " "), + ( + // exp predicate for partition pruning + """((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND + |(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))""" + .stripMargin.replaceAll("\n", " "), + + // exp predicate for checkpoint reader pushdown + """((column(`add`.`partitionValues_parsed`.`part1`) = 12) AND + |(column(`add`.`partitionValues_parsed`.`part3`) >= sss))""" + .stripMargin.replaceAll("\n", " ") + ), // multiple predicates on partition columns joined with OR predicate("OR", predicate("<=", col("part3"), ofString("sss")), predicate("=", col("part1"), ofInt(2781))) -> - """((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR - |(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))""" - .stripMargin.replaceAll("\n", " ") + ( + // exp predicate for partition pruning + """((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR + |(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))""" + .stripMargin.replaceAll("\n", " "), + + // exp predicate for checkpoint reader pushdown + """((column(`add`.`partitionValues_parsed`.`part3`) <= sss) OR + |(column(`add`.`partitionValues_parsed`.`part1`) = 2781))""" + .stripMargin.replaceAll("\n", " ") + ) ) - rewriteTestCases.foreach { - case (predicate, expRewrittenPredicate) => + case (predicate, (expPartitionPruningPredicate, expCheckpointReaderPushdownPredicate)) => test(s"rewrite partition predicate on scan file schema: $predicate") { - val actRewrittenPredicate = + val actPartitionPruningPredicate = rewritePartitionPredicateOnScanFileSchema(predicate, partitionColsMetadata) - assert(actRewrittenPredicate.toString === expRewrittenPredicate) + assert(actPartitionPruningPredicate.toString === expPartitionPruningPredicate) + + val actCheckpointReaderPushdownPredicate = + rewritePartitionPredicateOnCheckpointFileSchema(predicate, partitionColsMetadata) + assert(actCheckpointReaderPushdownPredicate.toString === + expCheckpointReaderPushdownPredicate) } }