diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 59de679f3..3528d1106 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3290,6 +3290,200 @@ class Update(Expression): "limit": False, } + def table( + self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts + ) -> Update: + """ + Set the table to update. + + Example: + >>> Update().table("my_table").set_("x = 1").sql() + 'UPDATE my_table SET x = 1' + + Args: + expression : the SQL code strings to parse. + If a `Table` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Table`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Update expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="this", + into=Table, + prefix=None, + dialect=dialect, + copy=copy, + **opts, + ) + + def set_( + self, + *expressions: ExpOrStr, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Append to or set the SET expressions. + + Example: + >>> Update().table("my_table").set_("x = 1").sql() + 'UPDATE my_table SET x = 1' + + Args: + *expressions: the SQL code strings to parse. + If `Expression` instance(s) are passed, they will be used as-is. + Multiple expressions are combined with a comma. + append: if `True`, add the new expressions to any existing SET expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="expressions", + append=append, + into=Expression, + prefix=None, + dialect=dialect, + copy=copy, + **opts, + ) + + def where( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: + """ + Append to or set the WHERE expressions. + + Example: + >>> Update().table("tbl").set_("x = 1").where("x = 'a' OR x < 'b'").sql() + "UPDATE tbl SET x = 1 WHERE x = 'a' OR x < 'b'" + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def from_( + self, + expression: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Set the FROM expression. + + Example: + >>> Update().table("my_table").set_("x = 1").from_("baz").sql() + 'UPDATE my_table SET x = 1 FROM baz' + + Args: + expression : the SQL code strings to parse. + If a `From` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `From`. + If nothing is passed in then a from is not applied to the expression + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Update expression. + """ + if not expression: + return maybe_copy(self, copy) + + return _apply_builder( + expression=expression, + instance=self, + arg="from", + into=From, + prefix="FROM", + dialect=dialect, + copy=copy, + **opts, + ) + + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + materialized: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Update: + """ + Append to or set the common table expressions. + + Example: + >>> Update().table("my_table").set_("x = 1").from_("baz").with_("baz", "SELECT id FROM foo").sql() + 'WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + materialized: set the MATERIALIZED part of the expression. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, + alias, + as_, + recursive=recursive, + materialized=materialized, + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + class Values(UDTF): arg_types = {"expressions": True, "alias": False} @@ -6811,9 +7005,10 @@ def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select: def update( table: str | Table, - properties: dict, + properties: t.Optional[dict] = None, where: t.Optional[ExpOrStr] = None, from_: t.Optional[ExpOrStr] = None, + with_: t.Optional[t.Dict[str, ExpOrStr]] = None, dialect: DialectType = None, **opts, ) -> Update: @@ -6821,14 +7016,15 @@ def update( Creates an update statement. Example: - >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql() - "UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1" + >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz_cte", where="baz_cte.id > 1 and my_table.id = baz_cte.id", with_={"baz_cte": "SELECT id FROM foo"}).sql() + "WITH baz_cte AS (SELECT id FROM foo) UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz_cte WHERE baz_cte.id > 1 AND my_table.id = baz_cte.id" Args: - *properties: dictionary of properties to set which are + properties: dictionary of properties to SET which are auto converted to sql objects eg None -> NULL where: sql conditional parsed into a WHERE statement from_: sql statement parsed into a FROM statement + with_: dictionary of CTE aliases / select statements to include in a WITH clause. dialect: the dialect used to parse the input expressions. **opts: other options to use to parse the input expressions. @@ -6836,13 +7032,14 @@ def update( Update: the syntax tree for the UPDATE statement. """ update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect)) - update_expr.set( - "expressions", - [ - EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) - for k, v in properties.items() - ], - ) + if properties: + update_expr.set( + "expressions", + [ + EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) + for k, v in properties.items() + ], + ) if from_: update_expr.set( "from", @@ -6855,6 +7052,15 @@ def update( "where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), ) + if with_: + cte_list = [ + CTE(this=maybe_parse(qry, dialect=dialect, **opts), alias=alias) + for alias, qry in with_.items() + ] + update_expr.set( + "with", + With(expressions=cte_list), + ) return update_expr diff --git a/tests/test_build.py b/tests/test_build.py index 7518b72a2..5d383ad00 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -577,6 +577,36 @@ def test_build(self): lambda: exp.update("tbl", {"x": 1}, from_="tbl2 cross join tbl3"), "UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3", ), + ( + lambda: exp.update( + "my_table", + {"x": 1}, + from_="baz", + where="my_table.id = baz.id", + with_={"baz": "SELECT id FROM foo UNION SELECT id FROM bar"}, + ), + "WITH baz AS (SELECT id FROM foo UNION SELECT id FROM bar) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id", + ), + ( + lambda: exp.update("my_table").set_("x = 1"), + "UPDATE my_table SET x = 1", + ), + ( + lambda: exp.update("my_table").set_("x = 1").where("y = 2"), + "UPDATE my_table SET x = 1 WHERE y = 2", + ), + ( + lambda: exp.update("my_table").set_("a = 1").set_("b = 2"), + "UPDATE my_table SET a = 1, b = 2", + ), + ( + lambda: exp.update("my_table") + .set_("x = 1") + .where("my_table.id = baz.id") + .from_("baz") + .with_("baz", "SELECT id FROM foo"), + "WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id", + ), ( lambda: union("SELECT * FROM foo", "SELECT * FROM bla"), "SELECT * FROM foo UNION SELECT * FROM bla",