Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: create builders for the INSERT statement #1630

Merged
merged 4 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 122 additions & 38 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,42 @@ class Insert(Expression):
"alternative": False,
}

def with_(
tobymao marked this conversation as resolved.
Show resolved Hide resolved
self,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Insert:
"""
Append to or set the common table expressions.

Example:
>>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql()
'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte'

Args:
alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.

Returns:
The modified expression.
"""
return _apply_cte_builder(
self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
)


class OnConflict(Expression):
arg_types = {
Expand Down Expand Up @@ -2062,14 +2098,14 @@ def named_selects(self):

def with_(
self,
alias,
as_,
recursive=None,
append=True,
dialect=None,
copy=True,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
):
) -> Subqueryable:
"""
Append to or set the common table expressions.

Expand All @@ -2078,43 +2114,22 @@ def with_(
'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'

Args:
alias (str | Expression): the SQL code string to parse as the table name.
alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_ (str | Expression): the SQL code string to parse as the table expression.
as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`.
append (bool): if `True`, add to any existing expressions.
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
dialect (str): the dialect used to parse the input expression.
copy (bool): if `False`, modify this expression instance in-place.
opts (kwargs): other options to use to parse the input expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.

Returns:
Select: the modified expression.
The modified expression.
"""
alias_expression = maybe_parse(
alias,
dialect=dialect,
into=TableAlias,
**opts,
)
as_expression = maybe_parse(
as_,
dialect=dialect,
**opts,
)
cte = CTE(
this=as_expression,
alias=alias_expression,
)
return _apply_child_list_builder(
cte,
instance=self,
arg="with",
append=append,
copy=copy,
into=With,
properties={"recursive": recursive or False},
return _apply_cte_builder(
self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
)


Expand Down Expand Up @@ -4525,6 +4540,30 @@ def _apply_conjunction_builder(
return inst


def _apply_cte_builder(
instance: E,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> E:
alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts)
as_expression = maybe_parse(as_, dialect=dialect, **opts)
cte = CTE(this=as_expression, alias=alias_expression)
return _apply_child_list_builder(
cte,
instance=instance,
arg="with",
append=append,
copy=copy,
into=With,
properties={"recursive": recursive or False},
)


def _combine(expressions, operator, dialect=None, copy=True, **opts):
expressions = [
condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
Expand Down Expand Up @@ -4742,6 +4781,51 @@ def delete(
return delete_expr


def insert(
expression: ExpOrStr,
into: ExpOrStr,
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
overwrite: t.Optional[bool] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Insert:
"""
Builds an INSERT statement.

Example:
>>> insert("VALUES (1, 2, 3)", "tbl").sql()
'INSERT INTO tbl VALUES (1, 2, 3)'

Args:
expression: the sql string or expression of the INSERT statement
into: the tbl to insert data to.
columns: optionally the table's column names.
overwrite: whether to INSERT OVERWRITE or not.
dialect: the dialect used to parse the input expressions.
copy: whether or not to copy the expression.
**opts: other options to use to parse the input expressions.

Returns:
Insert: the syntax tree for the INSERT statement.
"""
expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts)
this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)

if columns:
this = _apply_list_builder(
*columns,
instance=Schema(this=this),
arg="expressions",
into=Identifier,
copy=False,
dialect=dialect,
**opts,
)

return Insert(this=this, expression=expr, overwrite=overwrite)


def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.
Expand Down
16 changes: 16 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,22 @@ def test_build(self):
"DELETE FROM tbl WHERE x = 1 RETURNING *",
"postgres",
),
(
lambda: exp.insert("SELECT * FROM tbl2", "tbl"),
"INSERT INTO tbl SELECT * FROM tbl2",
),
(
lambda: exp.insert("SELECT * FROM tbl2", "tbl", overwrite=True),
"INSERT OVERWRITE TABLE tbl SELECT * FROM tbl2",
),
(
lambda: exp.insert("VALUES (1, 2), (3, 4)", "tbl", columns=["cola", "colb"]),
"INSERT INTO tbl (cola, colb) VALUES (1, 2), (3, 4)",
),
(
lambda: exp.insert("SELECT * FROM cte", "t").with_("cte", as_="SELECT x FROM tbl"),
"WITH cte AS (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte",
),
(
lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)),
"(x, y) IN ((1, 2), (3, 4))",
Expand Down