diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java index 3b1e6d0dec76b..57275029dbde8 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java @@ -850,6 +850,49 @@ public void testFromLimit() { assertThat(results.values(), contains(anyOf(contains(1L), contains(2L)), anyOf(contains(1L), contains(2L)))); } + public void testProjectAfterTopN() { + EsqlQueryResponse results = run("from test | sort time | limit 2 | project count"); + logger.info(results); + assertEquals(1, results.columns().size()); + assertEquals(new ColumnInfo("count", "long"), results.columns().get(0)); + assertEquals(2, results.values().size()); + assertEquals(40L, results.values().get(0).get(0)); + assertEquals(42L, results.values().get(1).get(0)); + } + + public void testProjectAfterTopNDesc() { + EsqlQueryResponse results = run("from test | sort time desc | limit 2 | project count"); + logger.info(results); + assertEquals(1, results.columns().size()); + assertEquals(new ColumnInfo("count", "long"), results.columns().get(0)); + assertEquals(2, results.values().size()); + assertEquals(46L, results.values().get(0).get(0)); + assertEquals(44L, results.values().get(1).get(0)); + } + + public void testTopNProjectEval() { + EsqlQueryResponse results = run("from test | sort time | limit 2 | project count | eval x = count + 1"); + logger.info(results); + assertEquals(2, results.columns().size()); + assertEquals(new ColumnInfo("count", "long"), results.columns().get(0)); + assertEquals(new ColumnInfo("x", "long"), results.columns().get(1)); + assertEquals(2, results.values().size()); + assertEquals(40L, results.values().get(0).get(0)); + assertEquals(41L, results.values().get(0).get(1)); + assertEquals(42L, results.values().get(1).get(0)); + assertEquals(43L, results.values().get(1).get(1)); + } + + public void testTopNProjectEvalProject() { + EsqlQueryResponse results = run("from test | sort time | limit 2 | project count | eval x = count + 1 | project x"); + logger.info(results); + assertEquals(1, results.columns().size()); + assertEquals(new ColumnInfo("x", "long"), results.columns().get(0)); + assertEquals(2, results.values().size()); + assertEquals(41L, results.values().get(0).get(0)); + assertEquals(43L, results.values().get(1).get(0)); + } + public void testEmptyIndex() { ElasticsearchAssertions.assertAcked( client().admin().indices().prepareCreate("test_empty").setMapping("k", "type=keyword", "v", "type=long").get() diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java index a3646b19110b5..8aafcb89c4bf9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.FieldAttribute; +import org.elasticsearch.xpack.ql.expression.NamedExpression; import org.elasticsearch.xpack.ql.expression.predicate.Predicates; import org.elasticsearch.xpack.ql.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; @@ -126,32 +127,100 @@ protected PhysicalPlan rule(LocalPlanExec plan) { * Copy any limit/sort/topN in the local plan (before the exchange) after it so after gathering the data, * the limit still applies. */ - private static class LocalToGlobalLimitAndTopNExec extends Rule { + private static class LocalToGlobalLimitAndTopNExec extends OptimizerRule { - public PhysicalPlan apply(PhysicalPlan plan) { - return plan.transformUp(UnaryExec.class, u -> { - PhysicalPlan pl = u; - if (u.child()instanceof ExchangeExec exchange) { - var localLimit = findLocalLimitOrTopN(exchange); - if (localLimit != null) { - pl = localLimit.replaceChild(u); - } - } - return pl; - }); + private LocalToGlobalLimitAndTopNExec() { + super(OptimizerRules.TransformDirection.UP); + } + + @Override + protected PhysicalPlan rule(ExchangeExec exchange) { + if (exchange.getType() == ExchangeExec.Type.GATHER) { + return maybeAddGlobalLimitOrTopN(exchange); + } + return exchange; } - private UnaryExec findLocalLimitOrTopN(UnaryExec localPlan) { - for (var plan = localPlan.child();;) { - if (plan instanceof LimitExec || plan instanceof TopNExec) { - return (UnaryExec) plan; + /** + * This method copies any Limit/Sort/TopN in the local plan (before the exchange) after it, + * ensuring that all the inputs are available at that point + * eg. if between the exchange and the TopN there is a project that filters out + * some inputs needed by the topN (i.e. the sorting fields), this method also modifies + * the existing project to make these inputs available to the global TopN, and then adds + * another project at the end of the plan, to ensure that the original semantics + * are preserved. + * + * In detail: + *
    + *
  1. Traverse the plan down starting from the exchange, looking for the first Limit/Sort/TopN
  2. + *
  3. If a Limit is found, copy it after the Exchange to make it global limit
  4. + *
  5. If a TopN is found, copy it after the Exchange and ensure that it has all the inputs needed: + *
      + *
    1. Starting from the TopN, traverse the plan backwards and check that all the nodes propagate + * the inputs needed by the TopN
    2. + *
    3. If a Project node filters out some of the inputs needed by the TopN, + * replace it with another one that includes those inputs
    4. + *
    5. Copy the TopN after the exchange, to make it global
    6. + *
    7. If the outputs of the new global TopN are different from the outputs of the original Exchange, + * add another Project that filters out the unneeded outputs and preserves the original semantics
    8. + *
    + *
  6. + *
+ * @param exchange + * @return + */ + private PhysicalPlan maybeAddGlobalLimitOrTopN(ExchangeExec exchange) { + List visitedNodes = new ArrayList<>(); + visitedNodes.add(exchange); + AttributeSet exchangeOutputSet = exchange.outputSet(); + // step 1: traverse the plan and find Limit/TopN + for (var plan = exchange.child();;) { + if (plan instanceof LimitExec limit) { + // Step 2: just add a global Limit + return limit.replaceChild(exchange); + } + if (plan instanceof TopNExec topN) { + // Step 3: copy the TopN after the Exchange and ensure that it has all the inputs needed + Set requiredAttributes = Expressions.references(topN.order()).combine(topN.inputSet()); + if (exchangeOutputSet.containsAll(requiredAttributes)) { + return topN.replaceChild(exchange); + } + + PhysicalPlan subPlan = topN; + // Step 3.1: Traverse the plan backwards to check inputs available + for (int i = visitedNodes.size() - 1; i >= 0; i--) { + UnaryExec node = visitedNodes.get(i); + if (node instanceof ProjectExec proj && node.outputSet().containsAll(requiredAttributes) == false) { + // Step 3.2: a Project is filtering out some inputs needed by the global TopN, + // replace it with another one that preserves these inputs + List newProjections = new ArrayList<>(proj.projections()); + for (Attribute attr : requiredAttributes) { + if (newProjections.contains(attr) == false) { + newProjections.add(attr); + } + } + node = new ProjectExec(proj.source(), proj.child(), newProjections); + } + subPlan = node.replaceChild(subPlan); + } + + // Step 3.3: add the global TopN right after the exchange + topN = topN.replaceChild(subPlan); + if (exchangeOutputSet.containsAll(topN.output())) { + return topN; + } else { + // Step 3.4: the output propagation is leaking at the end of the plan, + // add one more Project to preserve the original query semantics + return new ProjectExec(topN.source(), topN, new ArrayList<>(exchangeOutputSet)); + } } - // possible to go deeper if (plan instanceof ProjectExec || plan instanceof EvalExec) { + visitedNodes.add((UnaryExec) plan); + // go deeper with step 1 plan = ((UnaryExec) plan).child(); } else { - // no limit specified - return null; + // no limit specified, return the original plan + return exchange; } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index 87a29d26193bc..d34f27db36733 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.esql.session.EsqlConfiguration; import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.FieldAttribute; +import org.elasticsearch.xpack.ql.expression.NamedExpression; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.ql.index.EsIndex; @@ -48,6 +49,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import static java.util.Arrays.asList; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; @@ -639,7 +641,8 @@ public void testExtractorForEvalWithoutProject() throws Exception { | sort nullsum | limit 1 """)); - var topN = as(optimized, TopNExec.class); + var topProject = as(optimized, ProjectExec.class); + var topN = as(topProject.child(), TopNExec.class); var exchange = as(topN.child(), ExchangeExec.class); var project = as(exchange.child(), ProjectExec.class); var extract = as(project.child(), FieldExtractExec.class); @@ -647,6 +650,26 @@ public void testExtractorForEvalWithoutProject() throws Exception { var eval = as(topNLocal.child(), EvalExec.class); } + public void testProjectAfterTopN() throws Exception { + var optimized = optimizedPlan(physicalPlan(""" + from test + | sort emp_no + | project first_name + | limit 2 + """)); + var topProject = as(optimized, ProjectExec.class); + assertEquals(1, topProject.projections().size()); + assertEquals("first_name", topProject.projections().get(0).name()); + var topN = as(topProject.child(), TopNExec.class); + var exchange = as(topN.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + List projectionNames = project.projections().stream().map(NamedExpression::name).collect(Collectors.toList()); + assertTrue(projectionNames.containsAll(List.of("first_name", "emp_no"))); + var extract = as(project.child(), FieldExtractExec.class); + var topNLocal = as(extract.child(), TopNExec.class); + var fieldExtract = as(topNLocal.child(), FieldExtractExec.class); + } + private static EsQueryExec source(PhysicalPlan plan) { if (plan instanceof ExchangeExec exchange) { assertThat(exchange.getPartitioning(), is(ExchangeExec.Partitioning.FIXED_ARBITRARY_DISTRIBUTION));