Skip to content

Commit 8f48053

Browse files
authored
Minor: Make schema of grouping set columns nullable (#8248)
* Make output schema of aggregation grouping sets nullable * Improve * Fix tests
1 parent 76ced31 commit 8f48053

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,13 +2294,25 @@ impl Aggregate {
22942294
aggr_expr: Vec<Expr>,
22952295
) -> Result<Self> {
22962296
let group_expr = enumerate_grouping_sets(group_expr)?;
2297+
2298+
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
2299+
22972300
let grouping_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
2298-
let all_expr = grouping_expr.iter().chain(aggr_expr.iter());
22992301

2300-
let schema = DFSchema::new_with_metadata(
2301-
exprlist_to_fields(all_expr, &input)?,
2302-
input.schema().metadata().clone(),
2303-
)?;
2302+
let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?;
2303+
2304+
// Even columns that cannot be null will become nullable when used in a grouping set.
2305+
if is_grouping_set {
2306+
fields = fields
2307+
.into_iter()
2308+
.map(|field| field.with_nullable(true))
2309+
.collect::<Vec<_>>();
2310+
}
2311+
2312+
fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?);
2313+
2314+
let schema =
2315+
DFSchema::new_with_metadata(fields, input.schema().metadata().clone())?;
23042316

23052317
Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema))
23062318
}
@@ -2539,7 +2551,7 @@ pub struct Unnest {
25392551
mod tests {
25402552
use super::*;
25412553
use crate::logical_plan::table_scan;
2542-
use crate::{col, exists, in_subquery, lit, placeholder};
2554+
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
25432555
use arrow::datatypes::{DataType, Field, Schema};
25442556
use datafusion_common::tree_node::TreeNodeVisitor;
25452557
use datafusion_common::{not_impl_err, DFSchema, TableReference};
@@ -3006,4 +3018,36 @@ digraph {
30063018
plan.replace_params_with_values(&[42i32.into()])
30073019
.expect_err("unexpectedly succeeded to replace an invalid placeholder");
30083020
}
3021+
3022+
#[test]
3023+
fn test_nullable_schema_after_grouping_set() {
3024+
let schema = Schema::new(vec![
3025+
Field::new("foo", DataType::Int32, false),
3026+
Field::new("bar", DataType::Int32, false),
3027+
]);
3028+
3029+
let plan = table_scan(TableReference::none(), &schema, None)
3030+
.unwrap()
3031+
.aggregate(
3032+
vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![
3033+
vec![col("foo")],
3034+
vec![col("bar")],
3035+
]))],
3036+
vec![count(lit(true))],
3037+
)
3038+
.unwrap()
3039+
.build()
3040+
.unwrap();
3041+
3042+
let output_schema = plan.schema();
3043+
3044+
assert!(output_schema
3045+
.field_with_name(None, "foo")
3046+
.unwrap()
3047+
.is_nullable(),);
3048+
assert!(output_schema
3049+
.field_with_name(None, "bar")
3050+
.unwrap()
3051+
.is_nullable());
3052+
}
30093053
}

datafusion/optimizer/src/single_distinct_to_groupby.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ mod tests {
322322
.build()?;
323323

324324
// Should not be optimized
325-
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
325+
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\
326326
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
327327

328328
assert_optimized_plan_equal(&plan, expected)
@@ -340,7 +340,7 @@ mod tests {
340340
.build()?;
341341

342342
// Should not be optimized
343-
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
343+
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\
344344
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
345345

346346
assert_optimized_plan_equal(&plan, expected)
@@ -359,7 +359,7 @@ mod tests {
359359
.build()?;
360360

361361
// Should not be optimized
362-
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
362+
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\
363363
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
364364

365365
assert_optimized_plan_equal(&plan, expected)

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2672,9 +2672,10 @@ query TT
26722672
EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
26732673
----
26742674
logical_plan
2675-
Limit: skip=0, fetch=3
2676-
--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]]
2677-
----TableScan: aggregate_test_100 projection=[c2, c3]
2675+
Projection: aggregate_test_100.c2, aggregate_test_100.c3
2676+
--Limit: skip=0, fetch=3
2677+
----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]]
2678+
------TableScan: aggregate_test_100 projection=[c2, c3]
26782679
physical_plan
26792680
GlobalLimitExec: skip=0, fetch=3
26802681
--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3]

0 commit comments

Comments
 (0)