Skip to content

Commit

Permalink
fixup! Implement predicate pushdown for table functions
Browse files Browse the repository at this point in the history
  • Loading branch information
homar committed Aug 5, 2023
1 parent 828374c commit 3dd2d90
Show file tree
Hide file tree
Showing 17 changed files with 62 additions and 203 deletions.
4 changes: 2 additions & 2 deletions core/trino-main/src/main/java/io/trino/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ default boolean isView(Session session, QualifiedObjectName viewName)

Optional<LimitApplicationResult<TableHandle>> applyLimit(Session session, TableHandle table, long limit);

Optional<ConstraintApplicationResult<TableHandle, ColumnHandle>> applyFilter(Session session, TableHandle table, Constraint constraint);
Optional<ConstraintApplicationResult<TableHandle, ColumnHandle>> applyFilter(Session session, TableHandle table, Constraint<ColumnHandle> constraint);

Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle>> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint);
Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle, Integer>> applyFilter(Session session, TableFunctionHandle handle, Constraint<Integer> constraint);

Optional<ProjectionApplicationResult<TableHandle>> applyProjection(Session session, TableHandle table, List<ConnectorExpression> projections, Map<String, ColumnHandle> assignments);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,7 @@ public Optional<ConstraintApplicationResult<TableHandle, ColumnHandle>> applyFil
}

@Override
public Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle>> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint)
public Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle, Integer>> applyFilter(Session session, TableFunctionHandle handle, Constraint<Integer> constraint)
{
CatalogHandle catalogHandle = handle.getCatalogHandle();
ConnectorMetadata metadata = getMetadata(session, catalogHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ public PlanOptimizers(
ImmutableSet.of(
new ApplyTableScanRedirection(plannerContext),
new PruneTableScanColumns(metadata),
new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false))));
new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false),
new PushFilterIntoTableFunction(plannerContext, typeAnalyzer))));

Set<Rule<?>> pushIntoTableScanRulesExceptJoins = ImmutableSet.<Rule<?>>builder()
.addAll(columnPruningRules)
Expand All @@ -605,6 +606,7 @@ public PlanOptimizers(
.add(new RemoveRedundantIdentityProjections())
.add(new PushLimitIntoTableScan(metadata))
.add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false))
.add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer))
.add(new PushSampleIntoTableScan(metadata))
.add(new PushAggregationIntoTableScan(plannerContext, typeAnalyzer))
.add(new PushDistinctLimitIntoTableScan(plannerContext, typeAnalyzer))
Expand Down Expand Up @@ -669,6 +671,7 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder()
.addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown
.add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false))
.add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer))
.build()),
new UnaliasSymbolReferences(metadata), // Run again because predicate pushdown and projection pushdown might add more projections
columnPruningOptimizer, // Make sure to run this before index join. Filtered projections may not have all the columns.
Expand Down Expand Up @@ -734,6 +737,7 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder()
.addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown
.add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false))
.add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer))
.build()),
pushProjectionIntoTableScanOptimizer,
// Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential
Expand All @@ -748,6 +752,7 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder()
.addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown
.add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false))
.add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer))
.build()),
columnPruningOptimizer,
new IterativeOptimizer(
Expand Down Expand Up @@ -813,6 +818,7 @@ public PlanOptimizers(
costCalculator,
ImmutableSet.of(
new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, true),
new PushFilterIntoTableFunction(plannerContext, typeAnalyzer),
new RemoveEmptyUnionBranches(),
new EvaluateEmptyIntersect(),
new RemoveEmptyExceptBranches(),
Expand Down Expand Up @@ -908,6 +914,7 @@ public PlanOptimizers(
.addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown
.add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false))
.add(new RemoveRedundantPredicateAboveTableScan(plannerContext, typeAnalyzer))
.add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer))
.build()));
// Remove unsupported dynamic filters introduced by PredicatePushdown. Also, cleanup dynamic filters removed by
// PushPredicateIntoTableScan and RemoveRedundantPredicateAboveTableScan due to those rules replacing table scans with empty ValuesNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.Sets;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableFunctionHandle;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.expression.ConnectorExpression;
Expand All @@ -42,13 +43,13 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NodeRef;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors;
import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.ExpressionUtils.combineConjuncts;
Expand All @@ -59,6 +60,8 @@
import static io.trino.sql.planner.plan.Patterns.source;
import static io.trino.sql.planner.plan.Patterns.tableFunctionProcessor;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

public class PushFilterIntoTableFunction
implements Rule<FilterNode>
Expand Down Expand Up @@ -116,30 +119,20 @@ public static Optional<PlanNode> pushFilterIntoTableFunctionProcessorNode(
return Optional.empty();
}

if (!node.getHandle().getFunctionHandle().supportsPredicatePushdown()) {
return Optional.empty();
}

