diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2aac1768ac63..8500a654fc8f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -5274,7 +5274,7 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, @r" - Projection: Int32(3) AS $1 [$1:Null;N] + Projection: Int32(3) AS $1 [$1:Int32] EmptyRelation: rows=1 [] " ); diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 8a0f62062738..1b4237077de2 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -18,7 +18,6 @@ use std::collections::HashMap; use super::*; -use datafusion::assert_batches_eq; use datafusion_common::{metadata::ScalarAndMetadata, ParamValues, ScalarValue}; use insta::assert_snapshot; @@ -343,26 +342,53 @@ async fn test_query_parameters_with_metadata() -> Result<()> { ])) .unwrap(); - // df_with_params_replaced.schema() is not correct here - // https://github.com/apache/datafusion/issues/18102 - let batches = df_with_params_replaced.clone().collect().await.unwrap(); - let schema = batches[0].schema(); - + let schema = df_with_params_replaced.schema(); assert_eq!(schema.field(0).data_type(), &DataType::UInt32); assert_eq!(schema.field(0).metadata(), &metadata1); assert_eq!(schema.field(1).data_type(), &DataType::Utf8); assert_eq!(schema.field(1).metadata(), &metadata2); - assert_batches_eq!( - [ - "+----+-----+", - "| $1 | $2 |", - "+----+-----+", - "| 1 | two |", - "+----+-----+", - ], - &batches - ); + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-----+ + | $1 | $2 | + +----+-----+ + | 1 | two | + +----+-----+ + "); + + Ok(()) +} + +/// Test for https://github.com/apache/datafusion/issues/18102 +#[tokio::test] +async fn test_query_parameters_in_values_list_relation() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx + .sql("SELECT a, b FROM (VALUES ($1, $2)) AS t(a, b)") + .await + .unwrap(); + + let df_with_params_replaced = df + .with_param_values(ParamValues::List(vec![ + ScalarAndMetadata::new(ScalarValue::UInt32(Some(1)), None), + ScalarAndMetadata::new(ScalarValue::Utf8(Some("two".to_string())), None), + ])) + .unwrap(); + + let schema = df_with_params_replaced.schema(); + assert_eq!(schema.field(0).data_type(), &DataType::UInt32); + assert_eq!(schema.field(1).data_type(), &DataType::Utf8); + + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +---+-----+ + | a | b | + +---+-----+ + | 1 | two | + +---+-----+ + "); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index b9afd894d77d..e4b786b5dde6 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -216,10 +216,9 @@ impl LogicalPlanBuilder { if values.is_empty() { return plan_err!("Values list cannot be empty"); } + + // values list can have no columns, see: https://github.com/apache/datafusion/pull/12339 let n_cols = values[0].len(); - if n_cols == 0 { - return plan_err!("Values list cannot be zero length"); - } for (i, row) in values.iter().enumerate() { if row.len() != n_cols { return plan_err!( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 0f0d81186d68..0aafc375a64f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -633,8 +633,46 @@ impl LogicalPlan { LogicalPlan::Dml(_) => Ok(self), LogicalPlan::Copy(_) => Ok(self), LogicalPlan::Values(Values { schema, values }) => { - // todo it isn't clear why the schema is not recomputed here - Ok(LogicalPlan::Values(Values { schema, values })) + // We cannot compute the correct schema if we only use values. + // + // For example, given the following plan: + // Projection: col_1, col_2 + // Values: (Float32(1), Float32(10)), (Float32(100), Float32(10)) + // + // We wouldn't know about `col_1`, and `col_2` if we only relied on `values`. + // To correctly recompute the new schema, we also need to retain some information + // from the original schema. + let new_plan = LogicalPlanBuilder::values(values.clone())?.build()?; + + let qualified_fields = schema + .iter() + .zip(new_plan.schema().fields()) + .map(|((table_ref, old_field), new_field)| { + // `old_field`'s data type is unknown but `new_field`'s is known + if old_field.data_type().is_null() + && !new_field.data_type().is_null() + { + let field = old_field + .as_ref() + .clone() + .with_data_type(new_field.data_type().clone()); + (table_ref.cloned(), Arc::new(field)) + } else { + (table_ref.cloned(), Arc::clone(old_field)) + } + }) + .collect::>(); + + let schema = DFSchema::new_with_metadata( + qualified_fields, + schema.metadata().clone(), + )? + .with_functional_dependencies(schema.functional_dependencies().clone())?; + + Ok(LogicalPlan::Values(Values { + schema: Arc::new(schema), + values, + })) } LogicalPlan::Filter(Filter { predicate, input }) => { Filter::try_new(predicate, input).map(LogicalPlan::Filter) @@ -1471,7 +1509,10 @@ impl LogicalPlan { // Preserve name to avoid breaking column references to this expression Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) } - }) + })? + // always recompute the schema to ensure the changed in the schema's field should be + // poplulated to the plan's parent + .map_data(|plan| plan.recompute_schema()) }) .map(|res| res.data) } @@ -4247,6 +4288,7 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; + use crate::select_expr::SelectExpr; use crate::test::function_stub::{count, count_udaf}; use crate::{ binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, @@ -4825,6 +4867,82 @@ mod tests { .expect_err("prepared field metadata mismatch unexpectedly succeeded"); } + #[test] + fn test_replace_placeholder_values_list_relation_valid_schema() { + // SELECT a, b, c FROM (VALUES (1, $1, $2 + $3) AS t(a, b, c); + let plan = LogicalPlanBuilder::values(vec![vec![ + lit(1), + placeholder("$1"), + binary_expr(placeholder("$2"), Operator::Plus, placeholder("$3")), + ]]) + .unwrap() + .project(vec![ + col("column1").alias("a"), + col("column2").alias("b"), + col("column3").alias("c"), + ]) + .unwrap() + .alias("t") + .unwrap() + .project(vec![col("a"), col("b"), col("c")]) + .unwrap() + .build() + .unwrap(); + + // original + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: t.a, t.b, t.c [a:Int32;N, b:Null;N, c:Int64;N] + SubqueryAlias: t [a:Int32;N, b:Null;N, c:Int64;N] + Projection: column1 AS a, column2 AS b, column3 AS c [a:Int32;N, b:Null;N, c:Int64;N] + Values: (Int32(1), $1, $2 + $3) [column1:Int32;N, column2:Null;N, column3:Int64;N] + "); + + let plan = plan + .with_param_values(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ]) + .unwrap(); + + // replaced + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: t.a, t.b, t.c [a:Int32;N, b:Int32;N, c:Int64;N] + SubqueryAlias: t [a:Int32;N, b:Int32;N, c:Int64;N] + Projection: column1 AS a, column2 AS b, column3 AS c [a:Int32;N, b:Int32;N, c:Int64;N] + Values: (Int32(1), Int32(1) AS $1, Int32(2) + Int32(3) AS $2 + $3) [column1:Int32;N, column2:Int32;N, column3:Int64;N] + "); + } + + #[test] + fn test_replace_placeholder_empty_relation_valid_schema() { + // SELECT $1, $2; + let plan = LogicalPlanBuilder::empty(false) + .project(vec![ + SelectExpr::from(placeholder("$1")), + SelectExpr::from(placeholder("$2")), + ]) + .unwrap() + .build() + .unwrap(); + + // original + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: $1, $2 [$1:Null;N, $2:Null;N] + EmptyRelation: rows=0 [] + "); + + let plan = plan + .with_param_values(vec![ScalarValue::from(1i32), ScalarValue::from("s")]) + .unwrap(); + + // replaced + assert_snapshot!(plan.display_indent_schema(), @r#" + Projection: Int32(1) AS $1, Utf8("s") AS $2 [$1:Int32, $2:Utf8] + EmptyRelation: rows=0 [] + "#); + } + #[test] fn test_nullable_schema_after_grouping_set() { let schema = Schema::new(vec![ diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f14d4cbf1fcc..a6d0d620f505 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -32,8 +32,8 @@ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ - Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, Repartition, - UserDefinedLogicalNode, Values, Volatility, + Extension, InvariantLevel, LogicalPlan, LogicalPlanBuilder, PartitionEvaluator, + Repartition, UserDefinedLogicalNode, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -1258,10 +1258,10 @@ async fn roundtrip_values() -> Result<()> { async fn roundtrip_values_no_columns() -> Result<()> { let ctx = create_context().await?; // "VALUES ()" is not yet supported by the SQL parser, so we construct the plan manually - let plan = LogicalPlan::Values(Values { - values: vec![vec![], vec![]], // two rows, no columns - schema: DFSchemaRef::new(DFSchema::empty()), - }); + let plan = LogicalPlanBuilder::values( + vec![vec![], vec![]], // two rows, no columns + )? + .build()?; roundtrip_logical_plan_with_ctx(plan, ctx).await?; Ok(()) }