Skip to content

Commit

Permalink
Add support for semicolon stripping to DbApiHook (#34828)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Astankov committed Aug 31, 2024
1 parent 0f5c25b commit 0e97fa8
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ class DbApiHook(BaseHook):
conn_name_attr: str
# Override to have a default connection id for a particular dbHook
default_conn_name = "default_conn_id"
# Override if this db doesn't support semicolons in SQL queries
strip_semicolon = False
# Override if this db supports autocommit.
supports_autocommit = False
# Override if this db supports executemany.
Expand Down Expand Up @@ -336,14 +338,18 @@ def strip_sql_string(sql: str) -> str:
return sql.strip().rstrip(";")

@staticmethod
def split_sql_string(sql: str) -> list[str]:
def split_sql_string(sql: str, strip_semicolon: bool = False) -> list[str]:
"""
Split string into multiple SQL expressions.
:param sql: SQL string potentially consisting of multiple expressions
:param strip_semicolon: whether to strip semicolon from SQL string
:return: list of individual expressions
"""
splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
splits = sqlparse.split(
sql=sqlparse.format(sql, strip_comments=True),
strip_semicolon=strip_semicolon,
)
return [s for s in splits if s]

@property
Expand Down Expand Up @@ -438,7 +444,10 @@ def run(

if isinstance(sql, str):
if split_statements:
sql_list: Iterable[str] = self.split_sql_string(sql)
sql_list: Iterable[str] = self.split_sql_string(
sql=sql,
strip_semicolon=self.strip_semicolon,
)
else:
sql_list = [sql] if sql.strip() else []
else:
Expand Down

0 comments on commit 0e97fa8

Please sign in to comment.