Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions providers/trino/src/airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import json
import os
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, overload
from urllib.parse import quote_plus, urlencode

import trino
Expand Down Expand Up @@ -277,6 +277,50 @@ def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
df = pd.DataFrame(**kwargs)
return df

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> None: ...

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...

def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""
Override common run to set split_statements=True by default.

:param sql: SQL statement or list of statements to execute.
:param autocommit: Set autocommit mode before query execution.
:param parameters: Parameters to render the SQL query with.
:param handler: Optional callable to process each statement result.
:param split_statements: Split single SQL string into statements if True.
:param return_last: Return only last statement result if True.
:return: Query result or list of results.
"""
return super().run(sql, autocommit, parameters, handler, split_statements, return_last)

def _get_polars_df(self, sql: str = "", parameters=None, **kwargs):
try:
import polars as pl
Expand Down
32 changes: 32 additions & 0 deletions providers/trino/tests/unit/trino/hooks/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,38 @@ def test_run(self, mock_run):
self.db_hook.run(sql, autocommit, parameters, list)
mock_run.assert_called_once_with(sql, autocommit, parameters, handler)

@patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
def test_run_defaults_no_handler(self, super_run):
super_run.return_value = None
sql = "SELECT 1"
result = self.db_hook.run(sql)
assert result is None
super_run.assert_called_once_with(sql, False, None, None, True, True)

@patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
def test_run_with_handler_and_params(self, super_run):
super_run.return_value = [("ok",)]
sql = "SELECT 1"
autocommit = True
parameters = ("hello", "world")
handler = list
res = self.db_hook.run(
sql,
autocommit=autocommit,
parameters=parameters,
handler=handler,
split_statements=False,
return_last=False,
)
assert res == [("ok",)]
super_run.assert_called_once_with(sql, True, parameters, handler, False, False)

@patch("airflow.providers.common.sql.hooks.sql.DbApiHook.run")
def test_run_multistatement_defaults_to_split(self, super_run):
sql = "SELECT 1; SELECT 2"
self.db_hook.run(sql)
super_run.assert_called_once_with(sql, False, None, None, True, True)

def test_connection_success(self):
status, msg = self.db_hook.test_connection()
assert status is True
Expand Down