Skip to content

Commit

Permalink
Refactor Expr::AggregateFunction and Expr::WindowFunction to use stru…
Browse files Browse the repository at this point in the history
…ct (#4671)

* Refactor Expr::WindowFunction to struct

* Refactor Expr::AggregateFunction to struct

* Fix
  • Loading branch information
Jefffrey authored Dec 20, 2022
1 parent 975ff15 commit fe477e4
Show file tree
Hide file tree
Showing 16 changed files with 407 additions and 357 deletions.
16 changes: 8 additions & 8 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ mod tests {
use arrow::datatypes::DataType;

use datafusion_expr::{
avg, cast, count, count_distinct, create_udf, lit, max, min, sum,
avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunction,
};
Expand Down Expand Up @@ -861,13 +861,13 @@ mod tests {
async fn select_with_window_exprs() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction {
fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
args: vec![col("aggregate_test_100.c1")],
partition_by: vec![col("aggregate_test_100.c2")],
order_by: vec![],
window_frame: WindowFrame::new(false),
};
let first_row = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(false),
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();

Expand Down
25 changes: 13 additions & 12 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::{
self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, TryCast,
self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet,
Like, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan;
Expand Down Expand Up @@ -190,15 +191,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_physical_name(&fun.name, false, args)
}
Expr::WindowFunction { fun, args, .. } => {
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
}
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
..
} => create_function_physical_name(&fun.to_string(), *distinct, args),
}) => create_function_physical_name(&fun.to_string(), *distinct, args),
Expr::AggregateUDF { fun, args, filter } => {
if filter.is_some() {
return Err(DataFusionError::Execution(
Expand Down Expand Up @@ -547,18 +548,18 @@ impl DefaultPhysicalPlanner {
};

let get_sort_keys = |expr: &Expr| match expr {
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction{
ref partition_by,
ref order_by,
..
} => generate_sort_key(partition_by, order_by),
}) => generate_sort_key(partition_by, order_by),
Expr::Alias(expr, _) => {
// Convert &Box<T> to &T
match &**expr {
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction{
ref partition_by,
ref order_by,
..} => generate_sort_key(partition_by, order_by),
..}) => generate_sort_key(partition_by, order_by),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -1502,13 +1503,13 @@ pub fn create_window_expr_with_name(
) -> Result<Arc<dyn WindowExpr>> {
let name = name.into();
match e {
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
}) => {
let args = args
.iter()
.map(|e| {
Expand Down Expand Up @@ -1608,12 +1609,12 @@ pub fn create_aggregate_expr_with_name(
execution_props: &ExecutionProps,
) -> Result<Arc<dyn AggregateExpr>> {
match e {
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
..
} => {
}) => {
let args = args
.iter()
.map(|e| {
Expand Down
103 changes: 73 additions & 30 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,9 @@ pub enum Expr {
args: Vec<Expr>,
},
/// Represents the call of an aggregate built-in function with arguments.
AggregateFunction {
/// Name of the function
fun: aggregate_function::AggregateFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
distinct: bool,
/// Optional filter
filter: Option<Box<Expr>>,
},
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction {
/// Name of the function
fun: window_function::WindowFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// List of partition by expressions
partition_by: Vec<Expr>,
/// List of order by expressions
order_by: Vec<Expr>,
/// Window frame
window_frame: window_frame::WindowFrame,
},
WindowFunction(WindowFunction),
/// aggregate function
AggregateUDF {
/// The function
Expand Down Expand Up @@ -472,6 +452,69 @@ impl Sort {
}
}

/// Aggregate function
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct AggregateFunction {
/// Name of the function
pub fun: aggregate_function::AggregateFunction,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
pub distinct: bool,
/// Optional filter
pub filter: Option<Box<Expr>>,
}

impl AggregateFunction {
pub fn new(
fun: aggregate_function::AggregateFunction,
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
) -> Self {
Self {
fun,
args,
distinct,
filter,
}
}
}

/// Window function
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct WindowFunction {
/// Name of the function
pub fun: window_function::WindowFunction,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// List of partition by expressions
pub partition_by: Vec<Expr>,
/// List of order by expressions
pub order_by: Vec<Expr>,
/// Window frame
pub window_frame: window_frame::WindowFrame,
}

impl WindowFunction {
/// Create a new Window expression
pub fn new(
fun: window_function::WindowFunction,
args: Vec<Expr>,
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
) -> Self {
Self {
fun,
args,
partition_by,
order_by,
window_frame,
}
}
}

/// Grouping sets
/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS
/// for Postgres definition.
Expand Down Expand Up @@ -867,13 +910,13 @@ impl fmt::Debug for Expr {
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args, false)
}
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
}) => {
fmt_function(f, &fun.to_string(), false, args, false)?;
if !partition_by.is_empty() {
write!(f, " PARTITION BY {:?}", partition_by)?;
Expand All @@ -888,13 +931,13 @@ impl fmt::Debug for Expr {
)?;
Ok(())
}
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
ref args,
filter,
..
} => {
}) => {
fmt_function(f, &fun.to_string(), *distinct, args, true)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {})", fe)?;
Expand Down Expand Up @@ -1223,13 +1266,13 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&fun.to_string(), false, args)
}
Expr::ScalarUDF { fun, args, .. } => create_function_name(&fun.name, false, args),
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction {
fun,
args,
window_frame,
partition_by,
order_by,
} => {
}) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];
if !partition_by.is_empty() {
Expand All @@ -1241,12 +1284,12 @@ fn create_name(e: &Expr) -> Result<String> {
parts.push(format!("{}", window_frame));
Ok(parts.join(" "))
}
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
filter,
} => {
}) => {
let name = create_function_name(&fun.to_string(), *distinct, args)?;
if let Some(fe) = filter {
Ok(format!("{} FILTER (WHERE {})", name, fe))
Expand Down
Loading

0 comments on commit fe477e4

Please sign in to comment.