Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IGNORE NULLS for LAG window function #9221

Merged
merged 7 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,7 @@ mod tests {
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(None),
None,
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ pub fn bounded_window_exec(
&sort_exprs,
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
false,
)
.unwrap()],
input.clone(),
Expand Down
6 changes: 6 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
use itertools::{multiunzip, Itertools};
use log::{debug, trace};
use sqlparser::ast::NullTreatment;

fn create_function_physical_name(
fun: &str,
Expand Down Expand Up @@ -1581,6 +1582,7 @@ pub fn create_window_expr_with_name(
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
let args = args
.iter()
Expand All @@ -1605,6 +1607,9 @@ pub fn create_window_expr_with_name(
}

let window_frame = Arc::new(window_frame.clone());
let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
windows::create_window_expr(
fun,
name,
Expand All @@ -1613,6 +1618,7 @@ pub fn create_window_expr_with_name(
&order_by,
window_frame,
physical_input_schema,
ignore_nulls,
)
}
other => plan_err!("Invalid window expression '{other:?}'"),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
None,
))])?
.explain(false, false)?
.collect()
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
&orderby_exprs,
Arc::new(window_frame),
schema.as_ref(),
false,
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![window_expr],
Expand Down Expand Up @@ -642,6 +643,7 @@ async fn run_window_test(
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
false,
)
.unwrap()],
exec1,
Expand All @@ -664,6 +666,7 @@ async fn run_window_test(
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
false,
)
.unwrap()],
exec2,
Expand Down
18 changes: 18 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use arrow::datatypes::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, DFSchema, OwnedTableReference};
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
use sqlparser::ast::NullTreatment;
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter, Write};
Expand Down Expand Up @@ -646,6 +647,7 @@ pub struct WindowFunction {
pub order_by: Vec<Expr>,
/// Window frame
pub window_frame: window_frame::WindowFrame,
pub null_treatment: Option<NullTreatment>,
}

impl WindowFunction {
Expand All @@ -656,13 +658,15 @@ impl WindowFunction {
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
fun,
args,
partition_by,
order_by,
window_frame,
null_treatment,
}
}
}
Expand Down Expand Up @@ -1440,8 +1444,14 @@ impl fmt::Display for Expr {
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
fmt_function(f, &fun.to_string(), false, args, true)?;

if let Some(nt) = null_treatment {
write!(f, "{}", nt)?;
}

if !partition_by.is_empty() {
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
}
Expand Down Expand Up @@ -1768,15 +1778,23 @@ fn create_name(e: &Expr) -> Result<String> {
window_frame,
partition_by,
order_by,
null_treatment,
}) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];

if let Some(nt) = null_treatment {
parts.push(format!("{}", nt));
}

if !partition_by.is_empty() {
parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by)));
}

if !order_by.is_empty() {
parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by)));
}

parts.push(format!("{window_frame}"));
Ok(parts.join(" "))
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,14 @@ impl TreeNode for Expr {
partition_by,
order_by,
window_frame,
null_treatment,
}) => Expr::WindowFunction(WindowFunction::new(
fun,
transform_vec(args, &mut transform)?,
transform_vec(partition_by, &mut transform)?,
transform_vec(order_by, &mut transform)?,
window_frame,
null_treatment,
)),
Expr::AggregateFunction(AggregateFunction {
args,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ impl WindowUDF {
partition_by,
order_by,
window_frame,
null_treatment: None,
})
}

Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1255,27 +1255,31 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
Expand All @@ -1298,27 +1302,31 @@ mod tests {
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(Some(false)),
None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(None),
None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(Some(false)),
None,
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(Some(false)),
None,
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
Expand Down Expand Up @@ -1353,6 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(Some(false)),
None,
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
Expand All @@ -1364,6 +1373,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(Some(false)),
None,
)),
];
let expected = vec![
Expand Down
3 changes: 3 additions & 0 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Expr::WindowFunction(expr::WindowFunction {
Expand All @@ -138,6 +139,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
})
}

Expand Down Expand Up @@ -351,6 +353,7 @@ mod tests {
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
None,
))])?
.project(vec![count(wildcard())])?
.build()?;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
}) => {
let window_frame =
coerce_window_frame(window_frame, &self.schema, &order_by)?;
Expand All @@ -414,6 +415,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
partition_by,
order_by,
window_frame,
null_treatment,
));
Ok(expr)
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/push_down_projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ mod tests {
vec![col("test.b")],
vec![],
WindowFrame::new(None),
None,
));

let max2 = Expr::WindowFunction(expr::WindowFunction::new(
Expand All @@ -595,6 +596,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
None,
));
let col1 = col(max1.display_name()?);
let col2 = col(max2.display_name()?);
Expand Down
Loading
Loading