Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
from more_itertools import chunked
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor, execute_batch
from sqlalchemy.engine import URL

from airflow.exceptions import (
Expand Down Expand Up @@ -574,3 +575,79 @@ def get_db_log_messages(self, conn) -> None:

for output in conn.notices:
self.log.info(output)

def insert_rows(
self,
table,
rows,
target_fields=None,
commit_every=1000,
replace=False,
*,
executemany=False,
fast_executemany=False,
autocommit=False,
**kwargs,
):
"""
Insert a collection of tuples into a table.

Rows are inserted in chunks, each chunk (of size ``commit_every``) is
done in a new transaction.

:param table: Name of the target table
:param rows: The rows to insert into the table
:param target_fields: The names of the columns to fill in the table
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
:param executemany: If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if all rows
have same number of column names, but leads to better performance.
:param fast_executemany: If True, rows will be inserted using an optimized
bulk execution strategy (``psycopg2.extras.execute_batch``). This can
significantly improve performance for large inserts. If set to False,
the method falls back to the default implementation from
``DbApiHook.insert_rows``.
:param autocommit: What to set the connection's autocommit setting to
before executing the query.
"""
# if fast_executemany is disabled, defer to default implementation of insert_rows in DbApiHook
if not fast_executemany:
return super().insert_rows(
table,
rows,
target_fields=target_fields,
commit_every=commit_every,
replace=replace,
executemany=executemany,
autocommit=autocommit,
**kwargs,
)

# if fast_executemany is enabled, use optimized execute_batch from psycopg
nb_rows = 0
with self._create_autocommit_connection(autocommit) as conn:
conn.commit()
with closing(conn.cursor()) as cur:
for chunked_rows in chunked(rows, commit_every):
values = list(
map(
lambda row: self._serialize_cells(row, conn),
chunked_rows,
)
)
sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)

try:
execute_batch(cur, sql, values, page_size=commit_every)
except Exception as e:
self.log.error("Generated sql: %s", sql)
self.log.error("Parameters: %s", values)
raise e

conn.commit()
nb_rows += len(chunked_rows)
self.log.info("Loaded %s rows into %s so far", nb_rows, table)
self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table)
28 changes: 26 additions & 2 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.postgres.dialects.postgres import PostgresDialect
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook
from airflow.utils.types import NOTSET

from tests_common.test_utils.common_sql import mock_db_hook
Expand Down Expand Up @@ -65,7 +65,7 @@ def postgres_hook_setup():
"""Set up mock PostgresHook for testing."""
table = "test_postgres_hook_table"
cur = mock.MagicMock(rowcount=0)
conn = mock.MagicMock()
conn = mock.MagicMock(spec=CompatConnection)
conn.cursor.return_value = cur

class UnitTestPostgresHook(PostgresHook):
Expand Down Expand Up @@ -812,6 +812,30 @@ def test_insert_rows(self, postgres_hook_setup):
sql = f"INSERT INTO {table} VALUES (%s)"
setup.cur.executemany.assert_any_call(sql, rows)

@mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
def test_insert_rows_fast_executemany(self, mock_execute_batch, postgres_hook_setup):
setup = postgres_hook_setup
table = "table"
rows = [("hello",), ("world",)]

setup.db_hook.insert_rows(table, rows, fast_executemany=True)

assert setup.conn.close.call_count == 1
assert setup.cur.close.call_count == 1

commit_count = 2 # The first and last commit
assert setup.conn.commit.call_count == commit_count

mock_execute_batch.assert_called_once_with(
setup.cur,
f"INSERT INTO {table} VALUES (%s)", # expected SQL
[("hello",), ("world",)], # expected values
page_size=1000,
)

# executemany should NOT be called in this mode
setup.cur.executemany.assert_not_called()

def test_insert_rows_replace(self, postgres_hook_setup):
setup = postgres_hook_setup
table = "table"
Expand Down