Skip to content

Commit

Permalink
implement window functions with partition by
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 19, 2021
1 parent 5900b4c commit 2488674
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 30 deletions.
74 changes: 74 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,80 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn window_partition_by() -> Result<()> {
let results = execute(
"SELECT \
c1, \
c2, \
SUM(c2) OVER (PARTITION BY c2), \
COUNT(c2) OVER (PARTITION BY c2), \
MAX(c2) OVER (PARTITION BY c2), \
MIN(c2) OVER (PARTITION BY c2), \
AVG(c2) OVER (PARTITION BY c2) \
FROM test \
ORDER BY c1, c2 \
LIMIT 5",
4,
)
.await?;

let expected = vec![
"+----+----+---------+-----------+---------+---------+---------+",
"| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 4 | 4 | 1 | 1 | 1 |",
"| 0 | 2 | 8 | 4 | 2 | 2 | 2 |",
"| 0 | 3 | 12 | 4 | 3 | 3 | 3 |",
"| 0 | 4 | 16 | 4 | 4 | 4 | 4 |",
"| 0 | 5 | 20 | 4 | 5 | 5 | 5 |",
"+----+----+---------+-----------+---------+---------+---------+",
];

// window function shall respect ordering
assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn window_partition_by_order_by() -> Result<()> {
let results = execute(
"SELECT \
c1, \
c2, \
ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \
FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \
SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \
COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \
MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \
MIN(c2) OVER (PARTITION BY c2 ORDER BY c1), \
AVG(c2) OVER (PARTITION BY c2 ORDER BY c1) \
FROM test \
ORDER BY c1, c2 \
LIMIT 5",
4,
)
.await?;

let expected = vec![
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 1 | 1 | 4 | 2 | 1 | 1 | 1 | 1 | 1 |",
"| 0 | 2 | 1 | 2 | 5 | 3 | 2 | 1 | 2 | 2 | 2 |",
"| 0 | 3 | 1 | 3 | 6 | 4 | 3 | 1 | 3 | 3 | 3 |",
"| 0 | 4 | 1 | 4 | 7 | 5 | 4 | 1 | 4 | 4 | 4 |",
"| 0 | 5 | 1 | 5 | 8 | 6 | 5 | 1 | 5 | 5 | 5 |",
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
];

// window function shall respect ordering
assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
Expand Down
10 changes: 7 additions & 3 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::array::{new_empty_array, ArrayRef};
use arrow::array::{new_empty_array, new_null_array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::sync::Arc;
Expand Down Expand Up @@ -135,8 +135,12 @@ impl BuiltInWindowFunctionExpr for NthValue {
NthValueKind::Last => (num_rows as usize) - 1,
NthValueKind::Nth(n) => (n as usize) - 1,
};
let value = ScalarValue::try_from_array(value, index)?;
Ok(value.to_array_of_size(num_rows))
Ok(if index >= num_rows {
new_null_array(value.data_type(), num_rows)
} else {
let value = ScalarValue::try_from_array(value, index)?;
value.to_array_of_size(num_rows)
})
}
}

Expand Down
36 changes: 22 additions & 14 deletions datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,19 +485,20 @@ pub trait WindowExpr: Send + Sync + Debug {
/// evaluate the window function values against the batch
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;

/// evaluate the sort partition points
fn evaluate_sort_partition_points(
/// evaluate the partition points given the sort columns; if the sort columns are
/// empty then the result will be a single element vec of the whole column rows.
fn evaluate_partition_points(
&self,
batch: &RecordBatch,
num_rows: usize,
partition_columns: &[SortColumn],
) -> Result<Vec<Range<usize>>> {
let sort_columns = self.sort_columns(batch)?;
if sort_columns.is_empty() {
if partition_columns.is_empty() {
Ok(vec![Range {
start: 0,
end: batch.num_rows(),
end: num_rows,
}])
} else {
lexicographical_partition_ranges(&sort_columns)
lexicographical_partition_ranges(partition_columns)
.map_err(DataFusionError::ArrowError)
}
}
Expand All @@ -508,8 +509,8 @@ pub trait WindowExpr: Send + Sync + Debug {
/// expressions that's from the window function's order by clause, empty if absent
fn order_by(&self) -> &[PhysicalSortExpr];

/// get sort columns that can be used for partitioning, empty if absent
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
/// get partition columns that can be used for partitioning, empty if absent
fn partition_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.partition_by()
.iter()
.map(|expr| {
Expand All @@ -519,13 +520,20 @@ pub trait WindowExpr: Send + Sync + Debug {
}
.evaluate_to_sort_column(batch)
})
.chain(
self.order_by()
.iter()
.map(|e| e.evaluate_to_sort_column(batch)),
)
.collect()
}

/// get sort columns that can be used for peer evaluation, empty if absent
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
let mut sort_columns = self.partition_columns(batch)?;
let order_by_columns = self
.order_by()
.iter()
.map(|e| e.evaluate_to_sort_column(batch))
.collect::<Result<Vec<SortColumn>>>()?;
sort_columns.extend(order_by_columns);
Ok(sort_columns)
}
}

/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
Expand Down
6 changes: 0 additions & 6 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,6 @@ impl DefaultPhysicalPlanner {
)),
})
.collect::<Result<Vec<_>>>()?;
if !partition_by.is_empty() {
return Err(DataFusionError::NotImplemented(
"window expression with non-empty partition by clause is not yet supported"
.to_owned(),
));
}
if window_frame.is_some() {
return Err(DataFusionError::NotImplemented(
"window expression with window frame definition is not yet supported"
Expand Down
61 changes: 55 additions & 6 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,45 @@ impl WindowExpr for BuiltInWindowExpr {
// case when partition_by is supported, in which case we'll parallelize the calls.
// See https://github.com/apache/arrow-datafusion/issues/299
let values = self.evaluate_args(batch)?;
self.window.evaluate(batch.num_rows(), &values)
let partition_points = self.evaluate_partition_points(
batch.num_rows(),
&self.partition_columns(batch)?,
)?;
let results = partition_points
.iter()
.map(|partition_range| {
let start = partition_range.start;
let len = partition_range.end - start;
let values = values
.iter()
.map(|arr| arr.slice(start, len))
.collect::<Vec<_>>();
self.window.evaluate(len, &values)
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.collect::<Vec<ArrayRef>>();
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
concat(&results).map_err(DataFusionError::ArrowError)
}
}

/// Given a partition range, and the full list of sort partition points, given that the sort
/// partition points are sorted using [partition columns..., order columns...], the split
/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted
/// on finer columns), so this will use binary search to find ranges that are within the
/// partition range and return the valid slice.
fn find_ranges_in_range<'a>(
partition_range: &Range<usize>,
sort_partition_points: &'a [Range<usize>],
) -> &'a [Range<usize>] {
let start_idx = sort_partition_points
.partition_point(|sort_range| sort_range.start < partition_range.start);
let end_idx = sort_partition_points
.partition_point(|sort_range| sort_range.end <= partition_range.end);
&sort_partition_points[start_idx..end_idx]
}

/// A window expr that takes the form of an aggregate function
#[derive(Debug)]
pub struct AggregateWindowExpr {
Expand All @@ -205,13 +240,27 @@ impl AggregateWindowExpr {
/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
/// results for peers) and concatenate the results.
fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let sort_partition_points = self.evaluate_sort_partition_points(batch)?;
let mut window_accumulators = self.create_accumulator()?;
let num_rows = batch.num_rows();
let partition_points =
self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?;
let sort_partition_points =
self.evaluate_partition_points(num_rows, &self.sort_columns(batch)?)?;
let values = self.evaluate_args(batch)?;
let results = sort_partition_points
let results = partition_points
.iter()
.map(|peer_range| window_accumulators.scan_peers(&values, peer_range))
.collect::<Result<Vec<_>>>()?;
.map(|partition_range| {
let sort_partition_points =
find_ranges_in_range(partition_range, &sort_partition_points);
let mut window_accumulators = self.create_accumulator()?;
sort_partition_points
.iter()
.map(|range| window_accumulators.scan_peers(&values, range))
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<Vec<ArrayRef>>>>()?
.into_iter()
.flatten()
.collect::<Vec<ArrayRef>>();
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
concat(&results).map_err(DataFusionError::ArrowError)
}
Expand Down
64 changes: 64 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,70 @@ async fn csv_query_window_with_empty_over() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_window_with_partition_by() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
let sql = "select \
c9, \
sum(cast(c4 as Int)) over (partition by c3), \
avg(cast(c4 as Int)) over (partition by c3), \
count(cast(c4 as Int)) over (partition by c3), \
max(cast(c4 as Int)) over (partition by c3), \
min(cast(c4 as Int)) over (partition by c3), \
first_value(cast(c4 as Int)) over (partition by c3), \
last_value(cast(c4 as Int)) over (partition by c3), \
nth_value(cast(c4 as Int), 2) over (partition by c3) \
from aggregate_test_100 \
order by c9 \
limit 5";
let actual = execute(&mut ctx, sql).await;
let expected = vec![
vec![
"28774375", "-16110", "-16110", "1", "-16110", "-16110", "-16110", "-16110",
"NULL",
],
vec![
"63044568", "3917", "3917", "1", "3917", "3917", "3917", "3917", "NULL",
],
vec![
"141047417",
"-38455",
"-19227.5",
"2",
"-16974",
"-21481",
"-16974",
"-21481",
"-21481",
],
vec![
"141680161",
"-1114",
"-1114",
"1",
"-1114",
"-1114",
"-1114",
"-1114",
"NULL",
],
vec![
"145294611",
"15673",
"15673",
"1",
"15673",
"15673",
"15673",
"15673",
"NULL",
],
];
assert_eq!(expected, actual);
Ok(())
}

#[tokio::test]
async fn csv_query_window_with_order_by() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down
26 changes: 26 additions & 0 deletions integration-tests/sqls/simple_window_partition_aggregation.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- limitations under the License.

SELECT
c9,
row_number() OVER (PARTITION BY c2, c9) AS row_number,
count(c3) OVER (PARTITION BY c2) AS count_c3,
avg(c3) OVER (PARTITION BY c2) AS avg_c3_by_c2,
sum(c3) OVER (PARTITION BY c2) AS sum_c3_by_c2,
max(c3) OVER (PARTITION BY c2) AS max_c3_by_c2,
min(c3) OVER (PARTITION BY c2) AS min_c3_by_c2
FROM test
ORDER BY c9;
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- limitations under the License.

SELECT
c9,
row_number() OVER (PARTITION BY c2 ORDER BY c9) AS row_number,
count(c3) OVER (PARTITION BY c2 ORDER BY c9) AS count_c3,
avg(c3) OVER (PARTITION BY c2 ORDER BY c9) AS avg_c3_by_c2,
sum(c3) OVER (PARTITION BY c2 ORDER BY c9) AS sum_c3_by_c2,
max(c3) OVER (PARTITION BY c2 ORDER BY c9) AS max_c3_by_c2,
min(c3) OVER (PARTITION BY c2 ORDER BY c9) AS min_c3_by_c2
FROM test
ORDER BY c9;
2 changes: 1 addition & 1 deletion integration-tests/test_psql_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
self.assertEqual(len(files), 7, msg="tests are missed")
self.assertEqual(len(files), 9, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(
Expand Down

0 comments on commit 2488674

Please sign in to comment.