Skip to content

Commit

Permalink
Feat: add builder methods to exp.Update and add with_ arg to exp.upda…
Browse files Browse the repository at this point in the history
…te (#4217)

* Add builder methods to exp.Update and add with_ arg to exp.update

* Forgot to run style

* Review changes: add append mode to Update.set_, use maybe_copy, remove suggestion that CTE query in with_ must be a Select (can include set operators)
  • Loading branch information
brdbry authored Oct 7, 2024
1 parent 22a1684 commit 354cfff
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 11 deletions.
228 changes: 217 additions & 11 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3290,6 +3290,200 @@ class Update(Expression):
"limit": False,
}

def table(
self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts
) -> Update:
"""
Set the table to update.
Example:
>>> Update().table("my_table").set_("x = 1").sql()
'UPDATE my_table SET x = 1'
Args:
expression : the SQL code strings to parse.
If a `Table` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Table`.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified Update expression.
"""
return _apply_builder(
expression=expression,
instance=self,
arg="this",
into=Table,
prefix=None,
dialect=dialect,
copy=copy,
**opts,
)

def set_(
self,
*expressions: ExpOrStr,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Update:
"""
Append to or set the SET expressions.
Example:
>>> Update().table("my_table").set_("x = 1").sql()
'UPDATE my_table SET x = 1'
Args:
*expressions: the SQL code strings to parse.
If `Expression` instance(s) are passed, they will be used as-is.
Multiple expressions are combined with a comma.
append: if `True`, add the new expressions to any existing SET expressions.
Otherwise, this resets the expressions.
dialect: the dialect used to parse the input expressions.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
"""
return _apply_list_builder(
*expressions,
instance=self,
arg="expressions",
append=append,
into=Expression,
prefix=None,
dialect=dialect,
copy=copy,
**opts,
)

def where(
self,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Select:
"""
Append to or set the WHERE expressions.
Example:
>>> Update().table("tbl").set_("x = 1").where("x = 'a' OR x < 'b'").sql()
"UPDATE tbl SET x = 1 WHERE x = 'a' OR x < 'b'"
Args:
*expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
append: if `True`, AND the new expressions to any existing expression.
Otherwise, this resets the expression.
dialect: the dialect used to parse the input expressions.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
"""
return _apply_conjunction_builder(
*expressions,
instance=self,
arg="where",
append=append,
into=Where,
dialect=dialect,
copy=copy,
**opts,
)

def from_(
self,
expression: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Update:
"""
Set the FROM expression.
Example:
>>> Update().table("my_table").set_("x = 1").from_("baz").sql()
'UPDATE my_table SET x = 1 FROM baz'
Args:
expression : the SQL code strings to parse.
If a `From` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `From`.
If nothing is passed in then a from is not applied to the expression
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified Update expression.
"""
if not expression:
return maybe_copy(self, copy)

return _apply_builder(
expression=expression,
instance=self,
arg="from",
into=From,
prefix="FROM",
dialect=dialect,
copy=copy,
**opts,
)

def with_(
self,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
materialized: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Update:
"""
Append to or set the common table expressions.
Example:
>>> Update().table("my_table").set_("x = 1").from_("baz").with_("baz", "SELECT id FROM foo").sql()
'WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz'
Args:
alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
materialized: set the MATERIALIZED part of the expression.
append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified expression.
"""
return _apply_cte_builder(
self,
alias,
as_,
recursive=recursive,
materialized=materialized,
append=append,
dialect=dialect,
copy=copy,
**opts,
)


class Values(UDTF):
arg_types = {"expressions": True, "alias": False}
Expand Down Expand Up @@ -6811,38 +7005,41 @@ def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select:

def update(
table: str | Table,
properties: dict,
properties: t.Optional[dict] = None,
where: t.Optional[ExpOrStr] = None,
from_: t.Optional[ExpOrStr] = None,
with_: t.Optional[t.Dict[str, ExpOrStr]] = None,
dialect: DialectType = None,
**opts,
) -> Update:
"""
Creates an update statement.
Example:
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql()
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz_cte", where="baz_cte.id > 1 and my_table.id = baz_cte.id", with_={"baz_cte": "SELECT id FROM foo"}).sql()
"WITH baz_cte AS (SELECT id FROM foo) UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz_cte WHERE baz_cte.id > 1 AND my_table.id = baz_cte.id"
Args:
*properties: dictionary of properties to set which are
properties: dictionary of properties to SET which are
auto converted to sql objects eg None -> NULL
where: sql conditional parsed into a WHERE statement
from_: sql statement parsed into a FROM statement
with_: dictionary of CTE aliases / select statements to include in a WITH clause.
dialect: the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Update: the syntax tree for the UPDATE statement.
"""
update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update_expr.set(
"expressions",
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
)
if properties:
update_expr.set(
"expressions",
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
)
if from_:
update_expr.set(
"from",
Expand All @@ -6855,6 +7052,15 @@ def update(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
if with_:
cte_list = [
CTE(this=maybe_parse(qry, dialect=dialect, **opts), alias=alias)
for alias, qry in with_.items()
]
update_expr.set(
"with",
With(expressions=cte_list),
)
return update_expr


Expand Down
30 changes: 30 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,36 @@ def test_build(self):
lambda: exp.update("tbl", {"x": 1}, from_="tbl2 cross join tbl3"),
"UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3",
),
(
lambda: exp.update(
"my_table",
{"x": 1},
from_="baz",
where="my_table.id = baz.id",
with_={"baz": "SELECT id FROM foo UNION SELECT id FROM bar"},
),
"WITH baz AS (SELECT id FROM foo UNION SELECT id FROM bar) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id",
),
(
lambda: exp.update("my_table").set_("x = 1"),
"UPDATE my_table SET x = 1",
),
(
lambda: exp.update("my_table").set_("x = 1").where("y = 2"),
"UPDATE my_table SET x = 1 WHERE y = 2",
),
(
lambda: exp.update("my_table").set_("a = 1").set_("b = 2"),
"UPDATE my_table SET a = 1, b = 2",
),
(
lambda: exp.update("my_table")
.set_("x = 1")
.where("my_table.id = baz.id")
.from_("baz")
.with_("baz", "SELECT id FROM foo"),
"WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id",
),
(
lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
"SELECT * FROM foo UNION SELECT * FROM bla",
Expand Down

0 comments on commit 354cfff

Please sign in to comment.