Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions providers/presto/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ Requirements
PIP package Version required
======================================= ==================
``apache-airflow`` ``>=2.10.0``
``apache-airflow-providers-common-sql`` ``>=1.20.0``
``apache-airflow-providers-common-sql`` ``>=1.26.0``
``presto-python-client`` ``>=0.8.4``
``pandas`` ``>=2.1.2,<2.2``
======================================= ==================

Cross provider package dependencies
Expand Down
8 changes: 2 additions & 6 deletions providers/presto/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,8 @@ requires-python = "~=3.9"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow-providers-common-sql>=1.26.0",
"presto-python-client>=0.8.4",
# In pandas 2.2 minimal version of the sqlalchemy is 2.0
# https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies
# However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723
# In addition FAB also limit sqlalchemy to < 2.0
"pandas>=2.1.2,<2.2",
]

# The optional dependencies should be modified in place in the generated file
Expand All @@ -82,6 +77,7 @@ dev = [
"apache-airflow-providers-common-sql",
"apache-airflow-providers-google",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-common-sql[pandas,polars]",
]

# To build docs:
Expand Down
50 changes: 47 additions & 3 deletions providers/presto/src/airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
from typing import TYPE_CHECKING, Any, TypeVar

import prestodb
from deprecated import deprecated
from prestodb.exceptions import DatabaseError
from prestodb.transaction import IsolationLevel

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import (
AirflowException,
AirflowOptionalProviderFeatureException,
AirflowProviderDeprecationWarning,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.presto.version_compat import AIRFLOW_V_3_0_PLUS

Expand Down Expand Up @@ -173,8 +178,13 @@ def get_first(
except DatabaseError as e:
raise PrestoException(e)

def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
import pandas as pd
def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
try:
import pandas as pd
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Pandas is not installed. Please install it with `pip install pandas`."
)

cursor = self.get_cursor()
try:
Expand All @@ -190,6 +200,40 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
df = pd.DataFrame(**kwargs)
return df

def _get_polars_df(self, sql: str = "", parameters=None, **kwargs):
try:
import polars as pl
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Polars is not installed. Please install it with `pip install polars`."
)

cursor = self.get_cursor()
try:
cursor.execute(self.strip_sql_string(sql), parameters)
data = cursor.fetchall()
except DatabaseError as e:
raise PrestoException(e)
column_descriptions = cursor.description
if data:
df = pl.DataFrame(
data,
schema=[c[0] for c in column_descriptions],
orient="row",
**kwargs,
)
else:
df = pl.DataFrame(**kwargs)
return df

@deprecated(
reason="Replaced by function `get_df`.",
category=AirflowProviderDeprecationWarning,
action="ignore",
)
def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
return self._get_pandas_df(sql, parameters, **kwargs)

def insert_rows(
self,
table: str,
Expand Down
17 changes: 10 additions & 7 deletions providers/presto/tests/unit/presto/hooks/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,21 @@ def test_get_records(self):
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)

def test_get_pandas_df(self):
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_df(self, df_type):
statement = "SQL"
column = "col"
result_sets = [("row1",), ("row2",)]
self.cur.description = [(column,)]
self.cur.description = [(column, None, None, None, None, None, None)]
self.cur.fetchall.return_value = result_sets
df = self.db_hook.get_pandas_df(statement)

df = self.db_hook.get_df(statement, df_type=df_type)
assert column == df.columns[0]

assert result_sets[0][0] == df.values.tolist()[0][0]
assert result_sets[1][0] == df.values.tolist()[1][0]
if df_type == "pandas":
assert result_sets[0][0] == df.values.tolist()[0][0]
assert result_sets[1][0] == df.values.tolist()[1][0]
else:
assert result_sets[0][0] == df.row(0)[0]
assert result_sets[1][0] == df.row(1)[0]

self.cur.execute.assert_called_once_with(statement, None)

