From e2d40bc13aa48609a6ce645bb3daa9aa5963c94e Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 13 Jun 2021 13:22:41 +0800 Subject: [PATCH] add lead and lag --- ballista/rust/client/src/columnar_batch.rs | 1 - ballista/rust/core/src/serde/scheduler/mod.rs | 1 - .../src/physical_plan/expressions/lead_lag.rs | 159 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 2 + .../physical_plan/expressions/nth_value.rs | 3 +- datafusion/src/physical_plan/windows.rs | 36 ++-- 6 files changed, 188 insertions(+), 14 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/lead_lag.rs diff --git a/ballista/rust/client/src/columnar_batch.rs b/ballista/rust/client/src/columnar_batch.rs index a40b68ff3ebd7..fc9df34a23a14 100644 --- a/ballista/rust/client/src/columnar_batch.rs +++ b/ballista/rust/client/src/columnar_batch.rs @@ -114,7 +114,6 @@ impl ColumnarBatch { /// A columnar value can either be a scalar value or an Arrow array. #[allow(dead_code)] #[derive(Debug, Clone)] - pub enum ColumnarValue { Scalar(ScalarValue, usize), Columnar(ArrayRef), diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs index c9bd1e93db2c4..8092d3484fee2 100644 --- a/ballista/rust/core/src/serde/scheduler/mod.rs +++ b/ballista/rust/core/src/serde/scheduler/mod.rs @@ -34,7 +34,6 @@ pub mod to_proto; /// Action that can be sent to an executor #[derive(Debug, Clone)] - pub enum Action { /// Execute a query and store the results in memory ExecutePartition(ExecutePartition), diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs b/datafusion/src/physical_plan/expressions/lead_lag.rs new file mode 100644 index 0000000000000..f03963730e92c --- /dev/null +++ b/datafusion/src/physical_plan/expressions/lead_lag.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expression for `lead` and `lag` that can evaluated +//! at runtime during query execution + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::compute::kernels::window::shift; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::sync::Arc; + +/// window shift expression +#[derive(Debug)] +pub struct WindowShift { + name: String, + data_type: DataType, + shift_offset: i64, + expr: Arc, +} + +/// lead() window function +pub fn lead( + name: String, + data_type: DataType, + expr: Arc, +) -> WindowShift { + WindowShift { + name, + data_type, + shift_offset: -1, + expr, + } +} + +/// lag() window function +pub fn lag( + name: String, + data_type: DataType, + expr: Arc, +) -> WindowShift { + WindowShift { + name, + data_type, + shift_offset: 1, + expr, + } +} + +impl BuiltInWindowFunctionExpr for WindowShift { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result { + if values.is_empty() { + return Err(DataFusionError::Execution(format!( + "No arguments supplied to {}", + self.name() + ))); + } + let value = &values[0]; + if value.len() != num_rows { + return Err(DataFusionError::Execution(format!( + "Invalid data supplied to {}, expect {} rows, got {} rows", + self.name(), + num_rows, + value.len() + ))); + } + shift(value, self.shift_offset).map_err(DataFusionError::ArrowError) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::col; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let values = vec![arr]; + let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; + let result = expr.evaluate(batch.num_rows(), &values)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn lead_lag_window_shift() -> Result<()> { + test_i32_result( + lead("lead".to_owned(), DataType::Float32, col("c3")), + vec![ + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + Some(8), + None, + ] + .iter() + .collect::(), + )?; + + test_i32_result( + lag("lead".to_owned(), DataType::Float32, col("c3")), + vec![ + None, + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + ] + .iter() + .collect::(), + )?; + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index f8cb40cbacbdc..a299ced7fa82f 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -36,6 +36,7 @@ mod count; mod in_list; mod is_not_null; mod is_null; +mod lead_lag; mod literal; mod min_max; mod negative; @@ -57,6 +58,7 @@ pub use count::Count; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; +pub use lead_lag::{lag, lead}; pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 16897d45119f0..13ee669e55f75 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines physical expressions for `first_value`, `last_value`, and `nth_value` +//! that can evaluated at runtime during query execution use crate::error::{DataFusionError, Result}; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 466cc51b447d0..a66c92bff6ffe 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -21,7 +21,7 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::physical_plan::{ aggregates, common, - expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber}, + expressions::{lag, lead, Literal, NthValue, PhysicalSortExpr, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, @@ -98,8 +98,20 @@ fn create_built_in_window_expr( input_schema: &Schema, name: String, ) -> Result> { - match fun { - BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))), + Ok(match fun { + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name)), + BuiltInWindowFunction::Lag => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Arc::new(lag(name, data_type, arg)) + } + BuiltInWindowFunction::Lead => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Arc::new(lead(name, data_type, arg)) + } BuiltInWindowFunction::NthValue => { let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; let arg = coerced_args[0].clone(); @@ -114,25 +126,27 @@ fn create_built_in_window_expr( .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let n: u32 = n as u32; let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?)) + Arc::new(NthValue::nth_value(name, arg, data_type, n)?) } BuiltInWindowFunction::FirstValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::first_value(name, arg, data_type))) + Arc::new(NthValue::first_value(name, arg, data_type)) } BuiltInWindowFunction::LastValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::last_value(name, arg, data_type))) + Arc::new(NthValue::last_value(name, arg, data_type)) } - _ => Err(DataFusionError::NotImplemented(format!( - "Window function with {:?} not yet implemented", - fun - ))), - } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Window function with {:?} not yet implemented", + fun + ))) + } + }) } /// A window expr that takes the form of a built in window function