@@ -2294,13 +2294,25 @@ impl Aggregate {
22942294 aggr_expr : Vec < Expr > ,
22952295 ) -> Result < Self > {
22962296 let group_expr = enumerate_grouping_sets ( group_expr) ?;
2297+
2298+ let is_grouping_set = matches ! ( group_expr. as_slice( ) , [ Expr :: GroupingSet ( _) ] ) ;
2299+
22972300 let grouping_expr: Vec < Expr > = grouping_set_to_exprlist ( group_expr. as_slice ( ) ) ?;
2298- let all_expr = grouping_expr. iter ( ) . chain ( aggr_expr. iter ( ) ) ;
22992301
2300- let schema = DFSchema :: new_with_metadata (
2301- exprlist_to_fields ( all_expr, & input) ?,
2302- input. schema ( ) . metadata ( ) . clone ( ) ,
2303- ) ?;
2302+ let mut fields = exprlist_to_fields ( grouping_expr. iter ( ) , & input) ?;
2303+
2304+ // Even columns that cannot be null will become nullable when used in a grouping set.
2305+ if is_grouping_set {
2306+ fields = fields
2307+ . into_iter ( )
2308+ . map ( |field| field. with_nullable ( true ) )
2309+ . collect :: < Vec < _ > > ( ) ;
2310+ }
2311+
2312+ fields. extend ( exprlist_to_fields ( aggr_expr. iter ( ) , & input) ?) ;
2313+
2314+ let schema =
2315+ DFSchema :: new_with_metadata ( fields, input. schema ( ) . metadata ( ) . clone ( ) ) ?;
23042316
23052317 Self :: try_new_with_schema ( input, group_expr, aggr_expr, Arc :: new ( schema) )
23062318 }
@@ -2539,7 +2551,7 @@ pub struct Unnest {
25392551mod tests {
25402552 use super :: * ;
25412553 use crate :: logical_plan:: table_scan;
2542- use crate :: { col, exists, in_subquery, lit, placeholder} ;
2554+ use crate :: { col, count , exists, in_subquery, lit, placeholder, GroupingSet } ;
25432555 use arrow:: datatypes:: { DataType , Field , Schema } ;
25442556 use datafusion_common:: tree_node:: TreeNodeVisitor ;
25452557 use datafusion_common:: { not_impl_err, DFSchema , TableReference } ;
@@ -3006,4 +3018,36 @@ digraph {
30063018 plan. replace_params_with_values ( & [ 42i32 . into ( ) ] )
30073019 . expect_err ( "unexpectedly succeeded to replace an invalid placeholder" ) ;
30083020 }
3021+
3022+ #[ test]
3023+ fn test_nullable_schema_after_grouping_set ( ) {
3024+ let schema = Schema :: new ( vec ! [
3025+ Field :: new( "foo" , DataType :: Int32 , false ) ,
3026+ Field :: new( "bar" , DataType :: Int32 , false ) ,
3027+ ] ) ;
3028+
3029+ let plan = table_scan ( TableReference :: none ( ) , & schema, None )
3030+ . unwrap ( )
3031+ . aggregate (
3032+ vec ! [ Expr :: GroupingSet ( GroupingSet :: GroupingSets ( vec![
3033+ vec![ col( "foo" ) ] ,
3034+ vec![ col( "bar" ) ] ,
3035+ ] ) ) ] ,
3036+ vec ! [ count( lit( true ) ) ] ,
3037+ )
3038+ . unwrap ( )
3039+ . build ( )
3040+ . unwrap ( ) ;
3041+
3042+ let output_schema = plan. schema ( ) ;
3043+
3044+ assert ! ( output_schema
3045+ . field_with_name( None , "foo" )
3046+ . unwrap( )
3047+ . is_nullable( ) , ) ;
3048+ assert ! ( output_schema
3049+ . field_with_name( None , "bar" )
3050+ . unwrap( )
3051+ . is_nullable( ) ) ;
3052+ }
30093053}
0 commit comments