Skip to content

Commit

Permalink
add integration tests for rank, dense_rank (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist authored Jun 30, 2021
1 parent e861d01 commit cab3a98
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 54 deletions.
24 changes: 12 additions & 12 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1335,11 +1335,11 @@ mod tests {
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |",
"| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |",
"| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |",
"| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |",
"| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |",
"| 0 | 1 | 1 | 1 | 1 | | 1 | 1 | 1 | 1 | 1 |",
"| 0 | 2 | 2 | 1 | 2 | 2 | 3 | 2 | 2 | 1 | 1.5 |",
"| 0 | 3 | 3 | 1 | 3 | 2 | 6 | 3 | 3 | 1 | 2 |",
"| 0 | 4 | 4 | 1 | 4 | 2 | 10 | 4 | 4 | 1 | 2.5 |",
"| 0 | 5 | 5 | 1 | 5 | 2 | 15 | 5 | 5 | 1 | 3 |",
"+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
];

Expand Down Expand Up @@ -1392,7 +1392,7 @@ mod tests {
ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \
FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \
NTH_VALUE(c2 + c1, 1) OVER (PARTITION BY c2 ORDER BY c1), \
SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \
COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \
MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \
Expand All @@ -1407,13 +1407,13 @@ mod tests {

let expected = vec![
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(1)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 1 | 1 | 4 | 2 | 1 | 1 | 1 | 1 | 1 |",
"| 0 | 2 | 1 | 2 | 5 | 3 | 2 | 1 | 2 | 2 | 2 |",
"| 0 | 3 | 1 | 3 | 6 | 4 | 3 | 1 | 3 | 3 | 3 |",
"| 0 | 4 | 1 | 4 | 7 | 5 | 4 | 1 | 4 | 4 | 4 |",
"| 0 | 5 | 1 | 5 | 8 | 6 | 5 | 1 | 5 | 5 | 5 |",
"| 0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |",
"| 0 | 2 | 1 | 2 | 2 | 2 | 2 | 1 | 2 | 2 | 2 |",
"| 0 | 3 | 1 | 3 | 3 | 3 | 3 | 1 | 3 | 3 | 3 |",
"| 0 | 4 | 1 | 4 | 4 | 4 | 4 | 1 | 4 | 4 | 4 |",
"| 0 | 5 | 1 | 5 | 5 | 5 | 5 | 1 | 5 | 5 | 5 |",
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
];

Expand Down
94 changes: 72 additions & 22 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ use crate::physical_plan::window_functions::PartitionEvaluator;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::array::{new_null_array, ArrayRef};
use arrow::compute::kernels::window::shift;
use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
use std::any::Any;
use std::iter;
use std::ops::Range;
use std::sync::Arc;

Expand Down Expand Up @@ -138,21 +140,56 @@ pub(crate) struct NthValueEvaluator {
}

impl PartitionEvaluator for NthValueEvaluator {
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
let value = &self.values[0];
fn include_rank(&self) -> bool {
true
}

fn evaluate_partition(&self, _partition: Range<usize>) -> Result<ArrayRef> {
unreachable!("first, last, and nth_value evaluation must be called with evaluate_partition_with_rank")
}

fn evaluate_partition_with_rank(
&self,
partition: Range<usize>,
ranks_in_partition: &[Range<usize>],
) -> Result<ArrayRef> {
let arr = &self.values[0];
let num_rows = partition.end - partition.start;
let value = value.slice(partition.start, num_rows);
let index: usize = match self.kind {
NthValueKind::First => 0,
NthValueKind::Last => (num_rows as usize) - 1,
NthValueKind::Nth(n) => (n as usize) - 1,
};
Ok(if index >= num_rows {
new_null_array(value.data_type(), num_rows)
} else {
let value = ScalarValue::try_from_array(&value, index)?;
value.to_array_of_size(num_rows)
})
match self.kind {
NthValueKind::First => {
let value = ScalarValue::try_from_array(arr, partition.start)?;
Ok(value.to_array_of_size(num_rows))
}
NthValueKind::Last => {
// because the default window frame is between unbounded preceding and current
// row with peer evaluation, hence the last rows expands until the end of the peers
let values = ranks_in_partition
.iter()
.map(|range| {
let len = range.end - range.start;
let value = ScalarValue::try_from_array(arr, range.end - 1)?;
Ok(iter::repeat(value).take(len))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten();
ScalarValue::iter_to_array(values)
}
NthValueKind::Nth(n) => {
let index = (n as usize) - 1;
if index >= num_rows {
Ok(new_null_array(arr.data_type(), num_rows))
} else {
let value =
ScalarValue::try_from_array(arr, partition.start + index)?;
let arr = value.to_array_of_size(num_rows);
// because the default window frame is between unbounded preceding and current
// row, hence the shift because for values with indices < index they should be
// null. This changes when window frames other than default is implemented
shift(arr.as_ref(), index as i64).map_err(DataFusionError::ArrowError)
}
}
}
}
}

Expand All @@ -164,16 +201,17 @@ mod tests {
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

fn test_i32_result(expr: NthValue, expected: Vec<i32>) -> Result<()> {
fn test_i32_result(expr: NthValue, expected: Int32Array) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
let result = expr.create_evaluator(&batch)?.evaluate(vec![0..8])?;
let result = expr
.create_evaluator(&batch)?
.evaluate_with_rank(vec![0..8], vec![0..8])?;
assert_eq!(1, result.len());
let result = result[0].as_any().downcast_ref::<Int32Array>().unwrap();
let result = result.values();
assert_eq!(expected, result);
assert_eq!(expected, *result);
Ok(())
}

Expand All @@ -184,7 +222,7 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
);
test_i32_result(first_value, vec![1; 8])?;
test_i32_result(first_value, Int32Array::from_iter_values(vec![1; 8]))?;
Ok(())
}

Expand All @@ -195,7 +233,7 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
);
test_i32_result(last_value, vec![8; 8])?;
test_i32_result(last_value, Int32Array::from_iter_values(vec![8; 8]))?;
Ok(())
}

Expand All @@ -207,7 +245,7 @@ mod tests {
DataType::Int32,
1,
)?;
test_i32_result(nth_value, vec![1; 8])?;
test_i32_result(nth_value, Int32Array::from_iter_values(vec![1; 8]))?;
Ok(())
}

Expand All @@ -219,7 +257,19 @@ mod tests {
DataType::Int32,
2,
)?;
test_i32_result(nth_value, vec![-2; 8])?;
test_i32_result(
nth_value,
Int32Array::from(vec![
None,
Some(-2),
Some(-2),
Some(-2),
Some(-2),
Some(-2),
Some(-2),
Some(-2),
]),
)?;
Ok(())
}
}
21 changes: 7 additions & 14 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ async fn csv_query_window_with_partition_by() -> Result<()> {
"-21481",
"-16974",
"-21481",
"-21481",
"NULL",
],
vec![
"141680161",
Expand Down Expand Up @@ -952,15 +952,8 @@ async fn csv_query_window_with_order_by() -> Result<()> {
let actual = execute(&mut ctx, sql).await;
let expected = vec![
vec![
"28774375",
"61035129",
"61035129",
"1",
"61035129",
"61035129",
"61035129",
"2025611582",
"-108973366",
"28774375", "61035129", "61035129", "1", "61035129", "61035129", "61035129",
"61035129", "NULL",
],
vec![
"63044568",
Expand All @@ -970,7 +963,7 @@ async fn csv_query_window_with_order_by() -> Result<()> {
"61035129",
"-108973366",
"61035129",
"2025611582",
"-108973366",
"-108973366",
],
vec![
Expand All @@ -981,7 +974,7 @@ async fn csv_query_window_with_order_by() -> Result<()> {
"623103518",
"-108973366",
"61035129",
"2025611582",
"623103518",
"-108973366",
],
vec![
Expand All @@ -992,7 +985,7 @@ async fn csv_query_window_with_order_by() -> Result<()> {
"623103518",
"-1927628110",
"61035129",
"2025611582",
"-1927628110",
"-108973366",
],
vec![
Expand All @@ -1003,7 +996,7 @@ async fn csv_query_window_with_order_by() -> Result<()> {
"623103518",
"-1927628110",
"61035129",
"2025611582",
"-1899175111",
"-108973366",
],
];
Expand Down
27 changes: 27 additions & 0 deletions integration-tests/sqls/simple_window_built_in_functions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
c9,
row_number() OVER (ORDER BY c9) row_num,
first_value(c9) OVER (ORDER BY c9) first_c9,
first_value(c9) OVER (ORDER BY c9 DESC) first_c9_desc,
last_value(c9) OVER (ORDER BY c9) last_c9,
last_value(c9) OVER (ORDER BY c9 DESC) last_c9_desc,
nth_value(c9, 2) OVER (ORDER BY c9) second_c9,
nth_value(c9, 2) OVER (ORDER BY c9 DESC) second_c9_desc
FROM test
ORDER BY c9;
2 changes: 1 addition & 1 deletion integration-tests/sqls/simple_window_full_aggregation.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language gOVERning permissions and
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
Expand Down
22 changes: 22 additions & 0 deletions integration-tests/sqls/simple_window_ranked_built_in_functions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

select
c9,
rank() OVER (PARTITION BY c2 ORDER BY c3) rank_by_c3,
dense_rank() OVER (PARTITION BY c2 ORDER BY c3) dense_rank_by_c3
FROM test
ORDER BY c9;
4 changes: 2 additions & 2 deletions integration-tests/test_psql_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
self.assertEqual(len(files), 9, msg="tests are missed")
self.assertEqual(len(files), 11, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(
io.BytesIO(generate_csv_from_datafusion(fname))
)
psql_output = pd.read_csv(io.BytesIO(generate_csv_from_psql(fname)))
self.assertTrue(
np.allclose(datafusion_output, psql_output),
np.allclose(datafusion_output, psql_output, equal_nan=True),
msg=f"datafusion output=\n{datafusion_output}, psql_output=\n{psql_output}",
)

Expand Down

0 comments on commit cab3a98

Please sign in to comment.