diff --git a/providers/apache/druid/pyproject.toml b/providers/apache/druid/pyproject.toml index 57ae28e3e7a37..95e4ab3fb6b06 100644 --- a/providers/apache/druid/pyproject.toml +++ b/providers/apache/druid/pyproject.toml @@ -77,7 +77,7 @@ dev = [ "apache-airflow-providers-apache-hive", "apache-airflow-providers-common-sql", # Additional devel dependencies (do not remove this line and add extra development dependencies) - "apache-airflow-providers-common-sql[polars]", + "apache-airflow-providers-common-sql[pandas,polars]", ] # To build docs: diff --git a/providers/apache/hive/README.rst b/providers/apache/hive/README.rst index a13601bba83ff..3711d634df971 100644 --- a/providers/apache/hive/README.rst +++ b/providers/apache/hive/README.rst @@ -54,9 +54,8 @@ Requirements PIP package Version required ======================================= ================== ``apache-airflow`` ``>=2.10.0`` -``apache-airflow-providers-common-sql`` ``>=1.20.0`` +``apache-airflow-providers-common-sql`` ``>=1.26.0`` ``hmsclient`` ``>=0.1.0`` -``pandas`` ``>=2.1.2,<2.2`` ``pyhive[hive_pure_sasl]`` ``>=0.7.0`` ``thrift`` ``>=0.11.0`` ``jmespath`` ``>=0.7.0`` diff --git a/providers/apache/hive/pyproject.toml b/providers/apache/hive/pyproject.toml index 8a445bce85688..01106c9355748 100644 --- a/providers/apache/hive/pyproject.toml +++ b/providers/apache/hive/pyproject.toml @@ -58,13 +58,8 @@ requires-python = "~=3.9" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.10.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.26.0", "hmsclient>=0.1.0", - # In pandas 2.2 minimal version of the sqlalchemy is 2.0 - # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies - # However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723 - # In addition FAB also limit sqlalchemy to < 2.0 - "pandas>=2.1.2,<2.2", "pyhive[hive_pure_sasl]>=0.7.0", "thrift>=0.11.0", "jmespath>=0.7.0", @@ -109,6 +104,7 @@ dev = [ "apache-airflow-providers-samba", "apache-airflow-providers-vertica", # Additional devel dependencies (do not remove this line and add extra development dependencies) + "apache-airflow-providers-common-sql[pandas,polars]", ] # To build docs: diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py index 5c5b2ae2c669d..c95b2d09e386e 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py @@ -27,13 +27,17 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import TYPE_CHECKING, Any +from deprecated import deprecated +from typing_extensions import Literal + if TYPE_CHECKING: import pandas as pd + import polars as pl import csv from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -1031,37 +1035,84 @@ def get_records( schema = kwargs["schema"] if "schema" in kwargs else "default" return self.get_results(sql, schema=schema, hive_conf=parameters)["data"] - def get_pandas_df( # type: ignore + def _get_pandas_df( # type: ignore self, sql: str, schema: str = "default", hive_conf: dict[Any, Any] | None = None, **kwargs, ) -> pd.DataFrame: + try: + import pandas as pd + except ImportError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + + res = self.get_results(sql, schema=schema, hive_conf=hive_conf) + df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs) + return df + + def _get_polars_df( # type: ignore + self, + sql: str, + schema: str = "default", + hive_conf: dict[Any, Any] | None = None, + **kwargs, + ) -> pl.DataFrame: + try: + import polars as pl + except ImportError as e: + from airflow.exceptions import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + + res = self.get_results(sql, schema=schema, hive_conf=hive_conf) + df = pl.DataFrame(res["data"], schema=[c[0] for c in res["header"]], orient="row", **kwargs) + return df + + def get_df( # type: ignore + self, + sql: str, + schema: str = "default", + hive_conf: dict[Any, Any] | None = None, + *, + df_type: Literal["pandas", "polars"] = "pandas", + **kwargs, + ) -> pd.DataFrame | pl.DataFrame: """ - Get a pandas dataframe from a Hive query. + Get a pandas / polars dataframe from a Hive query. :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param hive_conf: hive_conf to execute alone with the hql. + :param df_type: type of dataframe to return, either 'pandas' or 'polars' :param kwargs: (optional) passed into pandas.DataFrame constructor :return: result of hive execution >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" - >>> df = hh.get_pandas_df(sql) + >>> df = hh.get_df(sql, df_type="pandas") >>> len(df.index) 100 - :return: pandas.DateFrame + :return: pandas.DateFrame | polars.DataFrame """ - try: - import pandas as pd - except ImportError as e: - from airflow.exceptions import AirflowOptionalProviderFeatureException - - raise AirflowOptionalProviderFeatureException(e) - - res = self.get_results(sql, schema=schema, hive_conf=hive_conf) - df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs) - return df + if df_type == "pandas": + return self._get_pandas_df(sql, schema=schema, hive_conf=hive_conf, **kwargs) + if df_type == "polars": + return self._get_polars_df(sql, schema=schema, hive_conf=hive_conf, **kwargs) + + @deprecated( + reason="Replaced by function `get_df`.", + category=AirflowProviderDeprecationWarning, + action="ignore", + ) + def get_pandas_df( # type: ignore + self, + sql: str, + schema: str = "default", + hive_conf: dict[Any, Any] | None = None, + **kwargs, + ) -> pd.DataFrame: + return self._get_pandas_df(sql, schema=schema, hive_conf=hive_conf, **kwargs) diff --git a/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py b/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py index 0f239bbe9892a..39efcb71ebe8c 100644 --- a/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py +++ b/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py @@ -715,7 +715,8 @@ def test_get_records(self): hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_owner=airflow") hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_email=test@airflow.com") - def test_get_pandas_df(self): + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_get_df(self, df_type): hook = MockHiveServer2Hook() query = f"SELECT * FROM {self.table}" @@ -731,10 +732,13 @@ def test_get_pandas_df(self): "AIRFLOW_CTX_DAG_EMAIL": "test@airflow.com", }, ): - df = hook.get_pandas_df(query, schema=self.database) + df = hook.get_df(query, schema=self.database, df_type=df_type) assert len(df) == 2 - assert df["hive_server_hook.a"].values.tolist() == [1, 2] + if df_type == "pandas": + assert df["hive_server_hook.a"].values.tolist() == [1, 2] + elif df_type == "polars": + assert df["hive_server_hook.a"].to_list() == [1, 2] date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" hook.get_conn.assert_called_with(self.database) hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_id=test_dag_id") @@ -747,7 +751,7 @@ def test_get_pandas_df(self): hook = MockHiveServer2Hook(connection_cursor=EmptyMockConnectionCursor()) query = f"SELECT * FROM {self.table}" - df = hook.get_pandas_df(query, schema=self.database) + df = hook.get_df(query, schema=self.database, df_type=df_type) assert len(df) == 0