From 1a1957d5b793f063a986a0030bdcd801819b0bf2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 8 Sep 2022 15:33:44 -0600 Subject: [PATCH 1/7] Use sqlparser-0.23 --- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/expr/Cargo.toml | 2 +- datafusion/sql/Cargo.toml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 7d68e877bd1f..d4b182c7c35f 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -47,4 +47,4 @@ ordered-float = "3.0" parquet = { version = "22.0.0", features = ["arrow"], optional = true } pyo3 = { version = "0.17.1", optional = true } serde_json = "1.0" -sqlparser = "0.22" +sqlparser = "0.23" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 63ca4070f31c..926489ed068b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -85,7 +85,7 @@ pyo3 = { version = "0.17.1", optional = true } rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } -sqlparser = "0.22" +sqlparser = "0.23" tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index e2c723992425..ec10918d20f4 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,4 @@ path = "src/lib.rs" ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } arrow = { version = "22.0.0", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "11.0.0" } -sqlparser = "0.22" +sqlparser = "0.23" diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 550c19d01d50..27701e056b49 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -42,5 +42,5 @@ arrow = { version = "22.0.0", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "11.0.0" } datafusion-expr = { path = "../expr", version = "11.0.0" } hashbrown = "0.12" -sqlparser = "0.22" +sqlparser = "0.23" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } From a4ea9afc8c346b1e93c94f99930958330d4be14d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 8 Sep 2022 16:23:33 -0600 Subject: [PATCH 2/7] Add filter to aggregate expressions --- datafusion/core/src/physical_plan/planner.rs | 7 ++- datafusion/expr/src/expr.rs | 46 ++++++++++++++++--- datafusion/expr/src/expr_fn.rs | 10 ++++ datafusion/expr/src/expr_rewriter.rs | 5 +- datafusion/expr/src/udaf.rs | 1 + .../src/single_distinct_to_groupby.rs | 6 ++- datafusion/proto/src/from_proto.rs | 2 + datafusion/proto/src/lib.rs | 4 ++ datafusion/proto/src/to_proto.rs | 31 ++++++++----- datafusion/sql/src/planner.rs | 29 ++++++++++-- datafusion/sql/src/utils.rs | 5 +- 11 files changed, 121 insertions(+), 25 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 6ec35a59648a..aa3b84c3e1dc 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -195,7 +195,12 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, .. } => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF { fun, args } => { + Expr::AggregateUDF { fun, args, filter } => { + if filter.is_some() { + return Err(DataFusionError::Execution( + "aggregate expression with filter is not supported".to_string(), + )); + } let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_physical_name(e, false)?); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ab45dd67d1b8..31a6cbcb8ef8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -231,6 +231,8 @@ pub enum Expr { args: Vec, /// Whether this is a DISTINCT aggregation or not distinct: bool, + /// Optional filter + filter: Option>, }, /// Represents the call of a window function with arguments. WindowFunction { @@ -251,6 +253,8 @@ pub enum Expr { fun: Arc, /// List of expressions to feed to the functions as arguments args: Vec, + /// Optional filter + filter: Option>, }, /// Returns whether the list contains the expr value. InList { @@ -668,10 +672,26 @@ impl fmt::Debug for Expr { fun, distinct, ref args, + filter, .. - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::AggregateUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) + } => { + fmt_function(f, &fun.to_string(), *distinct, args, true)?; + if let Some(fe) = filter { + write!(f, " FILTER (WHERE {})", fe)?; + } + Ok(()) + } + Expr::AggregateUDF { + fun, + ref args, + filter, + .. + } => { + fmt_function(f, &fun.name, false, args, false)?; + if let Some(fe) = filter { + write!(f, " FILTER (WHERE {})", fe)?; + } + Ok(()) } Expr::Between { expr, @@ -1010,14 +1030,26 @@ fn create_name(e: &Expr) -> Result { fun, distinct, args, - .. - } => create_function_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF { fun, args } => { + filter, + } => { + let name = create_function_name(&fun.to_string(), *distinct, args)?; + if let Some(fe) = filter { + Ok(format!("{} FILTER (WHERE {})", name, fe)) + } else { + Ok(name) + } + } + Expr::AggregateUDF { fun, args, filter } => { let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_name(e)?); } - Ok(format!("{}({})", fun.name, names.join(","))) + let filter = if let Some(fe) = filter { + format!(" FILTER (WHERE {})", fe) + } else { + "".to_string() + }; + Ok(format!("{}({}){}", fun.name, names.join(","), filter)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8b0f1646608d..9fa791c295be 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -66,6 +66,7 @@ pub fn min(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Min, distinct: false, args: vec![expr], + filter: None, } } @@ -75,6 +76,7 @@ pub fn max(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Max, distinct: false, args: vec![expr], + filter: None, } } @@ -84,6 +86,7 @@ pub fn sum(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Sum, distinct: false, args: vec![expr], + filter: None, } } @@ -93,6 +96,7 @@ pub fn avg(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Avg, distinct: false, args: vec![expr], + filter: None, } } @@ -102,6 +106,7 @@ pub fn count(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Count, distinct: false, args: vec![expr], + filter: None, } } @@ -111,6 +116,7 @@ pub fn count_distinct(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Count, distinct: true, args: vec![expr], + filter: None, } } @@ -163,6 +169,7 @@ pub fn approx_distinct(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::ApproxDistinct, distinct: false, args: vec![expr], + filter: None, } } @@ -172,6 +179,7 @@ pub fn approx_median(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::ApproxMedian, distinct: false, args: vec![expr], + filter: None, } } @@ -181,6 +189,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { fun: aggregate_function::AggregateFunction::ApproxPercentileCont, distinct: false, args: vec![expr, percentile], + filter: None, } } @@ -194,6 +203,7 @@ pub fn approx_percentile_cont_with_weight( fun: aggregate_function::AggregateFunction::ApproxPercentileContWithWeight, distinct: false, args: vec![expr, weight_expr, percentile], + filter: None, } } diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index b8b9fced921c..533f31ce1584 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -250,10 +250,12 @@ impl ExprRewritable for Expr { args, fun, distinct, + filter, } => Expr::AggregateFunction { args: rewrite_vec(args, rewriter)?, fun, distinct, + filter, }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { @@ -271,9 +273,10 @@ impl ExprRewritable for Expr { )) } }, - Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { + Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF { args: rewrite_vec(args, rewriter)?, fun, + filter, }, Expr::InList { expr, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 00f48dda272c..0ecb5280a942 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -89,6 +89,7 @@ impl AggregateUDF { Expr::AggregateUDF { fun: Arc::new(self.clone()), args, + filter: None, } } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index b61f0c25e5d1..e45fde725af2 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -87,7 +87,9 @@ fn optimize(plan: &LogicalPlan) -> Result { let new_aggr_exprs = aggr_expr .iter() .map(|aggr_expr| match aggr_expr { - Expr::AggregateFunction { fun, args, .. } => { + Expr::AggregateFunction { + fun, args, filter, .. + } => { // is_single_distinct_agg ensure args.len=1 if group_fields_set.insert(args[0].name()?) { inner_group_exprs @@ -97,6 +99,7 @@ fn optimize(plan: &LogicalPlan) -> Result { fun: fun.clone(), args: vec![col(SINGLE_DISTINCT_ALIAS)], distinct: false, // intentional to remove distinct here + filter: filter.clone(), }) } _ => Ok(aggr_expr.clone()), @@ -402,6 +405,7 @@ mod tests { fun: AggregateFunction::Max, distinct: true, args: vec![col("b")], + filter: None, }, ], )? diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index e0db97c0ca8e..f02bd5f1d6dd 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -891,6 +891,7 @@ pub fn parse_expr( .map(|e| parse_expr(e, registry)) .collect::, _>>()?, distinct: expr.distinct, + filter: None, // not supported yet }) } ExprType::Alias(alias) => Ok(Expr::Alias( @@ -1203,6 +1204,7 @@ pub fn parse_expr( .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + filter: None, // not supported yet }) } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 7e009a846e69..8ef8563f9225 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -935,6 +935,7 @@ mod roundtrip_tests { fun: AggregateFunction::Count, args: vec![col("bananas")], distinct: false, + filter: None, }; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -946,6 +947,7 @@ mod roundtrip_tests { fun: AggregateFunction::Count, args: vec![col("bananas")], distinct: true, + filter: None, }; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -957,6 +959,7 @@ mod roundtrip_tests { fun: AggregateFunction::ApproxPercentileCont, args: vec![col("bananas"), lit(0.42_f32)], distinct: false, + filter: None, }; let ctx = SessionContext::new(); @@ -1009,6 +1012,7 @@ mod roundtrip_tests { let test_expr = Expr::AggregateUDF { fun: Arc::new(dummy_agg.clone()), args: vec![lit(1.0_f64)], + filter: None, }; let mut ctx = SessionContext::new(); diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index ef7e5bf219f5..e367dc00157f 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -561,7 +561,11 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref fun, ref args, ref distinct, + ref filter } => { + if filter.is_some() { + return Err(Error::General("Proto serialization error: aggregate expression with filter is not supported".to_string())); + } let aggr_function = match fun { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct @@ -639,17 +643,22 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .collect::, Error>>()?, })), }, - Expr::AggregateUDF { fun, args } => Self { - expr_type: Some(ExprType::AggregateUdfExpr( - protobuf::AggregateUdfExprNode { - fun_name: fun.name.clone(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - }, - )), + Expr::AggregateUDF { fun, args, filter } => { + if filter.is_some() { + return Err(Error::General("Proto serialization error: aggregate expression with filter is not supported".to_string())); + } + Self { + expr_type: Some(ExprType::AggregateUdfExpr( + protobuf::AggregateUdfExprNode { + fun_name: fun.name.clone(), + args: args.iter().map(|expr| expr.try_into()).collect::, + Error, + >>( + )?, + }, + )), + } }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index d9dbdf143687..c101b2c2ccea 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2053,6 +2053,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction { fun, args }) } + SQLExpr::AggregateExpressionWithFilter { expr, filter } => { + match self.sql_expr_to_logical_expr(*expr, schema, ctes)? { + Expr::AggregateFunction { + fun, args, distinct, .. + } => Ok(Expr::AggregateFunction { fun: fun.clone(), args: args.clone(), distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), + _ => Err(DataFusionError::Internal("".to_string())) + } + + } + SQLExpr::Function(mut function) => { let name = if function.name.0.len() > 1 { // DF doesn't handle compound identifiers @@ -2149,6 +2159,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun, distinct, args, + filter: None }); }; @@ -2162,7 +2173,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None => match self.schema_provider.get_aggregate_meta(&name) { Some(fm) => { let args = self.function_args_to_expr(function.args, schema)?; - Ok(Expr::AggregateUDF { fun: fm, args }) + Ok(Expr::AggregateUDF { fun: fm, args, filter: None }) } _ => Err(DataFusionError::Plan(format!( "Invalid function '{}'", @@ -2181,7 +2192,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(&subquery, schema, ctes), _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported ast node {:?} in sqltorel", + "Unsupported ast node in sqltorel: {:?}", sql ))), } @@ -2688,7 +2699,7 @@ fn parse_sql_number(n: &str) -> Result { mod tests { use super::*; use crate::assert_contains; - use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; + use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use std::any::Any; #[test] @@ -4923,6 +4934,18 @@ mod tests { quick_test(sql, expected); } + #[test] + fn hive_aggregate_with_filter() -> Result<()> { + let dialect = &HiveDialect {}; + let sql = "SELECT SUM(age) FILTER (WHERE age > 4) FROM person"; + let plan = logical_plan_with_dialect(sql, dialect)?; + let expected = "Projection: #SUM(person.age) FILTER (WHERE #age > Int64(4))\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(#person.age) FILTER (WHERE #age > Int64(4))]]\ + \n TableScan: person".to_string(); + assert_eq!(expected, format!("{}", plan.display_indent())); + Ok(()) + } + #[test] fn order_by_unaliased_name() { // https://github.com/apache/arrow-datafusion/issues/3160 diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 25f5c549adbc..eb58509d0960 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -163,6 +163,7 @@ where fun, args, distinct, + filter, } => Ok(Expr::AggregateFunction { fun: fun.clone(), args: args @@ -170,6 +171,7 @@ where .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, distinct: *distinct, + filter: filter.clone(), }), Expr::WindowFunction { fun, @@ -193,12 +195,13 @@ where .collect::>>()?, window_frame: *window_frame, }), - Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF { + Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF { fun: fun.clone(), args: args .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, + filter: filter.clone(), }), Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias( Box::new(clone_with_replacement(nested_expr, replacement_fn)?), From d309a3ca9b16ce0cfc638584cd00e620c2f56190 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 10 Sep 2022 08:37:57 -0600 Subject: [PATCH 3/7] clippy --- datafusion/sql/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index c101b2c2ccea..c1898de12f2f 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2057,7 +2057,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match self.sql_expr_to_logical_expr(*expr, schema, ctes)? { Expr::AggregateFunction { fun, args, distinct, .. - } => Ok(Expr::AggregateFunction { fun: fun.clone(), args: args.clone(), distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), + } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), _ => Err(DataFusionError::Internal("".to_string())) } From 7380307d4ced7d6bcc088ccdd924813aa4c04e28 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 10 Sep 2022 08:58:29 -0600 Subject: [PATCH 4/7] implement protobuf serde --- datafusion/proto/proto/datafusion.proto | 2 ++ datafusion/proto/src/from_proto.rs | 11 ++++++----- datafusion/proto/src/lib.rs | 2 +- datafusion/proto/src/to_proto.rs | 20 +++++++++++--------- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 8d4da0250b99..baabc04cfff7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -504,11 +504,13 @@ message AggregateExprNode { AggregateFunction aggr_function = 1; repeated LogicalExprNode expr = 2; bool distinct = 3; + LogicalExprNode filter = 4; } message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + LogicalExprNode filter = 3; } message ScalarUDFExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index f02bd5f1d6dd..f005c99a66b8 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -891,7 +891,7 @@ pub fn parse_expr( .map(|e| parse_expr(e, registry)) .collect::, _>>()?, distinct: expr.distinct, - filter: None, // not supported yet + filter: parse_optional_expr(&expr.filter, registry)?.map(|e| Box::new(e)), }) } ExprType::Alias(alias) => Ok(Expr::Alias( @@ -1195,16 +1195,17 @@ pub fn parse_expr( .collect::, Error>>()?, }) } - ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name, args }) => { - let agg_fn = registry.udaf(fun_name.as_str())?; + ExprType::AggregateUdfExpr(pb) => { + let agg_fn = registry.udaf(pb.fun_name.as_str())?; Ok(Expr::AggregateUDF { fun: agg_fn, - args: args + args: pb + .args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, - filter: None, // not supported yet + filter: parse_optional_expr(&pb.filter, registry)?.map(|e| Box::new(e)), }) } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 8ef8563f9225..28ae78d45a26 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -1012,7 +1012,7 @@ mod roundtrip_tests { let test_expr = Expr::AggregateUDF { fun: Arc::new(dummy_agg.clone()), args: vec![lit(1.0_f64)], - filter: None, + filter: Some(Box::new(lit(true))), }; let mut ctx = SessionContext::new(); diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index e367dc00157f..2756e17c56ec 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -563,9 +563,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref distinct, ref filter } => { - if filter.is_some() { - return Err(Error::General("Proto serialization error: aggregate expression with filter is not supported".to_string())); - } let aggr_function = match fun { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct @@ -613,9 +610,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .map(|v| v.try_into()) .collect::, _>>()?, distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + } }; Self { - expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), + expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), } } Expr::ScalarVariable(_, _) => return Err(Error::General("Proto serialization error: Scalar Variable not supported".to_string())), @@ -644,20 +645,21 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { })), }, Expr::AggregateUDF { fun, args, filter } => { - if filter.is_some() { - return Err(Error::General("Proto serialization error: aggregate expression with filter is not supported".to_string())); - } Self { expr_type: Some(ExprType::AggregateUdfExpr( - protobuf::AggregateUdfExprNode { + Box::new(protobuf::AggregateUdfExprNode { fun_name: fun.name.clone(), args: args.iter().map(|expr| expr.try_into()).collect::, Error, >>( )?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + } }, - )), + ))), } }, Expr::Not(expr) => { From baf12a4f6c7348e1d9a214f155c625d2afec4316 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 12 Sep 2022 10:01:48 -0600 Subject: [PATCH 5/7] clippy --- datafusion/proto/src/from_proto.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index f005c99a66b8..9688ccda46c3 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -891,7 +891,7 @@ pub fn parse_expr( .map(|e| parse_expr(e, registry)) .collect::, _>>()?, distinct: expr.distinct, - filter: parse_optional_expr(&expr.filter, registry)?.map(|e| Box::new(e)), + filter: parse_optional_expr(&expr.filter, registry)?.map(Box::new), }) } ExprType::Alias(alias) => Ok(Expr::Alias( @@ -1205,7 +1205,7 @@ pub fn parse_expr( .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, - filter: parse_optional_expr(&pb.filter, registry)?.map(|e| Box::new(e)), + filter: parse_optional_expr(&pb.filter, registry)?.map(Box::new), }) } From 00b7fdb952ad8a27ba6078416778dd445aecc75b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 12 Sep 2022 10:11:06 -0600 Subject: [PATCH 6/7] fix error message --- datafusion/sql/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index d8476c9dd438..5d30b670f82a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2094,7 +2094,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::AggregateFunction { fun, args, distinct, .. } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), - _ => Err(DataFusionError::Internal("".to_string())) + _ => Err(DataFusionError::Internal("AggregateExpressionWithFilter expression was not an AggregateFunction".to_string())) } } From b0d158a3455074d01456ee192f1c6bdb6c228d95 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 12 Sep 2022 10:54:51 -0600 Subject: [PATCH 7/7] Update datafusion/expr/src/expr.rs Co-authored-by: Andrew Lamb --- datafusion/expr/src/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 31a6cbcb8ef8..8b90fb9e4dfa 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -253,7 +253,7 @@ pub enum Expr { fun: Arc, /// List of expressions to feed to the functions as arguments args: Vec, - /// Optional filter + /// Optional filter applied prior to aggregating filter: Option>, }, /// Returns whether the list contains the expr value.