Skip to content

Commit

Permalink
add lag and lead
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 15, 2021
1 parent 1b4268c commit 7db8d17
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 17 deletions.
1 change: 0 additions & 1 deletion ballista/rust/client/src/columnar_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ impl ColumnarBatch {
/// A columnar value can either be a scalar value or an Arrow array.
#[allow(dead_code)]
#[derive(Debug, Clone)]

pub enum ColumnarValue {
Scalar(ScalarValue, usize),
Columnar(ArrayRef),
Expand Down
1 change: 0 additions & 1 deletion ballista/rust/core/src/serde/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub mod to_proto;

/// Action that can be sent to an executor
#[derive(Debug, Clone)]

pub enum Action {
/// Execute a query and store the results in memory
ExecutePartition(ExecutePartition),
Expand Down
159 changes: 159 additions & 0 deletions datafusion/src/physical_plan/expressions/lead_lag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expression for `lead` and `lag` that can evaluated
//! at runtime during query execution
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use arrow::array::ArrayRef;
use arrow::compute::kernels::window::shift;
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::sync::Arc;

/// window shift expression
#[derive(Debug)]
pub struct WindowShift {
name: String,
data_type: DataType,
shift_offset: i64,
expr: Arc<dyn PhysicalExpr>,
}

/// lead() window function
pub fn lead(
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
) -> WindowShift {
WindowShift {
name,
data_type,
shift_offset: -1,
expr,
}
}

/// lag() window function
pub fn lag(
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
) -> WindowShift {
WindowShift {
name,
data_type,
shift_offset: 1,
expr,
}
}

impl BuiltInWindowFunctionExpr for WindowShift {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result<ArrayRef> {
if values.is_empty() {
return Err(DataFusionError::Execution(format!(
"No arguments supplied to {}",
self.name()
)));
}
let value = &values[0];
if value.len() != num_rows {
return Err(DataFusionError::Execution(format!(
"Invalid data supplied to {}, expect {} rows, got {} rows",
self.name(),
num_rows,
value.len()
)));
}
shift(value.as_ref(), self.shift_offset).map_err(DataFusionError::ArrowError)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::col;
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
let result = expr.evaluate(batch.num_rows(), &values)?;
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(expected, *result);
Ok(())
}

#[test]
fn lead_lag_window_shift() -> Result<()> {
test_i32_result(
lead("lead".to_owned(), DataType::Float32, col("c3")),
vec![
Some(-2),
Some(3),
Some(-4),
Some(5),
Some(-6),
Some(7),
Some(8),
None,
]
.iter()
.collect::<Int32Array>(),
)?;

test_i32_result(
lag("lead".to_owned(), DataType::Float32, col("c3")),
vec![
None,
Some(1),
Some(-2),
Some(3),
Some(-4),
Some(5),
Some(-6),
Some(7),
]
.iter()
.collect::<Int32Array>(),
)?;
Ok(())
}
}
2 changes: 2 additions & 0 deletions datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ mod count;
mod in_list;
mod is_not_null;
mod is_null;
mod lead_lag;
mod literal;
mod min_max;
mod negative;
Expand All @@ -55,6 +56,7 @@ pub use count::Count;
pub use in_list::{in_list, InListExpr};
pub use is_not_null::{is_not_null, IsNotNullExpr};
pub use is_null::{is_null, IsNullExpr};
pub use lead_lag::{lag, lead};
pub use literal::{lit, Literal};
pub use min_max::{Max, Min};
pub use negative::{negative, NegativeExpr};
Expand Down
3 changes: 2 additions & 1 deletion datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions that can evaluated at runtime during query execution
//! Defines physical expressions for `first_value`, `last_value`, and `nth_value`
//! that can evaluated at runtime during query execution
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
))
})?;

create_key_for_col(&dict_col.values(), values_index, vec)
create_key_for_col(dict_col.values(), values_index, vec)
}

/// Appends a sequence of [u8] bytes for the value in `col[row]` to
Expand Down Expand Up @@ -1091,7 +1091,7 @@ fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>(
))
})?;

create_group_by_value(&dict_col.values(), values_index)
create_group_by_value(dict_col.values(), values_index)
}

/// Extract the value in `col[row]` as a GroupByScalar
Expand Down
36 changes: 25 additions & 11 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::error::{DataFusionError, Result};
use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits};
use crate::physical_plan::{
aggregates, common,
expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber},
expressions::{lag, lead, Literal, NthValue, PhysicalSortExpr, RowNumber},
type_coercion::coerce,
window_functions::signature_for_built_in,
window_functions::BuiltInWindowFunctionExpr,
Expand Down Expand Up @@ -98,8 +98,20 @@ fn create_built_in_window_expr(
input_schema: &Schema,
name: String,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
match fun {
BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))),
Ok(match fun {
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name)),
BuiltInWindowFunction::Lag => {
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
let data_type = args[0].data_type(input_schema)?;
Arc::new(lag(name, data_type, arg))
}
BuiltInWindowFunction::Lead => {
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
let data_type = args[0].data_type(input_schema)?;
Arc::new(lead(name, data_type, arg))
}
BuiltInWindowFunction::NthValue => {
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
Expand All @@ -114,25 +126,27 @@ fn create_built_in_window_expr(
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
let n: u32 = n as u32;
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?))
Arc::new(NthValue::nth_value(name, arg, data_type, n)?)
}
BuiltInWindowFunction::FirstValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(NthValue::first_value(name, arg, data_type)))
Arc::new(NthValue::first_value(name, arg, data_type))
}
BuiltInWindowFunction::LastValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(NthValue::last_value(name, arg, data_type)))
Arc::new(NthValue::last_value(name, arg, data_type))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Window function with {:?} not yet implemented",
fun
))),
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Window function with {:?} not yet implemented",
fun
)))
}
})
}

/// A window expr that takes the form of a built in window function
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ impl ScalarValue {
keys_col.data_type()
))
})?;
Self::try_from_array(&dict_array.values(), values_index)
Self::try_from_array(dict_array.values(), values_index)
}
}

Expand Down

0 comments on commit 7db8d17

Please sign in to comment.