Skip to content

Commit

Permalink
Simplify and encapsulate window function state management (#6621)
Browse files Browse the repository at this point in the history
* Simplify and encapsulate window function state management

* Fix docs

* Improve documentation

* Improve comment and readability
  • Loading branch information
alamb authored Jun 13, 2023
1 parent d602de2 commit a56ae74
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 96 deletions.
46 changes: 3 additions & 43 deletions datafusion/physical-expr/src/window/built_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ use std::sync::Arc;
use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
use crate::window::window_expr::{
BuiltinWindowState, NthValueKind, NthValueState, WindowFn,
};
use crate::window::window_expr::WindowFn;
use crate::window::{
PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState,
};
use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
use arrow::array::{new_empty_array, Array, ArrayRef};
use arrow::array::{new_empty_array, ArrayRef};
use arrow::compute::SortOptions;
use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -211,13 +209,7 @@ impl WindowExpr for BuiltInWindowExpr {

state.update(&out_col, partition_batch_state)?;
if self.window_frame.start_bound.is_unbounded() {
let mut evaluator_state = evaluator.state()?;
if let BuiltinWindowState::NthValue(nth_value_state) =
&mut evaluator_state
{
memoize_nth_value(state, nth_value_state)?;
evaluator.set_state(&evaluator_state)?;
}
evaluator.memoize(state)?;
}
}
Ok(())
Expand All @@ -244,35 +236,3 @@ impl WindowExpr for BuiltInWindowExpr {
|| !self.window_frame.end_bound.is_unbounded())
}
}

// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), for
// FIRST_VALUE, LAST_VALUE and NTH_VALUE functions: we can memoize result.
// Once result is calculated it will always stay same. Hence, we do not
// need to keep past data as we process the entire dataset. This feature
// enables us to prune rows from table.
fn memoize_nth_value(
state: &mut WindowAggState,
nth_value_state: &mut NthValueState,
) -> Result<()> {
let out = &state.out_col;
let size = out.len();
let (is_prunable, new_prunable) = match nth_value_state.kind {
NthValueKind::First => {
let n_range = state.window_frame_range.end - state.window_frame_range.start;
(n_range > 0 && size > 0, true)
}
NthValueKind::Last => (true, false),
NthValueKind::Nth(n) => {
let n_range = state.window_frame_range.end - state.window_frame_range.start;
(n_range >= (n as usize) && size >= (n as usize), true)
}
};
if is_prunable {
if nth_value_state.finalized_result.is_none() && new_prunable {
let result = ScalarValue::try_from_array(out, size - 1)?;
nth_value_state.finalized_result = Some(result);
}
state.window_frame_range.start = state.window_frame_range.end.saturating_sub(1);
}
Ok(())
}
7 changes: 1 addition & 6 deletions datafusion/physical-expr/src/window/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::{BuiltinWindowState, LeadLagState};
use crate::window::window_expr::LeadLagState;
use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
Expand Down Expand Up @@ -182,11 +182,6 @@ fn shift_with_default_value(
}

