diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py b/providers/trino/src/airflow/providers/trino/hooks/trino.py index b59ba32cb59f7..28d23423f6cb1 100644 --- a/providers/trino/src/airflow/providers/trino/hooks/trino.py +++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py @@ -19,9 +19,9 @@ import json import os -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, overload from urllib.parse import quote_plus, urlencode import trino @@ -277,6 +277,50 @@ def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs): df = pd.DataFrame(**kwargs) return df + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: None = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> None: ... + + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: Callable[[Any], T] = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ... + + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping[str, Any] | None = None, + handler: Callable[[Any], T] | None = None, + split_statements: bool = True, + return_last: bool = True, + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: + """ + Override common run to set split_statements=True by default. + + :param sql: SQL statement or list of statements to execute. + :param autocommit: Set autocommit mode before query execution. + :param parameters: Parameters to render the SQL query with. + :param handler: Optional callable to process each statement result. + :param split_statements: Split single SQL string into statements if True. + :param return_last: Return only last statement result if True. + :return: Query result or list of results. + """ + return super().run(sql, autocommit, parameters, handler, split_statements, return_last) + def _get_polars_df(self, sql: str = "", parameters=None, **kwargs): try: import polars as pl diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py b/providers/trino/tests/unit/trino/hooks/test_trino.py index 75bcd6c0e8684..3623a4e971443 100644 --- a/providers/trino/tests/unit/trino/hooks/test_trino.py +++ b/providers/trino/tests/unit/trino/hooks/test_trino.py @@ -401,6 +401,38 @@ def test_run(self, mock_run): self.db_hook.run(sql, autocommit, parameters, list) mock_run.assert_called_once_with(sql, autocommit, parameters, handler) + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run") + def test_run_defaults_no_handler(self, super_run): + super_run.return_value = None + sql = "SELECT 1" + result = self.db_hook.run(sql) + assert result is None + super_run.assert_called_once_with(sql, False, None, None, True, True) + + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run") + def test_run_with_handler_and_params(self, super_run): + super_run.return_value = [("ok",)] + sql = "SELECT 1" + autocommit = True + parameters = ("hello", "world") + handler = list + res = self.db_hook.run( + sql, + autocommit=autocommit, + parameters=parameters, + handler=handler, + split_statements=False, + return_last=False, + ) + assert res == [("ok",)] + super_run.assert_called_once_with(sql, True, parameters, handler, False, False) + + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run") + def test_run_multistatement_defaults_to_split(self, super_run): + sql = "SELECT 1; SELECT 2" + self.db_hook.run(sql) + super_run.assert_called_once_with(sql, False, None, None, True, True) + def test_connection_success(self): status, msg = self.db_hook.test_connection() assert status is True