diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index a755a1cd67a1..1cf35a16e355 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -484,9 +484,9 @@ default boolean isView(Session session, QualifiedObjectName viewName) Optional> applyLimit(Session session, TableHandle table, long limit); - Optional> applyFilter(Session session, TableHandle table, Constraint constraint); + Optional> applyFilter(Session session, TableHandle table, Constraint constraint); - Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint); + Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint); Optional> applyProjection(Session session, TableHandle table, List projections, Map assignments); diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 4bf1eba99c6b..66cf7ded2127 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -1830,7 +1830,7 @@ public Optional> applyFil } @Override - public Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint) + public Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint) { CatalogHandle catalogHandle = handle.getCatalogHandle(); ConnectorMetadata metadata = getMetadata(session, catalogHandle); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index c682f4063bf2..cf73191a37bc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -596,7 +596,8 @@ public PlanOptimizers( ImmutableSet.of( new ApplyTableScanRedirection(plannerContext), new PruneTableScanColumns(metadata), - new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)))); + new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false), + new PushFilterIntoTableFunction(plannerContext, typeAnalyzer)))); Set> pushIntoTableScanRulesExceptJoins = ImmutableSet.>builder() .addAll(columnPruningRules) @@ -605,6 +606,7 @@ public PlanOptimizers( .add(new RemoveRedundantIdentityProjections()) .add(new PushLimitIntoTableScan(metadata)) .add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)) + .add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer)) .add(new PushSampleIntoTableScan(metadata)) .add(new PushAggregationIntoTableScan(plannerContext, typeAnalyzer)) .add(new PushDistinctLimitIntoTableScan(plannerContext, typeAnalyzer)) @@ -669,6 +671,7 @@ public PlanOptimizers( ImmutableSet.>builder() .addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown .add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)) + .add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer)) .build()), new UnaliasSymbolReferences(metadata), // Run again because predicate pushdown and projection pushdown might add more projections columnPruningOptimizer, // Make sure to run this before index join. Filtered projections may not have all the columns. @@ -734,6 +737,7 @@ public PlanOptimizers( ImmutableSet.>builder() .addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown .add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)) + .add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer)) .build()), pushProjectionIntoTableScanOptimizer, // Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential @@ -748,6 +752,7 @@ public PlanOptimizers( ImmutableSet.>builder() .addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown .add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)) + .add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer)) .build()), columnPruningOptimizer, new IterativeOptimizer( @@ -813,6 +818,7 @@ public PlanOptimizers( costCalculator, ImmutableSet.of( new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, true), + new PushFilterIntoTableFunction(plannerContext, typeAnalyzer), new RemoveEmptyUnionBranches(), new EvaluateEmptyIntersect(), new RemoveEmptyExceptBranches(), @@ -908,6 +914,7 @@ public PlanOptimizers( .addAll(simplifyOptimizerRules) // Should be always run after PredicatePushDown .add(new PushPredicateIntoTableScan(plannerContext, typeAnalyzer, false)) .add(new RemoveRedundantPredicateAboveTableScan(plannerContext, typeAnalyzer)) + .add(new PushFilterIntoTableFunction(plannerContext, typeAnalyzer)) .build())); // Remove unsupported dynamic filters introduced by PredicatePushdown. Also, cleanup dynamic filters removed by // PushPredicateIntoTableScan and RemoveRedundantPredicateAboveTableScan due to those rules replacing table scans with empty ValuesNode diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterIntoTableFunction.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterIntoTableFunction.java index 9e3cee05d421..f13d0f14af38 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterIntoTableFunction.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterIntoTableFunction.java @@ -13,13 +13,14 @@ */ package io.trino.sql.planner.iterative.rule; -import com.google.common.collect.Sets; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; import io.trino.metadata.TableFunctionHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.expression.ConnectorExpression; @@ -42,13 +43,13 @@ import io.trino.sql.tree.Expression; import io.trino.sql.tree.NodeRef; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.function.Function; -import java.util.stream.Collectors; +import java.util.stream.IntStream; -import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors; import static io.trino.matching.Capture.newCapture; import static io.trino.sql.ExpressionUtils.combineConjuncts; @@ -59,6 +60,8 @@ import static io.trino.sql.planner.plan.Patterns.source; import static io.trino.sql.planner.plan.Patterns.tableFunctionProcessor; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; public class PushFilterIntoTableFunction implements Rule @@ -116,10 +119,6 @@ public static Optional pushFilterIntoTableFunctionProcessorNode( return Optional.empty(); } - if (!node.getHandle().getFunctionHandle().supportsPredicatePushdown()) { - return Optional.empty(); - } - PushPredicateIntoTableScan.SplitExpression splitExpression = splitExpression(plannerContext, filterNode.getPredicate()); DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.getExtractionResult( plannerContext, @@ -127,19 +126,13 @@ public static Optional pushFilterIntoTableFunctionProcessorNode( splitExpression.getDeterministicPredicate(), symbolAllocator.getTypes()); - Map assignments = decomposedPredicate.getTupleDomain().getDomains().get() - .keySet().stream() - .collect(Collectors.toMap( - symbol -> node.getHandle().getFunctionHandle().getColumnHandles().get(symbol.getName()), - Function.identity())); - if (!Sets.difference( - decomposedPredicate.getTupleDomain().getDomains().orElseThrow().keySet().stream().map(Symbol::getName).collect(toImmutableSet()), - node.getHandle().getFunctionHandle().getColumnHandles().keySet()).isEmpty()) { - return Optional.empty(); - } + List outputSymbols = node.getOutputSymbols(); + + BiMap assignments = HashBiMap.create(IntStream.range(0, outputSymbols.size()).boxed() + .collect(toImmutableMap(identity(), outputSymbols::get))); - TupleDomain newDomain = decomposedPredicate.getTupleDomain() - .transformKeys(symbol -> node.getHandle().getFunctionHandle().getColumnHandles().get(symbol.getName())) + TupleDomain newDomain = decomposedPredicate.getTupleDomain() + .transformKeys(assignments.inverse()::get) .intersect(node.getEnforcedConstraint()); ConnectorExpressionTranslator.ConnectorExpressionTranslation expressionTranslation = ConnectorExpressionTranslator.translateConjuncts( @@ -148,15 +141,16 @@ public static Optional pushFilterIntoTableFunctionProcessorNode( symbolAllocator.getTypes(), plannerContext, typeAnalyzer); + ImmutableMap nameToPosition = assignments.inverse().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)); + Constraint constraint = new Constraint<>(newDomain, expressionTranslation.connectorExpression(), nameToPosition); - Constraint constraint = new Constraint(newDomain, expressionTranslation.connectorExpression(), node.getHandle().getFunctionHandle().getColumnHandles()); - - Optional> result = plannerContext.getMetadata().applyFilter(session, node.getHandle(), constraint); + Optional> result = plannerContext.getMetadata().applyFilter(session, node.getHandle(), constraint); if (result.isEmpty()) { return Optional.empty(); } - TupleDomain remainingFilter = result.get().getRemainingFilter(); + TupleDomain remainingFilter = result.get().getRemainingFilter(); Optional remainingConnectorExpression = result.get().getRemainingExpression(); TableFunctionProcessorNode tableFunctionProcessorNode = new TableFunctionProcessorNode( @@ -180,7 +174,7 @@ public static Optional pushFilterIntoTableFunctionProcessorNode( remainingDecomposedPredicate = decomposedPredicate.getRemainingExpression(); } else { - Map variableMappings = node.getOutputSymbols().stream().collect(Collectors.toMap(Symbol::getName, Function.identity())); + Map variableMappings = node.getOutputSymbols().stream().collect(toMap(Symbol::getName, identity())); LiteralEncoder literalEncoder = new LiteralEncoder(plannerContext); Expression translatedExpression = ConnectorExpressionTranslator.translate(session, remainingConnectorExpression.get(), plannerContext, variableMappings, literalEncoder); Map, Type> translatedExpressionTypes = typeAnalyzer.getTypes(session, symbolAllocator.getTypes(), translatedExpression); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java index a3b84fb21112..740c305ff97b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPredicateIntoTableScan.java @@ -386,7 +386,7 @@ static Expression createResultingPredicate( return expression; } - public static TupleDomain computeEnforced(TupleDomain predicate, TupleDomain unenforced) + public static TupleDomain computeEnforced(TupleDomain predicate, TupleDomain unenforced) { // The engine requested the connector to apply a filter with a non-none TupleDomain. // A TupleDomain is effectively a list of column-Domain pairs. @@ -398,23 +398,23 @@ public static TupleDomain computeEnforced(TupleDomain predicateDomains = predicate.getDomains().get(); - Map unenforcedDomains = unenforced.getDomains().get(); - ImmutableMap.Builder enforcedDomainsBuilder = ImmutableMap.builder(); - for (Map.Entry entry : predicateDomains.entrySet()) { - ColumnHandle predicateColumnHandle = entry.getKey(); + Map predicateDomains = predicate.getDomains().get(); + Map unenforcedDomains = unenforced.getDomains().get(); + ImmutableMap.Builder enforcedDomainsBuilder = ImmutableMap.builder(); + for (Map.Entry entry : predicateDomains.entrySet()) { + T column = entry.getKey(); Domain predicateDomain = entry.getValue(); - if (unenforcedDomains.containsKey(predicateColumnHandle)) { - Domain unenforcedDomain = unenforcedDomains.get(predicateColumnHandle); + if (unenforcedDomains.containsKey(column)) { + Domain unenforcedDomain = unenforcedDomains.get(column); checkArgument( predicateDomain.contains(unenforcedDomain), "Unexpected unenforced domain %s on column %s. Expected all, none, or a domain equal to or narrower than %s", unenforcedDomain, - predicateColumnHandle, + column, predicateDomain); } else { - enforcedDomainsBuilder.put(predicateColumnHandle, predicateDomain); + enforcedDomainsBuilder.put(column, predicateDomain); } } return TupleDomain.withColumnDomains(enforcedDomainsBuilder.buildOrThrow()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java index cace6091e9b7..5caf584d378a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java @@ -20,13 +20,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.metadata.TableFunctionHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Collection; import java.util.List; @@ -77,7 +75,7 @@ public class TableFunctionProcessorNode private final TableFunctionHandle handle; @Nullable // null on workers - private final TupleDomain enforcedConstraint; + private final TupleDomain enforcedConstraint; @JsonCreator public TableFunctionProcessorNode( @@ -93,54 +91,8 @@ public TableFunctionProcessorNode( @JsonProperty("prePartitioned") Set prePartitioned, @JsonProperty("preSorted") int preSorted, @JsonProperty("hashSymbol") Optional hashSymbol, - @JsonProperty("handle") TableFunctionHandle handle) - { - super(id); - this.name = requireNonNull(name, "name is null"); - this.properOutputs = ImmutableList.copyOf(properOutputs); - this.source = requireNonNull(source, "source is null"); - this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); - this.requiredSymbols = requiredSymbols.stream() - .map(ImmutableList::copyOf) - .collect(toImmutableList()); - this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); - this.specification = requireNonNull(specification, "specification is null"); - this.prePartitioned = ImmutableSet.copyOf(prePartitioned); - Set partitionBy = specification - .map(DataOrganizationSpecification::getPartitionBy) - .map(ImmutableSet::copyOf) - .orElse(ImmutableSet.of()); - checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); - this.preSorted = preSorted; - checkArgument( - specification - .flatMap(DataOrganizationSpecification::getOrderingScheme) - .map(OrderingScheme::getOrderBy) - .map(List::size) - .orElse(0) >= preSorted, - "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); - checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); - this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); - this.handle = requireNonNull(handle, "handle is null"); - this.enforcedConstraint = null; - } - - public TableFunctionProcessorNode( - PlanNodeId id, - String name, - List properOutputs, - Optional source, - boolean pruneWhenEmpty, - List passThroughSpecifications, - List> requiredSymbols, - Optional> markerSymbols, - Optional specification, - Set prePartitioned, - int preSorted, - Optional hashSymbol, - TableFunctionHandle handle, - TupleDomain enforcedConstraint) + @JsonProperty("handle") TableFunctionHandle handle, + @JsonProperty("tupleDomain") TupleDomain enforcedConstraint) { super(id); this.name = requireNonNull(name, "name is null"); @@ -170,7 +122,7 @@ public TableFunctionProcessorNode( checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); this.handle = requireNonNull(handle, "handle is null"); - this.enforcedConstraint = requireNonNull(enforcedConstraint, "enforcedConstraint is null"); + this.enforcedConstraint = enforcedConstraint; } @JsonProperty @@ -270,7 +222,7 @@ public List getOutputSymbols() @Nullable @JsonIgnore - public TupleDomain getEnforcedConstraint() + public TupleDomain getEnforcedConstraint() { checkState(enforcedConstraint != null, "enforcedConstraint should only be used in planner. It is not transported to workers."); return enforcedConstraint; diff --git a/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java b/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java index 8a7e395153f8..d19f9da60028 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java @@ -1053,7 +1053,7 @@ public Optional> } @Override - public Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) + public Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) { Span span = startSpan("applyFilter", handle); try (var ignored = scopedSpan(span)) { diff --git a/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java b/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java index aebc0d025b2f..933c97521888 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java @@ -895,7 +895,7 @@ public Optional> applyFil } @Override - public Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint) + public Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint) { Span span = startSpan("applyFilter"); if (span.isRecording()) { diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java index f4e1b3045dd0..60296eeb212e 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java @@ -458,7 +458,7 @@ public Optional> } @Override - public Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) + public Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) { return applyFilterForPtf.apply(session, handle, constraint); } diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java index bd2cb3d3d202..71c2d5406b4c 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java @@ -352,7 +352,7 @@ public interface ApplyFilter @FunctionalInterface public interface ApplyFilterForPtf { - Optional> apply(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint); + Optional> apply(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint); } @FunctionalInterface diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index aae4a5483114..d362600c1112 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -602,7 +602,7 @@ public Optional> applyFil } @Override - public Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint) + public Optional> applyFilter(Session session, TableFunctionHandle handle, Constraint constraint) { return Optional.empty(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterIntoTableFunction.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterIntoTableFunction.java index bd0b4c1902a2..02dae83642c9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterIntoTableFunction.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushFilterIntoTableFunction.java @@ -18,7 +18,6 @@ import io.trino.Session; import io.trino.connector.MockConnectorFactory; import io.trino.spi.connector.CatalogHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.function.table.ConnectorTableFunctionHandle; import io.trino.spi.predicate.NullableValue; @@ -28,7 +27,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.parallel.ResourceLock; -import java.util.Map; import java.util.Optional; import static io.trino.spi.type.BigintType.BIGINT; @@ -43,36 +41,10 @@ public class TestPushFilterIntoTableFunction extends BaseRuleTest { private static final String MOCK_CATALOG = "mock_catalog"; - private static final ConnectorTableFunctionHandle TABLE_FUNCTION_CONSUMES_ENTIRE_PREDICATE = new ConnectorTableFunctionHandle() - { - @Override - public Map getColumnHandles() - { - return ImmutableMap.of("p", new ColumnHandle() {}); - } - - @Override - public boolean supportsPredicatePushdown() - { - return true; - } - }; - - public static final ColumnHandle COLUMN_HANDLE = new ColumnHandle() {}; - private static final ConnectorTableFunctionHandle TABLE_FUNCTION_CONSUMES_PREDICATE_PARTIALLY = new ConnectorTableFunctionHandle() - { - @Override - public Map getColumnHandles() - { - return ImmutableMap.of("p", new ColumnHandle() {}, "z", COLUMN_HANDLE); - } + private static final ConnectorTableFunctionHandle TABLE_FUNCTION_CONSUMES_ENTIRE_PREDICATE = new ConnectorTableFunctionHandle() {}; - @Override - public boolean supportsPredicatePushdown() - { - return true; - } - }; + public static final int PUSHDOWN_COLUMN = 1; + private static final ConnectorTableFunctionHandle TABLE_FUNCTION_CONSUMES_PREDICATE_PARTIALLY = new ConnectorTableFunctionHandle() {}; private static final ConnectorTableFunctionHandle RESULT_TABLE_FUNCTION_HANDLE = new ConnectorTableFunctionHandle() {}; private PushFilterIntoTableFunction pushFilterIntoTableFunction; @@ -93,7 +65,7 @@ public void init() if (tableFunctionHandle.equals(TABLE_FUNCTION_CONSUMES_PREDICATE_PARTIALLY)) { return Optional.of(new ConstraintApplicationResult<>( RESULT_TABLE_FUNCTION_HANDLE, - TupleDomain.fromFixedValues(ImmutableMap.of(COLUMN_HANDLE, NullableValue.of(BIGINT, (long) 1))), + TupleDomain.fromFixedValues(ImmutableMap.of(PUSHDOWN_COLUMN, NullableValue.of(BIGINT, (long) 1))), false)); } return Optional.empty(); @@ -111,37 +83,6 @@ public void testDoesNotFireIfNoTableFunctionProcessor() .doesNotFire(); } - @Test - public void testDoesNotFireWhenFunctionDoesntSupportPredicatePushdown() - { - tester().assertThat(pushFilterIntoTableFunction) - .on(p -> p.filter( - expression("nationkey % 17 = BIGINT '44' AND nationkey % 15 = BIGINT '43'"), - p.tableFunctionProcessor( - builder -> builder - .name("test_function") - .properOutputs(p.symbol("p")) - .source(p.values(p.symbol("x"))) - .supportsPredicatePushdown(false)))) - .doesNotFire(); - } - - @Test - public void testDoesNotFireWhenFunctionDoesntSupportAllTheDomains() - { - tester().assertThat(pushFilterIntoTableFunction) - .on(p -> p.filter( - expression("p = BIGINT '44' AND x = BIGINT '21'"), - p.tableFunctionProcessor( - builder -> builder - .name("test_function") - .properOutputs(p.symbol("p")) - .source(p.values(p.symbol("x"))) - .supportsPredicatePushdown(true) - .supportedColumnHandles(ImmutableMap.of("p", new ColumnHandle() {}))))) - .doesNotFire(); - } - @Test public void testDoesNotFireWhenApplyFilterReturnsEmptyResult() { @@ -154,8 +95,7 @@ public void testDoesNotFireWhenApplyFilterReturnsEmptyResult() .properOutputs(p.symbol("p")) .source(p.values(p.symbol("x"))) .supportsPredicatePushdown(true) - .catalogHandle(catalogHandle) - .supportedColumnHandles(ImmutableMap.of("p", new ColumnHandle() {}))))) + .catalogHandle(catalogHandle)))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java index 7d82c2ceb8cd..4f5e3426d7ad 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java @@ -14,11 +14,9 @@ package io.trino.sql.planner.iterative.rule.test; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.metadata.TableFunctionHandle; import io.trino.spi.connector.CatalogHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.function.table.ConnectorTableFunctionHandle; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -50,7 +48,6 @@ public class TableFunctionProcessorBuilder private int preSorted; private Optional hashSymbol = Optional.empty(); private boolean supportsPredicatePushdown; - private Map supportedColumnHandles = ImmutableMap.of(); private CatalogHandle catalogHandle = TEST_CATALOG_HANDLE; private ConnectorTableFunctionHandle connectorTableFunctionHandle; @@ -128,12 +125,6 @@ public TableFunctionProcessorBuilder supportsPredicatePushdown(boolean supports) return this; } - public TableFunctionProcessorBuilder supportedColumnHandles(Map columnHandles) - { - this.supportedColumnHandles = columnHandles; - return this; - } - public TableFunctionProcessorBuilder catalogHandle(CatalogHandle catalogHandle) { this.catalogHandle = catalogHandle; @@ -164,21 +155,7 @@ public TableFunctionProcessorNode build(PlanNodeIdAllocator idAllocator) new TableFunctionHandle( catalogHandle, connectorTableFunctionHandle == null ? - new ConnectorTableFunctionHandle() - { - @Override - public boolean supportsPredicatePushdown() - { - return supportsPredicatePushdown; - } - - @Override - public Map getColumnHandles() - { - return supportedColumnHandles; - } - } : - connectorTableFunctionHandle, + new ConnectorTableFunctionHandle() {} : connectorTableFunctionHandle, TestingTransactionHandle.create()), TupleDomain.all()); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index 723556b5d767..7023bdfafb92 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -1102,7 +1102,7 @@ default Optional return Optional.empty(); } - default Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) + default Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) { // applyFilter is expected not to be invoked with a "false" constraint if (constraint.getSummary().isNone()) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/table/ConnectorTableFunctionHandle.java b/core/trino-spi/src/main/java/io/trino/spi/function/table/ConnectorTableFunctionHandle.java index ea1237ef7f37..8514e215af7c 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/table/ConnectorTableFunctionHandle.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/table/ConnectorTableFunctionHandle.java @@ -14,9 +14,6 @@ package io.trino.spi.function.table; import io.trino.spi.Experimental; -import io.trino.spi.connector.ColumnHandle; - -import java.util.Map; /** * An area to store all information necessary to execute the table function, gathered at analysis time @@ -24,13 +21,4 @@ @Experimental(eta = "2022-10-31") public interface ConnectorTableFunctionHandle { - default Map getColumnHandles() - { - return Map.of(); - } - - default boolean supportsPredicatePushdown() - { - return false; - } } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index f242bde58711..308adf78af32 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -892,7 +892,7 @@ public Optional> } @Override - public Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) + public Optional> applyFilter(ConnectorSession session, ConnectorTableFunctionHandle handle, Constraint constraint) { try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { return delegate.applyFilter(session, handle, constraint); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index 05bbfcd5e642..1ac6ef92222e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -3783,7 +3783,8 @@ private Object getHiveTableProperty(String tableName, Function new AssertionError("table not found: " + name)); - table = metadata.applyFilter(transactionSession, table, Constraint.alwaysTrue()) + Constraint constraint = Constraint.alwaysTrue(); + table = metadata.applyFilter(transactionSession, table, constraint) .orElseThrow(() -> new AssertionError("applyFilter did not return a result")) .getHandle(); return propertyGetter.apply((HiveTableHandle) table.getConnectorHandle());