Skip to content

Commit

Permalink
Infer count() aggregation is not null
Browse files Browse the repository at this point in the history
`count([DISTINCT [expr]])` aggregate function never returns null. Infer
non-nullness of such aggregate expression. This allows elimination of
the HAVING filter for a query such as

    SELECT ... count(*) AS c
    FROM ...
    GROUP BY ...
    HAVING c IS NOT NULL
  • Loading branch information
findepi committed Jul 4, 2024
1 parent f7ceed3 commit 53b34ae
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 43 deletions.
8 changes: 7 additions & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,16 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => {
Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
match func_def {
// TODO: min/max over non-nullable input should be recognized as non-nullable
AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(),
// TODO: UDF should be able to customize nullability
AggregateFunctionDefinition::UDF(udf)
if udf.name() == "count" && args.len() <= 1 =>
{
Ok(false)
}
AggregateFunctionDefinition::UDF(_) => Ok(true),
}
}
Expand Down
32 changes: 16 additions & 16 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ mod tests {
.project(vec![count(wildcard())])?
.sort(vec![count(wildcard()).sort(true, false)])?
.build()?;
let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64;N]\
\n Projection: count(*) [count(*):Int64;N]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64;N]\
let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
Expand All @@ -152,9 +152,9 @@ mod tests {
.build()?;

let expected = "Filter: t1.a IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [count(*):Int64;N]\
\n Projection: count(*) [count(*):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\
\n Subquery: [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
Expand All @@ -175,9 +175,9 @@ mod tests {
.build()?;

let expected = "Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [count(*):Int64;N]\
\n Projection: count(*) [count(*):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\
\n Subquery: [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
Expand Down Expand Up @@ -207,9 +207,9 @@ mod tests {

let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\
\n Filter: (<subquery>) > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [count(Int64(1)):Int64;N]\
\n Projection: count(Int64(1)) [count(Int64(1)):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64;N]\
\n Subquery: [count(Int64(1)):Int64]\
\n Projection: count(Int64(1)) [count(Int64(1)):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\
\n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
Expand All @@ -235,7 +235,7 @@ mod tests {
.project(vec![count(wildcard())])?
.build()?;

let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\
let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
\n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
Expand All @@ -249,8 +249,8 @@ mod tests {
.project(vec![count(wildcard())])?
.build()?;

let expected = "Projection: count(*) [count(*):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\
let expected = "Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
Expand All @@ -272,7 +272,7 @@ mod tests {
.project(vec![count(wildcard())])?
.build()?;

let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\
let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
Expand Down
52 changes: 26 additions & 26 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\
let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
\n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -419,7 +419,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -437,7 +437,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -456,7 +456,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -470,8 +470,8 @@ mod tests {
.aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
.build()?;

let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\
let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -487,8 +487,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64;N]\
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -507,7 +507,7 @@ mod tests {
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(DISTINCT test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -534,8 +534,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64;N, MAX(alias1):UInt32;N]\
let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -554,7 +554,7 @@ mod tests {
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(test.c):Int64;N]\
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -569,8 +569,8 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64;N]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand Down Expand Up @@ -599,8 +599,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64;N, MAX(alias1):UInt32;N]\
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -618,8 +618,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64;N]\
let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -637,8 +637,8 @@ mod tests {
)?
.build()?;
// Should work
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64;N]\
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64]\
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64]\
\n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

Expand All @@ -662,7 +662,7 @@ mod tests {
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -682,7 +682,7 @@ mod tests {
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -705,7 +705,7 @@ mod tests {
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -725,7 +725,7 @@ mod tests {
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -746,7 +746,7 @@ mod tests {
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
Loading

0 comments on commit 53b34ae

Please sign in to comment.