PushPredicateIntoTableScan.SplitExpression splitExpression = splitExpression(plannerContext, filterNode.getPredicate());
DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.getExtractionResult(
plannerContext,
session,
splitExpression.getDeterministicPredicate(),
symbolAllocator.getTypes());

Map<ColumnHandle, Symbol> assignments = decomposedPredicate.getTupleDomain().getDomains().get()
.keySet().stream()
.collect(Collectors.toMap(
symbol -> node.getHandle().getFunctionHandle().getColumnHandles().get(symbol.getName()),
Function.identity()));
if (!Sets.difference(
decomposedPredicate.getTupleDomain().getDomains().orElseThrow().keySet().stream().map(Symbol::getName).collect(toImmutableSet()),
node.getHandle().getFunctionHandle().getColumnHandles().keySet()).isEmpty()) {
return Optional.empty();
}
List<Symbol> outputSymbols = node.getOutputSymbols();

BiMap<Integer, Symbol> assignments = HashBiMap.create(IntStream.range(0, outputSymbols.size()).boxed()
.collect(toImmutableMap(identity(), outputSymbols::get)));

TupleDomain<ColumnHandle> newDomain = decomposedPredicate.getTupleDomain()
.transformKeys(symbol -> node.getHandle().getFunctionHandle().getColumnHandles().get(symbol.getName()))
TupleDomain<Integer> newDomain = decomposedPredicate.getTupleDomain()
.transformKeys(assignments.inverse()::get)
.intersect(node.getEnforcedConstraint());

ConnectorExpressionTranslator.ConnectorExpressionTranslation expressionTranslation = ConnectorExpressionTranslator.translateConjuncts(
Expand All @@ -148,15 +141,16 @@ public static Optional<PlanNode> pushFilterIntoTableFunctionProcessorNode(
symbolAllocator.getTypes(),
plannerContext,
typeAnalyzer);
ImmutableMap<String, Integer> nameToPosition = assignments.inverse().entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));
Constraint<Integer> constraint = new Constraint<>(newDomain, expressionTranslation.connectorExpression(), nameToPosition);

Constraint constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), node.getHandle().getFunctionHandle().getColumnHandles());

Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle>> result = plannerContext.getMetadata().applyFilter(session, node.getHandle(), constraint);
Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle, Integer>> result = plannerContext.getMetadata().applyFilter(session, node.getHandle(), constraint);
if (result.isEmpty()) {
return Optional.empty();
}

TupleDomain<ColumnHandle> remainingFilter = result.get().getRemainingFilter();
TupleDomain<Integer> remainingFilter = result.get().getRemainingFilter();
Optional<ConnectorExpression> remainingConnectorExpression = result.get().getRemainingExpression();