impl PartitionEvaluator for WindowShiftEvaluator {
fn state(&self) -> Result<BuiltinWindowState> {
// If we do not use state we just return Default
Ok(BuiltinWindowState::LeadLag(self.state.clone()))
}

fn update_state(
&mut self,
_state: &WindowAggState,
Expand Down
39 changes: 30 additions & 9 deletions datafusion/physical-expr/src/window/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! that can evaluated at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::{BuiltinWindowState, NthValueKind, NthValueState};
use crate::window::window_expr::{NthValueKind, NthValueState};
use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::{Array, ArrayRef};
Expand Down Expand Up @@ -152,11 +152,6 @@ pub(crate) struct NthValueEvaluator {
}

impl PartitionEvaluator for NthValueEvaluator {
fn state(&self) -> Result<BuiltinWindowState> {
// If we do not use state we just return Default
Ok(BuiltinWindowState::NthValue(self.state.clone()))
}

fn update_state(
&mut self,
state: &WindowAggState,
Expand All @@ -169,9 +164,35 @@ impl PartitionEvaluator for NthValueEvaluator {
Ok(())
}

fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> {
if let BuiltinWindowState::NthValue(nth_value_state) = state {
self.state = nth_value_state.clone()
/// When the window frame has a fixed beginning (e.g UNBOUNDED
/// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and
/// NTH_VALUE we can memoize result. Once result is calculated it
/// will always stay same. Hence, we do not need to keep past data
/// as we process the entire dataset. This feature enables us to
/// prune rows from table. The default implementation does nothing
fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
let out = &state.out_col;
let size = out.len();
let (is_prunable, is_last) = match self.state.kind {
NthValueKind::First => {
let n_range =
state.window_frame_range.end - state.window_frame_range.start;
(n_range > 0 && size > 0, false)
}
NthValueKind::Last => (true, true),
NthValueKind::Nth(n) => {
let n_range =
state.window_frame_range.end - state.window_frame_range.start;
(n_range >= (n as usize) && size >= (n as usize), false)
}
};
if is_prunable {
if self.state.finalized_result.is_none() && !is_last {
let result = ScalarValue::try_from_array(out, size - 1)?;
self.state.finalized_result = Some(result);
}
state.window_frame_range.start =
state.window_frame_range.end.saturating_sub(1);
}
Ok(())
}
Expand Down
27 changes: 10 additions & 17 deletions datafusion/physical-expr/src/window/partition_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! Partition evaluation module
use crate::window::window_expr::BuiltinWindowState;
use crate::window::WindowAggState;
use arrow::array::ArrayRef;
use datafusion_common::Result;
Expand Down Expand Up @@ -78,8 +77,7 @@ use std::ops::Range;
///
/// In this case, [`Self::evaluate_stateful`] is called to calculate
/// the results of the window function incrementally for each new
/// batch, saving and restoring any state needed to do so as
/// [`BuiltinWindowState`].
/// batch.
///
/// For example, when computing `ROW_NUMBER` incrementally,
/// [`Self::evaluate_stateful`] will be called multiple times with
Expand All @@ -91,14 +89,6 @@ use std::ops::Range;
/// [`BuiltInWindowFunctionExpr`]: crate::window::BuiltInWindowFunctionExpr
/// [`BuiltInWindowFunctionExpr::create_evaluator`]: crate::window::BuiltInWindowFunctionExpr::create_evaluator
pub trait PartitionEvaluator: Debug + Send {
/// Returns the internal state of the window function
///
/// Only used for stateful evaluation
fn state(&self) -> Result<BuiltinWindowState> {
// If we do not use state we just return Default
Ok(BuiltinWindowState::Default)
}

/// Updates the internal state for window function
///
/// Only used for stateful evaluation
Expand All @@ -118,13 +108,16 @@ pub trait PartitionEvaluator: Debug + Send {
Ok(())
}

/// Sets the internal state for window function
/// When the window frame has a fixed beginning (e.g UNBOUNDED
/// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
/// NTH_VALUE do not need the (unbounded) input once they have
/// seen a certain amount of input.
///
/// Only used for stateful evaluation
fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> {
Err(DataFusionError::NotImplemented(
"set_state is not implemented for this window function".to_string(),
))
/// `memoize` is called after each input batch is processed, and
/// such functions can save whatever they need and modify
/// [`WindowAggState`] appropriately to allow rows to be pruned
fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> {
Ok(())
}

/// Gets the range where the window function result is calculated.
Expand Down
6 changes: 1 addition & 5 deletions datafusion/physical-expr/src/window/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::{BuiltinWindowState, RankState};
use crate::window::window_expr::RankState;
use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
Expand Down Expand Up @@ -129,10 +129,6 @@ impl PartitionEvaluator for RankEvaluator {
Ok(Range { start, end })
}

fn state(&self) -> Result<BuiltinWindowState> {
Ok(BuiltinWindowState::Rank(self.state.clone()))
}

fn update_state(
&mut self,
state: &WindowAggState,
Expand Down
7 changes: 1 addition & 6 deletions datafusion/physical-expr/src/window/row_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Defines physical expression for `row_number` that can evaluated at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::{BuiltinWindowState, NumRowsState};
use crate::window::window_expr::NumRowsState;
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::{ArrayRef, UInt64Array};
Expand Down Expand Up @@ -76,11 +76,6 @@ pub(crate) struct NumRowsEvaluator {
}

impl PartitionEvaluator for NumRowsEvaluator {
fn state(&self) -> Result<BuiltinWindowState> {
// If we do not use state we just return Default
Ok(BuiltinWindowState::NumRows(self.state.clone()))
}

fn get_range(&self, idx: usize, _n_rows: usize) -> Result<Range<usize>> {
let start = idx;
let end = idx + 1;
Expand Down
11 changes: 1 addition & 10 deletions datafusion/physical-expr/src/window/window_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,7 @@ pub struct LeadLagState {
pub idx: usize,
}

#[derive(Debug, Clone, Default)]
pub enum BuiltinWindowState {
Rank(RankState),
NumRows(NumRowsState),
NthValue(NthValueState),
LeadLag(LeadLagState),
#[default]
Default,
}

/// Holds the state of evaluating a window function
#[derive(Debug)]
pub struct WindowAggState {
/// The range that we calculate the window function
Expand Down

0 comments on commit a56ae74

Please sign in to comment.