diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81e24f2e8a3b3..e87f89c8f43c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -413,6 +413,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/integration/otel/test_otel.py$| ^task_sdk.*\.py$ pass_filenames: true - id: update-supported-versions diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 81ab540c501dc..8d6f665811570 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -24,7 +24,7 @@ import time import pytest -from sqlalchemy import select +from sqlalchemy import func, select from airflow._shared.timezones import timezone from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -87,13 +87,11 @@ def wait_for_dag_run_and_check_span_status( while timezone.utcnow().timestamp() - start_time < max_wait_time: with create_session() as session: - dag_run = ( - session.query(DagRun) - .filter( + dag_run = session.scalar( + select(DagRun).where( DagRun.dag_id == dag_id, DagRun.run_id == run_id, ) - .first() ) if dag_run is None: @@ -121,13 +119,11 @@ def wait_for_dag_run_and_check_span_status( def check_dag_run_state_and_span_status(dag_id: str, run_id: str, state: str, span_status: str): with create_session() as session: - dag_run = ( - session.query(DagRun) - .filter( + dag_run = session.scalar( + select(DagRun).where( DagRun.dag_id == dag_id, DagRun.run_id == run_id, ) - .first() ) assert dag_run is not None @@ -139,13 +135,11 @@ def check_dag_run_state_and_span_status(dag_id: str, run_id: str, state: str, sp def check_ti_state_and_span_status(task_id: str, run_id: str, state: str, span_status: str | None): with create_session() as session: - ti = ( - session.query(TaskInstance) - .filter( + ti = session.scalar( + select(TaskInstance).where( TaskInstance.task_id == task_id, TaskInstance.run_id == run_id, ) - .first() ) assert ti is not None @@ -668,7 +662,12 @@ def serialize_and_get_dags(cls) -> dict[str, SerializedDAG]: if AIRFLOW_V_3_0_PLUS: from airflow.models.dagbundle import DagBundleModel - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + count = session.scalar( + select(func.count()) + .select_from(DagBundleModel) + .where(DagBundleModel.name == "testing") + ) + if count == 0: session.add(DagBundleModel(name="testing")) session.commit() SerializedDAG.bulk_write_to_db(