Skip to content

Commit

Permalink
implement window functions with partition by
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 24, 2021
1 parent c82c29c commit 20c19ad
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 71 deletions.
11 changes: 9 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1452,11 +1452,18 @@ impl fmt::Debug for Expr {
}
Expr::WindowFunction {
fun,
ref args,
args,
partition_by,
order_by,
window_frame,
..
} => {
fmt_function(f, &fun.to_string(), false, args)?;
if !partition_by.is_empty() {
write!(f, " PARTITION BY {:?}", partition_by)?;
}
if !order_by.is_empty() {
write!(f, " ORDER BY {:?}", order_by)?;
}
if let Some(window_frame) = window_frame {
write!(
f,
Expand Down
52 changes: 49 additions & 3 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::physical_plan::{
};
use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
use crate::sql::utils::generate_sort_key;
use crate::variable::VarType;
use crate::{
error::{DataFusionError, Result},
Expand Down Expand Up @@ -263,11 +264,56 @@ impl DefaultPhysicalPlanner {
"Impossibly got empty window expression".to_owned(),
));
}
let get_sort_keys = |expr: &Expr| match expr {
Expr::WindowFunction {
ref partition_by,
ref order_by,
..
} => generate_sort_key(partition_by, order_by),
_ => unreachable!(),
};

let sort_keys = get_sort_keys(&window_expr[0]);
if window_expr.len() > 1 {
debug_assert!(
window_expr[1..]
.iter()
.all(|expr| get_sort_keys(expr) == sort_keys),
"all window expressions shall have the same sort keys, as guaranteed by logical planning"
);
}

let input_exec = self.create_initial_plan(input, ctx_state)?;
let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();
let logical_input_schema = input.schema();

let input_exec = if sort_keys.is_empty() {
input_exec
} else {
let physical_input_schema = input_exec.schema();
let sort_keys = sort_keys
.iter()
.map(|e| match e {
Expr::Sort {
expr,
asc,
nulls_first,
} => self.create_physical_sort_expr(
expr,
logical_input_schema,
&physical_input_schema,
SortOptions {
descending: !*asc,
nulls_first: *nulls_first,
},
ctx_state,
),
_ => unreachable!(),
})
.collect::<Result<Vec<_>>>()?;
Arc::new(SortExec::try_new(sort_keys, input_exec)?)
};

let physical_input_schema = input_exec.schema();
let window_expr = window_expr
.iter()
.map(|e| {
Expand All @@ -282,7 +328,7 @@ impl DefaultPhysicalPlanner {

Ok(Arc::new(WindowAggExec::try_new(
window_expr,
input_exec.clone(),
input_exec,
physical_input_schema,
)?))
}
Expand Down
105 changes: 40 additions & 65 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,12 +695,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// if there's an empty over, it'll be at the top level
groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len()));
groups.reverse();
for (sort_keys, exprs) in groups {
if !sort_keys.is_empty() {
let sort_keys: Vec<Expr> = sort_keys.to_vec();
plan = LogicalPlanBuilder::from(&plan).sort(sort_keys)?.build()?;
}
let window_exprs: Vec<Expr> = exprs.into_iter().cloned().collect();
for (_, exprs) in groups {
let window_exprs = exprs.into_iter().cloned().collect::<Vec<_>>();
// the partition and sort itself is done at physical level, see physical_planner's
// fn create_initial_plan
plan = LogicalPlanBuilder::from(&plan)
.window(window_exprs)?
.build()?;
Expand Down Expand Up @@ -2861,9 +2859,8 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) PARTITION BY [#orders.order_id]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -2884,11 +2881,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id DESC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -2897,11 +2892,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id DESC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -2910,11 +2903,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id DESC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand Down Expand Up @@ -2955,11 +2946,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id DESC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -2980,11 +2969,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id Plus Int64(1) ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id Plus Int64(1) ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3007,11 +2994,9 @@ mod tests {
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty), #SUM(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[SUM(#orders.qty)]]\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3034,11 +3019,9 @@ mod tests {
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty), #SUM(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[SUM(#orders.qty)]]\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3065,11 +3048,9 @@ mod tests {
Sort: #orders.order_id ASC NULLS FIRST\
\n Projection: #orders.order_id, #MAX(orders.qty), #SUM(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[SUM(#orders.qty)]]\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) ORDER BY [#orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) ORDER BY [#orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3088,9 +3069,8 @@ mod tests {
"SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) PARTITION BY [#orders.order_id] ORDER BY [#orders.qty ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3109,9 +3089,8 @@ mod tests {
"SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) PARTITION BY [#orders.order_id, #orders.qty] ORDER BY [#orders.qty ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3133,11 +3112,9 @@ mod tests {
"SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.qty ASC NULLS FIRST, #orders.order_id ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) PARTITION BY [#orders.order_id, #orders.qty] ORDER BY [#orders.qty ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) PARTITION BY [#orders.qty] ORDER BY [#orders.order_id ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand All @@ -3158,11 +3135,9 @@ mod tests {
"SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders";
let expected = "\
Projection: #orders.order_id, #MAX(orders.qty), #MIN(orders.qty)\
\n WindowAggr: windowExpr=[[MAX(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST\
\n WindowAggr: windowExpr=[[MIN(#orders.qty)]]\
\n Sort: #orders.order_id ASC NULLS FIRST, #orders.qty ASC NULLS FIRST, #orders.price ASC NULLS FIRST\
\n TableScan: orders projection=None";
\n WindowAggr: windowExpr=[[MAX(#orders.qty) PARTITION BY [#orders.order_id] ORDER BY [#orders.qty ASC NULLS FIRST]]]\
\n WindowAggr: windowExpr=[[MIN(#orders.qty) PARTITION BY [#orders.order_id, #orders.qty] ORDER BY [#orders.price ASC NULLS FIRST]]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

Expand Down
6 changes: 5 additions & 1 deletion datafusion/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,11 @@ pub(crate) fn resolve_aliases_to_exprs(

type WindowSortKey = Vec<Expr>;

fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> WindowSortKey {
/// Generate a sort key for a given window expr's partition_by and order_bu expr
pub(crate) fn generate_sort_key(
partition_by: &[Expr],
order_by: &[Expr],
) -> WindowSortKey {
let mut sort_key = vec![];
partition_by.iter().for_each(|e| {
let e = e.clone().sort(true, true);
Expand Down

0 comments on commit 20c19ad

Please sign in to comment.