Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add builder methods to exp.Update and add with_ arg to exp.update #4217

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 208 additions & 11 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3284,6 +3284,191 @@ class Update(Expression):
"limit": False,
}

def table(
self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts
) -> Update:
"""
Set the table to update.

Example:
>>> Update().table("my_table").set_("x = 1").sql()
'UPDATE my_table SET x = 1'

Args:
expression : the SQL code strings to parse.
If a `Table` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Table`.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.

Returns:
The modified Update expression.
"""
return _apply_builder(
expression=expression,
instance=self,
arg="this",
into=Table,
prefix=None,
dialect=dialect,
copy=copy,
**opts,
)

def set_(
self, *expressions: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts
) -> Update:
brdbry marked this conversation as resolved.
Show resolved Hide resolved
"""
Set the SET expressions.

Example:
>>> Update().table("my_table").set_("x = 1").sql()
'UPDATE my_table SET x = 1'

Args:
*expressions: the SQL code strings to parse.
If an `Expression` instances are passed, they will be used as-is.
dialect: the dialect used to parse the input expressions.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
"""
return _apply_list_builder(
*expressions,
instance=self,
arg="expressions",
into=Expression,
prefix=None,
dialect=dialect,
copy=copy,
**opts,
)

def where(
self,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Select:
"""
Append to or set the WHERE expressions.

Example:
>>> Update().table("tbl").set_("x = 1").where("x = 'a' OR x < 'b'").sql()
"UPDATE tbl SET x = 1 WHERE x = 'a' OR x < 'b'"

Args:
*expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
append: if `True`, AND the new expressions to any existing expression.
Otherwise, this resets the expression.
dialect: the dialect used to parse the input expressions.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.

Returns:
Select: the modified expression.
"""
return _apply_conjunction_builder(
*expressions,
instance=self,
arg="where",
append=append,
into=Where,
dialect=dialect,
copy=copy,
**opts,
)

def from_(
self,
expression: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Update:
"""
Set the FROM expression.

Example:
>>> Update().table("my_table").set_("x = 1").from_("baz").sql()
'UPDATE my_table SET x = 1 FROM baz'

Args:
expression : the SQL code strings to parse.
If a `From` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `From`.
If nothing is passed in then a from is not applied to the expression
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.

Returns:
The modified Update expression.
"""
if not expression:
return self if not copy else self.copy()
brdbry marked this conversation as resolved.
Show resolved Hide resolved

return _apply_builder(
expression=expression,
instance=self,
arg="from",
into=From,
prefix="FROM",
dialect=dialect,
copy=copy,
**opts,
)

def with_(
self,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
materialized: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Update:
"""
Append to or set the common table expressions.

Example:
>>> Update().table("my_table").set_("x = 1").from_("baz").with_("baz", "SELECT id FROM foo").sql()
'WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz'

Args:
alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
materialized: set the MATERIALIZED part of the expression.
append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.

Returns:
The modified expression.
"""
return _apply_cte_builder(
self,
alias,
as_,
recursive=recursive,
materialized=materialized,
append=append,
dialect=dialect,
copy=copy,
**opts,
)


class Values(UDTF):
arg_types = {"expressions": True, "alias": False}
Expand Down Expand Up @@ -6803,38 +6988,41 @@ def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select:

def update(
table: str | Table,
properties: dict,
properties: t.Optional[dict] = None,
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
where: t.Optional[ExpOrStr] = None,
from_: t.Optional[ExpOrStr] = None,
with_: t.Optional[t.Dict[str, ExpOrStr]] = None,
dialect: DialectType = None,
**opts,
) -> Update:
"""
Creates an update statement.

Example:
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql()
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz_cte", where="baz_cte.id > 1 and my_table.id = baz_cte.id", with_={"baz_cte": "SELECT id FROM foo"}).sql()
"WITH baz_cte AS (SELECT id FROM foo) UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz_cte WHERE baz_cte.id > 1 AND my_table.id = baz_cte.id"

Args:
*properties: dictionary of properties to set which are
properties: dictionary of properties to SET which are
auto converted to sql objects eg None -> NULL
where: sql conditional parsed into a WHERE statement
from_: sql statement parsed into a FROM statement
with_: dictionary of CTE aliases / select statements to include in a WITH clause.
dialect: the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.

Returns:
Update: the syntax tree for the UPDATE statement.
"""
update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update_expr.set(
"expressions",
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
)
if properties:
update_expr.set(
"expressions",
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
)
if from_:
update_expr.set(
"from",
Expand All @@ -6847,6 +7035,15 @@ def update(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
if with_:
cte_list = [
CTE(this=maybe_parse(qry, into=Select, dialect=dialect, **opts), alias=alias)
brdbry marked this conversation as resolved.
Show resolved Hide resolved
for alias, qry in with_.items()
]
update_expr.set(
"with",
With(expressions=cte_list),
)
return update_expr


Expand Down
26 changes: 26 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,32 @@ def test_build(self):
lambda: exp.update("tbl", {"x": 1}, from_="tbl2 cross join tbl3"),
"UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3",
),
(
lambda: exp.update(
"my_table",
{"x": 1},
from_="baz",
where="my_table.id = baz.id",
with_={"baz": "SELECT id FROM foo"},
),
"WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id",
),
(
lambda: exp.update("my_table").set_("x = 1"),
"UPDATE my_table SET x = 1",
),
(
lambda: exp.update("my_table").set_("x = 1").where("y = 2"),
"UPDATE my_table SET x = 1 WHERE y = 2",
),
(
lambda: exp.update("my_table")
.set_("x = 1")
.where("my_table.id = baz.id")
.from_("baz")
.with_("baz", "SELECT id FROM foo"),
"WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id",
),
(
lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
"SELECT * FROM foo UNION SELECT * FROM bla",
Expand Down
Loading