From c6b226edf5370c25cfcc4afeb789be22a6cf52fd Mon Sep 17 00:00:00 2001 From: jonahgao Date: Mon, 30 Oct 2023 23:31:56 +0800 Subject: [PATCH] fix: generate logical plan for `UPDATE SET FROM` statement --- datafusion/sql/src/statement.rs | 68 ++++++++++--------- datafusion/sqllogictest/test_files/update.slt | 36 ++++++++++ 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 80a27db6e63d..7302228f8f56 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -963,10 +963,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; let table_source = self.context_provider.get_table_source(table_name.clone())?; - let arrow_schema = (*table_source.schema()).clone(); let table_schema = Arc::new(DFSchema::try_from_qualified_schema( table_name.clone(), - &arrow_schema, + &table_source.schema(), )?); // Overwrite with assignment expressions @@ -985,21 +984,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - let values_and_types = table_schema - .fields() - .iter() - .map(|f| { - let col_name = f.name(); - let val = assign_map.remove(col_name).unwrap_or_else(|| { - ast::Expr::Identifier(ast::Ident::from(col_name.as_str())) - }); - (col_name, val, f.data_type()) - }) - .collect::>(); - - // Build scan - let from = from.unwrap_or(table); - let scan = self.plan_from_tables(vec![from], &mut planner_context)?; + // Build scan, join with from table if it exists. + let mut input_tables = vec![table]; + input_tables.extend(from); + let scan = self.plan_from_tables(input_tables, &mut planner_context)?; // Filter let source = match predicate_expr { @@ -1007,33 +995,49 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(predicate_expr) => { let filter_expr = self.sql_to_expr( predicate_expr, - &table_schema, + scan.schema(), &mut planner_context, )?; let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[&table_schema]], + &[&[&scan.schema()]], &[using_columns], )?; LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) } }; - // Projection - let mut exprs = vec![]; - for (col_name, expr, dt) in values_and_types.into_iter() { - let mut expr = self.sql_to_expr(expr, &table_schema, &mut planner_context)?; - // Update placeholder's datatype to the type of the target column - if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr { - placeholder.data_type = - placeholder.data_type.take().or_else(|| Some(dt.clone())); - } - // Cast to target column type, if necessary - let expr = expr.cast_to(dt, source.schema())?.alias(col_name); - exprs.push(expr); - } + // Build updated values for each column, using the previous value if not modified + let exprs = table_schema + .fields() + .iter() + .map(|field| { + let expr = match assign_map.remove(field.name()) { + Some(new_value) => { + let mut expr = self.sql_to_expr( + new_value, + source.schema(), + &mut planner_context, + )?; + // Update placeholder's datatype to the type of the target column + if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr + { + placeholder.data_type = placeholder + .data_type + .take() + .or_else(|| Some(field.data_type().clone())); + } + // Cast to target column type, if necessary + expr.cast_to(field.data_type(), source.schema())? + } + None => datafusion_expr::Expr::Column(field.qualified_column()), + }; + Ok(expr.alias(field.name())) + }) + .collect::>>()?; + let source = project(source, exprs)?; let plan = LogicalPlan::Dml(DmlStatement { diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 4542a262390c..cb8c6a4fac28 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -41,3 +41,39 @@ logical_plan Dml: op=[Update] table=[t1] --Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d ----TableScan: t1 + +statement ok +create table t2(a int, b varchar, c double, d int); + +## set from subquery +query TT +explain update t1 set b = (select max(b) from t2 where t1.a = t2.a) +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t1.a AS a, () AS b, t1.c AS c, t1.d AS d +----Subquery: +------Projection: MAX(t2.b) +--------Aggregate: groupBy=[[]], aggr=[[MAX(t2.b)]] +----------Filter: outer_ref(t1.a) = t2.a +------------TableScan: t2 +----TableScan: t1 + +# set from other table +query TT +explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0; +---- +logical_plan +Dml: op=[Update] table=[t1] +--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d +----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) +------CrossJoin: +--------TableScan: t1 +--------TableScan: t2 + +statement ok +create table t3(a int, b varchar, c double, d int); + +# set from mutiple tables, sqlparser only supports from one table +query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) +explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; \ No newline at end of file