From d956581ff2f2170952d7594ea693cbd55a3256d5 Mon Sep 17 00:00:00 2001 From: Ben Feifke Date: Thu, 28 Sep 2023 22:28:59 +0200 Subject: [PATCH 1/3] Added get_polars_df to DbApiHook. --- airflow/providers/common/sql/hooks/sql.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index c3907041928fd..bdfc783aebc54 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -218,6 +218,26 @@ def get_pandas_df(self, sql, parameters: Iterable | Mapping[str, Any] | None = N with closing(self.get_conn()) as conn: return psql.read_sql(sql, con=conn, params=parameters, **kwargs) + def get_polars_df(self, sql, **kwargs): + """ + Executes the sql and returns a polars dataframe. + + :param sql: the sql statement to be executed (str) or a list of + sql statements to execute + :param parameters: The parameters to render the SQL query with. + :param kwargs: (optional) passed into polars.read_database method + """ + try: + import polars as pl + except ImportError: + raise Exception( + "polars library not installed, run: pip install " + "'apache-airflow-providers-common-sql[polars]'." + ) + + with closing(self.get_conn()) as conn: + return pl.read_database(sql, connection=conn, **kwargs) + def get_pandas_df_by_chunks( self, sql, parameters: Iterable | Mapping[str, Any] | None = None, *, chunksize: int | None, **kwargs ): From 810c0667982c82b1a8d5627c58789ca491e3d7af Mon Sep 17 00:00:00 2001 From: Ben Feifke Date: Fri, 29 Sep 2023 07:21:40 +0200 Subject: [PATCH 2/3] Added get_polars_df to BigQueryHook. --- .../providers/google/cloud/hooks/bigquery.py | 32 +++++++++++++++++++ .../google/common/hooks/base_google.py | 5 +++ .../src/airflow_breeze/global_constants.py | 1 + docs/conf.py | 2 ++ setup.py | 3 ++ 5 files changed, 43 insertions(+) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 9367582493512..88d6379cc5e9b 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -42,6 +42,7 @@ QueryJob, SchemaField, UnknownJob, + QueryJobConfig, ) from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference @@ -49,6 +50,7 @@ from googleapiclient.discovery import Resource, build from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa +import polars as pl from requests import Session from sqlalchemy import create_engine @@ -270,6 +272,36 @@ def get_pandas_df( return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + def get_polars_df( + self, + sql: str, + parameters: Iterable | Mapping[str, Any] | None = None, + dialect: str | None = None, + **kwargs, + ) -> pd.DataFrame: + """Get a Polars DataFrame for the BigQuery results. + + :param sql: The BigQuery SQL to execute. + :param parameters: The parameters to render the SQL query with (not + used, leave to override superclass method) + :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL + defaults to use `self.use_legacy_sql` if not specified + :param kwargs: (optional) passed into polars.from_arrow method + """ + if dialect is None: + dialect = "legacy" if self.use_legacy_sql else "standard" + + project_id = self.get_project_id() + client = self.get_client(project_id=project_id) + + job_config = QueryJobConfig(dialect=dialect) # Specify the SQL dialect here + query_job = client.query(sql, job_config=job_config) # API request + rows = query_job.result() # Waits for query to finish + + df = pl.from_arrow(rows.to_arrow(), **kwargs) + + return df + @GoogleBaseHook.fallback_to_default_project_id def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool: """Check if a table exists in Google BigQuery. diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 77fd9394cfc62..d9df5ea2890e9 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -296,6 +296,11 @@ def get_credentials(self) -> google.auth.credentials.Credentials: credentials, _ = self.get_credentials_and_project_id() return credentials + def get_project_id(self) -> str: + """Returns the project_id str for Google API.""" + _, project_id = self.get_credentials_and_project_id() + return project_id + def _get_access_token(self) -> str: """Returns a valid access token from Google API Credentials.""" credentials = self.get_credentials() diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 7e4fff566078b..c20894f99f89f 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -380,6 +380,7 @@ def get_airflow_extras(): "odbc", "openlineage", "pandas", + "polars" "postgres", "redis", "sendgrid", diff --git a/docs/conf.py b/docs/conf.py index e9af597525f65..9e6da9d73bf00 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -622,6 +622,7 @@ def _get_params(root_schema: dict, prefix: str = "", default_section: str = "") "pandas_gbq", "paramiko", "pinotdb", + "polars" "psycopg2", "pydruid", "pyhive", @@ -668,6 +669,7 @@ def _get_params(root_schema: dict, prefix: str = "", default_section: str = "") "jinja2", "mongodb", "pandas", + "polars" "python", "requests", "sqlalchemy", diff --git a/setup.py b/setup.py index d5b8c333d0c93..f5c95cbfd526d 100644 --- a/setup.py +++ b/setup.py @@ -358,6 +358,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve "bcrypt>=2.0.0", "flask-bcrypt>=0.7.1", ] +polars = ["polars>=0.19.5"] rabbitmq = [ "amqp", ] @@ -525,6 +526,7 @@ def get_unique_dependency_list(req_list_iterable: Iterable[list[str]]): get_provider_dependencies("mysql"), pandas, password, + polars, ] ) @@ -569,6 +571,7 @@ def get_unique_dependency_list(req_list_iterable: Iterable[list[str]]): "otel": otel, "pandas": pandas, "password": password, + "polars": polars, "rabbitmq": rabbitmq, "sentry": sentry, "statsd": statsd, From bb1d61a7008559e8dabcc07a47129347804d6366 Mon Sep 17 00:00:00 2001 From: Ben Feifke Date: Sat, 30 Sep 2023 10:54:20 +0200 Subject: [PATCH 3/3] Removed polars from core dependencies. --- airflow/providers/google/provider.yaml | 3 +++ dev/breeze/src/airflow_breeze/global_constants.py | 1 - docs/conf.py | 2 -- setup.py | 3 --- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 6323c4832f448..cd157daef6645 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1207,6 +1207,9 @@ additional-extras: - name: amazon dependencies: - apache-airflow-providers-amazon>=2.6.0 + - name: polars + dependencies: + - polars>=0.19.5 secrets-backends: - airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index c20894f99f89f..7e4fff566078b 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -380,7 +380,6 @@ def get_airflow_extras(): "odbc", "openlineage", "pandas", - "polars" "postgres", "redis", "sendgrid", diff --git a/docs/conf.py b/docs/conf.py index 9e6da9d73bf00..e9af597525f65 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -622,7 +622,6 @@ def _get_params(root_schema: dict, prefix: str = "", default_section: str = "") "pandas_gbq", "paramiko", "pinotdb", - "polars" "psycopg2", "pydruid", "pyhive", @@ -669,7 +668,6 @@ def _get_params(root_schema: dict, prefix: str = "", default_section: str = "") "jinja2", "mongodb", "pandas", - "polars" "python", "requests", "sqlalchemy", diff --git a/setup.py b/setup.py index f5c95cbfd526d..d5b8c333d0c93 100644 --- a/setup.py +++ b/setup.py @@ -358,7 +358,6 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve "bcrypt>=2.0.0", "flask-bcrypt>=0.7.1", ] -polars = ["polars>=0.19.5"] rabbitmq = [ "amqp", ] @@ -526,7 +525,6 @@ def get_unique_dependency_list(req_list_iterable: Iterable[list[str]]): get_provider_dependencies("mysql"), pandas, password, - polars, ] ) @@ -571,7 +569,6 @@ def get_unique_dependency_list(req_list_iterable: Iterable[list[str]]): "otel": otel, "pandas": pandas, "password": password, - "polars": polars, "rabbitmq": rabbitmq, "sentry": sentry, "statsd": statsd,