diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index e39d2ef04f..bcdbd375d1 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1479,6 +1479,42 @@ class Insert(Expression): "alternative": False, } + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Insert: + """ + Append to or set the common table expressions. + + Example: + >>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql() + 'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts + ) + class OnConflict(Expression): arg_types = { @@ -2062,14 +2098,14 @@ def named_selects(self): def with_( self, - alias, - as_, - recursive=None, - append=True, - dialect=None, - copy=True, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, **opts, - ): + ) -> Subqueryable: """ Append to or set the common table expressions. @@ -2078,43 +2114,22 @@ def with_( 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' Args: - alias (str | Expression): the SQL code string to parse as the table name. + alias: the SQL code string to parse as the table name. If an `Expression` instance is passed, this is used as-is. - as_ (str | Expression): the SQL code string to parse as the table expression. + as_: the SQL code string to parse as the table expression. If an `Expression` instance is passed, it will be used as-is. - recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. - append (bool): if `True`, add to any existing expressions. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified expression. """ - alias_expression = maybe_parse( - alias, - dialect=dialect, - into=TableAlias, - **opts, - ) - as_expression = maybe_parse( - as_, - dialect=dialect, - **opts, - ) - cte = CTE( - this=as_expression, - alias=alias_expression, - ) - return _apply_child_list_builder( - cte, - instance=self, - arg="with", - append=append, - copy=copy, - into=With, - properties={"recursive": recursive or False}, + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts ) @@ -4525,6 +4540,30 @@ def _apply_conjunction_builder( return inst +def _apply_cte_builder( + instance: E, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> E: + alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts) + as_expression = maybe_parse(as_, dialect=dialect, **opts) + cte = CTE(this=as_expression, alias=alias_expression) + return _apply_child_list_builder( + cte, + instance=instance, + arg="with", + append=append, + copy=copy, + into=With, + properties={"recursive": recursive or False}, + ) + + def _combine(expressions, operator, dialect=None, copy=True, **opts): expressions = [ condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions @@ -4742,6 +4781,51 @@ def delete( return delete_expr +def insert( + expression: ExpOrStr, + into: ExpOrStr, + columns: t.Optional[t.Sequence[ExpOrStr]] = None, + overwrite: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Insert: + """ + Builds an INSERT statement. + + Example: + >>> insert("VALUES (1, 2, 3)", "tbl").sql() + 'INSERT INTO tbl VALUES (1, 2, 3)' + + Args: + expression: the sql string or expression of the INSERT statement + into: the tbl to insert data to. + columns: optionally the table's column names. + overwrite: whether to INSERT OVERWRITE or not. + dialect: the dialect used to parse the input expressions. + copy: whether or not to copy the expression. + **opts: other options to use to parse the input expressions. + + Returns: + Insert: the syntax tree for the INSERT statement. + """ + expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts) + this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts) + + if columns: + this = _apply_list_builder( + *columns, + instance=Schema(this=this), + arg="expressions", + into=Identifier, + copy=False, + dialect=dialect, + **opts, + ) + + return Insert(this=this, expression=expr, overwrite=overwrite) + + def condition(expression, dialect=None, copy=True, **opts) -> Condition: """ Initialize a logical condition expression. diff --git a/tests/test_build.py b/tests/test_build.py index 8342de3e58..fc7e005ff1 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -594,6 +594,22 @@ def test_build(self): "DELETE FROM tbl WHERE x = 1 RETURNING *", "postgres", ), + ( + lambda: exp.insert("SELECT * FROM tbl2", "tbl"), + "INSERT INTO tbl SELECT * FROM tbl2", + ), + ( + lambda: exp.insert("SELECT * FROM tbl2", "tbl", overwrite=True), + "INSERT OVERWRITE TABLE tbl SELECT * FROM tbl2", + ), + ( + lambda: exp.insert("VALUES (1, 2), (3, 4)", "tbl", columns=["cola", "colb"]), + "INSERT INTO tbl (cola, colb) VALUES (1, 2), (3, 4)", + ), + ( + lambda: exp.insert("SELECT * FROM cte", "t").with_("cte", as_="SELECT x FROM tbl"), + "WITH cte AS (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte", + ), ( lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)), "(x, y) IN ((1, 2), (3, 4))",