diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 6c5e149cd3450..64ef443d06b71 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any @@ -47,6 +47,9 @@ class GenericTransfer(BaseOperator): :param source_hook_params: source hook parameters. :param destination_conn_id: destination connection. (templated) :param destination_hook_params: destination hook parameters. + :param rows_processor: (optional) A callable applied once per batch of rows before insertion. + It receives the full list of rows and the task context, and must return a list of rows compatible with + the underlying hook's. :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) :param insert_args: extra params for `insert_rows` method. @@ -80,6 +83,9 @@ def __init__( source_hook_params: dict | None = None, destination_conn_id: str, destination_hook_params: dict | None = None, + rows_processor: Callable[..., list[Any]] | None = None, + # rows_processor is called as rows_processor(rows, **context); + # context keys vary, so Callable[..., list[Any]] is intentional. preoperator: str | list[str] | None = None, insert_args: dict | None = None, page_size: int | None = None, @@ -93,6 +99,7 @@ def __init__( self.source_hook_params = source_hook_params self.destination_conn_id = destination_conn_id self.destination_hook_params = destination_hook_params + self._rows_processor = rows_processor self.preoperator = preoperator self.insert_args = insert_args or {} self.page_size = page_size @@ -139,6 +146,14 @@ def render_template_fields( if isinstance(commit_every, str): self.insert_args["commit_every"] = int(commit_every) + def _insert_rows(self, rows: list[Any], context: Context): + if self._rows_processor: + rows = self._rows_processor(rows, **context) + + self.log.info("Inserting %d rows into %s", len(rows), self.destination_conn_id) + + self.destination_hook.insert_rows(table=self.destination_table, rows=rows, **self.insert_args) + def execute(self, context: Context): if self.preoperator: self.log.info("Running preoperator") @@ -162,12 +177,8 @@ def execute(self, context: Context): for sql in self.sql: self.log.info("Executing: \n %s", sql) - results = self.source_hook.get_records(sql) - - self.log.info("Inserting rows into %s", self.destination_conn_id) - self.destination_hook.insert_rows( - table=self.destination_table, rows=results, **self.insert_args - ) + rows = self.source_hook.get_records(sql) + self._insert_rows(rows=rows, context=context) def execute_complete( self, @@ -178,9 +189,9 @@ def execute_complete( if event.get("status") == "failure": raise AirflowException(event.get("message")) - results = event.get("results") + rows = event.get("results") - if results: + if rows: map_index = context["ti"].map_index offset = ( context["ti"].xcom_pull( @@ -196,15 +207,7 @@ def execute_complete( self.log.info("Offset increased to %d", offset) context["ti"].xcom_push(key="offset", value=offset) - self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id) - self.destination_hook.insert_rows( - table=self.destination_table, rows=results, **self.insert_args - ) - self.log.info( - "Inserting %d rows into %s done!", - len(results), - self.destination_conn_id, - ) + self._insert_rows(rows=rows, context=context) self.defer( trigger=SQLExecuteQueryTrigger( diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi index acd1bca2698a9..3194e3877fdc2 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi @@ -29,7 +29,7 @@ # """Definition of the public interface for airflow.providers.common.sql.operators.generic_transfer.""" -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import cached_property as cached_property from typing import Any, ClassVar @@ -51,6 +51,7 @@ class GenericTransfer(BaseOperator): source_hook_params: Incomplete destination_conn_id: Incomplete destination_hook_params: Incomplete + rows_processor: Incomplete preoperator: Incomplete insert_args: Incomplete page_size: Incomplete @@ -63,6 +64,7 @@ class GenericTransfer(BaseOperator): source_hook_params: dict | None = None, destination_conn_id: str, destination_hook_params: dict | None = None, + rows_processor: Callable[..., list[Any]] | None = None, preoperator: str | list[str] | None = None, insert_args: dict | None = None, page_size: int | None = None, diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index ab9569ee18f9c..b032593459400 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -37,8 +37,6 @@ from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: - import jinja2 - from airflow.providers.common.compat.sdk import Context from airflow.providers.openlineage.extractors import OperatorLineage @@ -1307,7 +1305,9 @@ class SQLInsertRowsOperator(BaseSQLOperator): :param rows: the rows to insert into the table. Rows can be a list of tuples or a list of dictionaries. When a list of dictionaries is provided, the column names are inferred from the dictionary keys and will be matched with the column names, ignored columns will be filtered out. - :rows_processor: (optional) a function that will be applied to the rows before inserting them into the table. + :param rows_processor: (optional) A callable applied once per batch of rows before insertion. + It receives the full list of rows and the task context, and must return a list of rows compatible with + the underlying hook. :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) :param postoperator: sql statement or list of statements to be executed after loading the data. (templated) :param insert_args: (optional) dictionary of additional arguments passed to the underlying hook's @@ -1343,7 +1343,9 @@ def __init__( columns: Iterable[str] | None = None, ignored_columns: Iterable[str] | None = None, rows: list[Any] | XComArg | None = None, - rows_processor: Callable[[Any, Context], Any] = lambda rows, **context: rows, + rows_processor: Callable[..., list[Any]] | None = None, + # rows_processor is called as rows_processor(rows, **context); + # context keys vary, so Callable[..., list[Any]] is intentional. preoperator: str | list[str] | None = None, postoperator: str | list[str] | None = None, hook_params: dict | None = None, @@ -1367,16 +1369,6 @@ def __init__( self.insert_args = insert_args or {} self.do_xcom_push = False - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - super().render_template_fields(context=context, jinja_env=jinja_env) - - if isinstance(self.rows, XComArg): - self.rows = self.rows.resolve(context=context) - @property def table_name_with_schema(self) -> str: if self.schema is not None: @@ -1395,11 +1387,23 @@ def column_names(self) -> list[str]: return [column for column in self.columns if column not in self.ignored_columns] return self.columns - def _process_rows(self, context: Context): - return self._rows_processor(self.rows, **context) # type: ignore + def _insert_rows(self, rows: list[Any], context: Context): + if self._rows_processor: + rows = self._rows_processor(rows, **context) + + self.log.info("Inserting %d rows into %s", len(rows), self.conn_id) + + self.get_db_hook().insert_rows( + table=self.table_name_with_schema, + rows=rows, + target_fields=self.column_names, + **self.insert_args, + ) def execute(self, context: Context) -> Any: - if not self.rows: + rows = self.rows.resolve(context=context) if isinstance(self.rows, XComArg) else self.rows + + if not rows: raise AirflowSkipException(f"Skipping task {self.task_id} because no rows.") self.log.debug("Table: %s", self.table_name_with_schema) @@ -1408,13 +1412,7 @@ def execute(self, context: Context) -> Any: self.log.debug("Running preoperator") self.log.debug(self.preoperator) self.get_db_hook().run(self.preoperator) - rows = self._process_rows(context=context) - self.get_db_hook().insert_rows( - table=self.table_name_with_schema, - rows=rows, - target_fields=self.column_names, - **self.insert_args, - ) + self._insert_rows(rows=rows, context=context) if self.postoperator: self.log.debug("Running postoperator") self.log.debug(self.postoperator) diff --git a/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py b/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py index c1188e0e1b348..f9d8980f332ad 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from airflow.providers.common.compat.sdk import AirflowException, BaseHook +from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_2_PLUS from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -78,15 +79,23 @@ def get_hook(self) -> DbApiHook: ) return hook + async def _get_records(self) -> Any: + from asgiref.sync import sync_to_async + + hook = self.get_hook() + + if AIRFLOW_V_3_2_PLUS: + # This is only supported from Airflow 3.2 or higher due to added async support in CommsDecoder + return await sync_to_async(hook.get_records)(self.sql) + return hook.get_records(self.sql) + async def run(self) -> AsyncIterator[TriggerEvent]: try: - hook = self.get_hook() - self.log.info("Extracting data from %s", self.conn_id) self.log.info("Executing: \n %s", self.sql) self.log.info("Reading records from %s", self.conn_id) - results = hook.get_records(self.sql) + results = await self._get_records() self.log.info("Reading records from %s done!", self.conn_id) self.log.debug("results: %s", results) diff --git a/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.pyi b/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.pyi index 529da972c7dc2..3a14c65998d20 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.pyi +++ b/providers/common/sql/src/airflow/providers/common/sql/triggers/sql.pyi @@ -32,6 +32,7 @@ from collections.abc import AsyncIterator from typing import Any +from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook from airflow.triggers.base import BaseTrigger as BaseTrigger, TriggerEvent as TriggerEvent class SQLExecuteQueryTrigger(BaseTrigger): @@ -39,4 +40,5 @@ class SQLExecuteQueryTrigger(BaseTrigger): self, sql: str | list[str], conn_id: str, hook_params: dict | None = None, **kwargs ) -> None: ... def serialize(self) -> tuple[str, dict[str, Any]]: ... + def get_hook(self) -> DbApiHook: ... async def run(self) -> AsyncIterator[TriggerEvent]: ... # type: ignore diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index d6b16957f8730..ea94fa4e3541e 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -206,6 +206,10 @@ def get_connection(cls, conn_id: str): mocked_conn.get_hook.return_value = mocked_hook return mocked_conn + @classmethod + def convert_to_tuples(cls, rows, **context): + return [tuple(row) for row in rows] + def setup_method(self): # Reset mock states before each test self.mocked_source_hook.reset_mock() @@ -289,6 +293,30 @@ def test_non_paginated_read(self): **{"rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, } + def test_non_paginated_read_with_rows_processor(self): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql="SELECT * FROM HR.EMPLOYEES", + destination_table="NEW_HR.EMPLOYEES", + insert_args=INSERT_ARGS, + execution_timeout=timedelta(hours=1), + rows_processor=self.convert_to_tuples, + ) + + operator.execute(context=mock_context(task=operator)) + + assert self.mocked_source_hook.get_records.call_count == 1 + assert self.mocked_source_hook.get_records.call_args_list[0].args[0] == "SELECT * FROM HR.EMPLOYEES" + assert self.mocked_destination_hook.insert_rows.call_count == 1 + assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { + **INSERT_ARGS, + **{"rows": [(1, 2), (11, 12), (3, 4), (13, 14), (3, 4), (13, 14)], "table": "NEW_HR.EMPLOYEES"}, + } + def test_non_paginated_read_for_multiple_sql_statements(self): with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): @@ -321,6 +349,39 @@ def test_non_paginated_read_for_multiple_sql_statements(self): "table": "NEW_HR.EMPLOYEES", } + def test_non_paginated_read_for_multiple_sql_statements_with_rows_processor(self): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql=["SELECT * FROM HR.EMPLOYEES", "SELECT * FROM HR.PEOPLE"], + destination_table="NEW_HR.EMPLOYEES", + insert_args=INSERT_ARGS, + execution_timeout=timedelta(hours=1), + rows_processor=self.convert_to_tuples, + ) + + operator.execute(context=mock_context(task=operator)) + + assert self.mocked_source_hook.get_records.call_count == 2 + assert [call.args[0] for call in self.mocked_source_hook.get_records.call_args_list] == [ + "SELECT * FROM HR.EMPLOYEES", + "SELECT * FROM HR.PEOPLE", + ] + assert self.mocked_destination_hook.insert_rows.call_count == 2 + assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { + **INSERT_ARGS, + "rows": [(1, 2), (11, 12), (3, 4), (13, 14), (3, 4), (13, 14)], + "table": "NEW_HR.EMPLOYEES", + } + assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == { + **INSERT_ARGS, + "rows": [(1, 2), (11, 12), (3, 4), (13, 14), (3, 4), (13, 14)], + "table": "NEW_HR.EMPLOYEES", + } + def test_paginated_read(self): """ This unit test is based on the example described in the medium article: diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 03db9cf4ec139..c438a9b0709e3 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -30,6 +30,7 @@ from airflow.models import Connection from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.common.sql.hooks.handlers import fetch_all_handler +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.operators.sql import ( BaseSQLOperator, BranchSQLOperator, @@ -1608,7 +1609,47 @@ def test_new_style_subclass(self, mock_get_connection, operator_class): class TestSQLInsertRowsOperator: @mock.patch.object(SQLInsertRowsOperator, "get_db_hook") - def test_rows_processor(self, mock_get_db_hook): + def test_insert_rows_operator_with_preoperator(self, mock_get_db_hook): + mock_hook = MagicMock(spec=DbApiHook) + mock_get_db_hook.return_value = mock_hook + + operator = SQLInsertRowsOperator( + task_id="test_task", + conn_id="default_conn", + schema="hollywood", + table_name="actors", + preoperator="TRUNCATE TABLE hollywood.actors", + rows=[ + (1, "Stallone", "Sylvester", 78), + (2, "Statham", "Jason", 57), + (3, "Li", "Jet", 61), + (4, "Lundgren", "Dolph", 66), + (5, "Norris", "Chuck", 84), + ], + ) + + operator.execute({}) + + mock_hook.run.assert_called_once() + args, _ = mock_hook.run.call_args + assert args[0] == "TRUNCATE TABLE hollywood.actors" + + mock_hook.insert_rows.assert_called_once() + _, kwargs = mock_hook.insert_rows.call_args + + assert kwargs["rows"] == [ + (1, "Stallone", "Sylvester", 78), + (2, "Statham", "Jason", 57), + (3, "Li", "Jet", 61), + (4, "Lundgren", "Dolph", 66), + (5, "Norris", "Chuck", 84), + ] + + @mock.patch.object(SQLInsertRowsOperator, "get_db_hook") + def test_insert_rows_operator_with_rows_processor(self, mock_get_db_hook): + mock_hook = MagicMock(spec=DbApiHook) + mock_get_db_hook.return_value = mock_hook + operator = SQLInsertRowsOperator( task_id="test_task", conn_id="default_conn", @@ -1621,12 +1662,15 @@ def test_rows_processor(self, mock_get_db_hook): {"index": 4, "name": "Lundgren", "firstname": "Dolph", "age": 66}, {"index": 5, "name": "Norris", "firstname": "Chuck", "age": 84}, ], - rows_processor=lambda rows, **context: map(lambda row: tuple(row.values()), rows), + rows_processor=lambda rows, **context: [tuple(row.values()) for row in rows], ) - processed_rows = list(operator._process_rows({})) + operator.execute({}) + + mock_hook.insert_rows.assert_called_once() + _, kwargs = mock_hook.insert_rows.call_args - assert processed_rows == [ + assert kwargs["rows"] == [ (1, "Stallone", "Sylvester", 78), (2, "Statham", "Jason", 57), (3, "Li", "Jet", 61),