Skip to content

Commit

Permalink
add window function implementation with order_by clause
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 15, 2021
1 parent e3e7e29 commit 9f6a56b
Show file tree
Hide file tree
Showing 13 changed files with 476 additions and 391 deletions.
55 changes: 54 additions & 1 deletion datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,17 @@ mod tests {
#[tokio::test]
async fn window() -> Result<()> {
let results = execute(
"SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5",
"SELECT \
c1, \
c2, \
SUM(c2) OVER (), \
COUNT(c2) OVER (), \
MAX(c2) OVER (), \
MIN(c2) OVER (), \
AVG(c2) OVER () \
FROM test \
ORDER BY c1, c2 \
LIMIT 5",
4,
)
.await?;
Expand All @@ -1299,6 +1309,49 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn window_order_by() -> Result<()> {
let results = execute(
"SELECT \
c1, \
c2, \
ROW_NUMBER() OVER (ORDER BY c1, c2), \
FIRST_VALUE(c2) OVER (ORDER BY c1, c2), \
LAST_VALUE(c2) OVER (ORDER BY c1, c2), \
NTH_VALUE(c2, 2) OVER (ORDER BY c1, c2), \
SUM(c2) OVER (ORDER BY c1, c2), \
COUNT(c2) OVER (ORDER BY c1, c2), \
MAX(c2) OVER (ORDER BY c1, c2), \
MIN(c2) OVER (ORDER BY c1, c2), \
AVG(c2) OVER (ORDER BY c1, c2) \
FROM test \
ORDER BY c1, c2 \
LIMIT 5",
4,
)
.await?;
// result in one batch, although e.g. having 2 batches do not change
// result semantics, having a len=1 assertion upfront keeps surprises
// at bay
assert_eq!(results.len(), 1);

let expected = vec![
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |",
"| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |",
"| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |",
"| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |",
"| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |",
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
];

// window function shall respect ordering
assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
Expand Down
137 changes: 46 additions & 91 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
//! Defines physical expressions that can evaluated at runtime during query execution
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
};
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::array::{new_empty_array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

/// nth_value kind
Expand Down Expand Up @@ -113,54 +111,32 @@ impl BuiltInWindowFunctionExpr for NthValue {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
self.kind,
self.data_type.clone(),
)?))
}
}

#[derive(Debug)]
struct NthValueAccumulator {
kind: NthValueKind,
offset: u32,
value: ScalarValue,
}

impl NthValueAccumulator {
/// new count accumulator
pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> {
Ok(Self {
kind,
offset: 0,
// null value of that data_type by default
value: ScalarValue::try_from(&data_type)?,
})
}
}

