diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index a75e0e3fa515..f2069126c5ff 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -653,7 +653,7 @@ order by let expected = "\ Sort: #revenue DESC NULLS FIRST\ \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ - \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\ + \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]\ \n Inner Join: #customer.c_nationkey = #nation.n_nationkey\ \n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\ \n Inner Join: #customer.c_custkey = #orders.o_custkey\ diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 7be327713183..0f22c01d4837 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -79,10 +79,31 @@ impl OptimizerRule for TypeCoercion { const_evaluator, }; + let original_expr_names: Vec> = plan + .expressions() + .iter() + .map(|expr| expr.name().ok()) + .collect(); + let new_expr = plan .expressions() .into_iter() - .map(|expr| expr.rewrite(&mut expr_rewrite)) + .zip(original_expr_names) + .map(|(expr, original_name)| { + let expr = expr.rewrite(&mut expr_rewrite)?; + + // ensure aggregate names don't change: + // https://github.com/apache/arrow-datafusion/issues/3555 + if matches!(expr, Expr::AggregateFunction { .. }) { + if let Some((alias, name)) = original_name.zip(expr.name().ok()) { + if alias != name { + return Ok(expr.alias(&alias)); + } + } + } + + Ok(expr) + }) .collect::>>()?; from_plan(plan, &new_expr, &new_inputs) diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index dfe2bbaaa45f..e2abc92d6dff 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -54,6 +54,7 @@ mod roundtrip_tests { logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; use crate::logical_plan::LogicalExtensionCodec; + use arrow::datatypes::Schema; use arrow::{ array::ArrayRef, datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionMode}, @@ -128,6 +129,35 @@ mod roundtrip_tests { Ok(()) } + #[tokio::test] + async fn roundtrip_logical_plan_aggregation() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = + "SELECT a, SUM(b + 1) as b_sum FROM t1 GROUP BY a ORDER BY b_sum DESC"; + let plan = ctx.sql(query).await?.to_logical_plan()?; + + println!("{:?}", plan); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + + Ok(()) + } + #[tokio::test] async fn roundtrip_logical_plan_with_extension() -> Result<(), DataFusionError> { let ctx = SessionContext::new();