Skip to content

Commit 111e4e1

Browse files
committed
SQL: fix multi full-text functions usage with aggregate functions (#47444)
* Skip functions involving full-text predicates when replacing multiple aggregate functions with "stats" or "matrix_stats" aggregations. (cherry picked from commit bb14ba8)
1 parent abeca45 commit 111e4e1

File tree

3 files changed

+146
-3
lines changed

3 files changed

+146
-3
lines changed

x-pack/plugin/sql/qa/src/main/resources/fulltext.csv-spec

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,64 @@ SELECT emp_no, first_name, SCORE() as s FROM test_emp WHERE MATCH(first_name, 'E
139139
emp_no:i | first_name:s | s:f
140140
10076 |Erez |4.1053944
141141
;
142+
143+
//
144+
// Mixture of Aggs that triggers promotion of aggs to stats using multi full-text filtering
145+
//
146+
multiAggWithCountMatchAndQuery
147+
SELECT MIN(salary) min, MAX(salary) max, gender g, COUNT(*) c FROM "test_emp" WHERE languages > 0 AND (MATCH(gender, 'F') OR MATCH(gender, 'M')) AND QUERY('M*', 'default_field=last_name;lenient=true', 'fuzzy_rewrite=scoring_boolean') GROUP BY g HAVING max > 50000 ORDER BY gender;
148+
149+
min:i | max:i | g:s | c:l
150+
---------------+---------------+---------------+---------------
151+
37112 |69904 |F |3
152+
32568 |70011 |M |8
153+
;
154+
155+
multiAggWithCountAndMultiMatch
156+
SELECT MIN(salary) min, MAX(salary) max, gender g, COUNT(*) c FROM "test_emp" WHERE MATCH(gender, 'F') OR MATCH(gender, 'M') GROUP BY g HAVING max > 50000 ORDER BY gender;
157+
158+
min:i | max:i | g:s | c:l
159+
---------------+---------------+---------------+---------------
160+
25976 |74572 |F |33
161+
25945 |74999 |M |57
162+
;
163+
164+
multiAggWithMultiMatchOrderByCount
165+
SELECT MIN(salary) min, MAX(salary) max, ROUND(AVG(salary)) avg, gender g, COUNT(*) c FROM "test_emp" WHERE MATCH(gender, 'F') OR MATCH('first_name^3,last_name^5', 'geo hir', 'fuzziness=2;operator=or') GROUP BY g ORDER BY c DESC;
166+
167+
min:i | max:i | avg:d | g:s | c:l
168+
---------------+---------------+---------------+---------------+---------------
169+
25976 |74572 |50491 |F |33
170+
32568 |32568 |32568 |M |1
171+
;
172+
173+
multiAggWithMultiMatchOrderByCountAndSimpleCondition
174+
SELECT MIN(salary) min, MAX(salary) max, ROUND(AVG(salary)) avg, gender g, COUNT(*) c FROM "test_emp" WHERE (MATCH(gender, 'F') AND languages > 4) OR MATCH('first_name^3,last_name^5', 'geo hir', 'fuzziness=2;operator=or') GROUP BY g ORDER BY c DESC;
175+
176+
min:i | max:i | avg:d | g:s | c:l
177+
---------------+---------------+---------------+---------------+---------------
178+
32272 |66817 |48081 |F |11
179+
32568 |32568 |32568 |M |1
180+
;
181+
182+
multiAggWithPercentileAndMultiQuery
183+
SELECT languages, PERCENTILE(salary, 95) "95th", ROUND(PERCENTILE_RANK(salary, 65000)) AS rank, MAX(salary), MIN(salary), COUNT(*) c FROM test_emp WHERE QUERY('A*','default_field=first_name') OR QUERY('B*', 'default_field=first_name') OR languages IS NULL GROUP BY languages;
184+
185+
languages:bt | 95th:d | rank:d | MAX(salary):i | MIN(salary):i | c:l
186+
---------------+---------------+---------------+---------------+---------------+---------------
187+
null |74999 |74 |74999 |28336 |10
188+
2 |44307 |100 |44307 |29175 |3
189+
3 |65030 |100 |65030 |38376 |4
190+
5 |66817 |100 |66817 |37137 |4
191+
;
192+
193+
multiAggWithStatsAndMatrixStatsAndMultiQuery
194+
SELECT languages, KURTOSIS(salary) k, SKEWNESS(salary) s, MAX(salary), MIN(salary), COUNT(*) c FROM test_emp WHERE QUERY('A*','default_field=first_name') OR QUERY('B*', 'default_field=first_name') OR languages IS NULL GROUP BY languages;
195+
196+
languages:bt | k:d | s:d | MAX(salary):i | MIN(salary):i | c:l
197+
---------------+------------------+-------------------+---------------+---------------+---------------
198+
null |1.9161749939033146|0.1480828817161133 |74999 |28336 |10
199+
2 |1.5000000000000002|0.484743245141609 |44307 |29175 |3
200+
3 |1.0732551278666582|0.05483979801873433|65030 |38376 |4
201+
5 |1.322529094661261 |0.24501477738153868|66817 |37137 |4
202+
;

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Case;
5252
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Coalesce;
5353
import org.elasticsearch.xpack.sql.expression.predicate.conditional.IfConditional;
54+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.FullTextPredicate;
5455
import org.elasticsearch.xpack.sql.expression.predicate.logical.And;
5556
import org.elasticsearch.xpack.sql.expression.predicate.logical.Not;
5657
import org.elasticsearch.xpack.sql.expression.predicate.logical.Or;
@@ -451,11 +452,11 @@ static LogicalPlan updateAggAttributes(LogicalPlan p, Map<String, AggregateFunct
451452
}
452453
}
453454

454-
else if (e instanceof ScalarFunction) {
455+
else if (e instanceof ScalarFunction && false == Expressions.anyMatch(e.children(), c -> c instanceof FullTextPredicate)) {
455456
ScalarFunction sf = (ScalarFunction) e;
456457

457458
// if it's a unseen function check if the function children/arguments refers to any of the promoted aggs
458-
if (!updatedScalarAttrs.containsKey(sf.functionId()) && e.anyMatch(c -> {
459+
if (newAggIds.isEmpty() == false && !updatedScalarAttrs.containsKey(sf.functionId()) && e.anyMatch(c -> {
459460
Attribute a = Expressions.attribute(c);
460461
if (a instanceof FunctionAttribute) {
461462
return newAggIds.contains(((FunctionAttribute) a).functionId());

x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,19 @@
2020
import org.elasticsearch.xpack.sql.expression.Order.OrderDirection;
2121
import org.elasticsearch.xpack.sql.expression.function.Function;
2222
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
23+
import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg;
2324
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
25+
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats;
2426
import org.elasticsearch.xpack.sql.expression.function.aggregate.First;
27+
import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate;
2528
import org.elasticsearch.xpack.sql.expression.function.aggregate.Last;
2629
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
2730
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
31+
import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
32+
import org.elasticsearch.xpack.sql.expression.function.aggregate.StddevPop;
33+
import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum;
34+
import org.elasticsearch.xpack.sql.expression.function.aggregate.SumOfSquares;
35+
import org.elasticsearch.xpack.sql.expression.function.aggregate.VarPop;
2836
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
2937
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayName;
3038
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfMonth;
@@ -55,7 +63,12 @@
5563
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Iif;
5664
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Least;
5765
import org.elasticsearch.xpack.sql.expression.predicate.conditional.NullIf;
66+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.FullTextPredicate;
67+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
68+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate;
69+
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.StringQueryPredicate;
5870
import org.elasticsearch.xpack.sql.expression.predicate.logical.And;
71+
import org.elasticsearch.xpack.sql.expression.predicate.logical.BinaryLogic;
5972
import org.elasticsearch.xpack.sql.expression.predicate.logical.Not;
6073
import org.elasticsearch.xpack.sql.expression.predicate.logical.Or;
6174
import org.elasticsearch.xpack.sql.expression.predicate.nulls.IsNotNull;
@@ -85,6 +98,8 @@
8598
import org.elasticsearch.xpack.sql.optimizer.Optimizer.FoldNull;
8699
import org.elasticsearch.xpack.sql.optimizer.Optimizer.PropagateEquals;
87100
import org.elasticsearch.xpack.sql.optimizer.Optimizer.PruneDuplicateFunctions;
101+
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceAggsWithExtendedStats;
102+
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceAggsWithStats;
88103
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceFoldableAttributes;
89104
import org.elasticsearch.xpack.sql.optimizer.Optimizer.ReplaceMinMaxWithTopHits;
90105
import org.elasticsearch.xpack.sql.optimizer.Optimizer.SimplifyCase;
@@ -1498,4 +1513,70 @@ public void testSortAggregateOnOrderByOnlyAliases() {
14981513
assertEquals(firstAlias, groupings.get(0));
14991514
assertEquals(secondAlias, groupings.get(1));
15001515
}
1501-
}
1516+
1517+
/**
1518+
* Test queries like SELECT MIN(agg_field), MAX(agg_field) FROM table WHERE MATCH(match_field,'A') AND/OR QUERY('match_field:A')
1519+
* or SELECT STDDEV_POP(agg_field), VAR_POP(agg_field) FROM table WHERE MATCH(match_field,'A') AND/OR QUERY('match_field:A')
1520+
*/
1521+
public void testAggregatesPromoteToStats_WithFullTextPredicatesConditions() {
1522+
FieldAttribute matchField = new FieldAttribute(EMPTY, "match_field", new EsField("match_field", DataType.TEXT, emptyMap(), true));
1523+
FieldAttribute aggField = new FieldAttribute(EMPTY, "agg_field", new EsField("agg_field", DataType.INTEGER, emptyMap(), true));
1524+
1525+
FullTextPredicate matchPredicate = new MatchQueryPredicate(EMPTY, matchField, "A", StringUtils.EMPTY);
1526+
FullTextPredicate multiMatchPredicate = new MultiMatchQueryPredicate(EMPTY, "match_field", "A", StringUtils.EMPTY);
1527+
FullTextPredicate stringQueryPredicate = new StringQueryPredicate(EMPTY, "match_field:A", StringUtils.EMPTY);
1528+
List<FullTextPredicate> predicates = Arrays.asList(matchPredicate, multiMatchPredicate, stringQueryPredicate);
1529+
1530+
FullTextPredicate left = randomFrom(predicates);
1531+
FullTextPredicate right = randomFrom(predicates);
1532+
1533+
BinaryLogic or = new Or(EMPTY, left, right);
1534+
BinaryLogic and = new And(EMPTY, left, right);
1535+
BinaryLogic condition = randomFrom(or, and);
1536+
Filter filter = new Filter(EMPTY, FROM(), condition);
1537+
1538+
List<AggregateFunction> aggregates;
1539+
boolean isSimpleStats = randomBoolean();
1540+
if (isSimpleStats) {
1541+
aggregates = Arrays.asList(new Avg(EMPTY, aggField), new Sum(EMPTY, aggField), new Min(EMPTY, aggField),
1542+
new Max(EMPTY, aggField));
1543+
} else {
1544+
aggregates = Arrays.asList(new StddevPop(EMPTY, aggField), new SumOfSquares(EMPTY, aggField), new VarPop(EMPTY, aggField));
1545+
}
1546+
AggregateFunction firstAggregate = randomFrom(aggregates);
1547+
AggregateFunction secondAggregate = randomValueOtherThan(firstAggregate, () -> randomFrom(aggregates));
1548+
Aggregate aggregatePlan = new Aggregate(EMPTY, filter, Collections.singletonList(matchField),
1549+
Arrays.asList(firstAggregate, secondAggregate));
1550+
LogicalPlan result;
1551+
if (isSimpleStats) {
1552+
result = new ReplaceAggsWithStats().apply(aggregatePlan);
1553+
} else {
1554+
result = new ReplaceAggsWithExtendedStats().apply(aggregatePlan);
1555+
}
1556+
1557+
assertTrue(result instanceof Aggregate);
1558+
Aggregate resultAgg = (Aggregate) result;
1559+
assertEquals(2, resultAgg.aggregates().size());
1560+
assertTrue(resultAgg.aggregates().get(0) instanceof InnerAggregate);
1561+
assertTrue(resultAgg.aggregates().get(1) instanceof InnerAggregate);
1562+
1563+
InnerAggregate resultFirstAgg = (InnerAggregate) resultAgg.aggregates().get(0);
1564+
InnerAggregate resultSecondAgg = (InnerAggregate) resultAgg.aggregates().get(1);
1565+
assertEquals(resultFirstAgg.inner(), firstAggregate);
1566+
assertEquals(resultSecondAgg.inner(), secondAggregate);
1567+
if (isSimpleStats) {
1568+
assertTrue(resultFirstAgg.outer() instanceof Stats);
1569+
assertTrue(resultSecondAgg.outer() instanceof Stats);
1570+
assertEquals(((Stats) resultFirstAgg.outer()).field(), aggField);
1571+
assertEquals(((Stats) resultSecondAgg.outer()).field(), aggField);
1572+
} else {
1573+
assertTrue(resultFirstAgg.outer() instanceof ExtendedStats);
1574+
assertTrue(resultSecondAgg.outer() instanceof ExtendedStats);
1575+
assertEquals(((ExtendedStats) resultFirstAgg.outer()).field(), aggField);
1576+
assertEquals(((ExtendedStats) resultSecondAgg.outer()).field(), aggField);
1577+
}
1578+
1579+
assertTrue(resultAgg.child() instanceof Filter);
1580+
assertEquals(resultAgg.child(), filter);
1581+
}
1582+
}

0 commit comments

Comments
 (0)