Skip to content

Commit

Permalink
Fix logical plan serialization (#3574)
Browse files Browse the repository at this point in the history
* Fix issue in type coercion optimizer

* Update datafusion/optimizer/src/type_coercion.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
thinkharderdev and alamb authored Sep 21, 2022
1 parent 634d912 commit 488be64
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand Down
23 changes: 22 additions & 1 deletion datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,31 @@ impl OptimizerRule for TypeCoercion {
const_evaluator,
};

let original_expr_names: Vec<Option<String>> = plan
.expressions()
.iter()
.map(|expr| expr.name().ok())
.collect();

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| expr.rewrite(&mut expr_rewrite))
.zip(original_expr_names)
.map(|(expr, original_name)| {
let expr = expr.rewrite(&mut expr_rewrite)?;

// ensure aggregate names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
if matches!(expr, Expr::AggregateFunction { .. }) {
if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
if alias != name {
return Ok(expr.alias(&alias));
}
}
}

Ok(expr)
})
.collect::<Result<Vec<_>>>()?;

from_plan(plan, &new_expr, &new_inputs)
Expand Down
30 changes: 30 additions & 0 deletions datafusion/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ mod roundtrip_tests {
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
};
use crate::logical_plan::LogicalExtensionCodec;
use arrow::datatypes::Schema;
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field, IntervalUnit, TimeUnit, UnionMode},
Expand Down Expand Up @@ -128,6 +129,35 @@ mod roundtrip_tests {
Ok(())
}

#[tokio::test]
async fn roundtrip_logical_plan_aggregation() -> Result<(), DataFusionError> {
let ctx = SessionContext::new();

let schema = Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(15, 2), true),
]);

ctx.register_csv(
"t1",
"testdata/test.csv",
CsvReadOptions::default().schema(&schema),
)
.await?;

let query =
"SELECT a, SUM(b + 1) as b_sum FROM t1 GROUP BY a ORDER BY b_sum DESC";
let plan = ctx.sql(query).await?.to_logical_plan()?;

println!("{:?}", plan);

let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));

Ok(())
}

#[tokio::test]
async fn roundtrip_logical_plan_with_extension() -> Result<(), DataFusionError> {
let ctx = SessionContext::new();
Expand Down

0 comments on commit 488be64

Please sign in to comment.