Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement window functions with order_by clause #520

Merged
merged 1 commit into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,17 @@ mod tests {
#[tokio::test]
async fn window() -> Result<()> {
let results = execute(
"SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5",
"SELECT \
c1, \
c2, \
SUM(c2) OVER (), \
COUNT(c2) OVER (), \
MAX(c2) OVER (), \
MIN(c2) OVER (), \
AVG(c2) OVER () \
FROM test \
ORDER BY c1, c2 \
LIMIT 5",
4,
)
.await?;
Expand All @@ -1299,6 +1309,49 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn window_order_by() -> Result<()> {
let results = execute(
"SELECT \
c1, \
c2, \
ROW_NUMBER() OVER (ORDER BY c1, c2), \
FIRST_VALUE(c2) OVER (ORDER BY c1, c2), \
LAST_VALUE(c2) OVER (ORDER BY c1, c2), \
NTH_VALUE(c2, 2) OVER (ORDER BY c1, c2), \
SUM(c2) OVER (ORDER BY c1, c2), \
COUNT(c2) OVER (ORDER BY c1, c2), \
MAX(c2) OVER (ORDER BY c1, c2), \
MIN(c2) OVER (ORDER BY c1, c2), \
AVG(c2) OVER (ORDER BY c1, c2) \
FROM test \
ORDER BY c1, c2 \
LIMIT 5",
4,
)
.await?;
// result in one batch, although e.g. having 2 batches do not change
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// result semantics, having a len=1 assertion upfront keeps surprises
// at bay
assert_eq!(results.len(), 1);

let expected = vec![
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked these in postgres. 👍

"| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |",
"| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |",
"| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |",
"| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |",
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
];

// window function shall respect ordering
assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
Expand Down
137 changes: 46 additions & 91 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
//! Defines physical expressions that can evaluated at runtime during query execution

use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
};
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::array::{new_empty_array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

/// nth_value kind
Expand Down Expand Up @@ -113,54 +111,32 @@ impl BuiltInWindowFunctionExpr for NthValue {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
self.kind,
self.data_type.clone(),
)?))
}
}

#[derive(Debug)]
struct NthValueAccumulator {
kind: NthValueKind,
offset: u32,
value: ScalarValue,
}

impl NthValueAccumulator {
/// new count accumulator
pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> {
Ok(Self {
kind,
offset: 0,
// null value of that data_type by default
value: ScalarValue::try_from(&data_type)?,
})
}
}

impl WindowAccumulator for NthValueAccumulator {
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
self.offset += 1;
match self.kind {
NthValueKind::Last => {
self.value = values[0].clone();
}
NthValueKind::First if self.offset == 1 => {
self.value = values[0].clone();
}
NthValueKind::Nth(n) if self.offset == n => {
self.value = values[0].clone();
}
_ => {}
fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result<ArrayRef> {
if values.is_empty() {
return Err(DataFusionError::Execution(format!(
"No arguments supplied to {}",
self.name()
)));
}

Ok(None)
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(Some(self.value.clone()))
let value = &values[0];
if value.len() != num_rows {
return Err(DataFusionError::Execution(format!(
"Invalid data supplied to {}, expect {} rows, got {} rows",
self.name(),
num_rows,
value.len()
)));
}
if num_rows == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function could ever be passed a 0 row input? This check isn't a problem I am just wondering if my mental model is correct.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be changed in later pull request

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you are right this would not be passed with 0 length input. this check is just being pedantic.

return Ok(new_empty_array(value.data_type()));
}
let index: usize = match self.kind {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

NthValueKind::First => 0,
NthValueKind::Last => (num_rows as usize) - 1,
NthValueKind::Nth(n) => (n as usize) - 1,
};
let value = ScalarValue::try_from_array(value, index)?;
Ok(value.to_array_of_size(num_rows))
}
}

