diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 7e02a47c25b39..ed1e30674c36f 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -224,6 +224,9 @@ class SQLExecuteQueryOperator(BaseSQLOperator): :param return_last: (optional) return the result of only last statement (default: True). :param show_return_value_in_logs: (optional) if true operator output will be printed to the task log. Use with caution. It's not recommended to dump large datasets to the log. (default: False). + :param requires_result_fetch: (optional) if True, ensures that query results are fetched before + completing execution. If `do_xcom_push` is True, results are fetched automatically, + making this parameter redundant. (default: False). .. seealso:: For more information on how to use this operator, take a look at the guide: @@ -254,6 +257,7 @@ def __init__( split_statements: bool | None = None, return_last: bool = True, show_return_value_in_logs: bool = False, + requires_result_fetch: bool = False, **kwargs, ) -> None: super().__init__(conn_id=conn_id, database=database, **kwargs) @@ -265,6 +269,7 @@ def __init__( self.split_statements = split_statements self.return_last = return_last self.show_return_value_in_logs = show_return_value_in_logs + self.requires_result_fetch = requires_result_fetch def _process_output( self, results: list[Any], descriptions: list[Sequence[Sequence] | None] @@ -303,7 +308,9 @@ def execute(self, context): sql=self.sql, autocommit=self.autocommit, parameters=self.parameters, - handler=self.handler if self._should_run_output_processing() else None, + handler=self.handler + if self._should_run_output_processing() or self.requires_result_fetch + else None, return_last=self.return_last, **extra_kwargs, ) diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 28c18a32cc55d..2e93686ac8908 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -124,8 +124,11 @@ def _construct_operator(self, sql, **kwargs): @mock.patch.object(SQLExecuteQueryOperator, "_process_output") @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") - def test_do_xcom_push(self, mock_get_db_hook, mock_process_output): - operator = self._construct_operator("SELECT 1;", do_xcom_push=True) + @pytest.mark.parametrize("requires_result_fetch", [True, False]) + def test_do_xcom_push(self, mock_get_db_hook, mock_process_output, requires_result_fetch): + operator = self._construct_operator( + "SELECT 1;", do_xcom_push=True, requires_result_fetch=requires_result_fetch + ) operator.execute(context=MagicMock()) mock_get_db_hook.return_value.run.assert_called_once_with( @@ -152,6 +155,21 @@ def test_dont_xcom_push(self, mock_get_db_hook, mock_process_output): ) mock_process_output.assert_not_called() + @mock.patch.object(SQLExecuteQueryOperator, "_process_output") + @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") + def test_requires_result_fetch_dont_xcom_push(self, mock_get_db_hook, mock_process_output): + operator = self._construct_operator("SELECT 1;", requires_result_fetch=True, do_xcom_push=False) + operator.execute(context=MagicMock()) + + mock_get_db_hook.return_value.run.assert_called_once_with( + sql="SELECT 1;", + autocommit=False, + handler=fetch_all_handler, + parameters=None, + return_last=True, + ) + mock_process_output.assert_not_called() + @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") def test_output_processor(self, mock_get_db_hook): data = [(1, "Alice"), (2, "Bob")] diff --git a/providers/trino/tests/system/trino/example_trino.py b/providers/trino/tests/system/trino/example_trino.py index db9fef4128b93..611ed68d96a11 100644 --- a/providers/trino/tests/system/trino/example_trino.py +++ b/providers/trino/tests/system/trino/example_trino.py @@ -36,35 +36,36 @@ with models.DAG( dag_id="example_trino", schedule="@once", # Override to match your needs - start_date=datetime(2022, 1, 1), + start_date=datetime(2025, 2, 24), catchup=False, tags=["example"], ) as dag: trino_create_schema = SQLExecuteQueryOperator( task_id="trino_create_schema", - sql=f"CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://irisbkt/cities/');", + sql=f" CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://irisbkt/cities/') ", handler=list, ) trino_create_table = SQLExecuteQueryOperator( task_id="trino_create_table", - sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}( - cityid bigint, - cityname varchar - )""", + sql=f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}( cityid bigint, cityname varchar) ", handler=list, ) trino_insert = SQLExecuteQueryOperator( task_id="trino_insert", - sql=f"""INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco');""", + sql=f" INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco') ", handler=list, + requires_result_fetch=True, ) trino_multiple_queries = SQLExecuteQueryOperator( task_id="trino_multiple_queries", - sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar); - INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose'); - CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar); - INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego');""", + sql=[ + f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar) ", + f" INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose') ", + f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar) ", + f" INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego') ", + ], handler=list, + requires_result_fetch=True, ) trino_templated_query = SQLExecuteQueryOperator( task_id="trino_templated_query", @@ -74,7 +75,7 @@ ) trino_parameterized_query = SQLExecuteQueryOperator( task_id="trino_parameterized_query", - sql=f"select * from {SCHEMA}.{TABLE2} where cityname = ?", + sql=f" SELECT * FROM {SCHEMA}.{TABLE2} WHERE cityname = ?", parameters=("San Diego",), handler=list, )