diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py index 3d75be0965291..5c45e8b19fa19 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py @@ -54,7 +54,7 @@ from airflow.utils.module_loading import import_string if TYPE_CHECKING: - from pandas import DataFrame + from pandas import DataFrame as PandasDataFrame from polars import DataFrame as PolarsDataFrame from sqlalchemy.engine import URL, Engine, Inspector @@ -391,7 +391,7 @@ def get_pandas_df( sql, parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs, - ) -> DataFrame: + ) -> PandasDataFrame: """ Execute the sql and returns a pandas dataframe. @@ -413,17 +413,37 @@ def get_pandas_df_by_chunks( *, chunksize: int, **kwargs, - ) -> Generator[DataFrame, None, None]: + ) -> Generator[PandasDataFrame, None, None]: return self._get_pandas_df_by_chunks(sql, parameters, chunksize=chunksize, **kwargs) + @overload def get_df( self, - sql, + sql: str | list[str], + parameters: list | tuple | Mapping[str, Any] | None = None, + *, + df_type: Literal["pandas"] = "pandas", + **kwargs: Any, + ) -> PandasDataFrame: ... + + @overload + def get_df( + self, + sql: str | list[str], + parameters: list | tuple | Mapping[str, Any] | None = None, + *, + df_type: Literal["polars"], + **kwargs: Any, + ) -> PolarsDataFrame: ... + + def get_df( + self, + sql: str | list[str], parameters: list | tuple | Mapping[str, Any] | None = None, *, df_type: Literal["pandas", "polars"] = "pandas", **kwargs, - ) -> DataFrame | PolarsDataFrame: + ) -> PandasDataFrame | PolarsDataFrame: """ Execute the sql and returns a dataframe. @@ -442,7 +462,7 @@ def _get_pandas_df( sql, parameters: list | tuple | Mapping[str, Any] | None = None, **kwargs, - ) -> DataFrame: + ) -> PandasDataFrame: """ Execute the sql and returns a pandas dataframe. @@ -492,15 +512,37 @@ def _get_polars_df( return pl.read_database(sql, connection=conn, execute_options=execute_options, **kwargs) + @overload def get_df_by_chunks( self, - sql, + sql: str | list[str], + parameters: list | tuple | Mapping[str, Any] | None = None, + *, + chunksize: int, + df_type: Literal["pandas"] = "pandas", + **kwargs, + ) -> Generator[PandasDataFrame, None, None]: ... + + @overload + def get_df_by_chunks( + self, + sql: str | list[str], + parameters: list | tuple | Mapping[str, Any] | None = None, + *, + chunksize: int, + df_type: Literal["polars"], + **kwargs, + ) -> Generator[PolarsDataFrame, None, None]: ... + + def get_df_by_chunks( + self, + sql: str | list[str], parameters: list | tuple | Mapping[str, Any] | None = None, *, chunksize: int, df_type: Literal["pandas", "polars"] = "pandas", **kwargs, - ) -> Generator[DataFrame | PolarsDataFrame, None, None]: + ) -> Generator[PandasDataFrame | PolarsDataFrame, None, None]: """ Execute the sql and return a generator. @@ -522,7 +564,7 @@ def _get_pandas_df_by_chunks( *, chunksize: int, **kwargs, - ) -> Generator[DataFrame, None, None]: + ) -> Generator[PandasDataFrame, None, None]: """ Execute the sql and return a generator. diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.pyi b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.pyi index 939c0bcc5d2f1..903444a29528f 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.pyi +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.pyi @@ -34,13 +34,11 @@ isort:skip_file from collections.abc import Generator, Iterable, Mapping, MutableMapping, Sequence from functools import cached_property as cached_property -from typing import Any, Callable, Protocol, TypeVar, overload +from typing import Any, Protocol, TypeVar from _typeshed import Incomplete as Incomplete from pandas import DataFrame as PandasDataFrame -from polars import DataFrame as PolarsDataFrame from sqlalchemy.engine import URL as URL, Engine as Engine, Inspector as Inspector -from typing_extensions import Literal from airflow.hooks.base import BaseHook as BaseHook from airflow.models import Connection as Connection @@ -111,61 +109,6 @@ class DbApiHook(BaseHook): def get_pandas_df_by_chunks( self, sql, parameters: list | tuple | Mapping[str, Any] | None = None, *, chunksize: int, **kwargs ) -> Generator[PandasDataFrame, None, None]: ... - @overload - def get_df( - self, - sql: str | list[str], - parameters: list | tuple | Mapping[str, Any] | None = None, - *, - df_type: Literal["pandas"] = "pandas", - **kwargs: Any, - ) -> PandasDataFrame: ... - @overload - def get_df( - self, - sql: str | list[str], - parameters: list | tuple | Mapping[str, Any] | None = None, - *, - df_type: Literal["polars"] = "polars", - **kwargs: Any, - ) -> PolarsDataFrame: ... - @overload - def get_df( # fallback overload - self, - sql: str | list[str], - parameters: list | tuple | Mapping[str, Any] | None = None, - *, - df_type: Literal["pandas", "polars"] = "pandas", - ) -> PandasDataFrame | PolarsDataFrame: ... - @overload - def get_df_by_chunks( - self, - sql, - parameters: list | tuple | Mapping[str, Any] | None = None, - *, - chunksize: int, - df_type: Literal["pandas"] = "pandas", - **kwargs, - ) -> Generator[PandasDataFrame, None, None]: ... - @overload - def get_df_by_chunks( - self, - sql, - parameters: list | tuple | Mapping[str, Any] | None = None, - *, - chunksize: int, - df_type: Literal["polars"], - **kwargs, - ) -> Generator[PolarsDataFrame, None, None]: ... - @overload - def get_df_by_chunks( # fallback overload - self, - sql, - parameters: list | tuple | Mapping[str, Any] | None = None, - *, - chunksize: int, - df_type: Literal["pandas", "polars"] = "pandas", - ) -> Generator[PandasDataFrame | PolarsDataFrame, None, None]: ... def get_records( self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None ) -> Any: ... @@ -178,26 +121,6 @@ class DbApiHook(BaseHook): def split_sql_string(sql: str, strip_semicolon: bool = False) -> list[str]: ... @property def last_description(self) -> Sequence[Sequence] | None: ... - @overload - def run( - self, - sql: str | Iterable[str], - autocommit: bool = ..., - parameters: Iterable | Mapping[str, Any] | None = ..., - handler: None = ..., - split_statements: bool = ..., - return_last: bool = ..., - ) -> None: ... - @overload - def run( - self, - sql: str | Iterable[str], - autocommit: bool = ..., - parameters: Iterable | Mapping[str, Any] | None = ..., - handler: Callable[[Any], T] = ..., - split_statements: bool = ..., - return_last: bool = ..., - ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ... def set_autocommit(self, conn, autocommit) -> None: ... def get_autocommit(self, conn) -> bool: ... def get_cursor(self) -> Any: ...