diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 911bbf34b06f..f07dac649df9 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -322,7 +322,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 089b8db6a5a0..92cf33f4fdf6 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(5)), None)), )); assert_eq!(expr, expr2); diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 63f17484809e..176b1a69808c 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -171,7 +171,7 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { /// Return true if the expression is a literal or column reference fn is_lit_or_col(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// A simple user defined filter function diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index d2b2d1bf9655..b65ffb8d7174 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -133,7 +133,8 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() + else { return plan_err!("read_csv requires at least one string argument"); }; @@ -145,7 +146,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 037c69cebd57..00e9c71df348 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -61,7 +61,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -346,8 +346,8 @@ fn populate_partition_values<'a>( { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) - | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -984,7 +984,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3)), None))], ), Some(Path::from("a=1970-01-04")), ); @@ -993,9 +993,10 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( - 4 * 24 * 60 * 60 * 1000 - )))),], + &[col("a").eq(Expr::Literal( + ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)), + None + )),], ), Some(Path::from("a=1970-01-05")), ); diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 97d47fc3b907..063b8e6c86bb 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -71,8 +71,11 @@ fn criterion_benchmark(c: &mut Criterion) { let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 69992e57ca7d..02a18f22c916 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1337,7 +1337,10 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate( + vec![], + vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))], + )? .collect() .await?; let len = *rows diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 3c87d3ee2329..0dd6cd38ba53 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -2230,7 +2230,7 @@ mod tests { let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)), None)), )); // Create a new batch of data to insert into the table diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5ef666b61e54..dbe5c2c00f17 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1214,7 +1214,7 @@ impl SessionContext { let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), + Expr::Literal(scalar, _) => Ok(scalar), _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 61d1fee79472..c65fcb4c4c93 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2257,7 +2257,8 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; + let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "5", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#; + assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2282,7 +2283,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "NULL", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2309,7 +2310,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "NULL", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2493,7 +2494,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), field: Field { name: \"1\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index f68bcfaf1550..c80c0b4bf54b 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -179,12 +179,12 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, + Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aa36de1e555f..c737d0f9c3b0 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1210,7 +1210,7 @@ async fn join_on_filter_datatype() -> Result<()> { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::Literal(ScalarValue::Null, None)), )?; assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation"); @@ -4527,7 +4527,10 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column( + "t", + cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32), + ) .unwrap(); df.clone().show().await.unwrap(); diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 97bb2a727bbf..f5a8a30e0130 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -47,9 +47,9 @@ async fn count_only_nulls() -> Result<()> { let input = Arc::new(LogicalPlan::Values(Values { schema: input_schema, values: vec![ - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], ], })); let input_col_ref = Expr::Column(Column { diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 34e0487f312f..91a507bdf7f0 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -282,10 +282,13 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 123, - milliseconds: 0, - }))); + + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 123, + milliseconds: 0, + })), + None, + ); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr])? diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 1fc6d14c5b22..07d289cab06c 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -56,7 +56,7 @@ impl ExprPlanner for MyCustomPlanner { } BinaryOperator::Question => { Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::Literal(ScalarValue::Boolean(Some(true)), None), None::<&str>, format!("{} ? {}", expr.left, expr.right), )))) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index b68ef6aca093..4d3916c1760e 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -912,11 +912,12 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(ScalarValue::Int64(i), _) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) + Transformed::yes(Expr::Literal( + ScalarValue::UInt64(i.map(|i| i as u64)), + None, + )) } _ => Transformed::no(e), }) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 25458efa4fa5..3e8fafc7a636 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -20,7 +20,7 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array}; +use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, @@ -42,9 +42,9 @@ use datafusion_common::{ }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, + LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -1529,6 +1529,65 @@ async fn test_metadata_based_udf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_metadata_based_udf_with_literal() -> Result<()> { + let ctx = SessionContext::new(); + let input_metadata: HashMap = + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(); + let df = ctx.sql("select 0;").await?.select(vec![ + lit(5u64).alias_with_metadata("lit_with_doubling", Some(input_metadata.clone())), + lit(5u64).alias("lit_no_doubling"), + lit_with_metadata(5u64, Some(input_metadata)) + .alias("lit_with_double_no_alias_metadata"), + ])?; + + let output_metadata: HashMap = + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(); + let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone())); + + let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?) + .project(vec![ + custom_udf + .call(vec![col("lit_with_doubling")]) + .alias("doubled_output"), + custom_udf + .call(vec![col("lit_no_doubling")]) + .alias("not_doubled_output"), + custom_udf + .call(vec![col("lit_with_double_no_alias_metadata")]) + .alias("double_without_alias_metadata"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("not_doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("double_without_alias_metadata", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + ])); + + let expected = RecordBatch::try_new( + schema, + vec![ + create_array!(UInt64, [10]), + create_array!(UInt64, [5]), + create_array!(UInt64, [10]), + ], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + /// This UDF is to test extension handling, both on the input and output /// sides. For the input, we will handle the data differently if there is /// the canonical extension type Bool8. For the output we will add a diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index e4aff0b00705..2c6611f382ce 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index cde9e56c9280..db455fed6160 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -557,6 +557,7 @@ mod test { // Test all should fail let expr = col("timestamp_col").lt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), + None, )); let expr = logical2physical(&expr, &table_schema); let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); @@ -597,6 +598,7 @@ mod test { // Test all should pass let expr = col("timestamp_col").gt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), + None, )); let expr = logical2physical(&expr, &table_schema); let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); @@ -660,7 +662,7 @@ mod test { let expr = col("string_col") .is_not_null() - .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5))))); + .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)), None))); let expr = logical2physical(&expr, &table_schema); assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index d44fa1684320..f9fb9214429d 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -1242,12 +1242,16 @@ mod tests { .run( lit("1").eq(lit("1")).and( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello_Not_Exists", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("Hello_Not_Exists2")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello_Not_Exists"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from( + "Hello_Not_Exists2", + ))), + None, + ))), ), ) .await @@ -1327,15 +1331,18 @@ mod tests { // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` .run( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("the quick")), - )))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("are you")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("the quick"))), + None, + ))) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("are you"))), + None, + ))), ) .await } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 9cb51612d0ca..69525ea52137 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -72,7 +72,7 @@ impl CaseBuilder { let then_types: Vec = then_expr .iter() .map(|e| match e { - Expr::Literal(_) => e.get_type(&DFSchema::empty()), + Expr::Literal(_, _) => e.get_type(&DFSchema::empty()), _ => Ok(DataType::Null), }) .collect::>>()?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index dcd5380b4859..f379edf10584 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,7 +17,8 @@ //! Logical Expressions: [`Expr`] -use std::collections::HashSet; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; @@ -51,7 +52,7 @@ use sqlparser::ast::{ /// BinaryExpr { /// left: Expr::Column("A"), /// op: Operator::Plus, -/// right: Expr::Literal(ScalarValue::Int32(Some(1))) +/// right: Expr::Literal(ScalarValue::Int32(Some(1)), None) /// } /// ``` /// @@ -113,10 +114,10 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Expr}; /// // All literals are strongly typed in DataFusion. To make an `i64` 42: /// let expr = lit(42i64); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); /// // To make a (typed) NULL: -/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// let expr = Expr::Literal(ScalarValue::Int64(None), None); /// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): /// let expr = lit(ScalarValue::Null); /// ``` @@ -150,7 +151,7 @@ use sqlparser::ast::{ /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*binary_expr.right, Expr::Literal(scalar)); +/// assert_eq!(*binary_expr.right, Expr::Literal(scalar, None)); /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` @@ -194,7 +195,7 @@ use sqlparser::ast::{ /// ``` /// # use datafusion_expr::{lit, col}; /// let expr = col("c1") + lit(42); -/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42)) })"); +/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42), None) })"); /// ``` /// /// ## Use the `Display` trait (detailed expression) @@ -240,7 +241,7 @@ use sqlparser::ast::{ /// let mut scalars = HashSet::new(); /// // apply recursively visits all nodes in the expression tree /// expr.apply(|e| { -/// if let Expr::Literal(scalar) = e { +/// if let Expr::Literal(scalar, _) = e { /// scalars.insert(scalar); /// } /// // The return value controls whether to continue visiting the tree @@ -275,7 +276,7 @@ use sqlparser::ast::{ /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. Alias(Alias), @@ -283,8 +284,8 @@ pub enum Expr { Column(Column), /// A named reference to a variable in a registry. ScalarVariable(DataType, Vec), - /// A constant value. - Literal(ScalarValue), + /// A constant value along with associated metadata + Literal(ScalarValue, Option>), /// A binary expression such as "age > 21" BinaryExpr(BinaryExpr), /// LIKE expression @@ -368,7 +369,7 @@ pub enum Expr { impl Default for Expr { fn default() -> Self { - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) } } @@ -450,13 +451,13 @@ impl Hash for Alias { } impl PartialOrd for Alias { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { let cmp = self.expr.partial_cmp(&other.expr); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; let cmp = self.relation.partial_cmp(&other.relation); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; self.name.partial_cmp(&other.name) @@ -1537,8 +1538,16 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { - Ok(Transformed::yes(*expr)) + if let Expr::Alias(alias) = expr { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } } else { Ok(Transformed::no(expr)) } @@ -2299,7 +2308,7 @@ impl HashNode for Expr { data_type.hash(state); name.hash(state); } - Expr::Literal(scalar_value) => { + Expr::Literal(scalar_value, _) => { scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { @@ -2479,7 +2488,7 @@ impl Display for SchemaDisplay<'_> { // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::ScalarVariable(..) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) @@ -2738,7 +2747,7 @@ struct SqlDisplay<'a>(&'a Expr); impl Display for SqlDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.0 { - Expr::Literal(scalar) => scalar.fmt(f), + Expr::Literal(scalar, _) => scalar.fmt(f), Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -3005,7 +3014,12 @@ impl Display for Expr { write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})") } Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{v:?}"), + Expr::Literal(v, metadata) => { + match metadata.as_ref().map(|m| m.is_empty()).unwrap_or(true) { + false => write!(f, "{v:?} {:?}", metadata.as_ref().unwrap()), + true => write!(f, "{v:?}"), + } + } Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { @@ -3376,7 +3390,7 @@ mod test { #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), data_type: DataType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5182ccb15c0a..e8885ed6b724 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -690,17 +690,17 @@ impl WindowUDFImpl for SimpleWindowUDF { pub fn interval_year_month_lit(value: &str) -> Expr { let interval = parse_interval_year_month(value).ok(); - Expr::Literal(ScalarValue::IntervalYearMonth(interval)) + Expr::Literal(ScalarValue::IntervalYearMonth(interval), None) } pub fn interval_datetime_lit(value: &str) -> Expr { let interval = parse_interval_day_time(value).ok(); - Expr::Literal(ScalarValue::IntervalDayTime(interval)) + Expr::Literal(ScalarValue::IntervalDayTime(interval), None) } pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); - Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 90dcbce46b01..f80b8e5a7759 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -354,6 +354,7 @@ mod test { use std::ops::Add; use super::*; + use crate::literal::lit_with_metadata; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeRewriter; @@ -383,13 +384,17 @@ mod test { // rewrites all "foo" string literals to "bar" let transformer = |expr: Expr| -> Result> { match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => { let utf8_val = if utf8_val == "foo" { "bar".to_string() } else { utf8_val }; - Ok(Transformed::yes(lit(utf8_val))) + Ok(Transformed::yes(lit_with_metadata( + utf8_val, + metadata + .map(|m| m.into_iter().collect::>()), + ))) } // otherwise, return None _ => Ok(Transformed::no(expr)), diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bdf9911b006c..1973a00a67df 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -115,7 +115,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; @@ -278,7 +278,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => input_schema.nullable(c), Expr::OuterReferenceColumn(_, _) => Ok(true), - Expr::Literal(value) => Ok(value.is_null()), + Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { // This expression is nullable if any of the input expressions are nullable let then_nullable = case @@ -420,11 +420,18 @@ impl ExprSchemable for Expr { Expr::ScalarVariable(ty, _) => { Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) } - Expr::Literal(l) => Ok(Arc::new(Field::new( - &schema_name, - l.data_type(), - l.is_null(), - ))), + Expr::Literal(l, metadata) => { + let mut field = Field::new(&schema_name, l.data_type(), l.is_null()); + if let Some(metadata) = metadata { + field = field.with_metadata( + metadata + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + ); + } + Ok(Arc::new(field)) + } Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -533,7 +540,7 @@ impl ExprSchemable for Expr { let arguments = args .iter() .map(|e| match e { - Expr::Literal(sv) => Some(sv), + Expr::Literal(sv, _) => Some(sv), _ => None, }) .collect::>(); diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 48931d6525af..1f44f755b214 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -94,7 +94,9 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; -pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; +pub use literal::{ + lit, lit_timestamp_nano, lit_with_metadata, Literal, TimestampLiteral, +}; pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 90ba5a9a693c..48e058b8b7b1 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -19,12 +19,37 @@ use crate::Expr; use datafusion_common::ScalarValue; +use std::collections::HashMap; /// Create a literal expression pub fn lit(n: T) -> Expr { n.lit() } +pub fn lit_with_metadata( + n: T, + metadata: impl Into>>, +) -> Expr { + let metadata = metadata.into(); + let Some(metadata) = metadata else { + return n.lit(); + }; + + let Expr::Literal(sv, prior_metadata) = n.lit() else { + unreachable!(); + }; + + let new_metadata = match prior_metadata { + Some(mut prior) => { + prior.extend(metadata); + prior + } + None => metadata.into_iter().collect(), + }; + + Expr::Literal(sv, Some(new_metadata)) +} + /// Create a literal timestamp expression pub fn lit_timestamp_nano(n: T) -> Expr { n.lit_timestamp_nano() @@ -43,37 +68,37 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(*self)) + Expr::Literal(ScalarValue::from(*self), None) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for Vec { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for &[u8] { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for ScalarValue { fn lit(&self) -> Expr { - Expr::Literal(self.clone()) + Expr::Literal(self.clone(), None) } } @@ -82,7 +107,7 @@ macro_rules! make_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone())), None) } } }; @@ -93,7 +118,7 @@ macro_rules! make_nonzero_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.get())), None) } } }; @@ -104,10 +129,10 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some((self.clone()).into()), None), None, - )) + ) } } }; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index f75e79cd6672..533e81e64f29 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -341,8 +341,11 @@ impl LogicalPlanBuilder { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in fields.iter().map(|f| f.data_type()).enumerate() { - if let Expr::Literal(ScalarValue::Null) = row[j] { - row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); + if let Expr::Literal(ScalarValue::Null, metadata) = &row[j] { + row[j] = Expr::Literal( + ScalarValue::try_from(field_type)?, + metadata.clone(), + ); } else { row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 691f5684a11c..5bc07cf6213e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1305,7 +1305,7 @@ impl LogicalPlan { // Empty group_expr will return Some(1) if group_expr .iter() - .all(|expr| matches!(expr, Expr::Literal(_))) + .all(|expr| matches!(expr, Expr::Literal(_, _))) { Some(1) } else { @@ -1455,7 +1455,7 @@ impl LogicalPlan { let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value, None))) } else { Ok(Transformed::no(e)) } @@ -2698,7 +2698,9 @@ impl Union { { expr.push(Expr::Column(column)); } else { - expr.push(Expr::Literal(ScalarValue::Null).alias(column.name())); + expr.push( + Expr::Literal(ScalarValue::Null, None).alias(column.name()), + ); } } wrapped_inputs.push(Arc::new(LogicalPlan::Projection( @@ -3224,7 +3226,7 @@ impl Limit { pub fn get_skip_type(&self) -> Result { match self.skip.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(s)) => { + Expr::Literal(ScalarValue::Int64(s), _) => { // `skip = NULL` is equivalent to `skip = 0` let s = s.unwrap_or(0); if s >= 0 { @@ -3244,14 +3246,16 @@ impl Limit { pub fn get_fetch_type(&self) -> Result { match self.fetch.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => { + Expr::Literal(ScalarValue::Int64(Some(s)), _) => { if s >= 0 { Ok(FetchType::Literal(Some(s as usize))) } else { plan_err!("LIMIT must be >= 0, '{}' was provided", s) } } - Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + Expr::Literal(ScalarValue::Int64(None), _) => { + Ok(FetchType::Literal(None)) + } _ => Ok(FetchType::UnsupportedExpr), }, None => Ok(FetchType::Literal(None)), @@ -4539,7 +4543,7 @@ mod tests { let col = schema.field_names()[0].clone(); let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)), None)), scan, ) .unwrap(); @@ -4666,12 +4670,14 @@ mod tests { skip: None, fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input: Arc::clone(&input), }), LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), fetch: None, input: Arc::clone(&input), @@ -4679,9 +4685,11 @@ mod tests { LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_one(&DataType::UInt32).unwrap(), + None, ))), fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input, }), diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index bfdc19394576..f953aec5a1e3 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -73,7 +73,7 @@ impl TreeNode for Expr { // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } @@ -126,7 +126,7 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => Transformed::no(self), + | Expr::Literal(_, _) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 6f44e37d0523..b7851e530099 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -276,7 +276,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Expr::Unnest(_) | Expr::ScalarVariable(_, _) | Expr::Alias(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::BinaryExpr { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } @@ -785,7 +785,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes.push(idx); } } - Expr::Literal(_) => { + Expr::Literal(_, _) => { indexes.push(usize::MAX); } _ => {} diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index 08bc4d0cd83b..ceedec2599a2 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -214,7 +214,7 @@ mod tests { let args = args .iter() .map(|arg| { - if let Expr::Literal(scalar) = arg { + if let Expr::Literal(scalar, _) = arg { Ok(scalar) } else { exec_err!("Expected only literal arguments to table udf") diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index f375a68d9458..6b7199c44b32 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -101,7 +101,7 @@ pub fn count_distinct(expr: Expr) -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all() -> Expr { - count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") + count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)") } /// Creates window aggregation to count all rows. @@ -126,7 +126,7 @@ pub fn count_all() -> Expr { pub fn count_all_window() -> Expr { Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) } diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs index c8cb84118995..f0e37f6b1dbe 100644 --- a/datafusion/functions-aggregate/src/planner.rs +++ b/datafusion/functions-aggregate/src/planner.rs @@ -100,7 +100,7 @@ impl ExprPlanner for AggregateFunctionPlanner { let new_expr = Expr::AggregateFunction(AggregateFunction::new_udf( func, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], distinct, filter, order_by, diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index a752a47bcbaa..55dd7ad14460 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -58,8 +58,11 @@ fn criterion_benchmark(c: &mut Criterion) { let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } let planner = NestedFunctionPlanner {}; diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 5ef1491313b1..3b9b705e72c5 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -133,7 +133,7 @@ impl ScalarUDFImpl for ArrayHas { // if the haystack is a constant list, we can use an inlist expression which is more // efficient because the haystack is not varying per-row - if let Expr::Literal(ScalarValue::List(array)) = haystack { + if let Expr::Literal(ScalarValue::List(array), _) = haystack { // TODO: support LargeList // (not supported by `convert_array_to_scalar_vec`) // (FixedSizeList not supported either, but seems to have worked fine when attempting to @@ -147,7 +147,7 @@ impl ScalarUDFImpl for ArrayHas { let list = scalar_values .into_iter() .flatten() - .map(Expr::Literal) + .map(|v| Expr::Literal(v, None)) .collect(); return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index ee95567ab73d..ffb93cf59b16 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -199,8 +199,8 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { let mut normalize_args = Vec::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Null) => {} - Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n), + Expr::Literal(ScalarValue::Null, _) => {} + Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n), _ => return plan_err!("First argument must be an integer literal"), }; } diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 8fca0114f65e..091737bb9c15 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -97,7 +97,7 @@ impl ExprPlanner for WindowFunctionPlanner { let new_expr = Expr::from(WindowFunction::new( func_def, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 2d769dfa5657..e9dee09e74bf 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -177,7 +177,7 @@ impl ScalarUDFImpl for ArrowCastFunc { fn data_type_from_args(args: &[Expr]) -> Result { let [_, type_arg] = take_function_args("arrow_cast", args)?; - let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else { + let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", type_arg diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index de87308ef3c4..2f39132871bb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -108,7 +108,7 @@ impl ScalarUDFImpl for GetFieldFunc { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; @@ -118,7 +118,7 @@ impl ScalarUDFImpl for GetFieldFunc { fn schema_name(&self, args: &[Expr]) -> Result { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 9998e7d3758e..2bda1f262abe 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -108,6 +108,7 @@ impl ScalarUDFImpl for CurrentDateFunc { ); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Date32(days), + None, ))) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index c416d0240b13..9b9d3997e9d7 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -96,6 +96,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Time64Nanosecond(nano), + None, ))) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 30b4d4ca9c76..ffb3aed5a960 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -108,6 +108,7 @@ impl ScalarUDFImpl for NowFunc { .timestamp_nanos_opt(); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), + None, ))) } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ee52c035ac81..23e267a323b9 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -210,7 +210,9 @@ impl ScalarUDFImpl for LogFunc { }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_one(&number_datatype)? => + { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index bd1ae7c316c1..465844704f59 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -156,12 +156,15 @@ impl ScalarUDFImpl for PowerFunc { let exponent_type = info.get_data_type(&exponent)?; match exponent { - Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_zero(&exponent_type)? => + { Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::new_one(&info.get_data_type(&base)?)?, + None, ))) } - Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { + Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } Expr::ScalarFunction(ScalarFunction { func, mut args }) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 773c316422b7..64a527eac198 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -295,7 +295,7 @@ pub fn simplify_concat(args: Vec) -> Result { let data_types: Vec<_> = args .iter() .filter_map(|expr| match expr { - Expr::Literal(l) => Some(l.data_type()), + Expr::Literal(l, _) => Some(l.data_type()), _ => None, }) .collect(); @@ -304,25 +304,25 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { - Expr::Literal(ScalarValue::Utf8(None)) => {} - Expr::Literal(ScalarValue::LargeUtf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => {} + Expr::Literal(ScalarValue::LargeUtf8(None), _) => { } - Expr::Literal(ScalarValue::Utf8View(None)) => { } + Expr::Literal(ScalarValue::Utf8View(None), _) => { } // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal(ScalarValue::Utf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(x) => { + Expr::Literal(x, _) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." ) diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 2a2f9429f8fc..1f45f8501e1f 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -312,6 +312,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { // when the delimiter is an empty string, @@ -336,8 +337,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None), _) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), _) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -347,7 +348,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + Expr::Literal(s, _) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -374,10 +375,11 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Utf8(None), + None, ))), } } - Expr::Literal(d) => internal_err!( + Expr::Literal(d, _) => internal_err!( "The scalar {d} should be casted to string type during the type coercion." ), _ => { @@ -394,7 +396,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index b74be1546626..215f8f7a25b9 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -191,8 +191,11 @@ mod test { #[test] fn test_contains_api() { let expr = contains( - Expr::Literal(ScalarValue::Utf8(Some("the quick brown fox".to_string()))), - Expr::Literal(ScalarValue::Utf8(Some("row".to_string()))), + Expr::Literal( + ScalarValue::Utf8(Some("the quick brown fox".to_string())), + None, + ), + Expr::Literal(ScalarValue::Utf8(Some("row".to_string())), None), ); assert_eq!( expr.to_string(), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index a59d7080a580..ecab1af132e0 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -130,7 +130,7 @@ impl ScalarUDFImpl for StartsWithFunc { args: Vec, info: &dyn SimplifyInfo, ) -> Result { - if let Expr::Literal(scalar_value) = &args[1] { + if let Expr::Literal(scalar_value, _) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping // Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%' // 1. 'ja%' (input pattern) @@ -142,7 +142,7 @@ impl ScalarUDFImpl for StartsWithFunc { | ScalarValue::Utf8View(Some(pattern)) => { let escaped_pattern = pattern.replace("%", "\\%"); let like_pattern = format!("{escaped_pattern}%"); - Expr::Literal(ScalarValue::Utf8(Some(like_pattern))) + Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None) } _ => return Ok(ExprSimplifyResult::Original(args)), }; diff --git a/datafusion/optimizer/benches/projection_unnecessary.rs b/datafusion/optimizer/benches/projection_unnecessary.rs index ee7889eb3321..c9f248fe49b5 100644 --- a/datafusion/optimizer/benches/projection_unnecessary.rs +++ b/datafusion/optimizer/benches/projection_unnecessary.rs @@ -30,7 +30,7 @@ fn is_projection_unnecessary_old( // First check if all expressions are trivial (cheaper operation than `projection_schema`) if !proj_exprs .iter() - .all(|expr| matches!(expr, Expr::Column(_) | Expr::Literal(_))) + .all(|expr| matches!(expr, Expr::Column(_) | Expr::Literal(_, _))) { return Ok(false); } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index f8a818563609..fa7ff1b8b19d 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -189,19 +189,19 @@ fn grouping_function_on_id( // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0i32))); + return Ok(Expr::Literal(ScalarValue::from(0i32), None)); } let group_by_expr_count = group_by_expr.len(); let literal = |value: usize| { if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8)) + Expr::Literal(ScalarValue::from(value as u8), None) } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16)) + Expr::Literal(ScalarValue::from(value as u16), None) } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32)) + Expr::Literal(ScalarValue::from(value as u32), None) } else { - Expr::Literal(ScalarValue::from(value as u64)) + Expr::Literal(ScalarValue::from(value as u64), None) } }; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 7034982956ae..b5a3e9a2d585 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -579,7 +579,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Expr::Alias(_) | Expr::Column(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::SimilarTo(_) | Expr::IsNotNull(_) | Expr::IsNull(_) diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 1378b53fa73f..63236787743a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -494,9 +494,12 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + Transformed::yes(Expr::Literal( + ScalarValue::Int64(Some(0)), + None, + )) } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null, None)) } } _ => Transformed::no(expr), @@ -587,10 +590,10 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = simplifier.simplify(result_expr)?; match &result_expr { // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + Expr::Literal(ScalarValue::Null, _) + | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None, // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(Some(true)), _) => { for (name, exprs) in input_expr_result_map_for_count_bug { expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); } @@ -605,7 +608,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( Box::new(result_expr.clone()), Box::new(input_expr.clone()), )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))), }); let expr_key = new_expr.schema_name().to_string(); expr_result_map_for_count_bug.insert(expr_key, new_expr); diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 452df6e8331f..e28771be548b 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -60,7 +60,7 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), + predicate: Expr::Literal(ScalarValue::Boolean(v), _), input, .. }) => match v { @@ -122,7 +122,7 @@ mod tests { #[test] fn filter_null() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + let filter_expr = Expr::Literal(ScalarValue::Boolean(None), None); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 604f083b3709..9c47ce024f91 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -101,7 +101,7 @@ fn is_constant_expression(expr: &Expr) -> bool { Expr::BinaryExpr(e) => { is_constant_expression(&e.left) && is_constant_expression(&e.right) } - Expr::Literal(_) => true, + Expr::Literal(_, _) => true, Expr::ScalarFunction(e) => { matches!( e.func.signature().volatility, diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 2aad889b2fcb..dfc3a220d0f9 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -54,7 +54,7 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: join.schema, diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index ba583a8d7123..d0457e709026 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -533,7 +533,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// Rewrites a projection expression using the projection before it (i.e. its input) @@ -583,8 +583,18 @@ fn is_expr_trivial(expr: &Expr) -> bool { fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { expr.transform_up(|expr| { match expr { - // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + // remove any intermediate aliases if they do not carry metadata + Expr::Alias(alias) => { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } + } Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7c352031bce6..1c1996d6a241 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -254,7 +254,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; predicate.apply(|expr| match expr { Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Placeholder(_) | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 897e07cb987e..2f9a2f6bb9ed 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -335,7 +335,7 @@ fn build_join( .join_on( sub_query_alias, JoinType::Left, - vec![Expr::Literal(ScalarValue::Boolean(Some(true)))], + vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)], )? .build()? } @@ -365,7 +365,7 @@ fn build_join( ), ( Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null)), + Box::new(Expr::Literal(ScalarValue::Null, None)), ), ], else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fa565a973f6b..e91aea3305be 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -18,7 +18,7 @@ //! Expression simplification API use std::borrow::Cow; -use std::collections::HashSet; +use std::collections::{BTreeMap, HashSet}; use std::ops::Not; use arrow::{ @@ -477,7 +477,7 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { + (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -523,9 +523,9 @@ struct ConstEvaluator<'a> { #[allow(clippy::large_enum_variant)] enum ConstSimplifyResult { // Expr was simplified and contains the new expression - Simplified(ScalarValue), + Simplified(ScalarValue, Option>), // Expr was not simplified and original value is returned - NotSimplified(ScalarValue), + NotSimplified(ScalarValue, Option>), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -567,11 +567,11 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => match self.evaluate_to_scalar(expr) { - ConstSimplifyResult::Simplified(s) => { - Ok(Transformed::yes(Expr::Literal(s))) + ConstSimplifyResult::Simplified(s, m) => { + Ok(Transformed::yes(Expr::Literal(s, m))) } - ConstSimplifyResult::NotSimplified(s) => { - Ok(Transformed::no(Expr::Literal(s))) + ConstSimplifyResult::NotSimplified(s, m) => { + Ok(Transformed::no(Expr::Literal(s, m))) } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) @@ -640,7 +640,7 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) | Expr::BinaryExpr { .. } @@ -666,8 +666,8 @@ impl<'a> ConstEvaluator<'a> { /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { - if let Expr::Literal(s) = expr { - return ConstSimplifyResult::NotSimplified(s); + if let Expr::Literal(s, m) = expr { + return ConstSimplifyResult::NotSimplified(s, m); } let phys_expr = @@ -675,6 +675,18 @@ impl<'a> ConstEvaluator<'a> { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + let metadata = phys_expr + .return_field(self.input_batch.schema_ref()) + .ok() + .and_then(|f| { + let m = f.metadata(); + match m.is_empty() { + true => None, + false => { + Some(m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()) + } + } + }); let col_val = match phys_expr.evaluate(&self.input_batch) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -687,13 +699,15 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else if as_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::List( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::List(a.as_list::().to_owned().into()), + metadata, + ) } else if as_large_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::LargeList( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::LargeList(a.as_list::().to_owned().into()), + metadata, + ) } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { @@ -705,7 +719,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -723,7 +737,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } } @@ -1138,9 +1152,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // @@ -1181,9 +1196,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // A & !A -> 0 (if A not nullable) @@ -1192,9 +1208,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) & A --> (..A..) @@ -1267,9 +1284,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A | !A -> -1 (if A not nullable) @@ -1278,9 +1296,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) | A --> (..A..) @@ -1353,9 +1372,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A ^ !A -> -1 (if A not nullable) @@ -1364,9 +1384,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1377,7 +1398,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&right)?)?, + None, + ) } else { expr }) @@ -1391,7 +1415,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + ) } else { expr }) @@ -1642,7 +1669,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { expr, list, negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null, None) => { Transformed::yes(lit(negated)) } @@ -1868,7 +1895,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { .into_iter() .map(|right| { match right { - Expr::Literal(right_lit_value) => { + Expr::Literal(right_lit_value, _) => { // if the right_lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else { @@ -1902,18 +1929,18 @@ impl TreeNodeRewriter for Simplifier<'_, S> { fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { - Expr::Literal(ScalarValue::Utf8(s)) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s)) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s)) => Some((DataType::Utf8View, s)), + Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), + Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), + Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), _ => None, } } fn to_string_scalar(data_type: DataType, value: Option) -> Expr { match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value)), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value)), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value)), + DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), + DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), + DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), _ => unreachable!(), } } @@ -1959,12 +1986,12 @@ fn as_inlist(expr: &Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], negated: false, @@ -1984,12 +2011,12 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { expr: left, list: vec![*right], negated: false, }), - (Expr::Literal(_), Expr::Column(_)) => Some(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { expr: right, list: vec![*left], negated: false, @@ -2408,7 +2435,7 @@ mod tests { #[test] fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); + let null = Expr::Literal(ScalarValue::Null, None); // A * null --> null { let expr = col("c2") * null.clone(); @@ -4543,10 +4570,10 @@ mod tests { // The simplifier removes the cast. assert_eq!( simplify(coerced), - col("c5").eq(Expr::Literal(ScalarValue::FixedSizeBinary( - 3, - Some(bytes.to_vec()), - ))) + col("c5").eq(Expr::Literal( + ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),), + None + )) ); } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 2c11632ad6d2..bbb023cfbad9 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -84,7 +84,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { low, high, }) => { - if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( self.guarantees.get(inner.as_ref()), low.as_ref(), high.as_ref(), @@ -115,7 +115,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(left.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = left.as_ref() { + if let Expr::Literal(value, _) = left.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -126,7 +126,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(right.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = right.as_ref() { + if let Expr::Literal(value, _) = right.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -168,7 +168,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { let new_list: Vec = list .iter() .filter_map(|expr| { - if let Expr::Literal(item) = expr { + if let Expr::Literal(item, _) = expr { match interval .contains(NullableInterval::from(item.clone())) { @@ -415,7 +415,7 @@ mod tests { let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone())); + assert_eq!(output, Expr::Literal(scalar.clone(), None)); } } diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index ec6485bf4b44..82c5ea3d8d82 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -46,7 +46,7 @@ pub fn simplify_regex_expr( ) -> Result { let mode = OperatorMode::new(&op); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = right.as_ref() { + if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { // Handle the special case for ".*" pattern if pattern == ANY_CHAR_REGEX_PATTERN { let new_expr = if mode.not { @@ -121,7 +121,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), escape_char: None, case_insensitive: self.i, }; diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index b70b19bae6df..7c8ff8305e84 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -76,7 +76,7 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( match (cast_expr, literal) { ( Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), - Expr::Literal(lit_value), + Expr::Literal(lit_value, _), ) => { let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); @@ -126,7 +126,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< | Expr::Cast(Cast { expr: left_expr, .. }), - Expr::Literal(lit_val), + Expr::Literal(lit_val, _), ) => { let Ok(expr_type) = info.get_data_type(left_expr) else { return false; @@ -183,7 +183,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< } match right { - Expr::Literal(lit_val) + Expr::Literal(lit_val, _) if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {} _ => return false, } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index cf182175e48e..4df0e125eb18 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -139,34 +139,34 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> pub fn is_zero(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + Expr::Literal(ScalarValue::Int8(Some(0)), _) + | Expr::Literal(ScalarValue::Int16(Some(0)), _) + | Expr::Literal(ScalarValue::Int32(Some(0)), _) + | Expr::Literal(ScalarValue::Int64(Some(0)), _) + | Expr::Literal(ScalarValue::UInt8(Some(0)), _) + | Expr::Literal(ScalarValue::UInt16(Some(0)), _) + | Expr::Literal(ScalarValue::UInt32(Some(0)), _) + | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true, _ => false, } } pub fn is_one(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => { + Expr::Literal(ScalarValue::Int8(Some(1)), _) + | Expr::Literal(ScalarValue::Int16(Some(1)), _) + | Expr::Literal(ScalarValue::Int32(Some(1)), _) + | Expr::Literal(ScalarValue::Int64(Some(1)), _) + | Expr::Literal(ScalarValue::UInt8(Some(1)), _) + | Expr::Literal(ScalarValue::UInt16(Some(1)), _) + | Expr::Literal(ScalarValue::UInt32(Some(1)), _) + | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => { *s >= 0 && POWS_OF_TEN .get(*s as usize) @@ -179,7 +179,7 @@ pub fn is_one(s: &Expr) -> bool { pub fn is_true(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v, _ => false, } } @@ -187,24 +187,24 @@ pub fn is_true(expr: &Expr) -> bool { /// returns true if expr is a /// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise pub fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) + matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _)) } /// Return a literal NULL value of Boolean data type pub fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) + Expr::Literal(ScalarValue::Boolean(None), None) } pub fn is_null(expr: &Expr) -> bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } pub fn is_false(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v), _ => false, } } @@ -247,7 +247,7 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), + Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v), _ => internal_err!("Expected boolean literal, got {expr:?}"), } } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 135f37dd9883..0aa0bf3ea430 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -163,7 +163,11 @@ mod tests { (Expr::IsNotNull(Box::new(col("a"))), true), // a = NULL ( - binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + binary_expr( + col("a"), + Operator::Eq, + Expr::Literal(ScalarValue::Null, None), + ), true, ), // a > 8 @@ -226,12 +230,16 @@ mod tests { ), // a IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + in_list( + col("a"), + vec![Expr::Literal(ScalarValue::Null, None)], + false, + ), true, ), // a NOT IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), ]; diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index 9785203a7020..756fb638af2b 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -342,7 +342,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "42", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); let dynamic_filter_2 = reassign_predicate_columns( Arc::clone(&dynamic_filter) as Arc, &filter_schema_2, @@ -350,7 +350,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "42", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); // Both filters allow evaluating the same expression let batch_1 = RecordBatch::try_new( Arc::clone(&filter_schema_1), diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 469f7bbee317..a1a14b2f30ff 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1451,7 +1451,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a NOT IN ('a', 'b') let list = vec![lit("a"), lit("b")]; @@ -1459,7 +1459,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1467,7 +1467,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b, NULL)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"NULL\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a NOT IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1475,7 +1475,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b, NULL)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"NULL\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 6f7caaea8d45..0d4d62ef4719 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,11 +18,13 @@ //! Literal expressions for physical operations use std::any::Any; +use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::{Field, FieldRef}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -34,15 +36,48 @@ use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq)] pub struct Literal { value: ScalarValue, + field: FieldRef, +} + +impl Hash for Literal { + fn hash(&self, state: &mut H) { + self.value.hash(state); + let metadata = self.field.metadata(); + let mut keys = metadata.keys().collect::>(); + keys.sort(); + for key in keys { + key.hash(state); + metadata.get(key).unwrap().hash(state); + } + } } impl Literal { /// Create a literal value expression pub fn new(value: ScalarValue) -> Self { - Self { value } + Self::new_with_metadata(value, None) + } + + /// Create a literal value expression + pub fn new_with_metadata( + value: ScalarValue, + metadata: impl Into>>, + ) -> Self { + let metadata = metadata.into(); + let mut field = + Field::new(format!("{value}"), value.data_type(), value.is_null()); + + if let Some(metadata) = metadata { + field = field.with_metadata(metadata); + } + + Self { + value, + field: field.into(), + } } /// Get the scalar value @@ -71,6 +106,10 @@ impl PhysicalExpr for Literal { Ok(self.value.is_null()) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + fn evaluate(&self, _batch: &RecordBatch) -> Result { Ok(ColumnarValue::Scalar(self.value.clone())) } @@ -102,7 +141,7 @@ impl PhysicalExpr for Literal { /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { - Expr::Literal(v) => Arc::new(Literal::new(v)), + Expr::Literal(v, _) => Arc::new(Literal::new(v)), _ => unreachable!(), } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8660bff796d5..6f1417ec23bf 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use crate::ScalarFunctionExpr; @@ -111,14 +112,42 @@ pub fn create_physical_expr( let input_schema: &Schema = &input_dfschema.into(); match e { - Expr::Alias(Alias { expr, .. }) => { - Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + Expr::Alias(Alias { expr, metadata, .. }) => { + if let Expr::Literal(v, prior_metadata) = expr.as_ref() { + let mut new_metadata = prior_metadata + .as_ref() + .map(|m| { + m.iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>() + }) + .unwrap_or_default(); + if let Some(metadata) = metadata { + new_metadata.extend(metadata.clone()); + } + let new_metadata = match new_metadata.is_empty() { + true => None, + false => Some(new_metadata), + }; + + Ok(Arc::new(Literal::new_with_metadata( + v.clone(), + new_metadata, + ))) + } else { + Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + } } Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::Literal(value, metadata) => Ok(Arc::new(Literal::new_with_metadata( + value.clone(), + metadata + .as_ref() + .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()), + ))), Expr::ScalarVariable(_, variable_names) => { if is_system_variables(variable_names) { match execution_props.get_var_provider(VarType::System) { @@ -168,7 +197,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -176,7 +205,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -347,7 +376,7 @@ pub fn create_physical_expr( list, negated, }) => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 38546fa38064..1b5527c14a49 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -268,7 +268,7 @@ pub fn parse_expr( ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { let scalar_value: ScalarValue = literal.try_into()?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::Literal(scalar_value, None)) } ExprType::WindowExpr(expr) => { let window_function = expr diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 18073516610c..7f089b1c8467 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -217,7 +217,7 @@ pub fn serialize_expr( expr_type: Some(ExprType::Alias(alias)), } } - Expr::Literal(value) => { + Expr::Literal(value, _) => { let pb_value: protobuf::ScalarValue = value.try_into()?; protobuf::LogicalExprNode { expr_type: Some(ExprType::Literal(pb_value)), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3edf152f4c71..993cc6f87ca3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1968,7 +1968,7 @@ fn roundtrip_case_with_null() { let test_expr = Expr::Case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), + Some(Box::new(Expr::Literal(ScalarValue::Null, None))), )); let ctx = SessionContext::new(); @@ -1977,7 +1977,7 @@ fn roundtrip_case_with_null() { #[test] fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); + let test_expr = Expr::Literal(ScalarValue::Null, None); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index ed99150831e7..c9ef4377d43b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -256,7 +256,7 @@ fn test_expression_serialization_roundtrip() { use datafusion_proto::logical_plan::from_proto::parse_expr; let ctx = SessionContext::new(); - let lit = Expr::Literal(ScalarValue::Utf8(None)); + let lit = Expr::Literal(ScalarValue::Utf8(None), None); for function in string::functions() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index eadf66a91ef3..e92869873731 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -215,7 +215,7 @@ impl SqlToRel<'_, S> { } SQLExpr::Extract { field, expr, .. } => { let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), + Expr::Literal(ScalarValue::from(format!("{field}")), None), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index 59c78bc713cc..8f6e77e035c1 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -51,7 +51,7 @@ impl SqlToRel<'_, S> { (None, Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let from_logic = Expr::Literal(ScalarValue::Int64(Some(1)), None); let for_logic = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index b77f5eaf45da..7075a1afd9dd 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -50,7 +50,7 @@ impl SqlToRel<'_, S> { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), - Value::Null => Ok(Expr::Literal(ScalarValue::Null)), + Value::Null => Ok(Expr::Literal(ScalarValue::Null, None)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) @@ -380,11 +380,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal128( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal128(Some(val), precision as u8, scale as i8), + None, + )) } else if precision <= DECIMAL256_MAX_PRECISION as u64 { let val = bigint_to_i256(&int_val).ok_or_else(|| { // Failures are unexpected here as we have already checked the precision @@ -393,11 +392,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal256( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal256(Some(val), precision as u8, scale as i8), + None, + )) } else { not_impl_err!( "Decimal precision {} exceeds the maximum supported precision: {}", @@ -483,10 +481,13 @@ mod tests { ]; for (input, expect) in cases { let output = parse_decimal(input, true).unwrap(); - assert_eq!(output, Expr::Literal(expect.arithmetic_negate().unwrap())); + assert_eq!( + output, + Expr::Literal(expect.arithmetic_negate().unwrap(), None) + ); let output = parse_decimal(input, false).unwrap(); - assert_eq!(output, Expr::Literal(expect)); + assert_eq!(output, Expr::Literal(expect, None)); } // scale < i8::MIN diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 458b3ac13217..dafb0346485e 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -2065,7 +2065,7 @@ impl SqlToRel<'_, S> { .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 661e8581ac06..cce14894acaf 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -187,7 +187,7 @@ impl Unparser<'_> { Expr::Cast(Cast { expr, data_type }) => { Ok(self.cast_to_sql(expr, data_type)?) } - Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), + Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(window_fun) => { let WindowFunction { @@ -602,7 +602,7 @@ impl Unparser<'_> { .chunks_exact(2) .map(|chunk| { let key = match &chunk[0] { - Expr::Literal(ScalarValue::Utf8(Some(s))) => self.new_ident_quoted_if_needs(s.to_string()), + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => self.new_ident_quoted_if_needs(s.to_string()), _ => return internal_err!("named_struct expects even arguments to be strings, but received: {:?}", &chunk[0]) }; @@ -631,7 +631,7 @@ impl Unparser<'_> { }; let field = match &args[1] { - Expr::Literal(lit) => self.new_ident_quoted_if_needs(lit.to_string()), + Expr::Literal(lit, _) => self.new_ident_quoted_if_needs(lit.to_string()), _ => { return internal_err!( "get_field expects second argument to be a string, but received: {:?}", @@ -1911,87 +1911,87 @@ mod tests { r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(0))), + Expr::Literal(ScalarValue::Date64(Some(0)), None), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(10000))), + Expr::Literal(ScalarValue::Date64(Some(10000)), None), r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(-10000))), + Expr::Literal(ScalarValue::Date64(Some(-10000)), None), r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(0))), + Expr::Literal(ScalarValue::Date32(Some(0)), None), r#"CAST('1970-01-01' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(10))), + Expr::Literal(ScalarValue::Date32(Some(10)), None), r#"CAST('1970-01-11' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(-1))), + Expr::Literal(ScalarValue::Date32(Some(-1)), None), r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None), None), r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampSecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMillisecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMicrosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::Time32Second(Some(10001))), + Expr::Literal(ScalarValue::Time32Second(Some(10001)), None), r#"CAST('02:46:41' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + Expr::Literal(ScalarValue::Time32Millisecond(Some(10001)), None), r#"CAST('00:00:10.001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Microsecond(Some(10001)), None), r#"CAST('00:00:00.010001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001)), None), r#"CAST('00:00:00.000010001' AS TIME)"#, ), (sum(col("a")), r#"sum(a)"#), @@ -2136,19 +2136,17 @@ mod tests { (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( - Some(100123), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal128(Some(100123), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( - Some(100123.into()), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal256(Some(100123.into()), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( @@ -2184,28 +2182,39 @@ mod tests { "MAP {'a': 1, 'b': 2}", ), ( - Expr::Literal(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), - )), + Expr::Literal( + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), "'foo'", ), ( - Expr::Literal(ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ + Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![ Some(1), Some(2), Some(3), - ])]), - ))), + ])]))), + None, + ), "[1, 2, 3]", ), ( - Expr::Literal(ScalarValue::LargeList(Arc::new( - LargeListArray::from_iter_primitive::(vec![Some( - vec![Some(1), Some(2), Some(3)], - )]), - ))), + Expr::Literal( + ScalarValue::LargeList(Arc::new( + LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]), + )), + None, + ), "[1, 2, 3]", ), ( @@ -2510,11 +2519,17 @@ mod tests { #[test] fn test_float_scalar_to_expr() { let tests = [ - (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), - (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), - (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), + (Expr::Literal(ScalarValue::Float64(Some(3f64)), None), "3.0"), ( - Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), + Expr::Literal(ScalarValue::Float64(Some(3.1f64)), None), + "3.1", + ), + ( + Expr::Literal(ScalarValue::Float32(Some(-2f32)), None), + "-2.0", + ), + ( + Expr::Literal(ScalarValue::Float32(Some(-2.989f32)), None), "-2.989", ), ]; @@ -2534,18 +2549,20 @@ mod tests { let tests = [ ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::Binary, }), "'blah'", ), ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::BinaryView, }), "'blah'", @@ -2637,7 +2654,10 @@ mod tests { let expr = ScalarUDF::new_from_impl( datafusion_functions::datetime::date_part::DatePartFunc::new(), ) - .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + .call(vec![ + Expr::Literal(ScalarValue::new_utf8(unit), None), + col("x"), + ]); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2757,10 +2777,10 @@ mod tests { (&mysql_dialect, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Literal(ScalarValue::TimestampMillisecond( - Some(1738285549123), + let expr = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(1738285549123), None), None, - )); + ); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2828,9 +2848,10 @@ mod tests { fn test_cast_value_to_dict_expr() { let tests = [( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "variation".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("variation".to_string())), + None, + )), data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), }), "'variation'", @@ -2868,7 +2889,7 @@ mod tests { expr: Box::new(col("a")), data_type: DataType::Float64, }), - Expr::Literal(ScalarValue::Int64(Some(2))), + Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); let ast = unparser.expr_to_sql(&expr)?; @@ -3008,7 +3029,7 @@ mod tests { datafusion_functions::datetime::date_trunc::DateTruncFunc::new(), )), args: vec![ - Expr::Literal(ScalarValue::Utf8(Some(precision.to_string()))), + Expr::Literal(ScalarValue::Utf8(Some(precision.to_string())), None), col("date_col"), ], }); diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e89e25ddb15a..f6677617031f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -1078,6 +1078,7 @@ impl Unparser<'_> { if project_vec.is_empty() { builder = builder.project(vec![Expr::Literal( ScalarValue::Int64(Some(1)), + None, )])?; } else { let project_columns = project_vec diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index c36ffbfe5ecf..89fa392c183f 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -422,7 +422,7 @@ pub(crate) fn date_part_to_sql( match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { let date_expr = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, "month" => ast::DateTimeField::Month, @@ -443,7 +443,7 @@ pub(crate) fn date_part_to_sql( (DateFieldExtractStyle::Strftime, 2) => { let column = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => "%Y", "month" => "%m", @@ -531,7 +531,7 @@ pub(crate) fn sqlite_from_unixtime_to_sql( "datetime", &[ from_unixtime_args[0].clone(), - Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None), ], )?)) } @@ -554,7 +554,7 @@ pub(crate) fn sqlite_date_trunc_to_sql( ); } - if let Expr::Literal(ScalarValue::Utf8(Some(unit))) = &date_trunc_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] { let format = match unit.to_lowercase().as_str() { "year" => "%Y", "month" => "%Y-%m", @@ -568,7 +568,7 @@ pub(crate) fn sqlite_date_trunc_to_sql( return Ok(Some(unparser.scalar_function_to_sql( "strftime", &[ - Expr::Literal(ScalarValue::Utf8(Some(format.to_string()))), + Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None), date_trunc_args[1].clone(), ], )?)); diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 067da40cf9a8..52832e1324be 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -198,7 +198,7 @@ pub(crate) fn resolve_positions_to_exprs( match expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 - Expr::Literal(ScalarValue::Int64(Some(position))) + Expr::Literal(ScalarValue::Int64(Some(position)), _) if position > 0_i64 && position <= select_exprs.len() as i64 => { let index = (position - 1) as usize; @@ -208,7 +208,7 @@ pub(crate) fn resolve_positions_to_exprs( _ => select_expr.clone(), }) } - Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( + Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!( "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", position, select_exprs.len() ), diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index d89ba600d7a6..ac96daed0d44 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6032,7 +6032,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6061,7 +6061,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6090,7 +6090,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6150,7 +6150,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "7f4b18de3cfeb9b4ac78c381ee2ad278", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "a", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "b", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "c", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index edc452284cf9..cd2f407387ed 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -88,7 +88,7 @@ physical_plan 21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49), field: Field { name: "49", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(14), field: Field { name: "14", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(23), field: Field { name: "23", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(45), field: Field { name: "45", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(19), field: Field { name: "19", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(3), field: Field { name: "3", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(36), field: Field { name: "36", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(9), field: Field { name: "9", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 26)--------------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false 27)--------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part index 3b15fb3d8e53..ace2081eb18f 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part @@ -69,7 +69,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE") }, Literal { value: Utf8View("SM BOX") }, Literal { value: Utf8View("SM PACK") }, Literal { value: Utf8View("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG") }, Literal { value: Utf8View("MED BOX") }, Literal { value: Utf8View("MED PKG") }, Literal { value: Utf8View("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE") }, Literal { value: Utf8View("LG BOX") }, Literal { value: Utf8View("LG PACK") }, Literal { value: Utf8View("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE"), field: Field { name: "SM CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM BOX"), field: Field { name: "SM BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PACK"), field: Field { name: "SM PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PKG"), field: Field { name: "SM PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG"), field: Field { name: "MED BAG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED BOX"), field: Field { name: "MED BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PKG"), field: Field { name: "MED PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PACK"), field: Field { name: "MED PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE"), field: Field { name: "LG CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG BOX"), field: Field { name: "LG BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PACK"), field: Field { name: "LG PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PKG"), field: Field { name: "LG PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -78,6 +78,6 @@ physical_plan 12)------------CoalesceBatchesExec: target_batch_size=8192 13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 14)----------------CoalesceBatchesExec: target_batch_size=8192 -15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE") }, Literal { value: Utf8View("SM BOX") }, Literal { value: Utf8View("SM PACK") }, Literal { value: Utf8View("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG") }, Literal { value: Utf8View("MED BOX") }, Literal { value: Utf8View("MED PKG") }, Literal { value: Utf8View("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE") }, Literal { value: Utf8View("LG BOX") }, Literal { value: Utf8View("LG PACK") }, Literal { value: Utf8View("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 +15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE"), field: Field { name: "SM CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM BOX"), field: Field { name: "SM BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PACK"), field: Field { name: "SM PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PKG"), field: Field { name: "SM PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG"), field: Field { name: "MED BAG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED BOX"), field: Field { name: "MED BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PKG"), field: Field { name: "MED PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PACK"), field: Field { name: "MED PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE"), field: Field { name: "LG CASE", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG BOX"), field: Field { name: "LG BOX", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PACK"), field: Field { name: "LG PACK", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PKG"), field: Field { name: "LG PKG", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 15) AND p_size@2 >= 1 16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 17)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 828bf967d8f4..6af91b4aaa42 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]) +17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13"), field: Field { name: "13", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("31"), field: Field { name: "31", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("23"), field: Field { name: "23", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("29"), field: Field { name: "29", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("30"), field: Field { name: "30", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("18"), field: Field { name: "18", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("17"), field: Field { name: "17", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -100,6 +100,6 @@ physical_plan 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]), projection=[c_acctbal@1] +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13"), field: Field { name: "13", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("31"), field: Field { name: "31", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("23"), field: Field { name: "23", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("29"), field: Field { name: "29", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("30"), field: Field { name: "30", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("18"), field: Field { name: "18", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("17"), field: Field { name: "17", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]), projection=[c_acctbal@1] 28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 29)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs index 7687d9f7642a..114fe1e7aecd 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs @@ -60,7 +60,7 @@ pub async fn from_substrait_agg_func( // we inject a dummy argument that does not affect the query, but allows // us to bypass this limitation. let args = if udaf.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] } else { args }; diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs index 5adc137d9a43..d054e5267554 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs @@ -51,7 +51,7 @@ pub async fn from_literal( expr: &Literal, ) -> datafusion::common::Result { let scalar_value = from_substrait_literal_without_names(consumer, expr)?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::Literal(scalar_value, None)) } pub(crate) fn from_substrait_literal_without_names( diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index 027b61124ead..7797c935211f 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -261,7 +261,7 @@ impl BuiltinExprBuilder { .await?; match escape_char_expr { - Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { + Expr::Literal(ScalarValue::Utf8(escape_char_string), _) => { // Convert Option to Option escape_char_string.and_then(|s| s.chars().next()) } @@ -337,7 +337,7 @@ mod tests { fn int64_literals(integers: &[i64]) -> Vec { integers .iter() - .map(|value| Expr::Literal(ScalarValue::Int64(Some(*value)))) + .map(|value| Expr::Literal(ScalarValue::Int64(Some(*value)), None)) .collect() } diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 4a7fde256b6c..80b643a547ee 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -94,7 +94,7 @@ pub async fn from_window_function( // we inject a dummy argument that does not affect the query, but allows // us to bypass this limitation. let args = if fun.name() == "count" && window.arguments.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] } else { from_substrait_func_args(consumer, &window.arguments, input_schema).await? }; diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs index 47af44c692ae..f1cbd16d2d8f 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -136,7 +136,7 @@ pub async fn from_read_rel( lit, &named_struct.names, &mut name_idx, - )?)) + )?, None)) }) .collect::>()?; if name_idx != named_struct.names.len() { diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index b69474f09ee4..9741dcdd1095 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -31,7 +31,7 @@ pub fn from_cast( ) -> datafusion::common::Result { let Cast { expr, data_type } = cast; // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null - if let Expr::Literal(lit) = expr.as_ref() { + if let Expr::Literal(lit, _) = expr.as_ref() { // only the untyped(a null scalar value) null literal need this special handling // since all other kind of nulls are already typed and can be handled by substrait // e.g. null:: or null:: @@ -92,7 +92,7 @@ mod tests { let empty_schema = DFSchemaRef::new(DFSchema::empty()); let field = Field::new("out", DataType::Int32, false); - let expr = Expr::Literal(ScalarValue::Null) + let expr = Expr::Literal(ScalarValue::Null, None) .cast_to(&DataType::Int32, &empty_schema) .unwrap(); @@ -119,7 +119,7 @@ mod tests { } // a typed null should not be folded - let expr = Expr::Literal(ScalarValue::Int64(None)) + let expr = Expr::Literal(ScalarValue::Int64(None), None) .cast_to(&DataType::Int32, &empty_schema) .unwrap(); diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index fbc4d3754df0..42e1f962f1d1 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -109,7 +109,7 @@ pub fn to_substrait_rex( Expr::ScalarVariable(_, _) => { not_impl_err!("Cannot convert {expr:?} to Substrait") } - Expr::Literal(expr) => producer.handle_literal(expr), + Expr::Literal(expr, _) => producer.handle_literal(expr), Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), Expr::Like(expr) => producer.handle_like(expr, schema), Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), @@ -172,7 +172,7 @@ mod tests { let state = SessionStateBuilder::default().build(); // One expression, empty input schema - let expr = Expr::Literal(ScalarValue::Int32(Some(42))); + let expr = Expr::Literal(ScalarValue::Int32(Some(42)), None); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let substrait = diff --git a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs index e4e0ab11c65a..212874e7913b 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs @@ -115,10 +115,10 @@ pub fn from_values( let fields = row .iter() .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(producer, sv), + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), Expr::Alias(alias) => match alias.expr.as_ref() { // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(producer, sv), + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), _ => Err(substrait_datafusion_err!( "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() )), diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 8fb8a59fb860..cd40e664239a 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -1076,7 +1076,7 @@ pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1117,7 +1117,7 @@ With the UDTF implemented, you can register it with the `SessionContext`: # # impl TableFunctionImpl for EchoFunction { # fn call(&self, exprs: &[Expr]) -> Result> { -# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # }; #