From 2488674bc33e2b414c393756759f6c944fb06539 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 13 Jun 2021 13:22:41 +0800 Subject: [PATCH] implement window functions with partition by --- datafusion/src/execution/context.rs | 74 +++++++++++++++++++ .../physical_plan/expressions/nth_value.rs | 10 ++- datafusion/src/physical_plan/mod.rs | 36 +++++---- datafusion/src/physical_plan/planner.rs | 6 -- datafusion/src/physical_plan/windows.rs | 61 +++++++++++++-- datafusion/tests/sql.rs | 64 ++++++++++++++++ .../simple_window_partition_aggregation.sql | 26 +++++++ ...ple_window_partition_order_aggregation.sql | 26 +++++++ integration-tests/test_psql_parity.py | 2 +- 9 files changed, 275 insertions(+), 30 deletions(-) create mode 100644 integration-tests/sqls/simple_window_partition_aggregation.sql create mode 100644 integration-tests/sqls/simple_window_partition_order_aggregation.sql diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index ef652c28d1ed..b42695b0c4c6 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1355,6 +1355,80 @@ mod tests { Ok(()) } + #[tokio::test] + async fn window_partition_by() -> Result<()> { + let results = execute( + "SELECT \ + c1, \ + c2, \ + SUM(c2) OVER (PARTITION BY c2), \ + COUNT(c2) OVER (PARTITION BY c2), \ + MAX(c2) OVER (PARTITION BY c2), \ + MIN(c2) OVER (PARTITION BY c2), \ + AVG(c2) OVER (PARTITION BY c2) \ + FROM test \ + ORDER BY c1, c2 \ + LIMIT 5", + 4, + ) + .await?; + + let expected = vec![ + "+----+----+---------+-----------+---------+---------+---------+", + "| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", + "+----+----+---------+-----------+---------+---------+---------+", + "| 0 | 1 | 4 | 4 | 1 | 1 | 1 |", + "| 0 | 2 | 8 | 4 | 2 | 2 | 2 |", + "| 0 | 3 | 12 | 4 | 3 | 3 | 3 |", + "| 0 | 4 | 16 | 4 | 4 | 4 | 4 |", + "| 0 | 5 | 20 | 4 | 5 | 5 | 5 |", + "+----+----+---------+-----------+---------+---------+---------+", + ]; + + // window function shall respect ordering + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn window_partition_by_order_by() -> Result<()> { + let results = execute( + "SELECT \ + c1, \ + c2, \ + ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \ + FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \ + LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \ + NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \ + SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \ + COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \ + MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \ + MIN(c2) OVER (PARTITION BY c2 ORDER BY c1), \ + AVG(c2) OVER (PARTITION BY c2 ORDER BY c1) \ + FROM test \ + ORDER BY c1, c2 \ + LIMIT 5", + 4, + ) + .await?; + + let expected = vec![ + "+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+", + "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", + "+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+", + "| 0 | 1 | 1 | 1 | 4 | 2 | 1 | 1 | 1 | 1 | 1 |", + "| 0 | 2 | 1 | 2 | 5 | 3 | 2 | 1 | 2 | 2 | 2 |", + "| 0 | 3 | 1 | 3 | 6 | 4 | 3 | 1 | 3 | 3 | 3 |", + "| 0 | 4 | 1 | 4 | 7 | 5 | 4 | 1 | 4 | 4 | 4 |", + "| 0 | 5 | 1 | 5 | 8 | 6 | 5 | 1 | 5 | 5 | 5 |", + "+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+", + ]; + + // window function shall respect ordering + assert_batches_eq!(expected, &results); + Ok(()) + } + #[tokio::test] async fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 98083fa26eaa..16897d45119f 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -20,7 +20,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; use crate::scalar::ScalarValue; -use arrow::array::{new_empty_array, ArrayRef}; +use arrow::array::{new_empty_array, new_null_array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use std::any::Any; use std::sync::Arc; @@ -135,8 +135,12 @@ impl BuiltInWindowFunctionExpr for NthValue { NthValueKind::Last => (num_rows as usize) - 1, NthValueKind::Nth(n) => (n as usize) - 1, }; - let value = ScalarValue::try_from_array(value, index)?; - Ok(value.to_array_of_size(num_rows)) + Ok(if index >= num_rows { + new_null_array(value.data_type(), num_rows) + } else { + let value = ScalarValue::try_from_array(value, index)?; + value.to_array_of_size(num_rows) + }) } } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 713956f00a9e..50c30a57b5fe 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -485,19 +485,20 @@ pub trait WindowExpr: Send + Sync + Debug { /// evaluate the window function values against the batch fn evaluate(&self, batch: &RecordBatch) -> Result; - /// evaluate the sort partition points - fn evaluate_sort_partition_points( + /// evaluate the partition points given the sort columns; if the sort columns are + /// empty then the result will be a single element vec of the whole column rows. + fn evaluate_partition_points( &self, - batch: &RecordBatch, + num_rows: usize, + partition_columns: &[SortColumn], ) -> Result>> { - let sort_columns = self.sort_columns(batch)?; - if sort_columns.is_empty() { + if partition_columns.is_empty() { Ok(vec![Range { start: 0, - end: batch.num_rows(), + end: num_rows, }]) } else { - lexicographical_partition_ranges(&sort_columns) + lexicographical_partition_ranges(partition_columns) .map_err(DataFusionError::ArrowError) } } @@ -508,8 +509,8 @@ pub trait WindowExpr: Send + Sync + Debug { /// expressions that's from the window function's order by clause, empty if absent fn order_by(&self) -> &[PhysicalSortExpr]; - /// get sort columns that can be used for partitioning, empty if absent - fn sort_columns(&self, batch: &RecordBatch) -> Result> { + /// get partition columns that can be used for partitioning, empty if absent + fn partition_columns(&self, batch: &RecordBatch) -> Result> { self.partition_by() .iter() .map(|expr| { @@ -519,13 +520,20 @@ pub trait WindowExpr: Send + Sync + Debug { } .evaluate_to_sort_column(batch) }) - .chain( - self.order_by() - .iter() - .map(|e| e.evaluate_to_sort_column(batch)), - ) .collect() } + + /// get sort columns that can be used for peer evaluation, empty if absent + fn sort_columns(&self, batch: &RecordBatch) -> Result> { + let mut sort_columns = self.partition_columns(batch)?; + let order_by_columns = self + .order_by() + .iter() + .map(|e| e.evaluate_to_sort_column(batch)) + .collect::>>()?; + sort_columns.extend(order_by_columns); + Ok(sort_columns) + } } /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 1121c28184bd..af0e60f2194c 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -775,12 +775,6 @@ impl DefaultPhysicalPlanner { )), }) .collect::>>()?; - if !partition_by.is_empty() { - return Err(DataFusionError::NotImplemented( - "window expression with non-empty partition by clause is not yet supported" - .to_owned(), - )); - } if window_frame.is_some() { return Err(DataFusionError::NotImplemented( "window expression with window frame definition is not yet supported" diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index e5570971cf16..466cc51b447d 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -175,10 +175,45 @@ impl WindowExpr for BuiltInWindowExpr { // case when partition_by is supported, in which case we'll parallelize the calls. // See https://github.com/apache/arrow-datafusion/issues/299 let values = self.evaluate_args(batch)?; - self.window.evaluate(batch.num_rows(), &values) + let partition_points = self.evaluate_partition_points( + batch.num_rows(), + &self.partition_columns(batch)?, + )?; + let results = partition_points + .iter() + .map(|partition_range| { + let start = partition_range.start; + let len = partition_range.end - start; + let values = values + .iter() + .map(|arr| arr.slice(start, len)) + .collect::>(); + self.window.evaluate(len, &values) + }) + .collect::>>()? + .into_iter() + .collect::>(); + let results = results.iter().map(|i| i.as_ref()).collect::>(); + concat(&results).map_err(DataFusionError::ArrowError) } } +/// Given a partition range, and the full list of sort partition points, given that the sort +/// partition points are sorted using [partition columns..., order columns...], the split +/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted +/// on finer columns), so this will use binary search to find ranges that are within the +/// partition range and return the valid slice. +fn find_ranges_in_range<'a>( + partition_range: &Range, + sort_partition_points: &'a [Range], +) -> &'a [Range] { + let start_idx = sort_partition_points + .partition_point(|sort_range| sort_range.start < partition_range.start); + let end_idx = sort_partition_points + .partition_point(|sort_range| sort_range.end <= partition_range.end); + &sort_partition_points[start_idx..end_idx] +} + /// A window expr that takes the form of an aggregate function #[derive(Debug)] pub struct AggregateWindowExpr { @@ -205,13 +240,27 @@ impl AggregateWindowExpr { /// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same /// results for peers) and concatenate the results. fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result { - let sort_partition_points = self.evaluate_sort_partition_points(batch)?; - let mut window_accumulators = self.create_accumulator()?; + let num_rows = batch.num_rows(); + let partition_points = + self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?; + let sort_partition_points = + self.evaluate_partition_points(num_rows, &self.sort_columns(batch)?)?; let values = self.evaluate_args(batch)?; - let results = sort_partition_points + let results = partition_points .iter() - .map(|peer_range| window_accumulators.scan_peers(&values, peer_range)) - .collect::>>()?; + .map(|partition_range| { + let sort_partition_points = + find_ranges_in_range(partition_range, &sort_partition_points); + let mut window_accumulators = self.create_accumulator()?; + sort_partition_points + .iter() + .map(|range| window_accumulators.scan_peers(&values, range)) + .collect::>>() + }) + .collect::>>>()? + .into_iter() + .flatten() + .collect::>(); let results = results.iter().map(|i| i.as_ref()).collect::>(); concat(&results).map_err(DataFusionError::ArrowError) } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index b6393e91e321..cfdb6f4bc9e4 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -868,6 +868,70 @@ async fn csv_query_window_with_empty_over() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_window_with_partition_by() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "select \ + c9, \ + sum(cast(c4 as Int)) over (partition by c3), \ + avg(cast(c4 as Int)) over (partition by c3), \ + count(cast(c4 as Int)) over (partition by c3), \ + max(cast(c4 as Int)) over (partition by c3), \ + min(cast(c4 as Int)) over (partition by c3), \ + first_value(cast(c4 as Int)) over (partition by c3), \ + last_value(cast(c4 as Int)) over (partition by c3), \ + nth_value(cast(c4 as Int), 2) over (partition by c3) \ + from aggregate_test_100 \ + order by c9 \ + limit 5"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec![ + "28774375", "-16110", "-16110", "1", "-16110", "-16110", "-16110", "-16110", + "NULL", + ], + vec![ + "63044568", "3917", "3917", "1", "3917", "3917", "3917", "3917", "NULL", + ], + vec![ + "141047417", + "-38455", + "-19227.5", + "2", + "-16974", + "-21481", + "-16974", + "-21481", + "-21481", + ], + vec![ + "141680161", + "-1114", + "-1114", + "1", + "-1114", + "-1114", + "-1114", + "-1114", + "NULL", + ], + vec![ + "145294611", + "15673", + "15673", + "1", + "15673", + "15673", + "15673", + "15673", + "NULL", + ], + ]; + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn csv_query_window_with_order_by() -> Result<()> { let mut ctx = ExecutionContext::new(); diff --git a/integration-tests/sqls/simple_window_partition_aggregation.sql b/integration-tests/sqls/simple_window_partition_aggregation.sql new file mode 100644 index 000000000000..f395671db8cc --- /dev/null +++ b/integration-tests/sqls/simple_window_partition_aggregation.sql @@ -0,0 +1,26 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language gOVERning permissions and +-- limitations under the License. + +SELECT + c9, + row_number() OVER (PARTITION BY c2, c9) AS row_number, + count(c3) OVER (PARTITION BY c2) AS count_c3, + avg(c3) OVER (PARTITION BY c2) AS avg_c3_by_c2, + sum(c3) OVER (PARTITION BY c2) AS sum_c3_by_c2, + max(c3) OVER (PARTITION BY c2) AS max_c3_by_c2, + min(c3) OVER (PARTITION BY c2) AS min_c3_by_c2 +FROM test +ORDER BY c9; diff --git a/integration-tests/sqls/simple_window_partition_order_aggregation.sql b/integration-tests/sqls/simple_window_partition_order_aggregation.sql new file mode 100644 index 000000000000..a11a9ec6e4b1 --- /dev/null +++ b/integration-tests/sqls/simple_window_partition_order_aggregation.sql @@ -0,0 +1,26 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language gOVERning permissions and +-- limitations under the License. + +SELECT + c9, + row_number() OVER (PARTITION BY c2 ORDER BY c9) AS row_number, + count(c3) OVER (PARTITION BY c2 ORDER BY c9) AS count_c3, + avg(c3) OVER (PARTITION BY c2 ORDER BY c9) AS avg_c3_by_c2, + sum(c3) OVER (PARTITION BY c2 ORDER BY c9) AS sum_c3_by_c2, + max(c3) OVER (PARTITION BY c2 ORDER BY c9) AS max_c3_by_c2, + min(c3) OVER (PARTITION BY c2 ORDER BY c9) AS min_c3_by_c2 +FROM test +ORDER BY c9; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 4e0878c24b81..c4b5a7596ae9 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 7, msg="tests are missed") + self.assertEqual(len(files), 9, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv(