diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py index c95b2d09e386e..b3afc1a60c495 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any from deprecated import deprecated -from typing_extensions import Literal +from typing_extensions import Literal, overload if TYPE_CHECKING: import pandas as pd @@ -1071,6 +1071,28 @@ def _get_polars_df( # type: ignore df = pl.DataFrame(res["data"], schema=[c[0] for c in res["header"]], orient="row", **kwargs) return df + @overload # type: ignore[override] + def get_df( + self, + sql: str, + schema: str = "default", + hive_conf: dict[Any, Any] | None = None, + *, + df_type: Literal["pandas"] = "pandas", + **kwargs: Any, + ) -> pd.DataFrame: ... + + @overload # type: ignore[override] + def get_df( + self, + sql: str, + schema: str = "default", + hive_conf: dict[Any, Any] | None = None, + *, + df_type: Literal["polars"], + **kwargs: Any, + ) -> pl.DataFrame: ... + def get_df( # type: ignore self, sql: str, diff --git a/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py b/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py index 39efcb71ebe8c..b7d87f49046c3 100644 --- a/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py +++ b/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py @@ -23,6 +23,7 @@ from unittest import mock import pandas as pd +import polars as pl import pytest from hmsclient import HMSClient @@ -737,8 +738,10 @@ def test_get_df(self, df_type): assert len(df) == 2 if df_type == "pandas": assert df["hive_server_hook.a"].values.tolist() == [1, 2] + assert isinstance(df, pd.DataFrame) elif df_type == "polars": assert df["hive_server_hook.a"].to_list() == [1, 2] + assert isinstance(df, pl.DataFrame) date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" hook.get_conn.assert_called_with(self.database) hook.mock_cursor.execute.assert_any_call("set airflow.ctx.dag_id=test_dag_id")