Expand All @@ -172,68 +148,47 @@ mod tests {
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

fn test_i32_result(expr: Arc<NthValue>, expected: i32) -> Result<()> {
fn test_i32_result(expr: NthValue, expected: Vec<i32>) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test change shows the nice refactoring

let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let mut acc = expr.create_accumulator()?;
let expr = expr.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(false, result.is_some());
let result = acc.evaluate()?;
assert_eq!(Some(ScalarValue::Int32(Some(expected))), result);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
let result = expr.evaluate(batch.num_rows(), &values)?;
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
let result = result.values();
assert_eq!(expected, result);
Ok(())
}

#[test]
fn first_value() -> Result<()> {
let first_value = Arc::new(NthValue::first_value(
"first_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(first_value, 1)?;
let first_value =
NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32);
test_i32_result(first_value, vec![1; 8])?;
Ok(())
}

#[test]
fn last_value() -> Result<()> {
let last_value = Arc::new(NthValue::last_value(
"last_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(last_value, 8)?;
let last_value =
NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32);
test_i32_result(last_value, vec![8; 8])?;
Ok(())
}

#[test]
fn nth_value_1() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
1,
)?);
test_i32_result(nth_value, 1)?;
let nth_value =
NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?;
test_i32_result(nth_value, vec![1; 8])?;
Ok(())
}

#[test]
fn nth_value_2() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
2,
)?);
test_i32_result(nth_value, -2)?;
let nth_value =
NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?;
test_i32_result(nth_value, vec![-2; 8])?;
Ok(())
}
}
89 changes: 9 additions & 80 deletions datafusion/src/physical_plan/expressions/row_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
//! Defines physical expression for `row_number` that can evaluated at runtime during query execution

use crate::error::Result;
use crate::physical_plan::{
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
};
use crate::scalar::ScalarValue;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
Expand Down Expand Up @@ -60,46 +57,10 @@ impl BuiltInWindowFunctionExpr for RowNumber {
self.name.as_str()
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(RowNumberAccumulator::new()))
}
}

#[derive(Debug)]
struct RowNumberAccumulator {
row_number: u64,
}

impl RowNumberAccumulator {
/// new row_number accumulator
pub fn new() -> Self {
// row number is 1 based
Self { row_number: 1 }
}
}

impl WindowAccumulator for RowNumberAccumulator {
fn scan(&mut self, _values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
let result = Some(ScalarValue::UInt64(Some(self.row_number)));
self.row_number += 1;
Ok(result)
}

fn scan_batch(
&mut self,
num_rows: usize,
_values: &[ArrayRef],
) -> Result<Option<ArrayRef>> {
let new_row_number = self.row_number + (num_rows as u64);
// TODO: probably would be nice to have a (optimized) kernel for this at some point to
// generate an array like this.
let result = UInt64Array::from_iter_values(self.row_number..new_row_number);
self.row_number = new_row_number;
Ok(Some(Arc::new(result)))
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(None)
fn evaluate(&self, num_rows: usize, _values: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(UInt64Array::from_iter_values(
(1..num_rows + 1).map(|i| i as u64),
)))
}
}

Expand All @@ -117,27 +78,11 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let row_number = Arc::new(RowNumber::new("row_number".to_owned()));

let mut acc = row_number.create_accumulator()?;
let expr = row_number.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(true, result.is_some());

let result = result.unwrap();
let row_number = RowNumber::new("row_number".to_owned());
let result = row_number.evaluate(batch.num_rows(), &[])?;
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);

let result = acc.evaluate()?;
assert_eq!(false, result.is_some());
Ok(())
}

Expand All @@ -148,27 +93,11 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let row_number = Arc::new(RowNumber::new("row_number".to_owned()));

let mut acc = row_number.create_accumulator()?;
let expr = row_number.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(true, result.is_some());

let result = result.unwrap();
let row_number = RowNumber::new("row_number".to_owned());
let result = row_number.evaluate(batch.num_rows(), &[])?;
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);

let result = acc.evaluate()?;
assert_eq!(false, result.is_some());
Ok(())
}
}
4 changes: 2 additions & 2 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_col.keys_array();
let keys_col = dict_col.keys();
let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
Expand Down Expand Up @@ -1083,7 +1083,7 @@ fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>(
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_col.keys_array();
let keys_col = dict_col.keys();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change in API was not published with arrow 4.3 (it will be in arrow 5.0) but it is cool to leave the changes in this PR anyways 👍

let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
Expand Down
Loading