impl WindowAccumulator for NthValueAccumulator {
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
self.offset += 1;
match self.kind {
NthValueKind::Last => {
self.value = values[0].clone();
}
NthValueKind::First if self.offset == 1 => {
self.value = values[0].clone();
}
NthValueKind::Nth(n) if self.offset == n => {
self.value = values[0].clone();
}
_ => {}
fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result<ArrayRef> {
if values.is_empty() {
return Err(DataFusionError::Execution(format!(
"No arguments supplied to {}",
self.name()
)));
}

Ok(None)
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(Some(self.value.clone()))
let value = &values[0];
if value.len() != num_rows {
return Err(DataFusionError::Execution(format!(
"Invalid data supplied to {}, expect {} rows, got {} rows",
self.name(),
num_rows,
value.len()
)));
}
if num_rows == 0 {
return Ok(new_empty_array(value.data_type()));
}
let index: usize = match self.kind {
NthValueKind::First => 0,
NthValueKind::Last => (num_rows as usize) - 1,
NthValueKind::Nth(n) => (n as usize) - 1,
};
let value = ScalarValue::try_from_array(value, index)?;
Ok(value.to_array_of_size(num_rows))
}
}

Expand All @@ -172,68 +148,47 @@ mod tests {
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

fn test_i32_result(expr: Arc<NthValue>, expected: i32) -> Result<()> {
fn test_i32_result(expr: NthValue, expected: Vec<i32>) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let mut acc = expr.create_accumulator()?;
let expr = expr.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(false, result.is_some());
let result = acc.evaluate()?;
assert_eq!(Some(ScalarValue::Int32(Some(expected))), result);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
let result = expr.evaluate(batch.num_rows(), &values)?;
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
let result = result.values();
assert_eq!(expected, result);
Ok(())
}

#[test]
fn first_value() -> Result<()> {
let first_value = Arc::new(NthValue::first_value(
"first_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(first_value, 1)?;
let first_value =
NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32);
test_i32_result(first_value, vec![1; 8])?;
Ok(())
}

#[test]
fn last_value() -> Result<()> {
let last_value = Arc::new(NthValue::last_value(
"last_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(last_value, 8)?;
let last_value =
NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32);
test_i32_result(last_value, vec![8; 8])?;
Ok(())
}

#[test]
fn nth_value_1() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
1,
)?);
test_i32_result(nth_value, 1)?;
let nth_value =
NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?;
test_i32_result(nth_value, vec![1; 8])?;
Ok(())
}

#[test]
fn nth_value_2() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
2,
)?);
test_i32_result(nth_value, -2)?;
let nth_value =
NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?;
test_i32_result(nth_value, vec![-2; 8])?;
Ok(())
}
}
89 changes: 9 additions & 80 deletions datafusion/src/physical_plan/expressions/row_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
//! Defines physical expression for `row_number` that can evaluated at runtime during query execution
use crate::error::Result;
use crate::physical_plan::{
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
};
use crate::scalar::ScalarValue;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
Expand Down Expand Up @@ -60,46 +57,10 @@ impl BuiltInWindowFunctionExpr for RowNumber {
self.name.as_str()
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(RowNumberAccumulator::new()))
}
}

#[derive(Debug)]
struct RowNumberAccumulator {
row_number: u64,
}

impl RowNumberAccumulator {
/// new row_number accumulator
pub fn new() -> Self {
// row number is 1 based
Self { row_number: 1 }
}
}

impl WindowAccumulator for RowNumberAccumulator {
fn scan(&mut self, _values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
let result = Some(ScalarValue::UInt64(Some(self.row_number)));
self.row_number += 1;
Ok(result)
}

fn scan_batch(
&mut self,
num_rows: usize,
_values: &[ArrayRef],
) -> Result<Option<ArrayRef>> {
let new_row_number = self.row_number + (num_rows as u64);
// TODO: probably would be nice to have a (optimized) kernel for this at some point to
// generate an array like this.
let result = UInt64Array::from_iter_values(self.row_number..new_row_number);
self.row_number = new_row_number;
Ok(Some(Arc::new(result)))
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(None)
fn evaluate(&self, num_rows: usize, _values: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(UInt64Array::from_iter_values(
(1..num_rows + 1).map(|i| i as u64),
)))
}
}

Expand All @@ -117,27 +78,11 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let row_number = Arc::new(RowNumber::new("row_number".to_owned()));

let mut acc = row_number.create_accumulator()?;
let expr = row_number.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(true, result.is_some());

let result = result.unwrap();
let row_number = RowNumber::new("row_number".to_owned());
let result = row_number.evaluate(batch.num_rows(), &[])?;
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);

let result = acc.evaluate()?;
assert_eq!(false, result.is_some());
Ok(())
}

Expand All @@ -148,27 +93,11 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let row_number = Arc::new(RowNumber::new("row_number".to_owned()));

let mut acc = row_number.create_accumulator()?;
let expr = row_number.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(true, result.is_some());

let result = result.unwrap();
let row_number = RowNumber::new("row_number".to_owned());
let result = row_number.evaluate(batch.num_rows(), &[])?;
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);

let result = acc.evaluate()?;
assert_eq!(false, result.is_some());
Ok(())
}
}
4 changes: 2 additions & 2 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_col.keys_array();
let keys_col = dict_col.keys();
let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
Expand Down Expand Up @@ -1083,7 +1083,7 @@ fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>(
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();

// look up the index in the values dictionary
let keys_col = dict_col.keys_array();
let keys_col = dict_col.keys();
let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
Expand Down
Loading

0 comments on commit 9f6a56b

Please sign in to comment.