Skip to content

Commit

Permalink
Implement IGNORE NULLS for FIRST_VALUE (#9411)
Browse files Browse the repository at this point in the history
* Implement IGNORE NULLS for FIRST_VALUE

* fix style

* fix clippy error

* fix clippy error

* address comments

* fix error

* add test to aggregate.slt

* address comments

* Trigger Build

* Add one additional column in order by to ensure a deterministic order in the output

---------

Co-authored-by: Huaxin Gao <huaxin.gao@apple.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
3 people authored Mar 5, 2024
1 parent 2873fd0 commit 3aba67e
Show file tree
Hide file tree
Showing 25 changed files with 286 additions and 39 deletions.
6 changes: 6 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
args,
filter,
order_by,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
Expand Down Expand Up @@ -1662,6 +1663,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
args,
filter,
order_by,
null_treatment,
}) => {
let args = args
.iter()
Expand Down Expand Up @@ -1689,6 +1691,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
),
None => None,
};
let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
Expand All @@ -1699,6 +1704,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
&ordering_reqs,
physical_input_schema,
name,
ignore_nulls,
)?;
(agg_expr, filter, order_by)
}
Expand Down
80 changes: 80 additions & 0 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,83 @@ async fn test_accumulator_row_accumulator() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_first_value() -> Result<()> {
let session_ctx = SessionContext::new();
session_ctx
.sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)")
.await?
.collect()
.await?;

let results1 = session_ctx
.sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc")
.await?
.collect()
.await?;
let expected1 = [
"+--------------------------+",
"| FIRST_VALUE(abc.column1) |",
"+--------------------------+",
"| 4 |",
"+--------------------------+",
];
assert_batches_eq!(expected1, &results1);

let results2 = session_ctx
.sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc")
.await?
.collect()
.await?;
let expected2 = [
"+--------------------------+",
"| FIRST_VALUE(abc.column1) |",
"+--------------------------+",
"| |",
"+--------------------------+",
];
assert_batches_eq!(expected2, &results2);

Ok(())
}

#[tokio::test]
async fn test_first_value_with_sort() -> Result<()> {
let session_ctx = SessionContext::new();
session_ctx
.sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5), (1, 4, 7), (2, 3, 8)")
.await?
.collect()
.await?;

let results1 = session_ctx
.sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM abc")
.await?
.collect()
.await?;
let expected1 = [
"+--------------------------+",
"| FIRST_VALUE(abc.column1) |",
"+--------------------------+",
"| 2 |",
"+--------------------------+",
];
assert_batches_eq!(expected1, &results1);

let results2 = session_ctx
.sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM abc")
.await?
.collect()
.await?;
let expected2 = [
"+--------------------------+",
"| FIRST_VALUE(abc.column1) |",
"+--------------------------+",
"| |",
"+--------------------------+",
];
assert_batches_eq!(expected2, &results2);

Ok(())
}
13 changes: 13 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ pub struct AggregateFunction {
pub filter: Option<Box<Expr>>,
/// Optional ordering
pub order_by: Option<Vec<Expr>>,
pub null_treatment: Option<NullTreatment>,
}

impl AggregateFunction {
Expand All @@ -552,13 +553,15 @@ impl AggregateFunction {
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::BuiltIn(fun),
args,
distinct,
filter,
order_by,
null_treatment,
}
}

Expand All @@ -576,6 +579,7 @@ impl AggregateFunction {
distinct,
filter,
order_by,
null_treatment: None,
}
}
}
Expand Down Expand Up @@ -646,6 +650,7 @@ pub struct WindowFunction {
pub order_by: Vec<Expr>,
/// Window frame
pub window_frame: window_frame::WindowFrame,
/// Specifies how NULL value is treated: ignore or respect
pub null_treatment: Option<NullTreatment>,
}

Expand Down Expand Up @@ -1471,9 +1476,13 @@ impl fmt::Display for Expr {
ref args,
filter,
order_by,
null_treatment,
..
}) => {
fmt_function(f, func_def.name(), *distinct, args, true)?;
if let Some(nt) = null_treatment {
write!(f, " {}", nt)?;
}
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
}
Expand Down Expand Up @@ -1804,6 +1813,7 @@ fn create_name(e: &Expr) -> Result<String> {
args,
filter,
order_by,
null_treatment,
}) => {
let name = match func_def {
AggregateFunctionDefinition::BuiltIn(..)
Expand All @@ -1823,6 +1833,9 @@ fn create_name(e: &Expr) -> Result<String> {
if let Some(order_by) = order_by {
info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by));
};
if let Some(nt) = null_treatment {
info += &format!(" {}", nt);
}
match func_def {
AggregateFunctionDefinition::BuiltIn(..)
| AggregateFunctionDefinition::Name(..) => {
Expand Down
13 changes: 13 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub fn min(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -161,6 +162,7 @@ pub fn max(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -172,6 +174,7 @@ pub fn sum(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -183,6 +186,7 @@ pub fn array_agg(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -194,6 +198,7 @@ pub fn avg(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -205,6 +210,7 @@ pub fn count(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand Down Expand Up @@ -261,6 +267,7 @@ pub fn count_distinct(expr: Expr) -> Expr {
true,
None,
None,
None,
))
}

Expand Down Expand Up @@ -313,6 +320,7 @@ pub fn approx_distinct(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -324,6 +332,7 @@ pub fn median(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -335,6 +344,7 @@ pub fn approx_median(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -346,6 +356,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand All @@ -361,6 +372,7 @@ pub fn approx_percentile_cont_with_weight(
false,
None,
None,
None,
))
}

Expand Down Expand Up @@ -431,6 +443,7 @@ pub fn stddev(expr: Expr) -> Expr {
false,
None,
None,
None,
))
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ impl TreeNode for Expr {
distinct,
filter,
order_by,
null_treatment,
}) => transform_vec(args, &mut f)?
.update_data(|new_args| (new_args, filter, order_by))
.try_transform_node(|(new_args, filter, order_by)| {
Expand All @@ -368,6 +369,7 @@ impl TreeNode for Expr {
distinct,
new_filter,
new_order_by,
null_treatment,
)))
}
AggregateFunctionDefinition::UDF(fun) => {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
distinct,
filter,
order_by,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
Expand All @@ -166,6 +167,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
distinct,
filter,
order_by,
null_treatment,
)))
}
_ => Transformed::no(old_expr),
Expand Down
12 changes: 11 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
distinct,
filter,
order_by,
null_treatment,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let new_expr = coerce_agg_exprs_for_signature(
Expand All @@ -355,7 +356,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
expr::AggregateFunction::new(
fun, new_expr, distinct, filter, order_by,
fun,
new_expr,
distinct,
filter,
order_by,
null_treatment,
),
)))
}
Expand Down Expand Up @@ -946,6 +952,7 @@ mod test {
false,
None,
None,
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation";
Expand All @@ -959,6 +966,7 @@ mod test {
false,
None,
None,
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation";
Expand All @@ -976,6 +984,7 @@ mod test {
false,
None,
None,
None,
));
let err = Projection::try_new(vec![agg_expr], empty)
.err()
Expand All @@ -998,6 +1007,7 @@ mod test {
false,
None,
None,
None,
));

let err = Projection::try_new(vec![agg_expr], empty)
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/push_down_projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ mod tests {
false,
Some(Box::new(col("c").gt(lit(42)))),
None,
None,
));

let plan = LogicalPlanBuilder::from(table_scan)
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
false,
None,
sort_expr.clone(),
None,
))
})
.collect::<Vec<Expr>>();
Expand Down
Loading

0 comments on commit 3aba67e

Please sign in to comment.