diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 53679ffbf241a..76ece6d90758a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -491,6 +491,11 @@ repos: ^providers/edge3/.*\.py$| ^providers/mysql/.*\.py$| ^providers/openlineage/.*\.py$| + ^providers/google/src/airflow/providers/google/cloud/triggers/dataproc\.py$| + ^providers/google/src/airflow/providers/google/cloud/triggers/bigquery\.py$| + ^providers/google/tests/unit/google/cloud/utils/gcp_authenticator\.py$| + ^providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager\.py$| + ^providers/google/tests/system/google/gcp_api_client_helpers\.py$| ^providers/standard/tests/unit/standard/operators/test_latest_only_operator\.py$| ^providers/standard/tests/unit/standard/operators/test_trigger_dagrun\.py$| ^providers/standard/tests/unit/standard/operators/test_weekday\.py$| diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py index b193edaf772e5..70419e5b08d3f 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py @@ -34,6 +34,8 @@ from sqlalchemy.orm.session import Session if not AIRFLOW_V_3_0_PLUS: + from sqlalchemy import select + from airflow.models.taskinstance import TaskInstance from airflow.utils.session import provide_session @@ -105,13 +107,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]: @provide_session def get_task_instance(self, session: Session) -> TaskInstance: - query = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + task_instance = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) ) - task_instance = query.one_or_none() if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index f16faae9a18b9..ffe9b9aeaa6d4 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -41,6 +41,8 @@ from sqlalchemy.orm.session import Session if not AIRFLOW_V_3_0_PLUS: + from sqlalchemy import select + from airflow.models.taskinstance import TaskInstance from airflow.utils.session import provide_session @@ -130,13 +132,14 @@ def get_task_instance(self, session: Session) -> TaskInstance: :param session: Sqlalchemy session """ - query = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + task_instance = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) ) - task_instance = query.one_or_none() if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", @@ -266,13 +269,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]: @provide_session def get_task_instance(self, session: Session) -> TaskInstance: - query = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + task_instance = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) ) - task_instance = query.one_or_none() if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.", diff --git a/providers/google/tests/system/google/gcp_api_client_helpers.py b/providers/google/tests/system/google/gcp_api_client_helpers.py index cbd84c388c07e..d860cacc79c2f 100644 --- a/providers/google/tests/system/google/gcp_api_client_helpers.py +++ b/providers/google/tests/system/google/gcp_api_client_helpers.py @@ -93,14 +93,15 @@ def create_airflow_connection( connection_id=connection_id, connection=connection_conf, is_composer=is_composer ) else: + from sqlalchemy import delete + from airflow.models import Connection from airflow.settings import Session if Session is None: raise RuntimeError("Session not configured. Call configure_orm() first.") session = Session() - query = session.query(Connection).filter(Connection.conn_id == connection_id) - query.delete() + session.execute(delete(Connection).where(Connection.conn_id == connection_id)) connection = Connection(conn_id=connection_id, **connection_conf) session.add(connection) session.commit() @@ -114,12 +115,13 @@ def delete_airflow_connection(connection_id: str, is_composer: bool = False) -> if AIRFLOW_V_3_0_PLUS: delete_connection_request(connection_id=connection_id, is_composer=is_composer) else: + from sqlalchemy import delete + from airflow.models import Connection from airflow.settings import Session if Session is None: raise RuntimeError("Session not configured. Call configure_orm() first.") session = Session() - query = session.query(Connection).filter(Connection.conn_id == connection_id) - query.delete() + session.execute(delete(Connection).where(Connection.conn_id == connection_id)) session.commit() diff --git a/providers/google/tests/unit/google/cloud/utils/gcp_authenticator.py b/providers/google/tests/unit/google/cloud/utils/gcp_authenticator.py index 976cd41ef5dda..f442de83281f9 100644 --- a/providers/google/tests/unit/google/cloud/utils/gcp_authenticator.py +++ b/providers/google/tests/unit/google/cloud/utils/gcp_authenticator.py @@ -21,6 +21,8 @@ import os import subprocess +from sqlalchemy import select + from airflow import settings from airflow.models import Connection from airflow.providers.common.compat.sdk import AirflowException @@ -97,7 +99,7 @@ def set_key_path_in_airflow_connection(self): :return: None """ with settings.Session() as session: - conn = session.query(Connection).filter(Connection.conn_id == "google_cloud_default")[0] + conn = session.scalar(select(Connection).where(Connection.conn_id == "google_cloud_default")) extras = conn.extra_dejson extras[KEYPATH_EXTRA] = self.full_key_path if extras.get(KEYFILE_DICT_EXTRA): @@ -113,7 +115,7 @@ def set_dictionary_in_airflow_connection(self): :return: None """ with settings.Session() as session: - conn = session.query(Connection).filter(Connection.conn_id == "google_cloud_default")[0] + conn = session.scalar(select(Connection).where(Connection.conn_id == "google_cloud_default")) extras = conn.extra_dejson with open(self.full_key_path) as path_file: content = json.load(path_file) diff --git a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py index 20b71765441da..363622feeb714 100644 --- a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py +++ b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py @@ -22,6 +22,7 @@ from unittest import mock import pytest +from sqlalchemy import delete from airflow.models import TaskInstance as TI from airflow.providers.google.marketing_platform.operators.campaign_manager import ( @@ -91,11 +92,11 @@ def test_execute(self, mock_base_op, hook_mock): class TestGoogleCampaignManagerDownloadReportOperator: def setup_method(self): with create_session() as session: - session.query(TI).delete() + session.execute(delete(TI)) def teardown_method(self): with create_session() as session: - session.query(TI).delete() + session.execute(delete(TI)) @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.http") @mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.tempfile")