Skip to content

Commit

Permalink
Make group expressions nullable more accurate (#12256)
Browse files Browse the repository at this point in the history
* Make group expressions nullable more accurate

* Add test
  • Loading branch information
lewiszlw authored Sep 2, 2024
1 parent dd32089 commit 53de592
Showing 1 changed file with 61 additions and 8 deletions.
69 changes: 61 additions & 8 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,17 @@ impl PhysicalGroupBy {
}
}

/// Returns true if this GROUP BY contains NULL expressions
pub fn contains_null(&self) -> bool {
self.groups.iter().flatten().any(|is_null| *is_null)
/// Calculate GROUP BY expressions nullable
pub fn exprs_nullable(&self) -> Vec<bool> {
let mut exprs_nullable = vec![false; self.expr.len()];
for group in self.groups.iter() {
group.iter().enumerate().for_each(|(index, is_null)| {
if *is_null {
exprs_nullable[index] = true;
}
})
}
exprs_nullable
}

/// Returns the group expressions
Expand Down Expand Up @@ -278,7 +286,7 @@ pub struct AggregateExec {
}

impl AggregateExec {
/// Function used in `ConvertFirstLast` optimizer rule,
/// Function used in `OptimizeAggregateOrder` optimizer rule,
/// where we need parts of the new value, others cloned from the old one
/// Rewrites aggregate exec with new aggregate expressions.
pub fn with_new_aggr_exprs(
Expand Down Expand Up @@ -319,7 +327,7 @@ impl AggregateExec {
&input.schema(),
&group_by.expr,
&aggr_expr,
group_by.contains_null(),
group_by.exprs_nullable(),
mode,
)?;

Expand Down Expand Up @@ -793,18 +801,18 @@ fn create_schema(
input_schema: &Schema,
group_expr: &[(Arc<dyn PhysicalExpr>, String)],
aggr_expr: &[Arc<AggregateFunctionExpr>],
contains_null_expr: bool,
group_expr_nullable: Vec<bool>,
mode: AggregateMode,
) -> Result<Schema> {
let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
for (expr, name) in group_expr {
for (index, (expr, name)) in group_expr.iter().enumerate() {
fields.push(Field::new(
name,
expr.data_type(input_schema)?,
// In cases where we have multiple grouping sets, we will use NULL expressions in
// order to align the grouping sets. So the field must be nullable even if the underlying
// schema field is not.
contains_null_expr || expr.nullable(input_schema)?,
group_expr_nullable[index] || expr.nullable(input_schema)?,
))
}

Expand Down Expand Up @@ -2489,4 +2497,49 @@ mod tests {

Ok(())
}

#[test]
fn group_exprs_nullable() -> Result<()> {
let input_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

let aggr_expr =
vec![
AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
.schema(Arc::clone(&input_schema))
.alias("COUNT(a)")
.build()?,
];

let grouping_set = PhysicalGroupBy {
expr: vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
null_expr: vec![
(lit(ScalarValue::Float32(None)), "a".to_string()),
(lit(ScalarValue::Float32(None)), "b".to_string()),
],
groups: vec![
vec![false, true], // (a, NULL)
vec![false, false], // (a,b)
],
};
let aggr_schema = create_schema(
&input_schema,
&grouping_set.expr,
&aggr_expr,
grouping_set.exprs_nullable(),
AggregateMode::Final,
)?;
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, true),
Field::new("COUNT(a)", DataType::Int64, false),
]);
assert_eq!(aggr_schema, expected_schema);
Ok(())
}
}

0 comments on commit 53de592

Please sign in to comment.