diff --git a/providers/presto/src/airflow/providers/presto/hooks/presto.py b/providers/presto/src/airflow/providers/presto/hooks/presto.py index c41901080effd..a3c1ff058f3ba 100644 --- a/providers/presto/src/airflow/providers/presto/hooks/presto.py +++ b/providers/presto/src/airflow/providers/presto/hooks/presto.py @@ -20,7 +20,7 @@ import json import os from collections.abc import Iterable, Mapping -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, overload, Callable import prestodb from deprecated import deprecated @@ -178,6 +178,52 @@ def get_first( except DatabaseError as e: raise PrestoException(e) + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: None = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> None: ... + + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: Callable[[Any], T] = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ... + + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping[str, Any] | None = None, + handler: Callable[[Any], T] | None = None, + split_statements: bool = True, + return_last: bool = True, + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: + """ + Overwrite common run. Default split_statements = True + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :param autocommit: What to set the connection's autocommit setting to + before executing the query. + :param parameters: The parameters to render the SQL query with. + :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: if handler provided, returns query results (may be list of results depending on params) + """ + return super().run(sql, autocommit, parameters, handler, split_statements, return_last) + def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs): try: import pandas as pd diff --git a/providers/presto/tests/unit/presto/hooks/test_presto.py b/providers/presto/tests/unit/presto/hooks/test_presto.py index 75c61f697c19a..19ebb67d17b7d 100644 --- a/providers/presto/tests/unit/presto/hooks/test_presto.py +++ b/providers/presto/tests/unit/presto/hooks/test_presto.py @@ -297,3 +297,12 @@ def test_split_sql_string(self): def test_serialize_cell(self): assert self.db_hook._serialize_cell("foo", None) == "foo" assert self.db_hook._serialize_cell(1, None) == 1 + + @patch("airflow.providers.presto.hooks.presto.PrestoHook.run") + def test_run(self, mock_run): + sql = "SELECT 1" + autocommit = False + parameters = ("hello", "world") + handler = list + self.db_hook.run(sql, autocommit, parameters, list) + mock_run.assert_called_once_with(sql, autocommit, parameters, handler, split_statements=True) \ No newline at end of file diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py b/providers/trino/src/airflow/providers/trino/hooks/trino.py index 35bb83cce2a8a..47555908a87dd 100644 --- a/providers/trino/src/airflow/providers/trino/hooks/trino.py +++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py @@ -21,7 +21,7 @@ import os from collections.abc import Iterable, Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, overload, Callable from urllib.parse import quote_plus, urlencode import trino @@ -267,6 +267,52 @@ def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs): df = pd.DataFrame(**kwargs) return df + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: None = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> None: ... + + @overload + def run( + self, + sql: str | Iterable[str], + autocommit: bool = ..., + parameters: Iterable | Mapping[str, Any] | None = ..., + handler: Callable[[Any], T] = ..., + split_statements: bool = ..., + return_last: bool = ..., + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ... + + def run( + self, + sql: str | Iterable[str], + autocommit: bool = False, + parameters: Iterable | Mapping[str, Any] | None = None, + handler: Callable[[Any], T] | None = None, + split_statements: bool = True, + return_last: bool = True, + ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: + """ + Overwrite common run. Default split_statements = True + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :param autocommit: What to set the connection's autocommit setting to + before executing the query. + :param parameters: The parameters to render the SQL query with. + :param handler: The result handler which is called with the result of each statement. + :param split_statements: Whether to split a single SQL string into statements and run separately + :param return_last: Whether to return result for only last statement or for all after split + :return: if handler provided, returns query results (may be list of results depending on params) + """ + return super().run(sql, autocommit, parameters, handler, split_statements, return_last) + def _get_polars_df(self, sql: str = "", parameters=None, **kwargs): try: import polars as pl diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py b/providers/trino/tests/unit/trino/hooks/test_trino.py index 02966d2e919f3..5abd6f37252c6 100644 --- a/providers/trino/tests/unit/trino/hooks/test_trino.py +++ b/providers/trino/tests/unit/trino/hooks/test_trino.py @@ -396,7 +396,7 @@ def test_run(self, mock_run): parameters = ("hello", "world") handler = list self.db_hook.run(sql, autocommit, parameters, list) - mock_run.assert_called_once_with(sql, autocommit, parameters, handler) + mock_run.assert_called_once_with(sql, autocommit, parameters, handler, split_statements=True) def test_connection_success(self): status, msg = self.db_hook.test_connection()