diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 45c36762f90bd..c7839b28aec02 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -162,7 +162,7 @@ def execute(self, context: Context): self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) - results = self.destination_hook.get_records(self.sql) + results = self.source_hook.get_records(self.sql) self.log.info("Inserting rows into %s", self.destination_conn_id) self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args) diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index 12bc6e9295049..f3db9fdd9ad45 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -209,6 +209,53 @@ def test_templated_fields(self): assert operator.preoperator == "my_preoperator" assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True} + def test_not_paginated_transfer(self): + + mocked_source_hook = mock.MagicMock(conn_name_attr='my_source_conn_id', spec=DbApiHook) + mocked_destination_hook = mock.MagicMock(conn_name_attr='my_destination_conn_id', spec=DbApiHook) + + def get_hook(conn_id: str, hook_params: dict | None = None): + return { + 'my_source_conn_id': mocked_source_hook, + 'my_destination_conn_id': mocked_destination_hook + }[conn_id] + + def get_connection(conn_id: str): + mocked_hook = get_hook(conn_id=conn_id) + mocked_conn = mock.MagicMock(conn_id=conn_id, spec=Connection) + mocked_conn.get_hook.return_value = mocked_hook + return mocked_conn + + sql_statement = "SELECT * FROM generic_transfer" + preoperator_statements = [ + "DROP TABLE IF EXISTS test_generic_transfer", + "CREATE TABLE test_generic_transfer(LIKE generic_transfer INCLUDING INDEXES)" + ] + destination_table = "test_generic_transfer" + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql=sql_statement, + preoperator=preoperator_statements, + destination_table=destination_table + ) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection): + with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook): + execute_operator(operator) + + assert mocked_destination_hook.run.call_count == 1 + assert mocked_destination_hook.run.call_args_list[0].args[0] == preoperator_statements + assert not mocked_source_hook.run.called + + assert mocked_source_hook.get_records.call_count == 1 + assert mocked_source_hook.get_records.call_args_list[0].args[0] == sql_statement + assert not mocked_destination_hook.get_records.called + + assert mocked_destination_hook.insert_rows.call_count == 1 + assert mocked_destination_hook.insert_rows.call_args_list[0].kwargs['table'] == destination_table + assert not mocked_source_hook.insert_rows.called + def test_paginated_read(self): """ This unit test is based on the example described in the medium article: @@ -228,12 +275,15 @@ def side_effect(sql: str): return side_effect - get_records_side_effect = create_get_records_side_effect() + mocked_source_hook = mock.MagicMock(conn_name_attr='my_source_conn_id', spec=DbApiHook) + mocked_source_hook.get_records.side_effect = create_get_records_side_effect() + mocked_destination_hook = mock.MagicMock(conn_name_attr='my_destination_conn_id', spec=DbApiHook) def get_hook(conn_id: str, hook_params: dict | None = None): - mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook) - mocked_hook.get_records.side_effect = get_records_side_effect - return mocked_hook + return { + 'my_source_conn_id': mocked_source_hook, + 'my_destination_conn_id': mocked_destination_hook + }[conn_id] def get_connection(conn_id: str): mocked_hook = get_hook(conn_id=conn_id) @@ -266,6 +316,8 @@ def get_connection(conn_id: str): assert events[0].payload["results"] == [[1, 2], [11, 12], [3, 4], [13, 14]] assert events[1].payload["results"] == [[3, 4], [13, 14]] assert not events[2].payload["results"] + assert mocked_source_hook.get_records.called + assert mocked_destination_hook.insert_rows.called def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self): """