Skip to content

Commit

Permalink
Avoid generate duplicate sort Keys from Window Expressions, fix bug w…
Browse files Browse the repository at this point in the history
…hen decide Window Expressions ordering (#4643)

* Avoid generate duplicate sort Keys from Window Expressions, fix bug when decide Window Expressions ordering

* fix test comment

* fix UT

* fix UT

* Fix create Sort Columns from Partition Columns in WindowExpr, add more UTs for Null String sort testing

* fix clippy check

* tiny change

* merge with upstream, fix issue
  • Loading branch information
mingmwang authored Dec 19, 2022
1 parent dba34fc commit 891a800
Show file tree
Hide file tree
Showing 7 changed files with 422 additions and 28 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,12 @@ impl DefaultPhysicalPlanner {
}
_ => unreachable!(),
};
let sort_keys = get_sort_keys(&window_expr[0]);
let sort_keys = get_sort_keys(&window_expr[0])?;
if window_expr.len() > 1 {
debug_assert!(
window_expr[1..]
.iter()
.all(|expr| get_sort_keys(expr) == sort_keys),
.all(|expr| get_sort_keys(expr).unwrap() == sort_keys),
"all window expressions shall have the same sort keys, as guaranteed by logical planning"
);
}
Expand Down
82 changes: 82 additions & 0 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,88 @@ async fn query_on_string_dictionary() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn sort_on_window_null_string() -> Result<()> {
let d1: DictionaryArray<Int32Type> =
vec![Some("one"), None, Some("three")].into_iter().collect();
let d2: StringArray = vec![Some("ONE"), None, Some("THREE")].into_iter().collect();
let d3: LargeStringArray =
vec![Some("One"), None, Some("Three")].into_iter().collect();

let batch = RecordBatch::try_from_iter(vec![
("d1", Arc::new(d1) as ArrayRef),
("d2", Arc::new(d2) as ArrayRef),
("d3", Arc::new(d3) as ArrayRef),
])
.unwrap();

let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2));
ctx.register_batch("test", batch)?;

let sql =
"SELECT d1, row_number() OVER (partition by d1) as rn1 FROM test order by d1 asc";

let actual = execute_to_batches(&ctx, sql).await;
// NULLS LAST
let expected = vec![
"+-------+-----+",
"| d1 | rn1 |",
"+-------+-----+",
"| one | 1 |",
"| three | 1 |",
"| | 1 |",
"+-------+-----+",
];
assert_batches_eq!(expected, &actual);

let sql = "SELECT d2, row_number() OVER (partition by d2) as rn1 FROM test";
let actual = execute_to_batches(&ctx, sql).await;
// NULLS LAST
let expected = vec![
"+-------+-----+",
"| d2 | rn1 |",
"+-------+-----+",
"| ONE | 1 |",
"| THREE | 1 |",
"| | 1 |",
"+-------+-----+",
];
assert_batches_eq!(expected, &actual);

let sql =
"SELECT d2, row_number() OVER (partition by d2 order by d2 desc) as rn1 FROM test";

let actual = execute_to_batches(&ctx, sql).await;
// NULLS FIRST
let expected = vec![
"+-------+-----+",
"| d2 | rn1 |",
"+-------+-----+",
"| | 1 |",
"| THREE | 1 |",
"| ONE | 1 |",
"+-------+-----+",
];
assert_batches_eq!(expected, &actual);

// FIXME sort on LargeUtf8 String has bug.
// let sql =
// "SELECT d3, row_number() OVER (partition by d3) as rn1 FROM test";
// let actual = execute_to_batches(&ctx, sql).await;
// let expected = vec![
// "+-------+-----+",
// "| d3 | rn1 |",
// "+-------+-----+",
// "| | 1 |",
// "| One | 1 |",
// "| Three | 1 |",
// "+-------+-----+",
// ];
// assert_batches_eq!(expected, &actual);

Ok(())
}

#[tokio::test]
async fn filter_with_time32second() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
115 changes: 114 additions & 1 deletion datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,120 @@ async fn test_window_agg_sort() -> Result<()> {
assert_eq!(
expected, actual_trim_last,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
expected, actual_trim_last
);
Ok(())
}

