Skip to content

Commit

Permalink
Evaluate expressions after type coercion (#3444)
Browse files Browse the repository at this point in the history
* Evaluate expressions after type coercion

* Fix some explains

* Fix some explains

* Fix some explains

* Update test

* Update test

* Update test

* Update more tests

* Fix tests

* Use supported date string
  • Loading branch information
Dandandan authored Sep 12, 2022
1 parent 97b3a4b commit f48a997
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 88 deletions.
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
assert_eq!(results.len(), 1);

let expected = vec![
"+--------------+-------------------------+-------------------------+-------------------------+",
"| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |",
"+--------------+-------------------------+-------------------------+-------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+-------------------------+-------------------------+-------------------------+",
"+--------------+---------------------------+---------------------------+---------------------------+",
"| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |",
"+--------------+---------------------------+---------------------------+---------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+---------------------------+---------------------------+---------------------------+",
];
assert_batches_sorted_eq!(expected, &results);

Expand Down
114 changes: 57 additions & 57 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+------------------------------+",
"| decimal_simple.c1 + Int64(1) |",
"+------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+------------------------------+",
"+----------------------------------------------------+",
"| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
Expand Down Expand Up @@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+------------------------------+",
"| decimal_simple.c1 - Int64(1) |",
"+------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+------------------------------+",
"+----------------------------------------------------+",
"| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+-------------------------------+",
"| decimal_simple.c1 * Int64(20) |",
"+-------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-------------------------------+",
"+-----------------------------------------------------+",
"| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
"+-----------------------------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ order by s_name;
Projection: #part.p_partkey AS p_partkey, alias=__sq_1
Filter: #part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")]
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]]
Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
Filter: #lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
.to_string();
assert_eq!(actual, expected);
Expand Down Expand Up @@ -393,7 +393,7 @@ order by cntrycode;"#;
TableScan: orders projection=[o_custkey]
Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);
Expand Down Expand Up @@ -453,7 +453,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: #nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
Expand Down
69 changes: 51 additions & 18 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Optimizer rule for type validation and coercion
use crate::simplify_expressions::ConstEvaluator;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
Expand All @@ -26,6 +27,7 @@ use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};
use datafusion_physical_expr::execution_props::ExecutionProps;
use std::sync::Arc;

#[derive(Default)]
Expand Down Expand Up @@ -64,8 +66,14 @@ impl OptimizerRule for TypeCoercion {
},
);

let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time =
optimizer_config.query_execution_start_time;
let const_evaluator = ConstEvaluator::try_new(&execution_props)?;

let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
const_evaluator,
};

let new_expr = plan
Expand All @@ -78,11 +86,12 @@ impl OptimizerRule for TypeCoercion {
}
}

struct TypeCoercionRewriter {
struct TypeCoercionRewriter<'a> {
schema: DFSchemaRef,
const_evaluator: ConstEvaluator<'a>,
}

impl ExprRewriter for TypeCoercionRewriter {
impl ExprRewriter for TypeCoercionRewriter<'_> {
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
Expand All @@ -106,15 +115,17 @@ impl ExprRewriter for TypeCoercionRewriter {
}
_ => {
let coerced_type = coerce_types(&left_type, &op, &right_type)?;
Ok(Expr::BinaryExpr {
let expr = Expr::BinaryExpr {
left: Box::new(
left.clone().cast_to(&coerced_type, &self.schema)?,
),
op,
right: Box::new(
right.clone().cast_to(&coerced_type, &self.schema)?,
),
})
};

expr.rewrite(&mut self.const_evaluator)
}
}
}
Expand All @@ -133,23 +144,25 @@ impl ExprRewriter for TypeCoercionRewriter {
expr_type, low_type
))
})?;
Ok(Expr::Between {
let expr = Expr::Between {
expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
negated,
low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
})
};
expr.rewrite(&mut self.const_evaluator)
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
Ok(Expr::ScalarUDF {
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
})
};
expr.rewrite(&mut self.const_evaluator)
}
expr => Ok(expr),
}
Expand Down Expand Up @@ -188,7 +201,8 @@ mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, Result, ScalarValue};
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::{col, ColumnarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expand All @@ -199,28 +213,40 @@ mod test {

#[test]
fn simple_case() -> Result<()> {
let expr = lit(1.2_f64).lt(lit(2_u32));
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Float64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n EmptyRelation",
"Projection: #a < Float64(2)\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn nested_case() -> Result<()> {
let expr = lit(1.2_f64).lt(lit(2_u32));
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Float64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![expr.clone().or(expr)],
Expand All @@ -230,8 +256,11 @@ mod test {
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR Float64(1.2) < CAST(UInt32(2) AS Float64)\
\n EmptyRelation", &format!("{:?}", plan));
assert_eq!(
"Projection: #a < Float64(2) OR #a < Float64(2)\
\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

Expand All @@ -240,7 +269,11 @@ mod test {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
let fun: ScalarFunctionImplementation = Arc::new(move |_| {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
"a".to_string(),
))))
});
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
Expand All @@ -255,7 +288,7 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation",
"Projection: Utf8(\"a\")\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
Expand Down
Loading

0 comments on commit f48a997

Please sign in to comment.