From 29fdc24ad61d3da78d299f339cbc49b9819ba66d Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 13 Jun 2021 13:22:41 +0800 Subject: [PATCH 1/3] add lead and lag --- ballista/rust/core/src/serde/scheduler/mod.rs | 1 - .../src/physical_plan/expressions/lead_lag.rs | 180 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 2 + .../physical_plan/expressions/nth_value.rs | 3 +- datafusion/src/physical_plan/windows.rs | 42 ++-- 5 files changed, 213 insertions(+), 15 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/lead_lag.rs diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs index 75e3ac496ff5..f66bb08189d2 100644 --- a/ballista/rust/core/src/serde/scheduler/mod.rs +++ b/ballista/rust/core/src/serde/scheduler/mod.rs @@ -34,7 +34,6 @@ pub mod to_proto; /// Action that can be sent to an executor #[derive(Debug, Clone)] - pub enum Action { /// Execute a query and store the results in memory ExecutePartition(ExecutePartition), diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs b/datafusion/src/physical_plan/expressions/lead_lag.rs new file mode 100644 index 000000000000..3738c7d56003 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/lead_lag.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expression for `lead` and `lag` that can evaluated +//! at runtime during query execution + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::window_functions::PartitionEvaluator; +use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::compute::kernels::window::shift; +use arrow::datatypes::{DataType, Field}; +use arrow::record_batch::RecordBatch; +use std::any::Any; +use std::ops::Range; +use std::sync::Arc; + +/// window shift expression +#[derive(Debug)] +pub struct WindowShift { + name: String, + data_type: DataType, + shift_offset: i64, + expr: Arc, +} + +/// lead() window function +pub fn lead( + name: String, + data_type: DataType, + expr: Arc, +) -> WindowShift { + WindowShift { + name, + data_type, + shift_offset: -1, + expr, + } +} + +/// lag() window function +pub fn lag( + name: String, + data_type: DataType, + expr: Arc, +) -> WindowShift { + WindowShift { + name, + data_type, + shift_offset: 1, + expr, + } +} + +impl BuiltInWindowFunctionExpr for WindowShift { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_evaluator( + &self, + batch: &RecordBatch, + ) -> Result> { + let values = self + .expressions() + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + Ok(Box::new(WindowShiftEvaluator { + shift_offset: self.shift_offset, + values, + })) + } +} + +pub(crate) struct WindowShiftEvaluator { + shift_offset: i64, + values: Vec, +} + +impl PartitionEvaluator for WindowShiftEvaluator { + fn evaluate_partition(&self, _partition: Range) -> Result { + let value = &self.values[0]; + shift(value.as_ref(), self.shift_offset).map_err(DataFusionError::ArrowError) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::Column; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let values = vec![arr]; + let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; + let result = expr.create_evaluator(&batch)?.evaluate(vec![0..8])?; + assert_eq!(1, result.len()); + let result = result[0].as_any().downcast_ref::().unwrap(); + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn lead_lag_window_shift() -> Result<()> { + test_i32_result( + lead( + "lead".to_owned(), + DataType::Float32, + Arc::new(Column::new("c3", 0)), + ), + vec![ + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + Some(8), + None, + ] + .iter() + .collect::(), + )?; + + test_i32_result( + lag( + "lead".to_owned(), + DataType::Float32, + Arc::new(Column::new("c3", 0)), + ), + vec![ + None, + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + ] + .iter() + .collect::(), + )?; + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 440cb5b4ec67..bd3dab65b05d 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -36,6 +36,7 @@ mod count; mod in_list; mod is_not_null; mod is_null; +mod lead_lag; mod literal; mod min_max; mod negative; @@ -58,6 +59,7 @@ pub use count::Count; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; +pub use lead_lag::{lag, lead}; pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 854078e232f0..7542a251f50d 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines physical expressions for `first_value`, `last_value`, and `nth_value` +//! that can evaluated at runtime during query execution use crate::error::{DataFusionError, Result}; use crate::physical_plan::window_functions::PartitionEvaluator; diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index cd603fd5134e..1b783782e164 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -21,7 +21,9 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::physical_plan::{ aggregates, common, - expressions::{dense_rank, rank, Literal, NthValue, PhysicalSortExpr, RowNumber}, + expressions::{ + dense_rank, lag, lead, rank, Literal, NthValue, PhysicalSortExpr, RowNumber, + }, type_coercion::coerce, window_functions::{ signature_for_built_in, BuiltInWindowFunction, BuiltInWindowFunctionExpr, @@ -100,10 +102,22 @@ fn create_built_in_window_expr( input_schema: &Schema, name: String, ) -> Result> { - match fun { - BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))), - BuiltInWindowFunction::Rank => Ok(Arc::new(rank(name))), - BuiltInWindowFunction::DenseRank => Ok(Arc::new(dense_rank(name))), + Ok(match fun { + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name)), + BuiltInWindowFunction::Rank => Arc::new(rank(name)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name)), + BuiltInWindowFunction::Lag => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Arc::new(lag(name, data_type, arg)) + } + BuiltInWindowFunction::Lead => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Arc::new(lead(name, data_type, arg)) + } BuiltInWindowFunction::NthValue => { let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; let arg = coerced_args[0].clone(); @@ -118,25 +132,27 @@ fn create_built_in_window_expr( .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let n: u32 = n as u32; let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?)) + Arc::new(NthValue::nth_value(name, arg, data_type, n)?) } BuiltInWindowFunction::FirstValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::first_value(name, arg, data_type))) + Arc::new(NthValue::first_value(name, arg, data_type)) } BuiltInWindowFunction::LastValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::last_value(name, arg, data_type))) + Arc::new(NthValue::last_value(name, arg, data_type)) } - _ => Err(DataFusionError::NotImplemented(format!( - "Window function with {:?} not yet implemented", - fun - ))), - } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Window function with {:?} not yet implemented", + fun + ))) + } + }) } /// A window expr that takes the form of a built in window function From 4fce4064063a35441fa5b8a8266a643f1087be95 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 1 Jul 2021 08:21:55 +0800 Subject: [PATCH 2/3] add integration tests --- integration-tests/sqls/simple_window_built_in_functions.sql | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integration-tests/sqls/simple_window_built_in_functions.sql b/integration-tests/sqls/simple_window_built_in_functions.sql index e76b38306002..05c34dd12fca 100644 --- a/integration-tests/sqls/simple_window_built_in_functions.sql +++ b/integration-tests/sqls/simple_window_built_in_functions.sql @@ -17,6 +17,8 @@ SELECT c9, row_number() OVER (ORDER BY c9) row_num, + lead(c9) OVER (ORDER BY c9) lead_c9, + lag(c9) OVER (ORDER BY c9) lag_c9, first_value(c9) OVER (ORDER BY c9) first_c9, first_value(c9) OVER (ORDER BY c9 DESC) first_c9_desc, last_value(c9) OVER (ORDER BY c9) last_c9, From 20114add06a0d769d3ce03b69b44f85f2bf630f1 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 1 Jul 2021 08:31:17 +0800 Subject: [PATCH 3/3] add partitioned window functions --- .../src/physical_plan/expressions/lead_lag.rs | 3 +- .../partitioned_window_built_in_functions.sql | 29 +++++++++++++++++++ integration-tests/test_psql_parity.py | 2 +- 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 integration-tests/sqls/partitioned_window_built_in_functions.sql diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs b/datafusion/src/physical_plan/expressions/lead_lag.rs index 3738c7d56003..352d97c1e116 100644 --- a/datafusion/src/physical_plan/expressions/lead_lag.rs +++ b/datafusion/src/physical_plan/expressions/lead_lag.rs @@ -108,8 +108,9 @@ pub(crate) struct WindowShiftEvaluator { } impl PartitionEvaluator for WindowShiftEvaluator { - fn evaluate_partition(&self, _partition: Range) -> Result { + fn evaluate_partition(&self, partition: Range) -> Result { let value = &self.values[0]; + let value = value.slice(partition.start, partition.end - partition.start); shift(value.as_ref(), self.shift_offset).map_err(DataFusionError::ArrowError) } } diff --git a/integration-tests/sqls/partitioned_window_built_in_functions.sql b/integration-tests/sqls/partitioned_window_built_in_functions.sql new file mode 100644 index 000000000000..f27b085f5033 --- /dev/null +++ b/integration-tests/sqls/partitioned_window_built_in_functions.sql @@ -0,0 +1,29 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + c9, + row_number() OVER (PARTITION BY c2 ORDER BY c9) row_num, + lead(c9) OVER (PARTITION BY c2 ORDER BY c9) lead_c9, + lag(c9) OVER (PARTITION BY c2 ORDER BY c9) lag_c9, + first_value(c9) OVER (PARTITION BY c2 ORDER BY c9) first_c9, + first_value(c9) OVER (PARTITION BY c2 ORDER BY c9 DESC) first_c9_desc, + last_value(c9) OVER (PARTITION BY c2 ORDER BY c9) last_c9, + last_value(c9) OVER (PARTITION BY c2 ORDER BY c9 DESC) last_c9_desc, + nth_value(c9, 2) OVER (PARTITION BY c2 ORDER BY c9) second_c9, + nth_value(c9, 2) OVER (PARTITION BY c2 ORDER BY c9 DESC) second_c9_desc +FROM test +ORDER BY c9; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 2bb8da9fd5c5..766f403f3e54 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 11, msg="tests are missed") + self.assertEqual(len(files), 12, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv(