-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement window functions with order_by
clause
#520
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1273,7 +1273,17 @@ mod tests { | |
#[tokio::test] | ||
async fn window() -> Result<()> { | ||
let results = execute( | ||
"SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5", | ||
"SELECT \ | ||
c1, \ | ||
c2, \ | ||
SUM(c2) OVER (), \ | ||
COUNT(c2) OVER (), \ | ||
MAX(c2) OVER (), \ | ||
MIN(c2) OVER (), \ | ||
AVG(c2) OVER () \ | ||
FROM test \ | ||
ORDER BY c1, c2 \ | ||
LIMIT 5", | ||
4, | ||
) | ||
.await?; | ||
|
@@ -1299,6 +1309,49 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn window_order_by() -> Result<()> { | ||
let results = execute( | ||
"SELECT \ | ||
c1, \ | ||
c2, \ | ||
ROW_NUMBER() OVER (ORDER BY c1, c2), \ | ||
FIRST_VALUE(c2) OVER (ORDER BY c1, c2), \ | ||
LAST_VALUE(c2) OVER (ORDER BY c1, c2), \ | ||
NTH_VALUE(c2, 2) OVER (ORDER BY c1, c2), \ | ||
SUM(c2) OVER (ORDER BY c1, c2), \ | ||
COUNT(c2) OVER (ORDER BY c1, c2), \ | ||
MAX(c2) OVER (ORDER BY c1, c2), \ | ||
MIN(c2) OVER (ORDER BY c1, c2), \ | ||
AVG(c2) OVER (ORDER BY c1, c2) \ | ||
FROM test \ | ||
ORDER BY c1, c2 \ | ||
LIMIT 5", | ||
4, | ||
) | ||
.await?; | ||
// result in one batch, although e.g. having 2 batches do not change | ||
// result semantics, having a len=1 assertion upfront keeps surprises | ||
// at bay | ||
assert_eq!(results.len(), 1); | ||
|
||
let expected = vec![ | ||
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", | ||
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", | ||
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", | ||
"| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I double checked these in postgres. 👍 |
||
"| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |", | ||
"| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |", | ||
"| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |", | ||
"| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |", | ||
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", | ||
]; | ||
|
||
// window function shall respect ordering | ||
assert_batches_eq!(expected, &results); | ||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn aggregate() -> Result<()> { | ||
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,13 +18,11 @@ | |
//! Defines physical expressions that can evaluated at runtime during query execution | ||
|
||
use crate::error::{DataFusionError, Result}; | ||
use crate::physical_plan::{ | ||
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator, | ||
}; | ||
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; | ||
use crate::scalar::ScalarValue; | ||
use arrow::array::{new_empty_array, ArrayRef}; | ||
use arrow::datatypes::{DataType, Field}; | ||
use std::any::Any; | ||
use std::convert::TryFrom; | ||
use std::sync::Arc; | ||
|
||
/// nth_value kind | ||
|
@@ -113,54 +111,32 @@ impl BuiltInWindowFunctionExpr for NthValue { | |
&self.name | ||
} | ||
|
||
fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> { | ||
Ok(Box::new(NthValueAccumulator::try_new( | ||
self.kind, | ||
self.data_type.clone(), | ||
)?)) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct NthValueAccumulator { | ||
kind: NthValueKind, | ||
offset: u32, | ||
value: ScalarValue, | ||
} | ||
|
||
impl NthValueAccumulator { | ||
/// new count accumulator | ||
pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> { | ||
Ok(Self { | ||
kind, | ||
offset: 0, | ||
// null value of that data_type by default | ||
value: ScalarValue::try_from(&data_type)?, | ||
}) | ||
} | ||
} | ||
|
||
impl WindowAccumulator for NthValueAccumulator { | ||
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> { | ||
self.offset += 1; | ||
match self.kind { | ||
NthValueKind::Last => { | ||
self.value = values[0].clone(); | ||
} | ||
NthValueKind::First if self.offset == 1 => { | ||
self.value = values[0].clone(); | ||
} | ||
NthValueKind::Nth(n) if self.offset == n => { | ||
self.value = values[0].clone(); | ||
} | ||
_ => {} | ||
fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result<ArrayRef> { | ||
if values.is_empty() { | ||
return Err(DataFusionError::Execution(format!( | ||
"No arguments supplied to {}", | ||
self.name() | ||
))); | ||
} | ||
|
||
Ok(None) | ||
} | ||
|
||
fn evaluate(&self) -> Result<Option<ScalarValue>> { | ||
Ok(Some(self.value.clone())) | ||
let value = &values[0]; | ||
if value.len() != num_rows { | ||
return Err(DataFusionError::Execution(format!( | ||
"Invalid data supplied to {}, expect {} rows, got {} rows", | ||
self.name(), | ||
num_rows, | ||
value.len() | ||
))); | ||
} | ||
if num_rows == 0 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this function could ever be passed a 0 row input? This check isn't a problem I am just wondering if my mental model is correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will be changed in later pull request There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but you are right this would not be passed with 0 length input. this check is just being pedantic. |
||
return Ok(new_empty_array(value.data_type())); | ||
} | ||
let index: usize = match self.kind { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
NthValueKind::First => 0, | ||
NthValueKind::Last => (num_rows as usize) - 1, | ||
NthValueKind::Nth(n) => (n as usize) - 1, | ||
}; | ||
let value = ScalarValue::try_from_array(value, index)?; | ||
Ok(value.to_array_of_size(num_rows)) | ||
} | ||
} | ||
|
||
|
@@ -172,68 +148,47 @@ mod tests { | |
use arrow::record_batch::RecordBatch; | ||
use arrow::{array::*, datatypes::*}; | ||
|
||
fn test_i32_result(expr: Arc<NthValue>, expected: i32) -> Result<()> { | ||
fn test_i32_result(expr: NthValue, expected: Vec<i32>) -> Result<()> { | ||
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); | ||
let values = vec![arr]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test change shows the nice refactoring |
||
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); | ||
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; | ||
|
||
let mut acc = expr.create_accumulator()?; | ||
let expr = expr.expressions(); | ||
let values = expr | ||
.iter() | ||
.map(|e| e.evaluate(&batch)) | ||
.map(|r| r.map(|v| v.into_array(batch.num_rows()))) | ||
.collect::<Result<Vec<_>>>()?; | ||
let result = acc.scan_batch(batch.num_rows(), &values)?; | ||
assert_eq!(false, result.is_some()); | ||
let result = acc.evaluate()?; | ||
assert_eq!(Some(ScalarValue::Int32(Some(expected))), result); | ||
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; | ||
let result = expr.evaluate(batch.num_rows(), &values)?; | ||
let result = result.as_any().downcast_ref::<Int32Array>().unwrap(); | ||
let result = result.values(); | ||
assert_eq!(expected, result); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn first_value() -> Result<()> { | ||
let first_value = Arc::new(NthValue::first_value( | ||
"first_value".to_owned(), | ||
col("arr"), | ||
DataType::Int32, | ||
)); | ||
test_i32_result(first_value, 1)?; | ||
let first_value = | ||
NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32); | ||
test_i32_result(first_value, vec![1; 8])?; | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn last_value() -> Result<()> { | ||
let last_value = Arc::new(NthValue::last_value( | ||
"last_value".to_owned(), | ||
col("arr"), | ||
DataType::Int32, | ||
)); | ||
test_i32_result(last_value, 8)?; | ||
let last_value = | ||
NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32); | ||
test_i32_result(last_value, vec![8; 8])?; | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn nth_value_1() -> Result<()> { | ||
let nth_value = Arc::new(NthValue::nth_value( | ||
"nth_value".to_owned(), | ||
col("arr"), | ||
DataType::Int32, | ||
1, | ||
)?); | ||
test_i32_result(nth_value, 1)?; | ||
let nth_value = | ||
NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?; | ||
test_i32_result(nth_value, vec![1; 8])?; | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn nth_value_2() -> Result<()> { | ||
let nth_value = Arc::new(NthValue::nth_value( | ||
"nth_value".to_owned(), | ||
col("arr"), | ||
DataType::Int32, | ||
2, | ||
)?); | ||
test_i32_result(nth_value, -2)?; | ||
let nth_value = | ||
NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?; | ||
test_i32_result(nth_value, vec![-2; 8])?; | ||
Ok(()) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -500,7 +500,7 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>( | |
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap(); | ||
|
||
// look up the index in the values dictionary | ||
let keys_col = dict_col.keys_array(); | ||
let keys_col = dict_col.keys(); | ||
let values_index = keys_col.value(row).to_usize().ok_or_else(|| { | ||
DataFusionError::Internal(format!( | ||
"Can not convert index to usize in dictionary of type creating group by value {:?}", | ||
|
@@ -1083,7 +1083,7 @@ fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>( | |
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap(); | ||
|
||
// look up the index in the values dictionary | ||
let keys_col = dict_col.keys_array(); | ||
let keys_col = dict_col.keys(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change in API was not published with arrow 4.3 (it will be in arrow 5.0) but it is cool to leave the changes in this PR anyways 👍 |
||
let values_index = keys_col.value(row).to_usize().ok_or_else(|| { | ||
DataFusionError::Internal(format!( | ||
"Can not convert index to usize in dictionary of type creating group by value {:?}", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