diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 84bc367a94c6d..cadf722d3ae07 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -75,7 +75,7 @@ class GenericTransfer(BaseOperator): def __init__( self, *, - sql: str, + sql: str | list[str], destination_table: str, source_conn_id: str, source_hook_params: dict | None = None, @@ -156,13 +156,19 @@ def execute(self, context: Context): method_name=self.execute_complete.__name__, ) else: + if isinstance(self.sql, str): + self.sql = [self.sql] + self.log.info("Extracting data from %s", self.source_conn_id) - self.log.info("Executing: \n %s", self.sql) + for sql in self.sql: + self.log.info("Executing: \n %s", sql) - results = self.source_hook.get_records(self.sql) + results = self.source_hook.get_records(sql) - self.log.info("Inserting rows into %s", self.destination_conn_id) - self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args) + self.log.info("Inserting rows into %s", self.destination_conn_id) + self.destination_hook.insert_rows( + table=self.destination_table, rows=results, **self.insert_args + ) def execute_complete( self, diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi index 64606e44aa679..06daa77ee747c 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi @@ -57,7 +57,7 @@ class GenericTransfer(BaseOperator): def __init__( self, *, - sql: str, + sql: str | list[str], destination_table: str, source_conn_id: str, source_hook_params: dict | None = None, diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index ee82b5440162b..d6b16957f8730 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -289,6 +289,38 @@ def test_non_paginated_read(self): **{"rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, } + def test_non_paginated_read_for_multiple_sql_statements(self): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql=["SELECT * FROM HR.EMPLOYEES", "SELECT * FROM HR.PEOPLE"], + destination_table="NEW_HR.EMPLOYEES", + insert_args=INSERT_ARGS, + execution_timeout=timedelta(hours=1), + ) + + operator.execute(context=mock_context(task=operator)) + + assert self.mocked_source_hook.get_records.call_count == 2 + assert [call.args[0] for call in self.mocked_source_hook.get_records.call_args_list] == [ + "SELECT * FROM HR.EMPLOYEES", + "SELECT * FROM HR.PEOPLE", + ] + assert self.mocked_destination_hook.insert_rows.call_count == 2 + assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { + **INSERT_ARGS, + "rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], + "table": "NEW_HR.EMPLOYEES", + } + assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == { + **INSERT_ARGS, + "rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], + "table": "NEW_HR.EMPLOYEES", + } + def test_paginated_read(self): """ This unit test is based on the example described in the medium article: