From 3cabee14ca819dd80c89dbce6ae168d0df18e069 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 23:06:03 +0800 Subject: [PATCH] split expr type and null info to be expr-schemable (#1784) --- datafusion/src/logical_plan/builder.rs | 1 + datafusion/src/logical_plan/expr.rs | 202 +-------------- datafusion/src/logical_plan/expr_rewriter.rs | 1 + datafusion/src/logical_plan/expr_schema.rs | 231 ++++++++++++++++++ datafusion/src/logical_plan/mod.rs | 2 + .../src/optimizer/common_subexpr_eliminate.rs | 2 +- .../src/optimizer/simplify_expressions.rs | 8 +- .../optimizer/single_distinct_to_groupby.rs | 1 + datafusion/tests/simplification.rs | 1 + 9 files changed, 245 insertions(+), 204 deletions(-) create mode 100644 datafusion/src/logical_plan/expr_schema.rs diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index d81fa9d2afa66..a722238059f50 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -25,6 +25,7 @@ use crate::datasource::{ MemTable, TableProvider, }; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::expr_schema::ExprSchemable; use crate::logical_plan::plan::{ Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort, TableScan, ToStringifiedPlan, Union, Window, diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 69da346aee8de..f19e9d8d6a35e 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,16 +20,13 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::field_util::get_indexed_field; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{window_frames, DFField, DFSchema}; use crate::physical_plan::functions::Volatility; -use crate::physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, - window_functions, -}; +use crate::physical_plan::{aggregates, functions, udf::ScalarUDF, window_functions}; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::{compute::can_cast_types, datatypes::DataType}; +use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::collections::HashSet; @@ -251,151 +248,6 @@ impl PartialOrd for Expr { } impl Expr { - /// Returns the [arrow::datatypes::DataType] of the expression - /// based on [ExprSchema] - /// - /// Note: [DFSchema] implements [ExprSchema]. - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// [arrow::datatypes::DataType]. This happens when e.g. the - /// expression refers to a column that does not exist in the - /// schema, or when the expression is incorrectly typed - /// (e.g. `[utf8] + [bool]`). - pub fn get_type(&self, schema: &S) -> Result { - match self { - Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { - expr.get_type(schema) - } - Expr::Column(c) => Ok(schema.data_type(c)?.clone()), - Expr::ScalarVariable(_) => Ok(DataType::Utf8), - Expr::Literal(l) => Ok(l.get_datatype()), - Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), - Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { - Ok(data_type.clone()) - } - Expr::ScalarUDF { fun, args } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::ScalarFunction { fun, args } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - functions::return_type(fun, &data_types) - } - Expr::WindowFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - window_functions::return_type(fun, &data_types) - } - Expr::AggregateFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - aggregates::return_type(fun, &data_types) - } - Expr::AggregateUDF { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::Not(_) - | Expr::IsNull(_) - | Expr::Between { .. } - | Expr::InList { .. } - | Expr::IsNotNull(_) => Ok(DataType::Boolean), - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => binary_operator_data_type( - &left.get_type(schema)?, - op, - &right.get_type(schema)?, - ), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::GetIndexedField { ref expr, key } => { - let data_type = expr.get_type(schema)?; - - get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) - } - } - } - - /// Returns the nullability of the expression based on [ExprSchema]. - /// - /// Note: [DFSchema] implements [ExprSchema]. - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// nullability. This happens when the expression refers to a - /// column that does not exist in the schema. - pub fn nullable(&self, input_schema: &S) -> Result { - match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort { expr, .. } - | Expr::Between { expr, .. } - | Expr::InList { expr, .. } => expr.nullable(input_schema), - Expr::Column(c) => input_schema.nullable(c), - Expr::Literal(value) => Ok(value.is_null()), - Expr::Case { - when_then_expr, - else_expr, - .. - } => { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = when_then_expr - .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) - } else if let Some(e) = else_expr { - e.nullable(input_schema) - } else { - Ok(false) - } - } - Expr::Cast { expr, .. } => expr.nullable(input_schema), - Expr::ScalarVariable(_) - | Expr::TryCast { .. } - | Expr::ScalarFunction { .. } - | Expr::ScalarUDF { .. } - | Expr::WindowFunction { .. } - | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } => Ok(true), - Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), - Expr::BinaryExpr { - ref left, - ref right, - .. - } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::GetIndexedField { ref expr, key } => { - let data_type = expr.get_type(input_schema)?; - get_indexed_field(&data_type, key).map(|x| x.is_nullable()) - } - } - } - /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. /// /// This represents how a column with this expression is named when no alias is chosen @@ -403,54 +255,6 @@ impl Expr { create_name(self, input_schema) } - /// Returns a [arrow::datatypes::Field] compatible with this expression. - pub fn to_field(&self, input_schema: &DFSchema) -> Result { - match self { - Expr::Column(c) => Ok(DFField::new( - c.relation.as_deref(), - &c.name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - _ => Ok(DFField::new( - None, - &self.name(input_schema)?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - } - } - - /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. - /// - /// # Errors - /// - /// This function errors when it is impossible to cast the - /// expression to the target [arrow::datatypes::DataType]. - pub fn cast_to( - self, - cast_to_type: &DataType, - schema: &S, - ) -> Result { - // TODO(kszucs): most of the operations do not validate the type correctness - // like all of the binary expressions below. Perhaps Expr should track the - // type of the expression? - let this_type = self.get_type(schema)?; - if this_type == *cast_to_type { - Ok(self) - } else if can_cast_types(&this_type, cast_to_type) { - Ok(Expr::Cast { - expr: Box::new(self), - data_type: cast_to_type.clone(), - }) - } else { - Err(DataFusionError::Plan(format!( - "Cannot automatically convert {:?} to {:?}", - this_type, cast_to_type - ))) - } - } - /// Return `self == other` pub fn eq(self, other: Expr) -> Expr { binary_expr(self, Operator::Eq, other) diff --git a/datafusion/src/logical_plan/expr_rewriter.rs b/datafusion/src/logical_plan/expr_rewriter.rs index d452dcd4c4261..5062d5fce7ad7 100644 --- a/datafusion/src/logical_plan/expr_rewriter.rs +++ b/datafusion/src/logical_plan/expr_rewriter.rs @@ -20,6 +20,7 @@ use super::Expr; use crate::logical_plan::plan::Aggregate; use crate::logical_plan::DFSchema; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::LogicalPlan; use datafusion_common::Column; use datafusion_common::Result; diff --git a/datafusion/src/logical_plan/expr_schema.rs b/datafusion/src/logical_plan/expr_schema.rs new file mode 100644 index 0000000000000..2e44c72415c96 --- /dev/null +++ b/datafusion/src/logical_plan/expr_schema.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Expr; +use crate::field_util::get_indexed_field; +use crate::physical_plan::{ + aggregates, expressions::binary_operator_data_type, functions, window_functions, +}; +use arrow::compute::can_cast_types; +use arrow::datatypes::DataType; +use datafusion_common::{DFField, DFSchema, DataFusionError, ExprSchema, Result}; + +/// trait to allow expr to typable with respect to a schema +pub trait ExprSchemable { + /// given a schema, return the type of the expr + fn get_type(&self, schema: &S) -> Result; + + /// given a schema, return the nullability of the expr + fn nullable(&self, input_schema: &S) -> Result; + + /// convert to a field with respect to a schema + fn to_field(&self, input_schema: &DFSchema) -> Result; + + /// cast to a type with respect to a schema + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; +} + +impl ExprSchemable for Expr { + /// Returns the [arrow::datatypes::DataType] of the expression + /// based on [ExprSchema] + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// [arrow::datatypes::DataType]. This happens when e.g. the + /// expression refers to a column that does not exist in the + /// schema, or when the expression is incorrectly typed + /// (e.g. `[utf8] + [bool]`). + fn get_type(&self, schema: &S) -> Result { + match self { + Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { + expr.get_type(schema) + } + Expr::Column(c) => Ok(schema.data_type(c)?.clone()), + Expr::ScalarVariable(_) => Ok(DataType::Utf8), + Expr::Literal(l) => Ok(l.get_datatype()), + Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), + Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { + Ok(data_type.clone()) + } + Expr::ScalarUDF { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } + Expr::ScalarFunction { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + functions::return_type(fun, &data_types) + } + Expr::WindowFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + window_functions::return_type(fun, &data_types) + } + Expr::AggregateFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + aggregates::return_type(fun, &data_types) + } + Expr::AggregateUDF { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } + Expr::Not(_) + | Expr::IsNull(_) + | Expr::Between { .. } + | Expr::InList { .. } + | Expr::IsNotNull(_) => Ok(DataType::Boolean), + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => binary_operator_data_type( + &left.get_type(schema)?, + op, + &right.get_type(schema)?, + ), + Expr::Wildcard => Err(DataFusionError::Internal( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(schema)?; + + get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + } + } + } + + /// Returns the nullability of the expression based on [ExprSchema]. + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// nullability. This happens when the expression refers to a + /// column that does not exist in the schema. + fn nullable(&self, input_schema: &S) -> Result { + match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::Negative(expr) + | Expr::Sort { expr, .. } + | Expr::Between { expr, .. } + | Expr::InList { expr, .. } => expr.nullable(input_schema), + Expr::Column(c) => input_schema.nullable(c), + Expr::Literal(value) => Ok(value.is_null()), + Expr::Case { + when_then_expr, + else_expr, + .. + } => { + // this expression is nullable if any of the input expressions are nullable + let then_nullable = when_then_expr + .iter() + .map(|(_, t)| t.nullable(input_schema)) + .collect::>>()?; + if then_nullable.contains(&true) { + Ok(true) + } else if let Some(e) = else_expr { + e.nullable(input_schema) + } else { + Ok(false) + } + } + Expr::Cast { expr, .. } => expr.nullable(input_schema), + Expr::ScalarVariable(_) + | Expr::TryCast { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } => Ok(true), + Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::Wildcard => Err(DataFusionError::Internal( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(input_schema)?; + get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + } + } + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + fn to_field(&self, input_schema: &DFSchema) -> Result { + match self { + Expr::Column(c) => Ok(DFField::new( + c.relation.as_deref(), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + _ => Ok(DFField::new( + None, + &self.name(input_schema)?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + } + } + + /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. + /// + /// # Errors + /// + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result { + // TODO(kszucs): most of the operations do not validate the type correctness + // like all of the binary expressions below. Perhaps Expr should track the + // type of the expression? + let this_type = self.get_type(schema)?; + if this_type == *cast_to_type { + Ok(self) + } else if can_cast_types(&this_type, cast_to_type) { + Ok(Expr::Cast { + expr: Box::new(self), + data_type: cast_to_type.clone(), + }) + } else { + Err(DataFusionError::Plan(format!( + "Cannot automatically convert {:?} to {:?}", + this_type, cast_to_type + ))) + } + } +} diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 085775a2eb8c5..f2ecb0f762788 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -26,6 +26,7 @@ mod dfschema; mod display; mod expr; mod expr_rewriter; +mod expr_schema; mod expr_simplier; mod expr_visitor; mod extension; @@ -54,6 +55,7 @@ pub use expr_rewriter::{ normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion, }; +pub use expr_schema::ExprSchemable; pub use expr_simplier::{ExprSimplifiable, SimplifyInfo}; pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; pub use extension::UserDefinedLogicalNode; diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 5c2219b3d99a2..2ed45be25bc17 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -23,7 +23,7 @@ use crate::logical_plan::plan::{Filter, Projection, Window}; use crate::logical_plan::{ col, plan::{Aggregate, Sort}, - DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprVisitable, + DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprSchemable, ExprVisitable, ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion, }; use crate::optimizer::optimizer::OptimizerRule; diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index f8f3df44b673e..4e9709bd9b5fe 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -17,12 +17,9 @@ //! Simplify expressions optimizer rule -use arrow::array::new_null_array; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; - use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{ lit, DFSchema, DFSchemaRef, Expr, ExprRewritable, ExprRewriter, ExprSimplifiable, LogicalPlan, RewriteRecursion, SimplifyInfo, @@ -33,6 +30,9 @@ use crate::physical_plan::functions::Volatility; use crate::physical_plan::planner::create_physical_expr; use crate::scalar::ScalarValue; use crate::{error::Result, logical_plan::Operator}; +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; /// Provides simplification information based on schema and properties struct SimplifyContext<'a, 'b> { diff --git a/datafusion/src/optimizer/single_distinct_to_groupby.rs b/datafusion/src/optimizer/single_distinct_to_groupby.rs index 02a24e2144958..2e0bd5ff0549a 100644 --- a/datafusion/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/src/optimizer/single_distinct_to_groupby.rs @@ -20,6 +20,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Projection}; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs index 0ce8e7685b83a..fe5f5e254b523 100644 --- a/datafusion/tests/simplification.rs +++ b/datafusion/tests/simplification.rs @@ -18,6 +18,7 @@ //! This program demonstrates the DataFusion expression simplification API. use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::logical_plan::ExprSchemable; use datafusion::logical_plan::ExprSimplifiable; use datafusion::{ error::Result,