diff --git a/providers/google/README.rst b/providers/google/README.rst index 3da44186487c8..abef199cc9c86 100644 --- a/providers/google/README.rst +++ b/providers/google/README.rst @@ -62,7 +62,7 @@ PIP package Version required =========================================== ====================================== ``apache-airflow`` ``>=2.10.0`` ``apache-airflow-providers-common-compat`` ``>=1.4.0`` -``apache-airflow-providers-common-sql`` ``>=1.20.0`` +``apache-airflow-providers-common-sql`` ``>=1.27.0`` ``asgiref`` ``>=3.5.2`` ``dill`` ``>=0.2.3`` ``gcloud-aio-auth`` ``>=5.2.0`` diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml index 97d19cb1c574a..f28b6712f8c9d 100644 --- a/providers/google/pyproject.toml +++ b/providers/google/pyproject.toml @@ -59,7 +59,7 @@ requires-python = "~=3.9" dependencies = [ "apache-airflow>=2.10.0", "apache-airflow-providers-common-compat>=1.4.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.27.0", "asgiref>=3.5.2", "dill>=0.2.3", "gcloud-aio-auth>=5.2.0", @@ -129,11 +129,6 @@ dependencies = [ # See https://github.com/looker-open-source/sdk-codegen/issues/1518 "looker-sdk>=22.4.0,!=24.18.0", "pandas-gbq>=0.7.0", - # In pandas 2.2 minimal version of the sqlalchemy is 2.0 - # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies - # However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723 - # In addition FAB also limit sqlalchemy to < 2.0 - "pandas>=2.1.2,<2.2", # A transient dependency of google-cloud-bigquery-datatransfer, but we # further constrain it since older versions are buggy. "proto-plus>=1.19.6", @@ -233,6 +228,7 @@ dev = [ "apache-airflow-providers-trino", # Additional devel dependencies (do not remove this line and add extra development dependencies) "apache-airflow-providers-apache-kafka", + "apache-airflow-providers-common-sql[pandas,polars]", ] # To build docs: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index a7ae255e5a9a8..e31571418a427 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -29,7 +29,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import deepcopy from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, NoReturn, Union, cast +from typing import TYPE_CHECKING, Any, NoReturn, Union, cast, overload from aiohttp import ClientSession as ClientSession from gcloud.aio.bigquery import Job, Table as Table_async @@ -57,8 +57,13 @@ from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery`` from sqlalchemy import create_engine +from typing_extensions import Literal -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import ( + AirflowException, + AirflowOptionalProviderFeatureException, + AirflowProviderDeprecationWarning, +) from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast @@ -77,6 +82,7 @@ if TYPE_CHECKING: import pandas as pd + import polars as pl from google.api_core.page_iterator import HTTPIterator from google.api_core.retry import Retry from requests import Session @@ -275,15 +281,57 @@ def insert_rows( """ raise NotImplementedError() - def get_pandas_df( + def _get_pandas_df( self, sql: str, parameters: Iterable | Mapping[str, Any] | None = None, dialect: str | None = None, **kwargs, ) -> pd.DataFrame: + if dialect is None: + dialect = "legacy" if self.use_legacy_sql else "standard" + + credentials, project_id = self.get_credentials_and_project_id() + + return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + + def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame: + try: + import polars as pl + except ImportError: + raise AirflowOptionalProviderFeatureException( + "Polars is not installed. Please install it with `pip install polars`." + ) + + if dialect is None: + dialect = "legacy" if self.use_legacy_sql else "standard" + + credentials, project_id = self.get_credentials_and_project_id() + + pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + return pl.from_pandas(pandas_df) + + @overload + def get_df( + self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs + ) -> pd.DataFrame: ... + + @overload + def get_df( + self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs + ) -> pl.DataFrame: ... + + def get_df( + self, + sql, + parameters=None, + dialect=None, + *, + df_type: Literal["pandas", "polars"] = "pandas", + **kwargs, + ) -> pd.DataFrame | pl.DataFrame: """ - Get a Pandas DataFrame for the BigQuery results. + Get a DataFrame for the BigQuery results. The DbApiHook method must be overridden because Pandas doesn't support PEP 249 connections, except for SQLite. @@ -299,12 +347,19 @@ def get_pandas_df( defaults to use `self.use_legacy_sql` if not specified :param kwargs: (optional) passed into pandas_gbq.read_gbq method """ - if dialect is None: - dialect = "legacy" if self.use_legacy_sql else "standard" + if df_type == "polars": + return self._get_polars_df(sql, parameters, dialect, **kwargs) - credentials, project_id = self.get_credentials_and_project_id() + if df_type == "pandas": + return self._get_pandas_df(sql, parameters, dialect, **kwargs) - return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + @deprecated( + planned_removal_date="November 30, 2025", + use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_df", + category=AirflowProviderDeprecationWarning, + ) + def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs): + return self._get_pandas_df(sql, parameters, dialect, **kwargs) @GoogleBaseHook.fallback_to_default_project_id def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool: diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index d7febaada8618..c6a10815fbdc6 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -153,9 +153,22 @@ def test_bigquery_table_partition_exists_false_no_partition(self, mock_client): assert result is False @mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq") - def test_get_pandas_df(self, mock_read_gbq): - self.hook.get_pandas_df("select 1") - + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_get_df(self, mock_read_gbq, df_type): + import pandas as pd + import polars as pl + + mock_read_gbq.return_value = pd.DataFrame({"a": [1, 2, 3]}) + result = self.hook.get_df("select 1", df_type=df_type) + + expected_type = pd.DataFrame if df_type == "pandas" else pl.DataFrame + assert isinstance(result, expected_type) + assert result.shape == (3, 1) + assert result.columns == ["a"] + if df_type == "pandas": + assert result["a"].tolist() == [1, 2, 3] + else: + assert result.to_series().to_list() == [1, 2, 3] mock_read_gbq.assert_called_once_with( "select 1", credentials=CREDENTIALS, dialect="legacy", project_id=PROJECT_ID ) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery_system.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery_system.py index 36dd11364ae67..e8cfb9b4e4699 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery_system.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery_system.py @@ -33,26 +33,52 @@ class TestBigQueryDataframeResultsSystem(GoogleSystemTest): def setup_method(self): self.instance = hook.BigQueryHook() - def test_output_is_dataframe_with_valid_query(self): - import pandas as pd + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_output_is_dataframe_with_valid_query(self, df_type): + df = self.instance.get_df("select 1", df_type=df_type) + if df_type == "polars": + import polars as pl - df = self.instance.get_pandas_df("select 1") - assert isinstance(df, pd.DataFrame) + assert isinstance(df, pl.DataFrame) + else: + import pandas as pd - def test_throws_exception_with_invalid_query(self): + assert isinstance(df, pd.DataFrame) + + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_throws_exception_with_invalid_query(self, df_type): with pytest.raises(Exception) as ctx: - self.instance.get_pandas_df("from `1`") + self.instance.get_df("from `1`", df_type=df_type) assert "Reason: " in str(ctx.value), "" - def test_succeeds_with_explicit_legacy_query(self): - df = self.instance.get_pandas_df("select 1", dialect="legacy") - assert df.iloc(0)[0][0] == 1 + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_succeeds_with_explicit_legacy_query(self, df_type): + df = self.instance.get_df("select 1", df_type=df_type) + if df_type == "polars": + assert df.item(0, 0) == 1 + else: + assert df.iloc[0][0] == 1 - def test_succeeds_with_explicit_std_query(self): - df = self.instance.get_pandas_df("select * except(b) from (select 1 a, 2 b)", dialect="standard") - assert df.iloc(0)[0][0] == 1 + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_succeeds_with_explicit_std_query(self, df_type): + df = self.instance.get_df( + "select * except(b) from (select 1 a, 2 b)", + parameters=None, + dialect="standard", + df_type=df_type, + ) + if df_type == "polars": + assert df.item(0, 0) == 1 + else: + assert df.iloc[0][0] == 1 - def test_throws_exception_with_incompatible_syntax(self): + @pytest.mark.parametrize("df_type", ["pandas", "polars"]) + def test_throws_exception_with_incompatible_syntax(self, df_type): with pytest.raises(Exception) as ctx: - self.instance.get_pandas_df("select * except(b) from (select 1 a, 2 b)", dialect="legacy") + self.instance.get_df( + "select * except(b) from (select 1 a, 2 b)", + parameters=None, + dialect="legacy", + df_type=df_type, + ) assert "Reason: " in str(ctx.value), ""