Skip to content

Commit

Permalink
parallalize window function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 12, 2021
1 parent ad70a1e commit ddd3418
Showing 1 changed file with 44 additions and 76 deletions.
120 changes: 44 additions & 76 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,53 +329,29 @@ pin_project! {
}
}

type WindowAccumulatorItem = Box<dyn WindowAccumulator>;

fn window_expressions(
window_expr: &[Arc<dyn WindowExpr>],
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
Ok(window_expr
.iter()
.map(|expr| expr.expressions())
.collect::<Vec<_>>())
}

fn window_aggregate_batch(
batch: &RecordBatch,
window_accumulators: &mut [WindowAccumulatorItem],
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<Vec<Option<ArrayRef>>> {
window_accumulators
.iter_mut()
.zip(expressions)
.map(|(window_acc, expr)| {
let values = &expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
window_acc.scan_batch(batch.num_rows(), values)
})
.collect::<Result<Vec<_>>>()
}

/// returns a vector of ArrayRefs, where each entry corresponds to one window expr
fn finalize_window_aggregation(
window_accumulators: &[WindowAccumulatorItem],
) -> Result<Vec<Option<ScalarValue>>> {
window_accumulators
.iter()
.map(|window_accumulator| window_accumulator.evaluate())
.collect::<Result<Vec<_>>>()
}

fn create_window_accumulators(
window_expr: &[Arc<dyn WindowExpr>],
) -> Result<Vec<WindowAccumulatorItem>> {
window_expr
/// compute value for one window function
fn compute_window_aggregate(
window_expr: Arc<dyn WindowExpr>,
batch: Arc<RecordBatch>,
) -> Result<ArrayRef> {
let num_rows = batch.num_rows();
let values = window_expr
.expressions()
.iter()
.map(|expr| expr.create_accumulator())
.collect::<Result<Vec<_>>>()
.map(|e| e.evaluate(batch.as_ref()).map(|v| v.into_array(num_rows)))
.collect::<Result<Vec<_>>>()?;
let mut window_accumulator = window_expr.create_accumulator()?;
let window_aggregate = window_accumulator.scan_batch(num_rows, &values)?;
let final_aggregate = window_accumulator.evaluate()?;
Ok(match (window_aggregate, final_aggregate) {
(None, Some(fa)) => fa.to_array_of_size(num_rows),
(Some(wa), None) if wa.len() == num_rows => wa.clone(),
_ => {
return Err(DataFusionError::Execution(
"Invalid window function behavior".to_owned(),
))
}
})
}

/// Compute the window aggregate columns
Expand All @@ -397,39 +373,29 @@ fn create_window_accumulators(
/// a. some can be grow-only window-accumulating
/// b. some can be grow-and-shrink window-accumulating
/// c. some can be based on segment tree
fn compute_window_aggregates(
async fn compute_window_aggregates(
window_expr: Vec<Arc<dyn WindowExpr>>,
batch: &RecordBatch,
batch: Arc<RecordBatch>,
) -> Result<Vec<ArrayRef>> {
let mut window_accumulators = create_window_accumulators(&window_expr)?;
let expressions = Arc::new(window_expressions(&window_expr)?);
let num_rows = batch.num_rows();
let window_aggregates =
window_aggregate_batch(batch, &mut window_accumulators, &expressions)?;
let final_aggregates = finalize_window_aggregation(&window_accumulators)?;

// both must equal to window_expr.len()
if window_aggregates.len() != final_aggregates.len() {
return Err(DataFusionError::Internal(
"Impossibly got len mismatch".to_owned(),
));
}

window_aggregates
let handles = window_expr
.iter()
.zip(final_aggregates)
.map(|(wa, fa)| {
Ok(match (wa, fa) {
(None, Some(fa)) => fa.to_array_of_size(num_rows),
(Some(wa), None) if wa.len() == num_rows => wa.clone(),
_ => {
return Err(DataFusionError::Execution(
"Invalid window function behavior".to_owned(),
))
}
})
.map(|window_expr| {
let batch = batch.clone();
let window_expr = window_expr.clone();
tokio::spawn(async move { compute_window_aggregate(window_expr, batch) })
})
.collect()
.collect::<Vec<_>>();
let mut result = vec![];
for handle in handles {
let arr = handle.await.map_err(|e| {
DataFusionError::Execution(format!(
"Failed to join window aggregation handle {}",
e
))
})??;
result.push(arr);
}
Ok(result)
}

impl WindowAggStream {
Expand Down Expand Up @@ -465,8 +431,10 @@ impl WindowAggStream {
.map_err(DataFusionError::into_arrow_external_error)?;
let batch = common::combine_batches(&batches, input_schema.clone())?;
if let Some(batch) = batch {
let batch = Arc::new(batch);
// calculate window cols
let mut columns = compute_window_aggregates(window_expr, &batch)
let mut columns = compute_window_aggregates(window_expr, batch.clone())
.await
.map_err(DataFusionError::into_arrow_external_error)?;
// combine with the original cols
// note the setup of window aggregates is that they newly calculated window
Expand Down

0 comments on commit ddd3418

Please sign in to comment.