Expand Down
3 changes: 1 addition & 2 deletions providers/trino/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ Requirements
PIP package Version required
======================================= ==================
``apache-airflow`` ``>=2.10.0``
``apache-airflow-providers-common-sql`` ``>=1.20.0``
``pandas`` ``>=2.1.2,<2.2``
``apache-airflow-providers-common-sql`` ``>=1.26.0``
``trino`` ``>=0.319.0``
======================================= ==================

Expand Down
8 changes: 2 additions & 6 deletions providers/trino/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ requires-python = "~=3.9"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-sql>=1.20.0",
# In pandas 2.2 minimal version of the sqlalchemy is 2.0
# https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies
# However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723
# In addition FAB also limit sqlalchemy to < 2.0
"pandas>=2.1.2,<2.2",
"apache-airflow-providers-common-sql>=1.26.0",
"trino>=0.319.0",
]

Expand All @@ -86,6 +81,7 @@ dev = [
"apache-airflow-providers-google",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-common-sql[pandas,polars]",
]

# To build docs:
Expand Down
50 changes: 47 additions & 3 deletions providers/trino/src/airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
from urllib.parse import quote_plus, urlencode

import trino
from deprecated import deprecated
from trino.exceptions import DatabaseError
from trino.transaction import IsolationLevel

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import (
AirflowException,
AirflowOptionalProviderFeatureException,
AirflowProviderDeprecationWarning,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.trino.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.helpers import exactly_one
Expand Down Expand Up @@ -238,8 +243,13 @@ def get_first(
except DatabaseError as e:
raise TrinoException(e)

def get_pandas_df(self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None = None, **kwargs): # type: ignore[override]
import pandas as pd
def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
try:
import pandas as pd
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Pandas is not installed. Please install it with `pip install pandas`."
)

cursor = self.get_cursor()
try:
Expand All @@ -255,6 +265,40 @@ def get_pandas_df(self, sql: str = "", parameters: Iterable | Mapping[str, Any]
df = pd.DataFrame(**kwargs)
return df

def _get_polars_df(self, sql: str = "", parameters=None, **kwargs):
try:
import polars as pl
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Polars is not installed. Please install it with `pip install polars`."
)

cursor = self.get_cursor()
try:
cursor.execute(self.strip_sql_string(sql), parameters)
data = cursor.fetchall()
except DatabaseError as e:
raise TrinoException(e)
column_descriptions = cursor.description
if data:
df = pl.DataFrame(
data,
schema=[c[0] for c in column_descriptions],
orient="row",
**kwargs,
)
else:
df = pl.DataFrame(**kwargs)
return df

@deprecated(
reason="Replaced by function `get_df`.",
category=AirflowProviderDeprecationWarning,
action="ignore",
)
def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
return self._get_pandas_df(sql, parameters, **kwargs)

def insert_rows(
self,
table: str,
Expand Down
17 changes: 10 additions & 7 deletions providers/trino/tests/unit/trino/hooks/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,18 +341,21 @@ def test_get_records(self):
self.cur.close.assert_called_once_with()
self.cur.execute.assert_called_once_with(statement)

def test_get_pandas_df(self):
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_df(self, df_type):
statement = "SQL"
column = "col"
result_sets = [("row1",), ("row2",)]
self.cur.description = [(column,)]
self.cur.description = [(column, None, None, None, None, None, None)]
self.cur.fetchall.return_value = result_sets
df = self.db_hook.get_pandas_df(statement)

df = self.db_hook.get_df(statement, df_type=df_type)
assert column == df.columns[0]

assert result_sets[0][0] == df.values.tolist()[0][0]
assert result_sets[1][0] == df.values.tolist()[1][0]
if df_type == "pandas":
assert result_sets[0][0] == df.values.tolist()[0][0]
assert result_sets[1][0] == df.values.tolist()[1][0]
else:
assert result_sets[0][0] == df.row(0)[0]
assert result_sets[1][0] == df.row(1)[0]

self.cur.execute.assert_called_once_with(statement, None)

Expand Down
Loading