diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java index 32f348367c7a0..bdfba82384c67 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java @@ -873,6 +873,43 @@ public void testWithUnsupportedFieldsAndConflicts() { assertTrue(e.getMessage().contains("Column [embedding] has conflicting data types")); } + public void testValidationsAfterFork() { + var firstQuery = """ + FROM test* + | FORK ( WHERE true ) + ( WHERE true ) + | DROP _fork + | STATS a = count_distinct(embedding) + """; + + var e = expectThrows(VerificationException.class, () -> run(firstQuery)); + assertTrue( + e.getMessage().contains("[count_distinct(embedding)] must be [any exact type except unsigned_long, _source, or counter types]") + ); + + var secondQuery = """ + FROM test* + | FORK ( WHERE true ) + ( WHERE true ) + | DROP _fork + | EVAL a = substring(1, 2, 3) + """; + + e = expectThrows(VerificationException.class, () -> run(secondQuery)); + assertTrue(e.getMessage().contains("first argument of [substring(1, 2, 3)] must be [string], found value [1] type [integer]")); + + var thirdQuery = """ + FROM test* + | FORK ( WHERE true ) + ( WHERE true ) + | DROP _fork + | EVAL a = b + 2 + """; + + e = expectThrows(VerificationException.class, () -> run(thirdQuery)); + assertTrue(e.getMessage().contains("Unknown column [b]")); + } + public void testWithEvalWithConflictingTypes() { var query = """ FROM test diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 0f0331992416d..f48c95397dcab 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -788,10 +788,8 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) { } List subPlanColumns = logicalPlan.output().stream().map(Attribute::name).toList(); - // We need to add an explicit Keep even if the outputs align - // This is because at the moment the sub plans are executed and optimized separately and the output might change - // during optimizations. Once we add streaming we might not need to add a Keep when the outputs already align. - if (logicalPlan instanceof Keep == false || subPlanColumns.equals(forkColumns) == false) { + // We need to add an explicit EsqlProject to align the outputs. + if (logicalPlan instanceof Project == false || subPlanColumns.equals(forkColumns) == false) { changed = true; List newOutput = new ArrayList<>(); for (String attrName : forkColumns) { @@ -801,7 +799,7 @@ private LogicalPlan resolveFork(Fork fork, AnalyzerContext context) { } } } - logicalPlan = new Keep(logicalPlan.source(), logicalPlan, newOutput); + logicalPlan = resolveKeep(new Keep(logicalPlan.source(), logicalPlan, newOutput), logicalPlan.output()); } newSubPlans.add(logicalPlan); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 0c7cdf4855196..b99050b8ef090 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -74,7 +74,6 @@ import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Fork; import org.elasticsearch.xpack.esql.plan.logical.Insist; -import org.elasticsearch.xpack.esql.plan.logical.Keep; import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; @@ -3090,27 +3089,27 @@ public void testBasicFork() { // fork branch 1 limit = as(subPlans.get(0), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); - Keep keep = as(limit.child(), Keep.class); - List keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - Eval eval = as(keep.child(), Eval.class); + EsqlProject project = as(limit.child(), EsqlProject.class); + List projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + Eval eval = as(project.child(), Eval.class); assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork1")))); Filter filter = as(eval.child(), Filter.class); assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(1))); filter = as(filter.child(), Filter.class); assertThat(as(filter.condition(), Equals.class).right(), equalTo(string("Chris"))); - EsqlProject project = as(filter.child(), EsqlProject.class); + project = as(filter.child(), EsqlProject.class); var esRelation = as(project.child(), EsRelation.class); assertThat(esRelation.indexPattern(), equalTo("test")); // fork branch 2 limit = as(subPlans.get(1), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); - keep = as(limit.child(), Keep.class); - keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - eval = as(keep.child(), Eval.class); + project = as(limit.child(), EsqlProject.class); + projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + eval = as(project.child(), Eval.class); assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork2")))); filter = as(eval.child(), Filter.class); assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(2))); @@ -3124,10 +3123,10 @@ public void testBasicFork() { // fork branch 3 limit = as(subPlans.get(2), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT)); - keep = as(limit.child(), Keep.class); - keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - eval = as(keep.child(), Eval.class); + project = as(limit.child(), EsqlProject.class); + projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + eval = as(project.child(), Eval.class); assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork3")))); limit = as(eval.child(), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(7)); @@ -3143,10 +3142,10 @@ public void testBasicFork() { // fork branch 4 limit = as(subPlans.get(3), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); - keep = as(limit.child(), Keep.class); - keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - eval = as(keep.child(), Eval.class); + project = as(limit.child(), EsqlProject.class); + projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + eval = as(project.child(), Eval.class); assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork4")))); orderBy = as(eval.child(), OrderBy.class); filter = as(orderBy.child(), Filter.class); @@ -3158,10 +3157,10 @@ public void testBasicFork() { // fork branch 5 limit = as(subPlans.get(4), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT)); - keep = as(limit.child(), Keep.class); - keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - eval = as(keep.child(), Eval.class); + project = as(limit.child(), EsqlProject.class); + projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + eval = as(project.child(), Eval.class); assertThat(as(eval.fields().get(0), Alias.class), equalTo(alias("_fork", string("fork5")))); limit = as(eval.child(), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(9)); @@ -3193,11 +3192,11 @@ public void testForkBranchesWithDifferentSchemas() { // fork branch 1 limit = as(subPlans.get(0), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT)); - Keep keep = as(limit.child(), Keep.class); - List keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); + EsqlProject project = as(limit.child(), EsqlProject.class); + List projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); - Eval eval = as(keep.child(), Eval.class); + Eval eval = as(project.child(), Eval.class); assertEquals(eval.fields().size(), 3); Set evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet()); @@ -3215,7 +3214,7 @@ public void testForkBranchesWithDifferentSchemas() { Filter filter = as(orderBy.child(), Filter.class); assertThat(as(filter.condition(), GreaterThan.class).right(), equalTo(literal(3))); - EsqlProject project = as(filter.child(), EsqlProject.class); + project = as(filter.child(), EsqlProject.class); filter = as(project.child(), Filter.class); assertThat(as(filter.condition(), Equals.class).right(), equalTo(string("Chris"))); var esRelation = as(filter.child(), EsRelation.class); @@ -3224,10 +3223,10 @@ public void testForkBranchesWithDifferentSchemas() { // fork branch 2 limit = as(subPlans.get(1), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); - keep = as(limit.child(), Keep.class); - keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - eval = as(keep.child(), Eval.class); + project = as(limit.child(), EsqlProject.class); + projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + eval = as(project.child(), Eval.class); assertEquals(eval.fields().size(), 2); evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet()); assertThat(evalFieldNames, equalTo(Set.of("x", "y"))); @@ -3254,10 +3253,10 @@ public void testForkBranchesWithDifferentSchemas() { // fork branch 3 limit = as(subPlans.get(2), Limit.class); assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); - keep = as(limit.child(), Keep.class); - keptColumns = keep.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); - assertThat(keptColumns, equalTo(expectedOutput)); - eval = as(keep.child(), Eval.class); + project = as(limit.child(), EsqlProject.class); + projectColumns = project.expressions().stream().map(exp -> as(exp, Attribute.class).name()).toList(); + assertThat(projectColumns, equalTo(expectedOutput)); + eval = as(project.child(), Eval.class); assertEquals(eval.fields().size(), 2); evalFieldNames = eval.fields().stream().map(a -> a.name()).collect(Collectors.toSet()); assertThat(evalFieldNames, equalTo(Set.of("emp_no", "first_name")));