From f48a9971bb4822427395d641c071b7ea825ec496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 12 Sep 2022 21:29:05 +0200 Subject: [PATCH] Evaluate expressions after type coercion (#3444) * Evaluate expressions after type coercion * Fix some explains * Fix some explains * Fix some explains * Update test * Update test * Update test * Update more tests * Fix tests * Use supported date string --- datafusion/core/tests/sql/aggregates.rs | 10 +- datafusion/core/tests/sql/decimal.rs | 114 +++++++++--------- datafusion/core/tests/sql/explain_analyze.rs | 2 +- datafusion/core/tests/sql/subqueries.rs | 8 +- datafusion/optimizer/src/type_coercion.rs | 69 ++++++++--- .../optimizer/tests/integration-test.rs | 6 +- 6 files changed, 121 insertions(+), 88 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 357addbc0e21..b7f24992cb65 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+-------------------------+-------------------------+-------------------------+", - "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |", - "+--------------+-------------------------+-------------------------+-------------------------+", - "| 1.5 | 2.5 | 3.5 | 2.5 |", - "+--------------+-------------------------+-------------------------+-------------------------+", + "+--------------+---------------------------+---------------------------+---------------------------+", + "| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |", + "+--------------+---------------------------+---------------------------+---------------------------+", + "| 1.5 | 2.5 | 3.5 | 2.5 |", + "+--------------+---------------------------+---------------------------+---------------------------+", ]; assert_batches_sorted_eq!(expected, &results); diff --git a/datafusion/core/tests/sql/decimal.rs b/datafusion/core/tests/sql/decimal.rs index 7c74cdd52f0e..db686deb7070 100644 --- a/datafusion/core/tests/sql/decimal.rs +++ b/datafusion/core/tests/sql/decimal.rs @@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+------------------------------+", - "| decimal_simple.c1 + Int64(1) |", - "+------------------------------+", - "| 1.000010 |", - "| 1.000020 |", - "| 1.000020 |", - "| 1.000030 |", - "| 1.000030 |", - "| 1.000030 |", - "| 1.000040 |", - "| 1.000040 |", - "| 1.000040 |", - "| 1.000040 |", - "| 1.000050 |", - "| 1.000050 |", - "| 1.000050 |", - "| 1.000050 |", - "| 1.000050 |", - "+------------------------------+", + "+----------------------------------------------------+", + "| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |", + "+----------------------------------------------------+", + "| 1.000010 |", + "| 1.000020 |", + "| 1.000020 |", + "| 1.000030 |", + "| 1.000030 |", + "| 1.000030 |", + "| 1.000040 |", + "| 1.000040 |", + "| 1.000040 |", + "| 1.000040 |", + "| 1.000050 |", + "| 1.000050 |", + "| 1.000050 |", + "| 1.000050 |", + "| 1.000050 |", + "+----------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); // array decimal(10,6) + array decimal(12,7) => decimal(13,7) @@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+------------------------------+", - "| decimal_simple.c1 - Int64(1) |", - "+------------------------------+", - "| -0.999990 |", - "| -0.999980 |", - "| -0.999980 |", - "| -0.999970 |", - "| -0.999970 |", - "| -0.999970 |", - "| -0.999960 |", - "| -0.999960 |", - "| -0.999960 |", - "| -0.999960 |", - "| -0.999950 |", - "| -0.999950 |", - "| -0.999950 |", - "| -0.999950 |", - "| -0.999950 |", - "+------------------------------+", + "+----------------------------------------------------+", + "| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |", + "+----------------------------------------------------+", + "| -0.999990 |", + "| -0.999980 |", + "| -0.999980 |", + "| -0.999970 |", + "| -0.999970 |", + "| -0.999970 |", + "| -0.999960 |", + "| -0.999960 |", + "| -0.999960 |", + "| -0.999960 |", + "| -0.999950 |", + "| -0.999950 |", + "| -0.999950 |", + "| -0.999950 |", + "| -0.999950 |", + "+----------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> { actual[0].schema().field(0).data_type() ); let expected = vec![ - "+-------------------------------+", - "| decimal_simple.c1 * Int64(20) |", - "+-------------------------------+", - "| 0.000200 |", - "| 0.000400 |", - "| 0.000400 |", - "| 0.000600 |", - "| 0.000600 |", - "| 0.000600 |", - "| 0.000800 |", - "| 0.000800 |", - "| 0.000800 |", - "| 0.000800 |", - "| 0.001000 |", - "| 0.001000 |", - "| 0.001000 |", - "| 0.001000 |", - "| 0.001000 |", - "+-------------------------------+", + "+-----------------------------------------------------+", + "| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |", + "+-----------------------------------------------------+", + "| 0.000200 |", + "| 0.000400 |", + "| 0.000400 |", + "| 0.000600 |", + "| 0.000600 |", + "| 0.000600 |", + "| 0.000800 |", + "| 0.000800 |", + "| 0.000800 |", + "| 0.000800 |", + "| 0.001000 |", + "| 0.001000 |", + "| 0.001000 |", + "| 0.001000 |", + "| 0.001000 |", + "+-----------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 91dd9401ee1f..7f465c4c697f 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -653,7 +653,7 @@ order by let expected = "\ Sort: #revenue DESC NULLS FIRST\ \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ - \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\ + \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\ \n Inner Join: #customer.c_nationkey = #nation.n_nationkey\ \n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\ \n Inner Join: #customer.c_custkey = #orders.o_custkey\ diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 0d9fe37f9a1e..1ae5bc68e0d6 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -336,9 +336,9 @@ order by s_name; Projection: #part.p_partkey AS p_partkey, alias=__sq_1 Filter: #part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")] - Projection: #lineitem.l_partkey, #lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 + Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3 Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]] - Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32) + Filter: #lineitem.l_shipdate >= Date32("8766") TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"# .to_string(); assert_eq!(actual, expected); @@ -393,7 +393,7 @@ order by cntrycode;"#; TableScan: orders projection=[o_custkey] Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]] - Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) + Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"# .to_string(); assert_eq!(actual, expected); @@ -453,7 +453,7 @@ order by value desc; TableScan: supplier projection=[s_suppkey, s_nationkey] Filter: #nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")] - Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1 + Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1 Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]] Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 77580c0632e3..72ee5d19adcd 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -17,6 +17,7 @@ //! Optimizer rule for type validation and coercion +use crate::simplify_expressions::ConstEvaluator; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; @@ -26,6 +27,7 @@ use datafusion_expr::type_coercion::data_types; use datafusion_expr::utils::from_plan; use datafusion_expr::{Expr, LogicalPlan}; use datafusion_expr::{ExprSchemable, Signature}; +use datafusion_physical_expr::execution_props::ExecutionProps; use std::sync::Arc; #[derive(Default)] @@ -64,8 +66,14 @@ impl OptimizerRule for TypeCoercion { }, ); + let mut execution_props = ExecutionProps::new(); + execution_props.query_execution_start_time = + optimizer_config.query_execution_start_time; + let const_evaluator = ConstEvaluator::try_new(&execution_props)?; + let mut expr_rewrite = TypeCoercionRewriter { schema: Arc::new(schema), + const_evaluator, }; let new_expr = plan @@ -78,11 +86,12 @@ impl OptimizerRule for TypeCoercion { } } -struct TypeCoercionRewriter { +struct TypeCoercionRewriter<'a> { schema: DFSchemaRef, + const_evaluator: ConstEvaluator<'a>, } -impl ExprRewriter for TypeCoercionRewriter { +impl ExprRewriter for TypeCoercionRewriter<'_> { fn pre_visit(&mut self, _expr: &Expr) -> Result { Ok(RewriteRecursion::Continue) } @@ -106,7 +115,7 @@ impl ExprRewriter for TypeCoercionRewriter { } _ => { let coerced_type = coerce_types(&left_type, &op, &right_type)?; - Ok(Expr::BinaryExpr { + let expr = Expr::BinaryExpr { left: Box::new( left.clone().cast_to(&coerced_type, &self.schema)?, ), @@ -114,7 +123,9 @@ impl ExprRewriter for TypeCoercionRewriter { right: Box::new( right.clone().cast_to(&coerced_type, &self.schema)?, ), - }) + }; + + expr.rewrite(&mut self.const_evaluator) } } } @@ -133,12 +144,13 @@ impl ExprRewriter for TypeCoercionRewriter { expr_type, low_type )) })?; - Ok(Expr::Between { + let expr = Expr::Between { expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?), negated, low: Box::new(low.cast_to(&coerced_type, &self.schema)?), high: Box::new(high.cast_to(&coerced_type, &self.schema)?), - }) + }; + expr.rewrite(&mut self.const_evaluator) } Expr::ScalarUDF { fun, args } => { let new_expr = coerce_arguments_for_signature( @@ -146,10 +158,11 @@ impl ExprRewriter for TypeCoercionRewriter { &self.schema, &fun.signature, )?; - Ok(Expr::ScalarUDF { + let expr = Expr::ScalarUDF { fun, args: new_expr, - }) + }; + expr.rewrite(&mut self.const_evaluator) } expr => Ok(expr), } @@ -188,7 +201,8 @@ mod test { use crate::type_coercion::TypeCoercion; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; - use datafusion_common::{DFSchema, Result, ScalarValue}; + use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; + use datafusion_expr::{col, ColumnarValue}; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, @@ -199,17 +213,23 @@ mod test { #[test] fn simple_case() -> Result<()> { - let expr = lit(1.2_f64).lt(lit(2_u32)); + let expr = col("a").lt(lit(2_u32)); let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: Arc::new(DFSchema::empty()), + schema: Arc::new( + DFSchema::new_with_metadata( + vec![DFField::new(None, "a", DataType::Float64, true)], + std::collections::HashMap::new(), + ) + .unwrap(), + ), })); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); let rule = TypeCoercion::new(); let mut config = OptimizerConfig::default(); let plan = rule.optimize(&plan, &mut config)?; assert_eq!( - "Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n EmptyRelation", + "Projection: #a < Float64(2)\n EmptyRelation", &format!("{:?}", plan) ); Ok(()) @@ -217,10 +237,16 @@ mod test { #[test] fn nested_case() -> Result<()> { - let expr = lit(1.2_f64).lt(lit(2_u32)); + let expr = col("a").lt(lit(2_u32)); let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: Arc::new(DFSchema::empty()), + schema: Arc::new( + DFSchema::new_with_metadata( + vec![DFField::new(None, "a", DataType::Float64, true)], + std::collections::HashMap::new(), + ) + .unwrap(), + ), })); let plan = LogicalPlan::Projection(Projection::try_new( vec![expr.clone().or(expr)], @@ -230,8 +256,11 @@ mod test { let rule = TypeCoercion::new(); let mut config = OptimizerConfig::default(); let plan = rule.optimize(&plan, &mut config)?; - assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR Float64(1.2) < CAST(UInt32(2) AS Float64)\ - \n EmptyRelation", &format!("{:?}", plan)); + assert_eq!( + "Projection: #a < Float64(2) OR #a < Float64(2)\ + \n EmptyRelation", + &format!("{:?}", plan) + ); Ok(()) } @@ -240,7 +269,11 @@ mod test { let empty = empty(); let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); - let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); + let fun: ScalarFunctionImplementation = Arc::new(move |_| { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "a".to_string(), + )))) + }); let udf = Expr::ScalarUDF { fun: Arc::new(ScalarUDF::new( "TestScalarUDF", @@ -255,7 +288,7 @@ mod test { let mut config = OptimizerConfig::default(); let plan = rule.optimize(&plan, &mut config)?; assert_eq!( - "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation", + "Projection: Utf8(\"a\")\n EmptyRelation", &format!("{:?}", plan) ); Ok(()) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 87a0bab68a40..cb31600b97f6 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -83,7 +83,7 @@ fn between_date32_plus_interval() -> Result<()> { let plan = test_sql(sql)?; let expected = "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32) AND #test.col_date32 <= Date32(\"10393\")\ + \n Filter: #test.col_date32 >= Date32(\"10303\") AND #test.col_date32 <= Date32(\"10393\")\ \n TableScan: test projection=[col_date32]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) @@ -92,11 +92,11 @@ fn between_date32_plus_interval() -> Result<()> { #[test] fn between_date64_plus_interval() -> Result<()> { let sql = "SELECT count(1) FROM test \ - WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; + WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64) AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\ + \n Filter: #test.col_date64 >= Date64(\"890179200000\") AND #test.col_date64 <= Date64(\"897955200000\")\ \n TableScan: test projection=[col_date64]"; assert_eq!(expected, format!("{:?}", plan)); Ok(())