Skip to content

Commit

Permalink
[Kernel] Push predicate on partition values to checkpoint reader in s…
Browse files Browse the repository at this point in the history
…tate reconstruction (#2872)

## Description

Converts the partition predicate into a filter on
`add.partitionValues_parsed.<partitionPhysicalColName>`. This predicate
is pushed to the Parquet reader when reading the checkpoint files during
the state reconstruction. This helps prune reading checkpoint files that
can't possibly have any scan files satisfying the given partition
predicate.

This can be extended in future to even support pushdown of predicate on
data columns as well.

## How was this patch tested?
Unittests
  • Loading branch information
vkorukanti authored Apr 10, 2024
1 parent 0fe578b commit 3dcbbb8
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,14 @@ public static DeletionVectorDescriptor getDeletionVectorDescriptorFromRow(Row sc
Row addFile = getAddFileEntry(scanFile);
return DeletionVectorDescriptor.fromRow(addFile.getStruct(ADD_FILE_DV_ORDINAL));
}

/**
* Get a references column for given partition column name in partitionValues_parsed column in
* scan file row.
* @param partitionColName Partition column name
* @return {@link Column} reference
*/
public static Column getPartitionValuesParsedRefInAddFile(String partitionColName) {
return new Column(new String[]{"add", "partitionValues_parsed", partitionColName});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.util.*;
import java.util.function.Supplier;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

Expand All @@ -37,6 +38,7 @@
import io.delta.kernel.internal.skipping.DataSkippingUtils;
import io.delta.kernel.internal.util.*;
import static io.delta.kernel.internal.skipping.StatsSchemaHelper.getStatsSchema;
import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnCheckpointFileSchema;
import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnScanFileSchema;

/**
Expand All @@ -56,6 +58,7 @@ public class ScanImpl implements Scan {
private final LogReplay logReplay;
private final Path dataPath;
private final Optional<Tuple2<Predicate, Predicate>> partitionAndDataFilters;
private final Supplier<Map<String, StructField>> partitionColToStructFieldMap;
private boolean accessedScanFiles;

public ScanImpl(
Expand All @@ -73,6 +76,15 @@ public ScanImpl(
this.logReplay = logReplay;
this.partitionAndDataFilters = splitFilters(filter);
this.dataPath = dataPath;
this.partitionColToStructFieldMap = () -> {
Set<String> partitionColNames = metadata.getPartitionColNames();
return metadata.getSchema().fields().stream()
.filter(field -> partitionColNames.contains(
field.getName().toLowerCase(Locale.ROOT)))
.collect(toMap(
field -> field.getName().toLowerCase(Locale.ROOT),
identity()));
};
}

/**
Expand All @@ -92,8 +104,15 @@ public CloseableIterator<FilteredColumnarBatch> getScanFiles(TableClient tableCl
boolean shouldReadStats = dataSkippingFilter.isPresent();

// Get active AddFiles via log replay
CloseableIterator<FilteredColumnarBatch> scanFileIter = logReplay
.getAddFilesAsColumnarBatches(shouldReadStats);
// If there is a partition predicate, construct a predicate to prune checkpoint files
// while constructing the table state.
CloseableIterator<FilteredColumnarBatch> scanFileIter =
logReplay.getAddFilesAsColumnarBatches(
shouldReadStats,
getPartitionsFilters().map(predicate ->
rewritePartitionPredicateOnCheckpointFileSchema(
predicate,
partitionColToStructFieldMap.get())));

// Apply partition pruning
scanFileIter = applyPartitionPruning(tableClient, scanFileIter);
Expand Down Expand Up @@ -177,18 +196,9 @@ private CloseableIterator<FilteredColumnarBatch> applyPartitionPruning(
return scanFileIter;
}

Set<String> partitionColNames = metadata.getPartitionColNames();
Map<String, StructField> partitionColNameToStructFieldMap =
metadata.getSchema().fields().stream()
.filter(field ->
partitionColNames.contains(field.getName().toLowerCase(Locale.ROOT)))
.collect(toMap(
field -> field.getName().toLowerCase(Locale.ROOT),
identity()));

Predicate predicateOnScanFileBatch = rewritePartitionPredicateOnScanFileSchema(
partitionPredicate.get(),
partitionColNameToStructFieldMap);
partitionPredicate.get(),
partitionColToStructFieldMap.get());

return new CloseableIterator<FilteredColumnarBatch>() {
PredicateEvaluator predicateEvaluator = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import io.delta.kernel.client.TableClient;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
import io.delta.kernel.utils.FileStatus;
Expand All @@ -45,6 +46,8 @@
class ActionsIterator implements CloseableIterator<ActionWrapper> {
private final TableClient tableClient;

private final Optional<Predicate> checkpointPredicate;

/**
* Linked list of iterator files (commit files and/or checkpoint files)
* {@link LinkedList} to allow removing the head of the list and also to peek at the head
Expand All @@ -69,8 +72,10 @@ class ActionsIterator implements CloseableIterator<ActionWrapper> {
ActionsIterator(
TableClient tableClient,
List<FileStatus> files,
StructType readSchema) {
StructType readSchema,
Optional<Predicate> checkpointPredicate) {
this.tableClient = tableClient;
this.checkpointPredicate = checkpointPredicate;
this.filesList = new LinkedList<>();
this.filesList.addAll(files);
this.readSchema = readSchema;
Expand Down Expand Up @@ -191,7 +196,7 @@ private CloseableIterator<ActionWrapper> getNextActionsIter() {
tableClient.getParquetHandler().readParquetFiles(
checkpointFilesIter,
readSchema,
Optional.empty());
checkpointPredicate);

return combine(dataIter, true /* isFromCheckpoint */, fileVersion);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.expressions.Predicate;
import io.delta.kernel.types.StringType;
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
Expand Down Expand Up @@ -151,12 +152,14 @@ public Optional<Long> getLatestTransactionIdentifier(String applicationId) {
* </ol>
*/
public CloseableIterator<FilteredColumnarBatch> getAddFilesAsColumnarBatches(
boolean shouldReadStats) {
boolean shouldReadStats,
Optional<Predicate> checkpointPredicate) {
final CloseableIterator<ActionWrapper> addRemoveIter =
new ActionsIterator(
tableClient,
logSegment.allLogFilesReversed(),
getAddRemoveReadSchema(shouldReadStats));
new ActionsIterator(
tableClient,
logSegment.allLogFilesReversed(),
getAddRemoveReadSchema(shouldReadStats),
checkpointPredicate);
return new ActiveAddFilesIterator(tableClient, addRemoveIter, dataPath);
}

Expand Down Expand Up @@ -188,10 +191,11 @@ private Tuple2<Protocol, Metadata> loadTableProtocolAndMetadata(
Metadata metadata = null;

try (CloseableIterator<ActionWrapper> reverseIter =
new ActionsIterator(
tableClient,
logSegment.allLogFilesReversed(),
PROTOCOL_METADATA_READ_SCHEMA)) {
new ActionsIterator(
tableClient,
logSegment.allLogFilesReversed(),
PROTOCOL_METADATA_READ_SCHEMA,
Optional.empty())) {
while (reverseIter.hasNext()) {
final ActionWrapper nextElem = reverseIter.next();
final long version = nextElem.getVersion();
Expand Down Expand Up @@ -271,10 +275,11 @@ private Tuple2<Protocol, Metadata> loadTableProtocolAndMetadata(

private Optional<Long> loadLatestTransactionVersion(String applicationId) {
try (CloseableIterator<ActionWrapper> reverseIter =
new ActionsIterator(
tableClient,
logSegment.allLogFilesReversed(),
SET_TRANSACTION_READ_SCHEMA)) {
new ActionsIterator(
tableClient,
logSegment.allLogFilesReversed(),
SET_TRANSACTION_READ_SCHEMA,
Optional.empty())) {
while (reverseIter.hasNext()) {
final ColumnarBatch columnarBatch =
reverseIter.next().getColumnarBatch();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,52 @@ public static Tuple2<Predicate, Predicate> splitMetadataAndDataPredicates(
}
}

/**
* Rewrite the given predicate on partition columns on `partitionValues_parsed` in checkpoint
* schema. The rewritten predicate can be pushed to the Parquet reader when reading the
* checkpoint files.
*
* @param predicate Predicate on partition columns.
* @param partitionColNameToField Map of partition column name (in lower case) to its
* {@link StructField}.
* @return Rewritten {@link Predicate} on `partitionValues_parsed` in `add`.
*/
public static Predicate rewritePartitionPredicateOnCheckpointFileSchema(
Predicate predicate,
Map<String, StructField> partitionColNameToField) {
return new Predicate(
predicate.getName(),
predicate.getChildren().stream()
.map(child ->
rewriteColRefOnPartitionValuesParsed(
child, partitionColNameToField))
.collect(Collectors.toList()));
}

private static Expression rewriteColRefOnPartitionValuesParsed(
Expression expression,
Map<String, StructField> partitionColMetadata) {
if (expression instanceof Column) {
Column column = (Column) expression;
String partColName = column.getNames()[0];
StructField partColField =
partitionColMetadata.get(partColName.toLowerCase(Locale.ROOT));
if (partColField == null) {
throw new IllegalArgumentException(partColName + " is not present in metadata");
}

String partColPhysicalName = ColumnMapping.getPhysicalName(partColField);

return InternalScanFileUtils.getPartitionValuesParsedRefInAddFile(partColPhysicalName);
} else if (expression instanceof Predicate) {
return rewritePartitionPredicateOnCheckpointFileSchema(
(Predicate) expression,
partitionColMetadata);
}

return expression;
}

/**
* Utility method to rewrite the partition predicate referring to the table schema as predicate
* referring to the {@code partitionValues} in scan files read from Delta log. The scan file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
*/
package io.delta.kernel.internal.util

import java.util

import scala.collection.JavaConverters._

import io.delta.kernel.expressions._
import io.delta.kernel.expressions.Literal._
import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates}
import io.delta.kernel.expressions._
import io.delta.kernel.internal.util.PartitionUtils.{rewritePartitionPredicateOnCheckpointFileSchema, rewritePartitionPredicateOnScanFileSchema, splitMetadataAndDataPredicates}
import io.delta.kernel.types._
import org.scalatest.funsuite.AnyFunSuite

import java.util
import scala.collection.JavaConverters._

class PartitionUtilsSuite extends AnyFunSuite {
// Table schema
// Data columns: data1: int, data2: string, date3: struct(data31: boolean, data32: long)
Expand Down Expand Up @@ -154,33 +153,60 @@ class PartitionUtilsSuite extends AnyFunSuite {
}
}

// Map entry format: (given predicate -> expected rewritten predicate)
// Map entry format: (given predicate -> \
// (exp predicate for partition pruning, exp predicate for checkpoint reader pushdown))
val rewriteTestCases = Map(
// single predicate on a partition column
predicate("=", col("part2"), ofTimestamp(12)) ->
"(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)",
(
// exp predicate for partition pruning
"(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part2), date) = 12)",

// exp predicate for checkpoint reader pushdown
"(column(`add`.`partitionValues_parsed`.`part2`) = 12)"
),
// multiple predicates on partition columns joined with AND
predicate("AND",
predicate("=", col("part1"), ofInt(12)),
predicate(">=", col("part3"), ofString("sss"))) ->
"""((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND
|(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))"""
.stripMargin.replaceAll("\n", " "),
(
// exp predicate for partition pruning
"""((partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 12) AND
|(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss))"""
.stripMargin.replaceAll("\n", " "),

// exp predicate for checkpoint reader pushdown
"""((column(`add`.`partitionValues_parsed`.`part1`) = 12) AND
|(column(`add`.`partitionValues_parsed`.`part3`) >= sss))"""
.stripMargin.replaceAll("\n", " ")
),
// multiple predicates on partition columns joined with OR
predicate("OR",
predicate("<=", col("part3"), ofString("sss")),
predicate("=", col("part1"), ofInt(2781))) ->
"""((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR
|(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))"""
.stripMargin.replaceAll("\n", " ")
(
// exp predicate for partition pruning
"""((ELEMENT_AT(column(`add`.`partitionValues`), part3) <= sss) OR
|(partition_value(ELEMENT_AT(column(`add`.`partitionValues`), part1), integer) = 2781))"""
.stripMargin.replaceAll("\n", " "),

// exp predicate for checkpoint reader pushdown
"""((column(`add`.`partitionValues_parsed`.`part3`) <= sss) OR
|(column(`add`.`partitionValues_parsed`.`part1`) = 2781))"""
.stripMargin.replaceAll("\n", " ")
)
)

rewriteTestCases.foreach {
case (predicate, expRewrittenPredicate) =>
case (predicate, (expPartitionPruningPredicate, expCheckpointReaderPushdownPredicate)) =>
test(s"rewrite partition predicate on scan file schema: $predicate") {
val actRewrittenPredicate =
val actPartitionPruningPredicate =
rewritePartitionPredicateOnScanFileSchema(predicate, partitionColsMetadata)
assert(actRewrittenPredicate.toString === expRewrittenPredicate)
assert(actPartitionPruningPredicate.toString === expPartitionPruningPredicate)

val actCheckpointReaderPushdownPredicate =
rewritePartitionPredicateOnCheckpointFileSchema(predicate, partitionColsMetadata)
assert(actCheckpointReaderPushdownPredicate.toString ===
expCheckpointReaderPushdownPredicate)
}
}

Expand Down

0 comments on commit 3dcbbb8

Please sign in to comment.