Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: generate logical plan for UPDATE SET FROM statement #7984

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// Do a table lookup to verify the table exists
let table_name = self.object_name_to_table_reference(table_name)?;
let table_source = self.context_provider.get_table_source(table_name.clone())?;
let arrow_schema = (*table_source.schema()).clone();
let table_schema = Arc::new(DFSchema::try_from_qualified_schema(
table_name.clone(),
&arrow_schema,
&table_source.schema(),
)?);

// Overwrite with assignment expressions
Expand All @@ -985,55 +984,60 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
})
.collect::<Result<HashMap<String, Expr>>>()?;

let values_and_types = table_schema
.fields()
.iter()
.map(|f| {
let col_name = f.name();
let val = assign_map.remove(col_name).unwrap_or_else(|| {
ast::Expr::Identifier(ast::Ident::from(col_name.as_str()))
});
(col_name, val, f.data_type())
})
.collect::<Vec<_>>();

// Build scan
let from = from.unwrap_or(table);
let scan = self.plan_from_tables(vec![from], &mut planner_context)?;
// Build scan, join with from table if it exists.
let mut input_tables = vec![table];
input_tables.extend(from);
let scan = self.plan_from_tables(input_tables, &mut planner_context)?;

// Filter
let source = match predicate_expr {
None => scan,
Some(predicate_expr) => {
let filter_expr = self.sql_to_expr(
predicate_expr,
&table_schema,
scan.schema(),
&mut planner_context,
)?;
let mut using_columns = HashSet::new();
expr_to_columns(&filter_expr, &mut using_columns)?;
let filter_expr = normalize_col_with_schemas_and_ambiguity_check(
filter_expr,
&[&[&table_schema]],
&[&[&scan.schema()]],
&[using_columns],
)?;
LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?)
}
};

// Projection
let mut exprs = vec![];
for (col_name, expr, dt) in values_and_types.into_iter() {
let mut expr = self.sql_to_expr(expr, &table_schema, &mut planner_context)?;
// Update placeholder's datatype to the type of the target column
if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr {
placeholder.data_type =
placeholder.data_type.take().or_else(|| Some(dt.clone()));
}
// Cast to target column type, if necessary
let expr = expr.cast_to(dt, source.schema())?.alias(col_name);
exprs.push(expr);
}
// Build updated values for each column, using the previous value if not modified
let exprs = table_schema
.fields()
.iter()
.map(|field| {
let expr = match assign_map.remove(field.name()) {
Some(new_value) => {
let mut expr = self.sql_to_expr(
new_value,
source.schema(),
&mut planner_context,
)?;
// Update placeholder's datatype to the type of the target column
if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr
{
placeholder.data_type = placeholder
.data_type
.take()
.or_else(|| Some(field.data_type().clone()));
}
// Cast to target column type, if necessary
expr.cast_to(field.data_type(), source.schema())?
}
None => datafusion_expr::Expr::Column(field.qualified_column()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this expression be cast as well (in case the column type in the source query/table doesn't match the type in the target)?

Copy link
Member Author

@jonahgao jonahgao Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The columns in the None arm are not modified as they do not appear in the assignment map.
For these columns, the old values in the target table will be used, so the data types remain the same.

};
Ok(expr.alias(field.name()))
})
.collect::<Result<Vec<_>>>()?;

let source = project(source, exprs)?;

let plan = LogicalPlan::Dml(DmlStatement {
Expand Down
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/update.slt
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,39 @@ logical_plan
Dml: op=[Update] table=[t1]
--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d
----TableScan: t1

statement ok
create table t2(a int, b varchar, c double, d int);

## set from subquery
query TT
explain update t1 set b = (select max(b) from t2 where t1.a = t2.a)
----
logical_plan
Dml: op=[Update] table=[t1]
--Projection: t1.a AS a, (<subquery>) AS b, t1.c AS c, t1.d AS d
----Subquery:
------Projection: MAX(t2.b)
--------Aggregate: groupBy=[[]], aggr=[[MAX(t2.b)]]
----------Filter: outer_ref(t1.a) = t2.a
------------TableScan: t2
----TableScan: t1

# set from other table
query TT
explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1.b > 'foo' and t2.c > 1.0;
----
logical_plan
Dml: op=[Update] table=[t1]
--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d
----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1)
------CrossJoin:
--------TableScan: t1
--------TableScan: t2

statement ok
create table t3(a int, b varchar, c double, d int);

# set from mutiple tables, sqlparser only supports from one table
query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\)
explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a;