diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 6caa2000059e9..2a85e462cf14e 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -1359,7 +1359,7 @@ def column_names(self) -> list[str]: return self.columns def _process_rows(self, context: Context): - return self._rows_processor(context, self.rows) # type: ignore + return self._rows_processor(self.rows, **context) # type: ignore def execute(self, context: Context) -> Any: if not self.rows: diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index c395a78b3618d..a5c1d3819f06b 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -43,6 +43,7 @@ SQLCheckOperator, SQLColumnCheckOperator, SQLExecuteQueryOperator, + SQLInsertRowsOperator, SQLIntervalCheckOperator, SQLTableCheckOperator, SQLThresholdCheckOperator, @@ -1577,3 +1578,32 @@ def test_new_style_subclass(self, mock_get_connection, operator_class): mock_get_connection.return_value.get_hook.return_value = MagicMock(spec=DbApiHook) op.get_db_hook() mock_get_connection.assert_called_once_with("test_conn") + + +class TestSQLInsertRowsOperator: + @mock.patch.object(SQLInsertRowsOperator, "get_db_hook") + def test_rows_processor(self, mock_get_db_hook): + operator = SQLInsertRowsOperator( + task_id="test_task", + conn_id="default_conn", + schema="hollywood", + table_name="actors", + rows=[ + {"index": 1, "name": "Stallone", "firstname": "Sylvester", "age": 78}, + {"index": 2, "name": "Statham", "firstname": "Jason", "age": 57}, + {"index": 3, "name": "Li", "firstname": "Jet", "age": 61}, + {"index": 4, "name": "Lundgren", "firstname": "Dolph", "age": 66}, + {"index": 5, "name": "Norris", "firstname": "Chuck", "age": 84}, + ], + rows_processor=lambda rows, **context: map(lambda row: tuple(row.values()), rows), + ) + + processed_rows = list(operator._process_rows({})) + + assert processed_rows == [ + (1, "Stallone", "Sylvester", 78), + (2, "Statham", "Jason", 57), + (3, "Li", "Jet", 61), + (4, "Lundgren", "Dolph", 66), + (5, "Norris", "Chuck", 84), + ]