#[tokio::test]
async fn over_order_by_sort_keys_sorting_prefix_compacting() -> Result<()> {
let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2));
register_aggregate_csv(&ctx).await?;

let sql = "SELECT c2, MAX(c9) OVER (ORDER BY c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100";

let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(sql).expect(&msg);
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
// Only 1 SortExec was added
let expected = {
vec![
"ProjectionExec: expr=[c2@3 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)]",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]",
" WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" WindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST]"
]
};

let actual: Vec<&str> = formatted.trim().lines().collect();
let actual_len = actual.len();
let actual_trim_last = &actual[..actual_len - 1];
assert_eq!(
expected, actual_trim_last,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual_trim_last
);
Ok(())
}

/// FIXME: for now we are not detecting prefix of sorting keys in order to re-arrange with global and save one SortExec
#[tokio::test]
async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()> {
let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2));
register_aggregate_csv(&ctx).await?;

let sql = "SELECT c2, MAX(c9) OVER (ORDER BY c9, c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100 ORDER BY c2";
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(sql).expect(&msg);
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
// 3 SortExec are added
let expected = {
vec![
"SortExec: [c2@0 ASC NULLS LAST]",
" CoalescePartitionsExec",
" ProjectionExec: expr=[c2@3 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]",
" WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c9@2 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
" WindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST]",
]
};

let actual: Vec<&str> = formatted.trim().lines().collect();
let actual_len = actual.len();
let actual_trim_last = &actual[..actual_len - 1];
assert_eq!(
expected, actual_trim_last,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual_trim_last
);
Ok(())
}

#[tokio::test]
async fn test_window_partition_by_order_by() -> Result<()> {
let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2));
register_aggregate_csv(&ctx).await?;

let sql = "SELECT \
SUM(c4) OVER(PARTITION BY c1, c2 ORDER BY c2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
COUNT(*) OVER(PARTITION BY c1 ORDER BY c2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \
FROM aggregate_test_100";

let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(sql).expect(&msg);
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
// Only 1 SortExec was added
let expected = {
vec![
"ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]",
" SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
]
};

let actual: Vec<&str> = formatted.trim().lines().collect();
let actual_len = actual.len();
let actual_trim_last = &actual[..actual_len - 1];
assert_eq!(
expected, actual_trim_last,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual_trim_last
);
Ok(())
}
32 changes: 23 additions & 9 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::expr_rewriter::{
normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::type_coercion::binary::comparison_coercion;
use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan};
use crate::{and, binary_expr, Operator};
use crate::{
logical_plan::{
Expand All @@ -43,6 +43,7 @@ use datafusion_common::{
ToDFSchema,
};
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::sync::Arc;
Expand Down Expand Up @@ -250,16 +251,29 @@ impl LogicalPlanBuilder {
) -> Result<LogicalPlan> {
let mut plan = input;
let mut groups = group_window_expr_by_sort_keys(&window_exprs)?;
// sort by sort_key len descending, so that more deeply sorted plans gets nested further
// down as children; to further mimic the behavior of PostgreSQL, we want stable sort
// and a reverse so that tieing sort keys are reversed in order; note that by this rule
// if there's an empty over, it'll be at the top level
groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len()));
groups.reverse();
// To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first
// we compare the sort key themselves and if one window's sort keys are a prefix of another
// put the window with more sort keys first. so more deeply sorted plans gets nested further down as children.
// The sort_by() implementation here is a stable sort.
// Note that by this rule if there's an empty over, it'll be at the top level
groups.sort_by(|(key_a, _), (key_b, _)| {
for (first, second) in key_a.iter().zip(key_b.iter()) {
let key_ordering = compare_sort_expr(first, second, plan.schema());
match key_ordering {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
key_b.len().cmp(&key_a.len())
});
for (_, exprs) in groups {
let window_exprs = exprs.into_iter().cloned().collect::<Vec<_>>();
// the partition and sort itself is done at physical level, see physical_planner's
// fn create_initial_plan
// the partition and sort itself is done at physical level, see the BasicEnforcement rule
plan = LogicalPlanBuilder::from(plan)
.window(window_exprs)?
.build()?;
Expand Down
Loading

0 comments on commit 891a800

Please sign in to comment.