Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Expr::AggregateFunction and Expr::WindowFunction to use struct #4671

Merged
merged 5 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ mod tests {
use arrow::datatypes::DataType;

use datafusion_expr::{
avg, cast, count, count_distinct, create_udf, lit, max, min, sum,
avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunction,
};
Expand Down Expand Up @@ -861,13 +861,13 @@ mod tests {
async fn select_with_window_exprs() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction {
fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
args: vec![col("aggregate_test_100.c1")],
partition_by: vec![col("aggregate_test_100.c2")],
order_by: vec![],
window_frame: WindowFrame::new(false),
};
let first_row = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(false),
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();

Expand Down
25 changes: 13 additions & 12 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::{
self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, TryCast,
self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet,
Like, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan;
Expand Down Expand Up @@ -190,15 +191,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_physical_name(&fun.name, false, args)
}
Expr::WindowFunction { fun, args, .. } => {
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
}
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
..
} => create_function_physical_name(&fun.to_string(), *distinct, args),
}) => create_function_physical_name(&fun.to_string(), *distinct, args),
Expr::AggregateUDF { fun, args, filter } => {
if filter.is_some() {
return Err(DataFusionError::Execution(
Expand Down Expand Up @@ -547,18 +548,18 @@ impl DefaultPhysicalPlanner {
};

let get_sort_keys = |expr: &Expr| match expr {
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction{
ref partition_by,
ref order_by,
..
} => generate_sort_key(partition_by, order_by),
}) => generate_sort_key(partition_by, order_by),
Expr::Alias(expr, _) => {
// Convert &Box<T> to &T
match &**expr {
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction{
ref partition_by,
ref order_by,
..} => generate_sort_key(partition_by, order_by),
..}) => generate_sort_key(partition_by, order_by),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -1502,13 +1503,13 @@ pub fn create_window_expr_with_name(
) -> Result<Arc<dyn WindowExpr>> {
let name = name.into();
match e {
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
}) => {
let args = args
.iter()
.map(|e| {
Expand Down Expand Up @@ -1608,12 +1609,12 @@ pub fn create_aggregate_expr_with_name(
execution_props: &ExecutionProps,
) -> Result<Arc<dyn AggregateExpr>> {
match e {
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
..
} => {
}) => {
let args = args
.iter()
.map(|e| {
Expand Down
103 changes: 73 additions & 30 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,9 @@ pub enum Expr {
args: Vec<Expr>,
},
/// Represents the call of an aggregate built-in function with arguments.
AggregateFunction {
/// Name of the function
fun: aggregate_function::AggregateFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
distinct: bool,
/// Optional filter
filter: Option<Box<Expr>>,
},
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction {
/// Name of the function
fun: window_function::WindowFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
/// List of partition by expressions
partition_by: Vec<Expr>,
/// List of order by expressions
order_by: Vec<Expr>,
/// Window frame
window_frame: window_frame::WindowFrame,
},
WindowFunction(WindowFunction),
/// aggregate function
AggregateUDF {
/// The function
Expand Down Expand Up @@ -472,6 +452,69 @@ impl Sort {
}
}

/// Aggregate function
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct AggregateFunction {
/// Name of the function
pub fun: aggregate_function::AggregateFunction,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
pub distinct: bool,
/// Optional filter
pub filter: Option<Box<Expr>>,
}

impl AggregateFunction {
pub fn new(
fun: aggregate_function::AggregateFunction,
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
) -> Self {
Self {
fun,
args,
distinct,
filter,
}
}
}

/// Window function
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct WindowFunction {
/// Name of the function
pub fun: window_function::WindowFunction,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// List of partition by expressions
pub partition_by: Vec<Expr>,
/// List of order by expressions
pub order_by: Vec<Expr>,
/// Window frame
pub window_frame: window_frame::WindowFrame,
}

impl WindowFunction {
/// Create a new Window expression
pub fn new(
fun: window_function::WindowFunction,
args: Vec<Expr>,
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
) -> Self {
Self {
fun,
args,
partition_by,
order_by,
window_frame,
}
}
}

/// Grouping sets
/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS
/// for Postgres definition.
Expand Down Expand Up @@ -867,13 +910,13 @@ impl fmt::Debug for Expr {
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args, false)
}
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
}) => {
fmt_function(f, &fun.to_string(), false, args, false)?;
if !partition_by.is_empty() {
write!(f, " PARTITION BY {:?}", partition_by)?;
Expand All @@ -888,13 +931,13 @@ impl fmt::Debug for Expr {
)?;
Ok(())
}
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
ref args,
filter,
..
} => {
}) => {
fmt_function(f, &fun.to_string(), *distinct, args, true)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {})", fe)?;
Expand Down Expand Up @@ -1223,13 +1266,13 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&fun.to_string(), false, args)
}
Expr::ScalarUDF { fun, args, .. } => create_function_name(&fun.name, false, args),
Expr::WindowFunction {
Expr::WindowFunction(WindowFunction {
fun,
args,
window_frame,
partition_by,
order_by,
} => {
}) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];
if !partition_by.is_empty() {
Expand All @@ -1241,12 +1284,12 @@ fn create_name(e: &Expr) -> Result<String> {
parts.push(format!("{}", window_frame));
Ok(parts.join(" "))
}
Expr::AggregateFunction {
Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
filter,
} => {
}) => {
let name = create_function_name(&fun.to_string(), *distinct, args)?;
if let Some(fe) = filter {
Ok(format!("{} FILTER (WHERE {})", name, fe))
Expand Down
Loading