Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Added fast_executemany parameter to insert_rows of DbApiHook #43357

Merged
merged 6 commits into from
Oct 25, 2024
12 changes: 12 additions & 0 deletions providers/src/airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ def insert_rows(
replace=False,
*,
executemany=False,
fast_executemany=False,
autocommit=False,
**kwargs,
):
Expand All @@ -638,6 +639,8 @@ def insert_rows(
:param executemany: If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if all rows
have same number of column names, but leads to better performance.
:param fast_executemany: If True, the `fast_executemany` parameter will be set on the
cursor used by `executemany` which leads to better performance, if supported by driver.
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
"""
Expand All @@ -646,6 +649,15 @@ def insert_rows(
conn.commit()
with closing(conn.cursor()) as cur:
if self.supports_executemany or executemany:
if fast_executemany:
with contextlib.suppress(AttributeError):
# Try to set the fast_executemany attribute
cur.fast_executemany = True
self.log.info(
"Fast_executemany is enabled for conn_id '%s'!",
self.get_conn_id(),
)

for chunked_rows in chunked(rows, commit_every):
values = list(
map(
Expand Down
16 changes: 16 additions & 0 deletions providers/tests/common/sql/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class TestDbApiHook:
def setup_method(self, **kwargs):
self.cur = mock.MagicMock(
rowcount=0,
fast_executemany=False,
spec=Cursor,
)
self.conn = mock.MagicMock()
Expand Down Expand Up @@ -188,6 +189,21 @@ def test_insert_rows_executemany(self):
self.db_hook.insert_rows(table, rows, executemany=True)

assert self.conn.close.call_count == 1
assert not self.cur.fast_executemany
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 2

sql = f"INSERT INTO {table} VALUES (%s)"
self.cur.executemany.assert_any_call(sql, rows)

def test_insert_rows_fast_executemany(self):
table = "table"
rows = [("hello",), ("world",)]

self.db_hook.insert_rows(table, rows, executemany=True, fast_executemany=True)

assert self.conn.close.call_count == 1
assert self.cur.fast_executemany
assert self.cur.close.call_count == 1
assert self.conn.commit.call_count == 2

Expand Down