diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 8aa812cc5258..0621522b07f9 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1951,6 +1951,10 @@ impl SimplifyInfo for SessionSimplifyProvider<'_> { fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { expr.get_type(self.df_schema) } + + fn get_schema(&self) -> Option<&DFSchema> { + Some(self.df_schema) + } } #[derive(Debug)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 53199294709a..915553e22979 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2238,7 +2238,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5), metadata: None }, fail_on_overflow: false }"; assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2263,7 +2263,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), metadata: None }, "c1"), (Literal { value: Int64(NULL), metadata: None }, "c2"), (Literal { value: Int64(NULL), metadata: None }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2290,7 +2290,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), metadata: None }, "c1"), (Literal { value: Int64(NULL), metadata: None }, "c2"), (Literal { value: Int64(NULL), metadata: None }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2474,7 +2474,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), metadata: None }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), metadata: None }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 34e0487f312f..75554a5694e1 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -23,7 +23,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, TimeZone, Utc}; use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; use datafusion_common::cast::as_int32_array; -use datafusion_common::ScalarValue; +use datafusion_common::{DFSchema, ScalarValue}; use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; @@ -71,6 +71,10 @@ impl SimplifyInfo for MyInfo { fn get_data_type(&self, expr: &Expr) -> Result { expr.get_type(self.schema.as_ref()) } + + fn get_schema(&self) -> Option<&DFSchema> { + Some(self.schema.as_ref()) + } } impl From for MyInfo { diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b5960ae5bd8d..4cdee5248d46 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -20,7 +20,7 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array}; +use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, @@ -1527,6 +1527,57 @@ async fn test_metadata_based_udf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_metadata_based_udf_with_literal() -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx.sql("select 0;").await?.select(vec![ + lit(5u64).alias_with_metadata( + "lit_with_doubling", + Some( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ), + lit(5u64).alias("lit_no_doubling"), + ])?; + + let output_metadata: HashMap = + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(); + let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone())); + + let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?) + .project(vec![ + custom_udf + .call(vec![col("lit_with_doubling")]) + .alias("doubled_output"), + custom_udf + .call(vec![col("lit_no_doubling")]) + .alias("not_doubled_output"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("doubled_output", DataType::UInt64, true) + .with_metadata(output_metadata.clone()), + Field::new("not_doubled_output", DataType::UInt64, true) + .with_metadata(output_metadata.clone()), + ])); + + let expected = RecordBatch::try_new( + schema, + vec![create_array!(UInt64, [10]), create_array!(UInt64, [5])], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + /// This UDF is to test extension handling, both on the input and output /// sides. For the input, we will handle the data differently if there is /// the canonical extension type Bool8. For the output we will add a diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 92d4497918fa..8332629c9beb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1512,8 +1512,16 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { - Ok(Transformed::yes(*expr)) + if let Expr::Alias(alias) = expr { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } } else { Ok(Transformed::no(expr)) } diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 467ce8bf53e2..0a9aad1a49cb 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -18,7 +18,7 @@ //! Structs and traits to provide the information needed for expression simplification. use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; @@ -40,6 +40,9 @@ pub trait SimplifyInfo { /// Returns data type of this expr needed for determining optimized int type of a value fn get_data_type(&self, expr: &Expr) -> Result; + + /// Returns the Schema which may be needed to evalute an Expr + fn get_schema(&self) -> Option<&DFSchema>; } /// Provides simplification information based on DFSchema and @@ -106,6 +109,10 @@ impl SimplifyInfo for SimplifyContext<'_> { fn execution_props(&self) -> &ExecutionProps { self.props } + + fn get_schema(&self) -> Option<&DFSchema> { + self.schema.as_ref().map(|v| v.as_ref()) + } } /// Was the expression simplified? diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 1093c0b3cf69..a4bfb22d93d7 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -583,8 +583,18 @@ fn is_expr_trivial(expr: &Expr) -> bool { fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { expr.transform_up(|expr| { match expr { - // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + // remove any intermediate aliases if they do not carry metadata + Expr::Alias(alias) => { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } + } Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index ad4fbb7bde63..7e27c5ab0929 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -33,8 +33,8 @@ use datafusion_common::{ }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, WindowFunctionDefinition, + and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, + ExprSchemable, Like, Operator, Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ @@ -172,6 +172,9 @@ impl ExprSimplifier { /// fn get_data_type(&self, expr: &Expr) -> Result { /// Ok(DataType::Int32) /// } + /// fn get_schema(&self) -> Option<&DFSchema> { + /// None + /// } /// } /// /// // Create the simplifier @@ -227,12 +230,17 @@ impl ExprSimplifier { mut expr: Expr, ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); - let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; + + let empty_schema = DFSchema::empty(); + let mut const_evaluator = ConstEvaluator::try_new( + self.info.execution_props(), + self.info.get_schema().unwrap_or(&empty_schema), + )?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { - expr = expr.rewrite(&mut Canonicalizer::new()).data()? + expr = expr.rewrite(&mut Canonicalizer::new()).data()?; } // Evaluating constants can enable new simplifications and @@ -249,14 +257,17 @@ impl ExprSimplifier { .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; expr = data; num_cycles += 1; + // Track if any transformation occurred has_transformed = has_transformed || transformed; if !transformed || num_cycles >= self.max_simplifier_cycles { break; } } + // shorten inlist should be started after other inlist rules are applied expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; + Ok(( Transformed::new_transformed(expr, has_transformed), num_cycles, @@ -540,7 +551,7 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // stack as not ok (as all parents have at least one child or // descendant that can not be evaluated - if !Self::can_evaluate(&expr) { + if !self.can_evaluate(&expr) { // walk back up stack, marking first parent that is not mutable let parent_iter = self.can_evaluate.iter_mut().rev(); for p in parent_iter { @@ -586,12 +597,16 @@ impl<'a> ConstEvaluator<'a> { /// Create a new `ConstantEvaluator`. Session constants (such as /// the time for `now()` are taken from the passed /// `execution_props`. - pub fn try_new(execution_props: &'a ExecutionProps) -> Result { + pub fn try_new( + execution_props: &'a ExecutionProps, + input_schema: &DFSchema, + ) -> Result { // The dummy column name is unused and doesn't matter as only // expressions without column references can be evaluated static DUMMY_COL_NAME: &str = "."; + let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); - let input_schema = DFSchema::try_from(schema.clone())?; + let input_schema = input_schema.clone(); // Need a single "input" row to produce a single output row let col = new_null_array(&DataType::Null, 1); let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?; @@ -616,7 +631,17 @@ impl<'a> ConstEvaluator<'a> { /// Can the expression be evaluated at plan time, (assuming all of /// its children can also be evaluated)? - fn can_evaluate(expr: &Expr) -> bool { + fn can_evaluate(&self, expr: &Expr) -> bool { + // Any time we have metadata on the return value we cannot + // simplify to a literal because we will lose it. + let false = expr + .metadata(&self.input_schema) + .map(|metadata| !metadata.is_empty()) + .unwrap_or(false) + else { + return false; + }; + // check for reasons we can't evaluate this node // // NOTE all expr types are listed here so when new ones are @@ -640,8 +665,8 @@ impl<'a> ConstEvaluator<'a> { Self::volatility_ok(func.signature().volatility) } Expr::Literal(_) - | Expr::Alias(..) | Expr::Unnest(_) + | Expr::Alias(_) | Expr::BinaryExpr { .. } | Expr::Not(_) | Expr::IsNotNull(_) @@ -4326,16 +4351,23 @@ mod tests { #[derive(Debug, Clone)] struct SimplifyMockUdaf { simplify: bool, + signature: Signature, } impl SimplifyMockUdaf { /// make simplify method return new expression fn new_with_simplify() -> Self { - Self { simplify: true } + Self { + simplify: true, + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } } /// make simplify method return no change fn new_without_simplify() -> Self { - Self { simplify: false } + Self { + simplify: false, + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } } } @@ -4349,7 +4381,7 @@ mod tests { } fn signature(&self) -> &Signature { - unimplemented!() + &self.signature } fn return_type(&self, _arg_types: &[DataType]) -> Result { @@ -4409,16 +4441,24 @@ mod tests { #[derive(Debug, Clone)] struct SimplifyMockUdwf { simplify: bool, + signature: Signature, } impl SimplifyMockUdwf { /// make simplify method return new expression fn new_with_simplify() -> Self { - Self { simplify: true } + Self { + simplify: true, + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } } + /// make simplify method return no change fn new_without_simplify() -> Self { - Self { simplify: false } + Self { + simplify: false, + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } } } @@ -4432,7 +4472,7 @@ mod tests { } fn signature(&self) -> &Signature { - unimplemented!() + &self.signature } fn simplify(&self) -> Option { diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index 9785203a7020..0820e0009f86 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -342,7 +342,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), metadata: None }, fail_on_overflow: false }"#); let dynamic_filter_2 = reassign_predicate_columns( Arc::clone(&dynamic_filter) as Arc, &filter_schema_2, @@ -350,7 +350,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), metadata: None }, fail_on_overflow: false }"#); // Both filters allow evaluating the same expression let batch_1 = RecordBatch::try_new( Arc::clone(&filter_schema_1), diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 469f7bbee317..8d96fbc8c02c 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1451,7 +1451,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }])"); // Test: a NOT IN ('a', 'b') let list = vec![lit("a"), lit("b")]; @@ -1459,7 +1459,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }])"); // Test: a IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1467,7 +1467,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b, NULL)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }, Literal { value: Utf8(NULL), metadata: None }])"); // Test: a NOT IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1475,7 +1475,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b, NULL)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }, Literal { value: Utf8(NULL), metadata: None }])"); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 6f7caaea8d45..cf469a678bd8 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,11 +18,13 @@ //! Literal expressions for physical operations use std::any::Any; +use std::collections::{BTreeMap, HashMap}; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -34,15 +36,37 @@ use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq)] pub struct Literal { value: ScalarValue, + metadata: Option>, +} + +impl Hash for Literal { + fn hash(&self, state: &mut H) { + let ordered_metadata = self + .metadata + .as_ref() + .map(|m| m.iter().collect::>()); + self.value.hash(state); + ordered_metadata.hash(state); + } } impl Literal { /// Create a literal value expression pub fn new(value: ScalarValue) -> Self { - Self { value } + Self { + value, + metadata: None, + } + } + + pub fn new_with_metadata( + value: ScalarValue, + metadata: Option>, + ) -> Self { + Self { value, metadata } } /// Get the scalar value @@ -71,6 +95,20 @@ impl PhysicalExpr for Literal { Ok(self.value.is_null()) } + fn return_field(&self, _input_schema: &Schema) -> Result { + let mut field = Field::new( + format!("{self}"), + self.value.data_type(), + self.value.is_null(), + ); + + if let Some(metadata) = &self.metadata { + field = field.with_metadata(metadata.clone()); + } + + Ok(field) + } + fn evaluate(&self, _batch: &RecordBatch) -> Result { Ok(ColumnarValue::Scalar(self.value.clone())) } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8660bff796d5..b577b6e4434b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -32,7 +32,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, + binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -109,16 +109,29 @@ pub fn create_physical_expr( execution_props: &ExecutionProps, ) -> Result> { let input_schema: &Schema = &input_dfschema.into(); + // let expr_field = e.to_field(input_dfschema)?.1; + let metadata = e.metadata(input_dfschema)?; + let metadata = match metadata.is_empty() { + true => None, + false => Some(metadata), + }; match e { Expr::Alias(Alias { expr, .. }) => { - Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + if let Expr::Literal(v) = expr.as_ref() { + Ok(Arc::new(Literal::new_with_metadata(v.clone(), metadata))) + } else { + Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + } } Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::Literal(value) => Ok(Arc::new(Literal::new_with_metadata( + value.clone(), + metadata.clone(), + ))), Expr::ScalarVariable(_, variable_names) => { if is_system_variables(variable_names) { match execution_props.get_var_provider(VarType::System) { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 5340ddef1e81..64c7ccaedccb 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6019,7 +6019,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), metadata: None }, Literal { value: Utf8View("a"), metadata: None }, Literal { value: Utf8View("b"), metadata: None }, Literal { value: Utf8View("c"), metadata: None }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6048,7 +6048,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), metadata: None }, Literal { value: Utf8View("a"), metadata: None }, Literal { value: Utf8View("b"), metadata: None }, Literal { value: Utf8View("c"), metadata: None }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6077,7 +6077,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), metadata: None }, Literal { value: Utf8View("a"), metadata: None }, Literal { value: Utf8View("b"), metadata: None }, Literal { value: Utf8View("c"), metadata: None }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6137,7 +6137,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), metadata: None }, Literal { value: Utf8View("a"), metadata: None }, Literal { value: Utf8View("b"), metadata: None }, Literal { value: Utf8View("c"), metadata: None }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 828bf967d8f4..6c30e80cb948 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]) +17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13"), metadata: None }, Literal { value: Utf8View("31"), metadata: None }, Literal { value: Utf8View("23"), metadata: None }, Literal { value: Utf8View("29"), metadata: None }, Literal { value: Utf8View("30"), metadata: None }, Literal { value: Utf8View("18"), metadata: None }, Literal { value: Utf8View("17"), metadata: None }]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -100,6 +100,6 @@ physical_plan 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]), projection=[c_acctbal@1] +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13"), metadata: None }, Literal { value: Utf8View("31"), metadata: None }, Literal { value: Utf8View("23"), metadata: None }, Literal { value: Utf8View("29"), metadata: None }, Literal { value: Utf8View("30"), metadata: None }, Literal { value: Utf8View("18"), metadata: None }, Literal { value: Utf8View("17"), metadata: None }]), projection=[c_acctbal@1] 28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 29)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false