Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/apache/druid/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ dev = [
"apache-airflow-providers-apache-hive",
"apache-airflow-providers-common-sql",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-common-sql[polars]",
"apache-airflow-providers-common-sql[pandas,polars]",
]

# To build docs:
Expand Down
3 changes: 1 addition & 2 deletions providers/apache/hive/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ Requirements
PIP package Version required
======================================= ==================
``apache-airflow`` ``>=2.10.0``
``apache-airflow-providers-common-sql`` ``>=1.20.0``
``apache-airflow-providers-common-sql`` ``>=1.26.0``
``hmsclient`` ``>=0.1.0``
``pandas`` ``>=2.1.2,<2.2``
``pyhive[hive_pure_sasl]`` ``>=0.7.0``
``thrift`` ``>=0.11.0``
``jmespath`` ``>=0.7.0``
Expand Down
8 changes: 2 additions & 6 deletions providers/apache/hive/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,8 @@ requires-python = "~=3.9"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow-providers-common-sql>=1.26.0",
"hmsclient>=0.1.0",
# In pandas 2.2 minimal version of the sqlalchemy is 2.0
# https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies
# However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723
# In addition FAB also limit sqlalchemy to < 2.0
"pandas>=2.1.2,<2.2",
"pyhive[hive_pure_sasl]>=0.7.0",
"thrift>=0.11.0",
"jmespath>=0.7.0",
Expand Down Expand Up @@ -109,6 +104,7 @@ dev = [
"apache-airflow-providers-samba",
"apache-airflow-providers-vertica",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-common-sql[pandas,polars]",
]

# To build docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Any

from deprecated import deprecated
from typing_extensions import Literal

if TYPE_CHECKING:
import pandas as pd
import polars as pl

import csv

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.common.sql.hooks.sql import DbApiHook
Expand Down Expand Up @@ -1031,37 +1035,84 @@ def get_records(
schema = kwargs["schema"] if "schema" in kwargs else "default"
return self.get_results(sql, schema=schema, hive_conf=parameters)["data"]

def get_pandas_df( # type: ignore
def _get_pandas_df( # type: ignore
self,
sql: str,
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pd.DataFrame:
try:
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df

def _get_polars_df( # type: ignore
self,
sql: str,
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pl.DataFrame:
try:
import polars as pl
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pl.DataFrame(res["data"], schema=[c[0] for c in res["header"]], orient="row", **kwargs)
return df

def get_df( # type: ignore
self,
sql: str,
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
*,
df_type: Literal["pandas", "polars"] = "pandas",
**kwargs,
) -> pd.DataFrame | pl.DataFrame:
"""
Get a pandas dataframe from a Hive query.
Get a pandas / polars dataframe from a Hive query.

:param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param hive_conf: hive_conf to execute alone with the hql.
:param df_type: type of dataframe to return, either 'pandas' or 'polars'
:param kwargs: (optional) passed into pandas.DataFrame constructor
:return: result of hive execution

>>> hh = HiveServer2Hook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
>>> df = hh.get_pandas_df(sql)
>>> df = hh.get_df(sql, df_type="pandas")
>>> len(df.index)
100

:return: pandas.DateFrame
:return: pandas.DateFrame | polars.DataFrame
"""
try:
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df
if df_type == "pandas":
return self._get_pandas_df(sql, schema=schema, hive_conf=hive_conf, **kwargs)
if df_type == "polars":
return self._get_polars_df(sql, schema=schema, hive_conf=hive_conf, **kwargs)

@deprecated(
reason="Replaced by function `get_df`.",
category=AirflowProviderDeprecationWarning,
action="ignore",
)
def get_pandas_df( # type: ignore
self,
sql: str,
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pd.DataFrame:
return self._get_pandas_df(sql, schema=schema, hive_conf=hive_conf, **kwargs)
12 changes: 8 additions & 4 deletions providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,8 @@ def test_get_records(self):
hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_owner=airflow")
hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_email=test@airflow.com")

def test_get_pandas_df(self):
@pytest.mark.parametrize("df_type", ["pandas", "polars"])
def test_get_df(self, df_type):
hook = MockHiveServer2Hook()
query = f"SELECT * FROM {self.table}"

Expand All @@ -731,10 +732,13 @@ def test_get_pandas_df(self):
"AIRFLOW_CTX_DAG_EMAIL": "test@airflow.com",
},
):
df = hook.get_pandas_df(query, schema=self.database)
df = hook.get_df(query, schema=self.database, df_type=df_type)

assert len(df) == 2
assert df["hive_server_hook.a"].values.tolist() == [1, 2]
if df_type == "pandas":
assert df["hive_server_hook.a"].values.tolist() == [1, 2]
elif df_type == "polars":
assert df["hive_server_hook.a"].to_list() == [1, 2]
date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
hook.get_conn.assert_called_with(self.database)
hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_id=test_dag_id")
Expand All @@ -747,7 +751,7 @@ def test_get_pandas_df(self):
hook = MockHiveServer2Hook(connection_cursor=EmptyMockConnectionCursor())
query = f"SELECT * FROM {self.table}"

df = hook.get_pandas_df(query, schema=self.database)
df = hook.get_df(query, schema=self.database, df_type=df_type)

assert len(df) == 0

Expand Down