diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 9c159deba7478..7a3be0ff4e314 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -26,7 +26,8 @@ import psycopg2 import psycopg2.extensions import psycopg2.extras -from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor +from more_itertools import chunked +from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor, execute_batch from sqlalchemy.engine import URL from airflow.exceptions import ( @@ -574,3 +575,79 @@ def get_db_log_messages(self, conn) -> None: for output in conn.notices: self.log.info(output) + + def insert_rows( + self, + table, + rows, + target_fields=None, + commit_every=1000, + replace=False, + *, + executemany=False, + fast_executemany=False, + autocommit=False, + **kwargs, + ): + """ + Insert a collection of tuples into a table. + + Rows are inserted in chunks, each chunk (of size ``commit_every``) is + done in a new transaction. + + :param table: Name of the target table + :param rows: The rows to insert into the table + :param target_fields: The names of the columns to fill in the table + :param commit_every: The maximum number of rows to insert in one + transaction. Set to 0 to insert all rows in one transaction. + :param replace: Whether to replace instead of insert + :param executemany: If True, all rows are inserted at once in + chunks defined by the commit_every parameter. This only works if all rows + have same number of column names, but leads to better performance. + :param fast_executemany: If True, rows will be inserted using an optimized + bulk execution strategy (``psycopg2.extras.execute_batch``). This can + significantly improve performance for large inserts. If set to False, + the method falls back to the default implementation from + ``DbApiHook.insert_rows``. + :param autocommit: What to set the connection's autocommit setting to + before executing the query. + """ + # if fast_executemany is disabled, defer to default implementation of insert_rows in DbApiHook + if not fast_executemany: + return super().insert_rows( + table, + rows, + target_fields=target_fields, + commit_every=commit_every, + replace=replace, + executemany=executemany, + autocommit=autocommit, + **kwargs, + ) + + # if fast_executemany is enabled, use optimized execute_batch from psycopg + nb_rows = 0 + with self._create_autocommit_connection(autocommit) as conn: + conn.commit() + with closing(conn.cursor()) as cur: + for chunked_rows in chunked(rows, commit_every): + values = list( + map( + lambda row: self._serialize_cells(row, conn), + chunked_rows, + ) + ) + sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs) + self.log.debug("Generated sql: %s", sql) + + try: + execute_batch(cur, sql, values, page_size=commit_every) + except Exception as e: + self.log.error("Generated sql: %s", sql) + self.log.error("Parameters: %s", values) + raise e + + conn.commit() + nb_rows += len(chunked_rows) + self.log.info("Loaded %s rows into %s so far", nb_rows, table) + self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index 5a842666a5a2c..f394c561c2c53 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -30,7 +30,7 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.postgres.dialects.postgres import PostgresDialect -from airflow.providers.postgres.hooks.postgres import PostgresHook +from airflow.providers.postgres.hooks.postgres import CompatConnection, PostgresHook from airflow.utils.types import NOTSET from tests_common.test_utils.common_sql import mock_db_hook @@ -65,7 +65,7 @@ def postgres_hook_setup(): """Set up mock PostgresHook for testing.""" table = "test_postgres_hook_table" cur = mock.MagicMock(rowcount=0) - conn = mock.MagicMock() + conn = mock.MagicMock(spec=CompatConnection) conn.cursor.return_value = cur class UnitTestPostgresHook(PostgresHook): @@ -812,6 +812,30 @@ def test_insert_rows(self, postgres_hook_setup): sql = f"INSERT INTO {table} VALUES (%s)" setup.cur.executemany.assert_any_call(sql, rows) + @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") + def test_insert_rows_fast_executemany(self, mock_execute_batch, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [("hello",), ("world",)] + + setup.db_hook.insert_rows(table, rows, fast_executemany=True) + + assert setup.conn.close.call_count == 1 + assert setup.cur.close.call_count == 1 + + commit_count = 2 # The first and last commit + assert setup.conn.commit.call_count == commit_count + + mock_execute_batch.assert_called_once_with( + setup.cur, + f"INSERT INTO {table} VALUES (%s)", # expected SQL + [("hello",), ("world",)], # expected values + page_size=1000, + ) + + # executemany should NOT be called in this mode + setup.cur.executemany.assert_not_called() + def test_insert_rows_replace(self, postgres_hook_setup): setup = postgres_hook_setup table = "table"