Skip to content

Commit 09aea09

Browse files
authored
Refactor Builtin Window Function Implementation (#4441)
1 parent a0485e7 commit 09aea09

File tree

10 files changed

+122
-164
lines changed

10 files changed

+122
-164
lines changed

datafusion/physical-expr/src/window/aggregate.rs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,9 @@ impl WindowExpr for AggregateWindowExpr {
8787
let partition_columns = self.partition_columns(batch)?;
8888
let partition_points =
8989
self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;
90-
let values = self.evaluate_args(batch)?;
91-
9290
let sort_options: Vec<SortOptions> =
9391
self.order_by.iter().map(|o| o.options).collect();
94-
let columns = self.sort_columns(batch)?;
95-
let order_columns: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect();
96-
// Sort values, this will make the same partitions consecutive. Also, within the partition
97-
// range, values will be sorted.
98-
let order_bys = &order_columns[self.partition_by.len()..];
92+
let (_, order_bys) = self.get_values_orderbys(batch)?;
9993
let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() {
10094
// OVER (ORDER BY a) case
10195
// We create an implicit window for ORDER BY.
@@ -107,14 +101,8 @@ impl WindowExpr for AggregateWindowExpr {
107101
for partition_range in &partition_points {
108102
let mut accumulator = self.aggregate.create_accumulator()?;
109103
let length = partition_range.end - partition_range.start;
110-
let slice_order_bys = order_bys
111-
.iter()
112-
.map(|v| v.slice(partition_range.start, length))
113-
.collect::<Vec<_>>();
114-
let value_slice = values
115-
.iter()
116-
.map(|v| v.slice(partition_range.start, length))
117-
.collect::<Vec<_>>();
104+
let (values, order_bys) =
105+
self.get_values_orderbys(&batch.slice(partition_range.start, length))?;
118106

119107
let mut window_frame_ctx = WindowFrameContext::new(&window_frame);
120108
let mut last_range: (usize, usize) = (0, 0);
@@ -123,7 +111,7 @@ impl WindowExpr for AggregateWindowExpr {
123111
// First, cur_range is calculated, then it is compared with last_range.
124112
for i in 0..length {
125113
let cur_range = window_frame_ctx.calculate_range(
126-
&slice_order_bys,
114+
&order_bys,
127115
&sort_options,
128116
length,
129117
i,
@@ -135,7 +123,7 @@ impl WindowExpr for AggregateWindowExpr {
135123
// Accumulate any new rows that have entered the window:
136124
let update_bound = cur_range.1 - last_range.1;
137125
if update_bound > 0 {
138-
let update: Vec<ArrayRef> = value_slice
126+
let update: Vec<ArrayRef> = values
139127
.iter()
140128
.map(|v| v.slice(last_range.1, update_bound))
141129
.collect();
@@ -144,7 +132,7 @@ impl WindowExpr for AggregateWindowExpr {
144132
// Remove rows that have now left the window:
145133
let retract_bound = cur_range.0 - last_range.0;
146134
if retract_bound > 0 {
147-
let retract: Vec<ArrayRef> = value_slice
135+
let retract: Vec<ArrayRef> = values
148136
.iter()
149137
.map(|v| v.slice(last_range.0, retract_bound))
150138
.collect();

datafusion/physical-expr/src/window/built_in.rs

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use super::window_frame_state::WindowFrameContext;
2121
use super::BuiltInWindowFunctionExpr;
2222
use super::WindowExpr;
2323
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
24-
use arrow::array::Array;
2524
use arrow::compute::{concat, SortOptions};
2625
use arrow::record_batch::RecordBatch;
2726
use arrow::{array::ArrayRef, datatypes::Field};
@@ -85,7 +84,7 @@ impl WindowExpr for BuiltInWindowExpr {
8584
}
8685

8786
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
88-
let evaluator = self.expr.create_evaluator(batch)?;
87+
let evaluator = self.expr.create_evaluator()?;
8988
let num_rows = batch.num_rows();
9089
let partition_columns = self.partition_columns(batch)?;
9190
let partition_points =
@@ -94,12 +93,7 @@ impl WindowExpr for BuiltInWindowExpr {
9493
let results = if evaluator.uses_window_frame() {
9594
let sort_options: Vec<SortOptions> =
9695
self.order_by.iter().map(|o| o.options).collect();
97-
let columns = self.sort_columns(batch)?;
98-
let order_columns: Vec<&ArrayRef> =
99-
columns.iter().map(|s| &s.values).collect();
100-
// Sort values, this will make the same partitions consecutive. Also, within the partition
101-
// range, values will be sorted.
102-
let order_bys = &order_columns[self.partition_by.len()..];
96+
let (_, order_bys) = self.get_values_orderbys(batch)?;
10397
let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() {
10498
// OVER (ORDER BY a) case
10599
// We create an implicit window for ORDER BY.
@@ -110,24 +104,22 @@ impl WindowExpr for BuiltInWindowExpr {
110104
let mut row_wise_results = vec![];
111105
for partition_range in &partition_points {
112106
let length = partition_range.end - partition_range.start;
113-
let slice_order_bys = order_bys
114-
.iter()
115-
.map(|v| v.slice(partition_range.start, length))
116-
.collect::<Vec<_>>();
107+
let (values, order_bys) = self
108+
.get_values_orderbys(&batch.slice(partition_range.start, length))?;
117109
let mut window_frame_ctx = WindowFrameContext::new(&window_frame);
118110
// We iterate on each row to calculate window frame range and and window function result
119111
for idx in 0..length {
120112
let range = window_frame_ctx.calculate_range(
121-
&slice_order_bys,
113+
&order_bys,
122114
&sort_options,
123115
num_rows,
124116
idx,
125117
)?;
126118
let range = Range {
127-
start: partition_range.start + range.0,
128-
end: partition_range.start + range.1,
119+
start: range.0,
120+
end: range.1,
129121
};
130-
let value = evaluator.evaluate_inside_range(range)?;
122+
let value = evaluator.evaluate_inside_range(&values, range)?;
131123
row_wise_results.push(value.to_array());
132124
}
133125
}
@@ -138,7 +130,8 @@ impl WindowExpr for BuiltInWindowExpr {
138130
self.evaluate_partition_points(num_rows, &columns)?;
139131
evaluator.evaluate_with_rank(partition_points, sort_partition_points)?
140132
} else {
141-
evaluator.evaluate(partition_points)?
133+
let (values, _) = self.get_values_orderbys(batch)?;
134+
evaluator.evaluate(&values, partition_points)?
142135
};
143136
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
144137
concat(&results).map_err(DataFusionError::ArrowError)

datafusion/physical-expr/src/window/built_in_window_function_expr.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use super::partition_evaluator::PartitionEvaluator;
1919
use crate::PhysicalExpr;
20+
use arrow::array::ArrayRef;
2021
use arrow::datatypes::Field;
2122
use arrow::record_batch::RecordBatch;
2223
use datafusion_common::Result;
@@ -45,9 +46,16 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
4546
"BuiltInWindowFunctionExpr: default name"
4647
}
4748

49+
/// Evaluate window function arguments against the batch and return
50+
/// an array ref. Typically, the resulting vector is a single element vector.
51+
fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
52+
self.expressions()
53+
.iter()
54+
.map(|e| e.evaluate(batch))
55+
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
56+
.collect()
57+
}
58+
4859
/// Create built-in window evaluator with a batch
49-
fn create_evaluator(
50-
&self,
51-
batch: &RecordBatch,
52-
) -> Result<Box<dyn PartitionEvaluator>>;
60+
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
5361
}

datafusion/physical-expr/src/window/cume_dist.rs

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use crate::PhysicalExpr;
2424
use arrow::array::ArrayRef;
2525
use arrow::array::Float64Array;
2626
use arrow::datatypes::{DataType, Field};
27-
use arrow::record_batch::RecordBatch;
2827
use datafusion_common::Result;
2928
use std::any::Any;
3029
use std::iter;
@@ -62,10 +61,7 @@ impl BuiltInWindowFunctionExpr for CumeDist {
6261
&self.name
6362
}
6463

65-
fn create_evaluator(
66-
&self,
67-
_batch: &RecordBatch,
68-
) -> Result<Box<dyn PartitionEvaluator>> {
64+
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
6965
Ok(Box::new(CumeDistEvaluator {}))
7066
}
7167
}
@@ -77,12 +73,6 @@ impl PartitionEvaluator for CumeDistEvaluator {
7773
true
7874
}
7975

80-
fn evaluate_partition(&self, _partition: Range<usize>) -> Result<ArrayRef> {
81-
unreachable!(
82-
"cume_dist evaluation must be called with evaluate_partition_with_rank"
83-
)
84-
}
85-
8676
fn evaluate_partition_with_rank(
8777
&self,
8878
partition: Range<usize>,
@@ -108,22 +98,16 @@ impl PartitionEvaluator for CumeDistEvaluator {
10898
#[cfg(test)]
10999
mod tests {
110100
use super::*;
111-
use arrow::{array::*, datatypes::*};
112101
use datafusion_common::cast::as_float64_array;
113102

114103
fn test_i32_result(
115104
expr: &CumeDist,
116-
data: Vec<i32>,
117105
partition: Range<usize>,
118106
ranks: Vec<Range<usize>>,
119107
expected: Vec<f64>,
120108
) -> Result<()> {
121-
let arr: ArrayRef = Arc::new(Int32Array::from(data));
122-
let values = vec![arr];
123-
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
124-
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
125109
let result = expr
126-
.create_evaluator(&batch)?
110+
.create_evaluator()?
127111
.evaluate_with_rank(vec![partition], ranks)?;
128112
assert_eq!(1, result.len());
129113
let result = as_float64_array(&result[0])?;
@@ -137,25 +121,19 @@ mod tests {
137121
let r = cume_dist("arr".into());
138122

139123
let expected = vec![0.0; 0];
140-
test_i32_result(&r, vec![], 0..0, vec![], expected)?;
124+
test_i32_result(&r, 0..0, vec![], expected)?;
141125

142126
let expected = vec![1.0; 1];
143-
test_i32_result(&r, vec![20; 1], 0..1, vec![0..1], expected)?;
127+
test_i32_result(&r, 0..1, vec![0..1], expected)?;
144128

145129
let expected = vec![1.0; 2];
146-
test_i32_result(&r, vec![20; 2], 0..2, vec![0..2], expected)?;
130+
test_i32_result(&r, 0..2, vec![0..2], expected)?;
147131

148132
let expected = vec![0.5, 0.5, 1.0, 1.0];
149-
test_i32_result(&r, vec![1, 1, 2, 2], 0..4, vec![0..2, 2..4], expected)?;
133+
test_i32_result(&r, 0..4, vec![0..2, 2..4], expected)?;
150134

151135
let expected = vec![0.25, 0.5, 0.75, 1.0];
152-
test_i32_result(
153-
&r,
154-
vec![1, 2, 4, 5],
155-
0..4,
156-
vec![0..1, 1..2, 2..3, 3..4],
157-
expected,
158-
)?;
136+
test_i32_result(&r, 0..4, vec![0..1, 1..2, 2..3, 3..4], expected)?;
159137

160138
Ok(())
161139
}

datafusion/physical-expr/src/window/lead_lag.rs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use crate::PhysicalExpr;
2424
use arrow::array::ArrayRef;
2525
use arrow::compute::cast;
2626
use arrow::datatypes::{DataType, Field};
27-
use arrow::record_batch::RecordBatch;
2827
use datafusion_common::ScalarValue;
2928
use datafusion_common::{DataFusionError, Result};
3029
use std::any::Any;
@@ -95,27 +94,16 @@ impl BuiltInWindowFunctionExpr for WindowShift {
9594
&self.name
9695
}
9796

98-
fn create_evaluator(
99-
&self,
100-
batch: &RecordBatch,
101-
) -> Result<Box<dyn PartitionEvaluator>> {
102-
let values = self
103-
.expressions()
104-
.iter()
105-
.map(|e| e.evaluate(batch))
106-
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
107-
.collect::<Result<Vec<_>>>()?;
97+
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
10898
Ok(Box::new(WindowShiftEvaluator {
10999
shift_offset: self.shift_offset,
110-
values,
111100
default_value: self.default_value.clone(),
112101
}))
113102
}
114103
}
115104

116105
pub(crate) struct WindowShiftEvaluator {
117106
shift_offset: i64,
118-
values: Vec<ArrayRef>,
119107
default_value: Option<ScalarValue>,
120108
}
121109

@@ -169,8 +157,13 @@ fn shift_with_default_value(
169157
}
170158

171159
impl PartitionEvaluator for WindowShiftEvaluator {
172-
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
173-
let value = &self.values[0];
160+
fn evaluate_partition(
161+
&self,
162+
values: &[ArrayRef],
163+
partition: Range<usize>,
164+
) -> Result<ArrayRef> {
165+
// LEAD, LAG window functions take single column, values will have size 1
166+
let value = &values[0];
174167
let value = value.slice(partition.start, partition.end - partition.start);
175168
shift_with_default_value(&value, self.shift_offset, self.default_value.as_ref())
176169
}
@@ -190,7 +183,8 @@ mod tests {
190183
let values = vec![arr];
191184
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
192185
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
193-
let result = expr.create_evaluator(&batch)?.evaluate(vec![0..8])?;
186+
let values = expr.evaluate_args(&batch)?;
187+
let result = expr.create_evaluator()?.evaluate(&values, vec![0..8])?;
194188
assert_eq!(1, result.len());
195189
let result = as_int32_array(&result[0])?;
196190
assert_eq!(expected, *result);

datafusion/physical-expr/src/window/nth_value.rs

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use crate::window::BuiltInWindowFunctionExpr;
2323
use crate::PhysicalExpr;
2424
use arrow::array::{Array, ArrayRef};
2525
use arrow::datatypes::{DataType, Field};
26-
use arrow::record_batch::RecordBatch;
2726
use datafusion_common::ScalarValue;
2827
use datafusion_common::{DataFusionError, Result};
2928
use std::any::Any;
@@ -116,40 +115,28 @@ impl BuiltInWindowFunctionExpr for NthValue {
116115
&self.name
117116
}
118117

119-
fn create_evaluator(
120-
&self,
121-
batch: &RecordBatch,
122-
) -> Result<Box<dyn PartitionEvaluator>> {
123-
let values = self
124-
.expressions()
125-
.iter()
126-
.map(|e| e.evaluate(batch))
127-
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
128-
.collect::<Result<Vec<_>>>()?;
129-
Ok(Box::new(NthValueEvaluator {
130-
kind: self.kind,
131-
values,
132-
}))
118+
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
119+
Ok(Box::new(NthValueEvaluator { kind: self.kind }))
133120
}
134121
}
135122

136123
/// Value evaluator for nth_value functions
137124
pub(crate) struct NthValueEvaluator {
138125
kind: NthValueKind,
139-
values: Vec<ArrayRef>,
140126
}
141127

142128
impl PartitionEvaluator for NthValueEvaluator {
143129
fn uses_window_frame(&self) -> bool {
144130
true
145131
}
146132

147-
fn evaluate_partition(&self, _partition: Range<usize>) -> Result<ArrayRef> {
148-
unreachable!("first, last, and nth_value evaluation must be called with evaluate_partition_with_rank")
149-
}
150-
151-
fn evaluate_inside_range(&self, range: Range<usize>) -> Result<ScalarValue> {
152-
let arr = &self.values[0];
133+
fn evaluate_inside_range(
134+
&self,
135+
values: &[ArrayRef],
136+
range: Range<usize>,
137+
) -> Result<ScalarValue> {
138+
// FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take single column, values will have size 1
139+
let arr = &values[0];
153140
let n_range = range.end - range.start;
154141
match self.kind {
155142
NthValueKind::First => ScalarValue::try_from_array(arr, range.start),
@@ -188,10 +175,11 @@ mod tests {
188175
end: i + 1,
189176
})
190177
}
191-
let evaluator = expr.create_evaluator(&batch)?;
178+
let evaluator = expr.create_evaluator()?;
179+
let values = expr.evaluate_args(&batch)?;
192180
let result = ranges
193181
.into_iter()
194-
.map(|range| evaluator.evaluate_inside_range(range))
182+
.map(|range| evaluator.evaluate_inside_range(&values, range))
195183
.into_iter()
196184
.collect::<Result<Vec<ScalarValue>>>()?;
197185
let result = ScalarValue::iter_to_array(result.into_iter())?;

0 commit comments

Comments
 (0)