diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 45c36762f90bd..c7839b28aec02 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -162,7 +162,7 @@ def execute(self, context: Context): self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) - results = self.destination_hook.get_records(self.sql) + results = self.source_hook.get_records(self.sql) self.log.info("Inserting rows into %s", self.destination_conn_id) self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args) diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index 12bc6e9295049..92b118d703ae0 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -24,6 +24,7 @@ from unittest.mock import MagicMock import pytest +from more_itertools import flatten from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models.connection import Connection @@ -34,7 +35,7 @@ from airflow.utils import timezone from tests_common.test_utils.compat import GenericTransfer -from tests_common.test_utils.operators.run_deferrable import execute_operator +from tests_common.test_utils.operators.run_deferrable import execute_operator, mock_context from tests_common.test_utils.providers import get_provider_min_airflow_version pytestmark = pytest.mark.db_test @@ -43,6 +44,12 @@ DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] TEST_DAG_ID = "unit_test_dag" +INSERT_ARGS = { + "commit_every": 1000, # Number of rows inserted in each batch + "executemany": True, # Enable batch inserts + "fast_executemany": True, # Boost performance for MSSQL inserts + "replace": True, # Used for upserts/merges if needed +} counter = 0 @@ -175,6 +182,44 @@ def test_postgres_to_postgres_replace(self, mock_insert, dag_maker): class TestGenericTransfer: + mocked_source_hook = MagicMock(conn_name_attr="my_source_conn_id", spec=DbApiHook) + mocked_destination_hook = MagicMock(conn_name_attr="my_destination_conn_id", spec=DbApiHook) + mocked_hooks = { + "my_source_conn_id": mocked_source_hook, + "my_destination_conn_id": mocked_destination_hook, + } + + @classmethod + def get_hook(cls, conn_id: str, hook_params: dict | None = None): + return cls.mocked_hooks[conn_id] + + @classmethod + def get_connection(cls, conn_id: str): + mocked_hook = cls.get_hook(conn_id=conn_id) + mocked_conn = MagicMock(conn_id=conn_id, spec=Connection) + mocked_conn.get_hook.return_value = mocked_hook + return mocked_conn + + def setup_method(self): + # Reset mock states before each test + self.mocked_source_hook.reset_mock() + self.mocked_destination_hook.reset_mock() + + # Set up the side effect for paginated read + records = [ + [[1, 2], [11, 12], [3, 4], [13, 14]], + [[3, 4], [13, 14]], + ] + + def get_records_side_effect(sql: str): + if records: + if "LIMIT" not in sql: + return list(flatten(records)) + return records.pop(0) + return [] + + self.mocked_source_hook.get_records.side_effect = get_records_side_effect + def test_templated_fields(self): dag = DAG( "test_dag", @@ -209,40 +254,37 @@ def test_templated_fields(self): assert operator.preoperator == "my_preoperator" assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True} + def test_non_paginated_read(self): + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection): + with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql="SELECT * FROM HR.EMPLOYEES", + destination_table="NEW_HR.EMPLOYEES", + insert_args=INSERT_ARGS, + execution_timeout=timedelta(hours=1), + ) + + operator.execute(context=mock_context(task=operator)) + + assert self.mocked_source_hook.get_records.call_count == 1 + assert self.mocked_source_hook.get_records.call_args_list[0].args[0] == "SELECT * FROM HR.EMPLOYEES" + assert self.mocked_destination_hook.insert_rows.call_count == 1 + assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { + **INSERT_ARGS, + **{"rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, + } + def test_paginated_read(self): """ This unit test is based on the example described in the medium article: https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f """ - def create_get_records_side_effect(): - records = [ - [[1, 2], [11, 12], [3, 4], [13, 14]], - [[3, 4], [13, 14]], - ] - - def side_effect(sql: str): - if records: - return records.pop(0) - return [] - - return side_effect - - get_records_side_effect = create_get_records_side_effect() - - def get_hook(conn_id: str, hook_params: dict | None = None): - mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook) - mocked_hook.get_records.side_effect = get_records_side_effect - return mocked_hook - - def get_connection(conn_id: str): - mocked_hook = get_hook(conn_id=conn_id) - mocked_conn = MagicMock(conn_id=conn_id, spec=Connection) - mocked_conn.get_hook.return_value = mocked_hook - return mocked_conn - - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection): - with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook): + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection): + with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook): operator = GenericTransfer( task_id="transfer_table", source_conn_id="my_source_conn_id", @@ -250,12 +292,7 @@ def get_connection(conn_id: str): sql="SELECT * FROM HR.EMPLOYEES", destination_table="NEW_HR.EMPLOYEES", page_size=1000, # Fetch data in chunks of 1000 rows for pagination - insert_args={ - "commit_every": 1000, # Number of rows inserted in each batch - "executemany": True, # Enable batch inserts - "fast_executemany": True, # Boost performance for MSSQL inserts - "replace": True, # Used for upserts/merges if needed - }, + insert_args=INSERT_ARGS, execution_timeout=timedelta(hours=1), ) @@ -267,6 +304,21 @@ def get_connection(conn_id: str): assert events[1].payload["results"] == [[3, 4], [13, 14]] assert not events[2].payload["results"] + assert self.mocked_source_hook.get_records.call_count == 3 + assert ( + self.mocked_source_hook.get_records.call_args_list[0].args[0] + == "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0" + ) + assert self.mocked_destination_hook.insert_rows.call_count == 2 + assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { + **INSERT_ARGS, + **{"rows": [[1, 2], [11, 12], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, + } + assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == { + **INSERT_ARGS, + **{"rows": [[3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, + } + def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self): """ Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher