diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 939039b0827fb..2189e6e26d365 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -427,6 +427,10 @@ repos: ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| ^airflow-core/tests/unit/utils/test_db_cleanup.py$| ^dev/airflow_perf/scheduler_dag_execution_timing.py$| + ^providers/celery/.*\.py$| + ^providers/cncf/kubernetes/.*\.py$| + ^providers/databricks/.*\.py$| + ^providers/mysql/.*\.py$| ^providers/openlineage/.*\.py$| ^task_sdk.*\.py$ pass_filenames: true diff --git a/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py b/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py index 8584217c664f3..13d7765b38b35 100644 --- a/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py +++ b/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py @@ -22,6 +22,7 @@ from unittest import mock import pytest +from sqlalchemy import delete from airflow.executors import executor_loader from airflow.models.dagrun import DagRun @@ -50,8 +51,8 @@ class TestFileTaskLogHandler: def clean_up(self): with create_session() as session: - session.query(DagRun).delete() - session.query(TaskInstance).delete() + session.execute(delete(DagRun)) + session.execute(delete(TaskInstance)) def setup_method(self): logging.root.disabled = False diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index 2013e1c05813f..3646ff8b75c31 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -27,6 +27,7 @@ from kubernetes.client.rest import ApiException as SyncApiException from kubernetes_asyncio.client.exceptions import ApiException as AsyncApiException from slugify import slugify +from sqlalchemy import select from urllib3.exceptions import HTTPError from airflow.configuration import conf @@ -175,15 +176,14 @@ def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey: raise RuntimeError("Session not configured. Call configure_orm() first.") session = Session() - task_instance_run_id = ( - session.query(TaskInstance.run_id) + task_instance_run_id = session.scalar( + select(TaskInstance.run_id) .join(TaskInstance.dag_run) - .filter( + .where( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id, getattr(DagRun, logical_date_key) == logical_date, ) - .scalar() ) else: task_instance_run_id = annotation_run_id diff --git a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py index 57bcacffbda23..905e79ec59bd1 100644 --- a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py @@ -20,6 +20,8 @@ from typing import TYPE_CHECKING, Any from urllib.parse import unquote +from sqlalchemy import select + from airflow.exceptions import TaskInstanceNotFound from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance, TaskInstanceKey, clear_task_instances @@ -143,7 +145,9 @@ def _get_dagrun(dag, run_id: str, session: Session) -> DagRun: if not session: raise AirflowException("Session not provided.") - return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id).one() + return session.scalars( + select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.run_id == run_id) + ).one() @provide_session def _clear_task_instances( @@ -162,15 +166,13 @@ def get_task_instance(operator: BaseOperator, dttm, session: Session = NEW_SESSI dag_run = DagRun.find(dag_id, execution_date=dttm)[0] # type: ignore[call-arg] else: dag_run = DagRun.find(dag_id, logical_date=dttm)[0] - ti = ( - session.query(TaskInstance) - .filter( + ti = session.scalars( + select(TaskInstance).where( TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run.run_id, TaskInstance.task_id == operator.task_id, ) - .one_or_none() - ) + ).one_or_none() if not ti: raise TaskInstanceNotFound("Task instance not found") return ti diff --git a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py index 3469177ebfdb0..2058ae41a43c6 100644 --- a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py @@ -82,7 +82,7 @@ def test_get_dagrun_airflow2(): session = MagicMock() dag = MagicMock(dag_id=DAG_ID) - session.query.return_value.filter.return_value.one.return_value = DagRun() + session.scalars.return_value.one.return_value = DagRun() result = _get_dagrun(dag, RUN_ID, session=session) @@ -168,7 +168,7 @@ def test_get_task_instance_airflow2(): dttm = "2022-01-01T00:00:00Z" session = Mock() dag_run = Mock() - session.query().filter().one_or_none.return_value = dag_run + session.scalars().one_or_none.return_value = dag_run with patch( "airflow.providers.databricks.plugins.databricks_workflow.DagRun.find", return_value=[dag_run] diff --git a/providers/mysql/tests/unit/mysql/transfers/test_s3_to_mysql.py b/providers/mysql/tests/unit/mysql/transfers/test_s3_to_mysql.py index dca4d4c55088b..1912e2c6a75f5 100644 --- a/providers/mysql/tests/unit/mysql/transfers/test_s3_to_mysql.py +++ b/providers/mysql/tests/unit/mysql/transfers/test_s3_to_mysql.py @@ -19,7 +19,7 @@ from unittest.mock import patch import pytest -from sqlalchemy import or_ +from sqlalchemy import delete, or_ from airflow import models from airflow.providers.mysql.transfers.s3_to_mysql import S3ToMySqlOperator @@ -101,10 +101,8 @@ def test_execute_exception(self, mock_remove, mock_bulk_load_custom, mock_downlo def teardown_method(self): with create_session() as session: - ( - session.query(models.Connection) - .filter( + session.execute( + delete(models.Connection).where( or_(models.Connection.conn_id == "s3_test", models.Connection.conn_id == "mysql_test") ) - .delete() )