TableFunctionProcessorNode tableFunctionProcessorNode = new TableFunctionProcessorNode(
Expand All @@ -180,7 +174,7 @@ public static Optional<PlanNode> pushFilterIntoTableFunctionProcessorNode(
remainingDecomposedPredicate = decomposedPredicate.getRemainingExpression();
}
else {
Map<String, Symbol> variableMappings = node.getOutputSymbols().stream().collect(Collectors.toMap(Symbol::getName, Function.identity()));
Map<String, Symbol> variableMappings = node.getOutputSymbols().stream().collect(toMap(Symbol::getName, identity()));
LiteralEncoder literalEncoder = new LiteralEncoder(plannerContext);
Expression translatedExpression = ConnectorExpressionTranslator.translate(session, remainingConnectorExpression.get(), plannerContext, variableMappings, literalEncoder);
Map<NodeRef<Expression>, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), translatedExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ static Expression createResultingPredicate(
return expression;
}

public static TupleDomain<ColumnHandle> computeEnforced(TupleDomain<ColumnHandle> predicate, TupleDomain<ColumnHandle> unenforced)
public static <T> TupleDomain<T> computeEnforced(TupleDomain<T> predicate, TupleDomain<T> unenforced)
{
// The engine requested the connector to apply a filter with a non-none TupleDomain.
// A TupleDomain is effectively a list of column-Domain pairs.
Expand All @@ -398,23 +398,23 @@ public static TupleDomain<ColumnHandle> computeEnforced(TupleDomain<ColumnHandle
// In all 3 cases shown above, the unenforced is not TupleDomain.none().
checkArgument(!unenforced.isNone(), "Unexpected unenforced none tuple domain");

Map<ColumnHandle, Domain> predicateDomains = predicate.getDomains().get();
Map<ColumnHandle, Domain> unenforcedDomains = unenforced.getDomains().get();
ImmutableMap.Builder<ColumnHandle, Domain> enforcedDomainsBuilder = ImmutableMap.builder();
for (Map.Entry<ColumnHandle, Domain> entry : predicateDomains.entrySet()) {
ColumnHandle predicateColumnHandle = entry.getKey();
Map<T, Domain> predicateDomains = predicate.getDomains().get();
Map<T, Domain> unenforcedDomains = unenforced.getDomains().get();
ImmutableMap.Builder<T, Domain> enforcedDomainsBuilder = ImmutableMap.builder();
for (Map.Entry<T, Domain> entry : predicateDomains.entrySet()) {
T column = entry.getKey();
Domain predicateDomain = entry.getValue();
if (unenforcedDomains.containsKey(predicateColumnHandle)) {
Domain unenforcedDomain = unenforcedDomains.get(predicateColumnHandle);
if (unenforcedDomains.containsKey(column)) {
Domain unenforcedDomain = unenforcedDomains.get(column);
checkArgument(
predicateDomain.contains(unenforcedDomain),
"Unexpected unenforced domain %s on column %s. Expected all, none, or a domain equal to or narrower than %s",
unenforcedDomain,
predicateColumnHandle,
column,
predicateDomain);
}
else {
enforcedDomainsBuilder.put(predicateColumnHandle, predicateDomain);
enforcedDomainsBuilder.put(column, predicateDomain);
}
}
return TupleDomain.withColumnDomains(enforcedDomainsBuilder.buildOrThrow());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.TableFunctionHandle;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification;

import javax.annotation.Nullable;
import jakarta.annotation.Nullable;

import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -77,7 +75,7 @@ public class TableFunctionProcessorNode
private final TableFunctionHandle handle;

@Nullable // null on workers
private final TupleDomain<ColumnHandle> enforcedConstraint;
private final TupleDomain<Integer> enforcedConstraint;

@JsonCreator
public TableFunctionProcessorNode(
Expand All @@ -93,54 +91,8 @@ public TableFunctionProcessorNode(
@JsonProperty("prePartitioned") Set<Symbol> prePartitioned,
@JsonProperty("preSorted") int preSorted,
@JsonProperty("hashSymbol") Optional<Symbol> hashSymbol,
@JsonProperty("handle") TableFunctionHandle handle)
{
super(id);
this.name = requireNonNull(name, "name is null");
this.properOutputs = ImmutableList.copyOf(properOutputs);
this.source = requireNonNull(source, "source is null");
this.pruneWhenEmpty = pruneWhenEmpty;
this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications);
this.requiredSymbols = requiredSymbols.stream()
.map(ImmutableList::copyOf)
.collect(toImmutableList());
this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf);
this.specification = requireNonNull(specification, "specification is null");
this.prePartitioned = ImmutableSet.copyOf(prePartitioned);
Set<Symbol> partitionBy = specification
.map(DataOrganizationSpecification::getPartitionBy)
.map(ImmutableSet::copyOf)
.orElse(ImmutableSet.of());
checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list");
this.preSorted = preSorted;
checkArgument(
specification
.flatMap(DataOrganizationSpecification::getOrderingScheme)
.map(OrderingScheme::getOrderBy)
.map(List::size)
.orElse(0) >= preSorted,
"the number of pre-sorted symbols cannot be greater than the number of all ordering symbols");
checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned");
this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null");
this.handle = requireNonNull(handle, "handle is null");
this.enforcedConstraint = null;
}

public TableFunctionProcessorNode(
PlanNodeId id,
String name,
List<Symbol> properOutputs,
Optional<PlanNode> source,
boolean pruneWhenEmpty,
List<PassThroughSpecification> passThroughSpecifications,
List<List<Symbol>> requiredSymbols,
Optional<Map<Symbol, Symbol>> markerSymbols,
Optional<DataOrganizationSpecification> specification,
Set<Symbol> prePartitioned,
int preSorted,
Optional<Symbol> hashSymbol,
TableFunctionHandle handle,
TupleDomain<ColumnHandle> enforcedConstraint)
@JsonProperty("handle") TableFunctionHandle handle,
@JsonProperty("tupleDomain") TupleDomain<Integer> enforcedConstraint)
{
super(id);
this.name = requireNonNull(name, "name is null");
Expand Down Expand Up @@ -170,7 +122,7 @@ public TableFunctionProcessorNode(
checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned");
this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null");
this.handle = requireNonNull(handle, "handle is null");
this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null");
this.enforcedConstraint = enforcedConstraint;
}

@JsonProperty
Expand Down Expand Up @@ -270,7 +222,7 @@ public List<Symbol> getOutputSymbols()

@Nullable
@JsonIgnore
public TupleDomain<ColumnHandle> getEnforcedConstraint()
public TupleDomain<Integer> getEnforcedConstraint()
{
checkState(enforcedConstraint != null, "enforcedConstraint should only be used in planner. It is not transported to workers.");
return enforcedConstraint;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle, ColumnHandle>>
}

@Override
public Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle>> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint)
public Optional<ConstraintApplicationResult<ConnectorTableFunctionHandle, Integer>> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint<Integer> constraint)
{
Span span = startSpan("applyFilter", handle);
try (var ignored = scopedSpan(span)) {
Expand Down
Loading

0 comments on commit 3dd2d90

Please sign in to comment.