diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f417efb23332d..572189651b3e7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -425,6 +425,13 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/models/test_serialized_dag.py$| + ^airflow-core/tests/unit/models/test_pool.py$| + ^airflow-core/tests/unit/models/test_trigger.py$| + ^airflow-core/tests/unit/models/test_callback.py$| + ^airflow-core/tests/unit/models/test_cleartasks.py$| + ^airflow-core/tests/unit/models/test_xcom.py$| + ^airflow-core/tests/unit/models/test_dagrun.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py$| @@ -444,6 +451,8 @@ repos: ^airflow-core/tests/unit/models/test_dagwarning.py$| ^airflow-core/tests/integration/otel/test_otel.py$| ^airflow-core/tests/unit/utils/test_db_cleanup.py$| + ^airflow-core/tests/unit/utils/test_state.py$| + ^airflow-core/tests/unit/utils/test_log_handlers.py$| ^dev/airflow_perf/scheduler_dag_execution_timing.py$| ^providers/celery/.*\.py$| ^providers/cncf/kubernetes/.*\.py$| diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 09d0931557de1..dfc19fc61a354 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -17,6 +17,7 @@ from __future__ import annotations import pytest +from sqlalchemy import select from airflow.models import Trigger from airflow.models.callback import ( @@ -118,7 +119,7 @@ def test_polymorphic_serde(self, session): session.add(callback) session.commit() - retrieved = session.query(Callback).filter_by(id=callback.id).one() + retrieved = session.scalar(select(Callback).where(Callback.id == callback.id)) assert isinstance(retrieved, TriggererCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_ASYNC_CALLBACK.serialize() @@ -188,7 +189,7 @@ def test_polymorphic_serde(self, session): session.add(callback) session.commit() - retrieved = session.query(Callback).filter_by(id=callback.id).one() + retrieved = session.scalar(select(Callback).where(Callback.id == callback.id)) assert isinstance(retrieved, ExecutorCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_SYNC_CALLBACK.serialize() diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index 4d8063f841982..e92f57e741d7b 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -21,7 +21,7 @@ import random import pytest -from sqlalchemy import select +from sqlalchemy import func, select from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun @@ -87,7 +87,7 @@ def test_clear_task_instances(self, dag_maker): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session) ti0.refresh_from_db(session) @@ -121,7 +121,7 @@ def test_clear_task_instances_external_executor_id(self, dag_maker): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session) ti0.refresh_from_db() @@ -186,12 +186,12 @@ def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() - assert session.query(TaskInstanceHistory).count() == 0 + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() + assert session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 0 clear_task_instances(qry, session, dag_run_state=state) session.flush() # 2 TIs were cleared so 2 history records should be created - assert session.query(TaskInstanceHistory).count() == 2 + assert session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 2 session.refresh(dr) @@ -229,7 +229,7 @@ def test_clear_task_instances_on_running_dr(self, state, dag_maker): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session) session.flush() @@ -282,7 +282,7 @@ def test_clear_task_instances_on_finished_dr(self, state, last_scheduling, dag_m # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session) session.flush() @@ -394,7 +394,7 @@ def test_clear_task_instances_without_dag_param(self, dag_maker, session): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session) ti0.refresh_from_db(session=session) @@ -477,7 +477,9 @@ def test_clear_task_instances_with_task_reschedule(self, dag_maker): with create_session() as session: def count_task_reschedule(ti): - return session.query(TaskReschedule).filter(TaskReschedule.ti_id == ti.id).count() + return session.scalar( + select(func.count()).select_from(TaskReschedule).where(TaskReschedule.ti_id == ti.id) + ) assert count_task_reschedule(ti0) == 1 assert count_task_reschedule(ti1) == 1 @@ -485,12 +487,9 @@ def count_task_reschedule(ti): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - qry = ( - session.query(TI) - .filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id) - .order_by(TI.task_id) - .all() - ) + qry = session.scalars( + select(TI).where(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id).order_by(TI.task_id) + ).all() clear_task_instances(qry, session) assert count_task_reschedule(ti0) == 0 assert count_task_reschedule(ti1) == 1 @@ -531,7 +530,7 @@ def test_task_instance_history_record(self, state, state_recorded, dag_maker): ti1.state = state session = dag_maker.session session.flush() - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session) session.flush() @@ -716,10 +715,10 @@ def test_clear_task_instances_with_run_on_latest_version(self, run_on_latest_ver new_dag_version = DagVersion.get_latest_version(dag.dag_id) assert old_dag_version.id != new_dag_version.id - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id).order_by(TI.task_id)).all() clear_task_instances(qry, session, run_on_latest_version=run_on_latest_version) session.commit() - dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one() + dr = session.scalar(select(DagRun).where(DagRun.dag_id == dag.dag_id)) if run_on_latest_version: assert dr.created_dag_version_id == new_dag_version.id assert dr.bundle_version == new_dag_version.bundle_version diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index daeff282424b6..a09aceccc029f 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -27,7 +27,7 @@ import pendulum import pytest -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from sqlalchemy.orm import joinedload from airflow import settings @@ -155,10 +155,10 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all() clear_task_instances(qry, session) session.flush() - dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() + dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now)) assert dr0.state == state assert dr0.clear_number < 1 @@ -170,10 +170,10 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all() clear_task_instances(qry, session) session.flush() - dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() + dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now)) assert dr0.state == DagRunState.QUEUED assert dr0.clear_number == 1 @@ -721,7 +721,7 @@ def test_dagrun_set_state_end_date(self, dag_maker, session): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -729,14 +729,14 @@ def test_dagrun_set_state_end_date(self, dag_maker, session): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is None dr.set_state(DagRunState.FAILED) session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -764,7 +764,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -772,7 +772,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): ti_op2.set_state(state=TaskInstanceState.RUNNING, session=session) dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr._state == DagRunState.RUNNING assert dr.end_date is None @@ -782,7 +782,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): ti_op2.set_state(state=TaskInstanceState.FAILED, session=session) dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -1216,7 +1216,7 @@ def test_dag_run_dag_versions_method(self, dag_maker, session): EmptyOperator(task_id="empty") dag_run = dag_maker.create_dagrun() - dm = session.query(DagModel).options(joinedload(DagModel.dag_versions)).one() + dm = session.scalar(select(DagModel).options(joinedload(DagModel.dag_versions))) assert dag_run.dag_versions[0].id == dm.dag_versions[0].id def test_dag_run_version_number(self, dag_maker, session): @@ -1231,7 +1231,7 @@ def test_dag_run_version_number(self, dag_maker, session): tis[1].dag_version = dag_v session.merge(tis[1]) session.flush() - dag_run = session.query(DagRun).filter(DagRun.run_id == dag_run.run_id).one() + dag_run = session.scalar(select(DagRun).where(DagRun.run_id == dag_run.run_id)) # Check that dag_run.version_number returns the version number of # the latest task instance dag_version assert dag_run.version_number == dag_v.version_number @@ -1337,14 +1337,14 @@ def test_dagrun_success_deadline_prune(self, dag_maker, session): dag_run1_deadline = exists().where(Deadline.dagrun_id == dag_run1.id) dag_run2_deadline = exists().where(Deadline.dagrun_id == dag_run2.id) - assert session.query(dag_run1_deadline).scalar() - assert session.query(dag_run2_deadline).scalar() + assert session.scalar(select(dag_run1_deadline)) + assert session.scalar(select(dag_run2_deadline)) session.add(dag_run1) dag_run1.update_state() - assert not session.query(dag_run1_deadline).scalar() - assert session.query(dag_run2_deadline).scalar() + assert not session.scalar(select(dag_run1_deadline)) + assert session.scalar(select(dag_run2_deadline)) assert dag_run1.state == DagRunState.SUCCESS assert dag_run2.state == DagRunState.RUNNING @@ -1399,13 +1399,12 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) dr = dag_maker.create_dagrun() - indices = ( - session.query(TI.map_index) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + indices = session.scalars( + select(TI.map_index) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) - assert indices == [(0,), (1,), (2,), (3,)] + ).all() + assert indices == [0, 1, 2, 3] @pytest.mark.parametrize("is_noop", [True, False]) @@ -1422,13 +1421,12 @@ def mynameis(arg): mynameis.expand(arg=literal) dr = dag_maker.create_dagrun() - indices = ( - session.query(TI.map_index) - .filter_by(task_id="mynameis", dag_id=dr.dag_id, run_id=dr.run_id) + indices = session.scalars( + select(TI.map_index) + .where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) - assert indices == [(0,), (1,), (2,), (3,)] + ).all() + assert indices == [0, 1, 2, 3] def test_mapped_literal_verify_integrity(dag_maker, session): @@ -1444,7 +1442,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1483,12 +1481,11 @@ def task_2(arg2): ... dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) - indices = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + indices = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [ (0, TaskInstanceState.REMOVED), @@ -1511,7 +1508,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1552,7 +1549,7 @@ def task_2(arg2): ... dr = dag_maker.create_dagrun() query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1661,7 +1658,7 @@ def task_2(arg2): ... dr.task_instance_scheduling_decisions(session=session) query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1751,7 +1748,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1786,17 +1783,17 @@ def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): dr = dag_maker.create_dagrun() query = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + select(TI.map_index, TI.state) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) - assert query.all() == [(-1, None)] + assert session.execute(query).all() == [(-1, None)] # Verify_integrity shouldn't change the result now that the TIs exist dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) - assert query.all() == [(-1, None)] + assert session.execute(query).all() == [(-1, None)] def test_mapped_task_group_expands_at_create(dag_maker, session): @@ -1823,11 +1820,11 @@ def tg(x): dr = dag_maker.create_dagrun() query = ( - session.query(TI.task_id, TI.map_index, TI.state) - .filter_by(dag_id=dr.dag_id, run_id=dr.run_id) + select(TI.task_id, TI.map_index, TI.state) + .where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.task_id, TI.map_index) ) - assert query.all() == [ + assert session.execute(query).all() == [ ("tg.t1", 0, None), ("tg.t1", 1, None), # ("tg.t2", 0, None), @@ -1904,12 +1901,11 @@ def test_ti_scheduling_mapped_zero_length(dag_maker, session): # expanded against a zero-length XCom. assert decision.finished_tis == [ti1, ti2] - indices = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + indices = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [(-1, TaskInstanceState.SKIPPED)] @@ -2576,8 +2572,14 @@ def printx(x): dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) ti = dr1.get_task_instances()[0] - filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) - ti = session.query(TaskInstance).filter_by(**filter_kwargs).one() + ti = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, + ) + ) tr = TaskReschedule( ti_id=ti.id, @@ -2598,10 +2600,10 @@ def printx(x): XComModel.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id) session.commit() for table in [TaskInstanceNote, TaskReschedule, XComModel]: - assert session.query(table).count() == 1 + assert session.scalar(select(func.count()).select_from(table)) == 1 dr1.task_instance_scheduling_decisions(session) for table in [TaskInstanceNote, TaskReschedule, XComModel]: - assert session.query(table).count() == 0 + assert session.scalar(select(func.count()).select_from(table)) == 0 def test_dagrun_with_note(dag_maker, session): @@ -2619,14 +2621,14 @@ def the_task(): session.add(dr) session.commit() - dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one() + dr_note = session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) assert dr_note.content == "dag run with note" session.delete(dr) session.commit() - assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is None - assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one_or_none() is None + assert session.scalar(select(DagRun).where(DagRun.id == dr.id)) is None + assert session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) is None @pytest.mark.parametrize( @@ -2655,7 +2657,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == dag_run_state @@ -2694,7 +2696,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == dag_run_state @@ -2729,7 +2731,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == DagRunState.FAILED @@ -2765,7 +2767,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == DagRunState.FAILED diff --git a/airflow-core/tests/unit/models/test_pool.py b/airflow-core/tests/unit/models/test_pool.py index fa59d85ffe82d..a275bf2d4b246 100644 --- a/airflow-core/tests/unit/models/test_pool.py +++ b/airflow-core/tests/unit/models/test_pool.py @@ -21,6 +21,7 @@ import pendulum import pytest +from sqlalchemy import func, select from airflow import settings from airflow.exceptions import AirflowException, PoolNotFound @@ -292,7 +293,7 @@ def test_create_pool(self, session): assert pool.slots == 5 assert pool.description == "" assert pool.include_deferred is True - assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + 1 + assert session.scalar(select(func.count()).select_from(Pool)) == self.TOTAL_POOL_COUNT + 1 def test_create_pool_existing(self, session): self.add_pools() @@ -303,13 +304,13 @@ def test_create_pool_existing(self, session): assert pool.slots == 5 assert pool.description == "" assert pool.include_deferred is False - assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + assert session.scalar(select(func.count()).select_from(Pool)) == self.TOTAL_POOL_COUNT def test_delete_pool(self, session): self.add_pools() pool = Pool.delete_pool(name=self.pools[-1].pool) assert pool.pool == self.pools[-1].pool - assert session.query(Pool).count() == self.TOTAL_POOL_COUNT - 1 + assert session.scalar(select(func.count()).select_from(Pool)) == self.TOTAL_POOL_COUNT - 1 def test_delete_pool_non_existing(self): with pytest.raises(PoolNotFound, match="^Pool 'test' doesn't exist$"): diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index 2472c0720461e..ddc98f7206492 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -24,7 +24,7 @@ import pendulum import pytest -from sqlalchemy import func, select, update +from sqlalchemy import delete, func, select, update import airflow.example_dags as example_dags_module from airflow.dag_processing.dagbag import DagBag @@ -59,7 +59,12 @@ def make_example_dags(module): from airflow.utils.session import create_session with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + if ( + session.scalar( + select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing") + ) + == 0 + ): testing = DagBundleModel(name="testing") session.add(testing) @@ -101,7 +106,7 @@ def test_write_dag(self, testing_dag_bundle): with create_session() as session: for dag in example_dags.values(): assert SDM.has_dag(dag.dag_id) - result = session.query(SDM).filter(SDM.dag_id == dag.dag_id).one() + result = session.scalar(select(SDM).where(SDM.dag_id == dag.dag_id)) assert result.dag_version.dag_code.fileloc == dag.fileloc # Verifies JSON schema. @@ -118,7 +123,7 @@ def my_callable(): with dag_maker("dag1"): PythonOperator(task_id="task1", python_callable=lambda x: None) dag_maker.create_dagrun(run_id="test2", logical_date=pendulum.datetime(2025, 1, 1)) - assert len(session.query(DagVersion).all()) == 2 + assert len(session.scalars(select(DagVersion)).all()) == 2 with dag_maker("dag2"): @@ -136,7 +141,7 @@ def my_callable2(): pass my_callable2() - assert len(session.query(DagVersion).all()) == 4 + assert len(session.scalars(select(DagVersion)).all()) == 4 def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle): """Test Serialized DAG is updated if DAG is changed""" @@ -212,7 +217,7 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): dag.doc_md = "new doc string" SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing") serialized_dags2 = SDM.read_all_dags() - sdags = session.query(SDM).all() + sdags = session.scalars(select(SDM)).all() # assert only the latest SDM is returned assert len(sdags) != len(serialized_dags2) @@ -334,7 +339,7 @@ def test_get_latest_serdag_versions(self, dag_maker, session): def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session): with dag_maker("dag1") as dag: PythonOperator(task_id="task1", python_callable=lambda: None) - assert session.query(SDM).count() == 1 + assert session.scalar(select(func.count()).select_from(SDM)) == 1 sdm1 = SDM.get(dag.dag_id, session=session) dag_hash = sdm1.dag_hash created_at = sdm1.created_at @@ -347,21 +352,21 @@ def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session assert sdm2.dag_hash != dag_hash # first recorded serdag assert sdm2.created_at == created_at assert sdm2.last_updated != last_updated - assert session.query(DagVersion).count() == 1 - assert session.query(SDM).count() == 1 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 1 + assert session.scalar(select(func.count()).select_from(SDM)) == 1 def test_new_dag_versions_are_created_if_there_is_a_dagrun(self, dag_maker, session): with dag_maker("dag1") as dag: PythonOperator(task_id="task1", python_callable=lambda: None) dag_maker.create_dagrun(run_id="test3", logical_date=pendulum.datetime(2025, 1, 2)) - assert session.query(SDM).count() == 1 - assert session.query(DagVersion).count() == 1 + assert session.scalar(select(func.count()).select_from(SDM)) == 1 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 1 # new task PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") - assert session.query(DagVersion).count() == 2 - assert session.query(SDM).count() == 2 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 2 + assert session.scalar(select(func.count()).select_from(SDM)) == 2 def test_example_dag_sorting_serialised_dag(self, session): """ @@ -517,14 +522,14 @@ def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(sel # Create TIs dag_maker.create_dagrun(run_id="test_run") - assert session.query(DagVersion).count() == 1 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 1 # Write the same DAG (no changes, so hash is the same) with a new bundle_name new_bundle = "bundleB" SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=new_bundle) # There should now be two versions of the DAG - assert session.query(DagVersion).count() == 2 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 2 def test_hash_method_removes_fileloc_and_remains_consistent(self): """Test that the hash method removes fileloc before hashing.""" @@ -632,7 +637,7 @@ def test_dynamic_dag_update_preserves_null_check(self, dag_maker, session): assert dag_version is not None # Manually delete SerializedDagModel (simulates edge case) - session.query(SDM).filter(SDM.dag_id == "test_missing_serdag").delete() + session.execute(delete(SDM).where(SDM.dag_id == "test_missing_serdag")) session.commit() # Verify no SerializedDagModel exists @@ -709,7 +714,9 @@ def test_write_dag_atomicity_on_dagcode_failure(self, dag_maker, session): EmptyOperator(task_id="task1") dag = dag_maker.dag - initial_version_count = session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count() + initial_version_count = session.scalar( + select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id) + ) assert initial_version_count == 1, "Should have one DagVersion after initial write" dag_maker.create_dagrun() # ensure the second dag version is created @@ -732,8 +739,8 @@ def test_write_dag_atomicity_on_dagcode_failure(self, dag_maker, session): # Verify that no new DagVersion was committed # Use a fresh session to ensure we're reading from committed data with create_session() as fresh_session: - final_version_count = ( - fresh_session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count() + final_version_count = fresh_session.scalar( + select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id) ) assert final_version_count == initial_version_count, ( "DagVersion should not be committed when DagCode.write_code fails" diff --git a/airflow-core/tests/unit/models/test_trigger.py b/airflow-core/tests/unit/models/test_trigger.py index fe2fbeb6b98f2..df169852cb52c 100644 --- a/airflow-core/tests/unit/models/test_trigger.py +++ b/airflow-core/tests/unit/models/test_trigger.py @@ -26,6 +26,7 @@ import pytest import pytz from cryptography.fernet import Fernet +from sqlalchemy import delete, func, select from airflow._shared.timezones import timezone from airflow.jobs.job import Job @@ -61,21 +62,21 @@ def session(): @pytest.fixture(autouse=True) def clear_db(session): - session.query(TaskInstance).delete() - session.query(AssetWatcherModel).delete() - session.query(Callback).delete() - session.query(Trigger).delete() - session.query(AssetModel).delete() - session.query(AssetEvent).delete() - session.query(Job).delete() + session.execute(delete(TaskInstance)) + session.execute(delete(AssetWatcherModel)) + session.execute(delete(Callback)) + session.execute(delete(Trigger)) + session.execute(delete(AssetModel)) + session.execute(delete(AssetEvent)) + session.execute(delete(Job)) yield session - session.query(TaskInstance).delete() - session.query(AssetWatcherModel).delete() - session.query(Callback).delete() - session.query(Trigger).delete() - session.query(AssetModel).delete() - session.query(AssetEvent).delete() - session.query(Job).delete() + session.execute(delete(TaskInstance)) + session.execute(delete(AssetWatcherModel)) + session.execute(delete(Callback)) + session.execute(delete(Trigger)) + session.execute(delete(AssetModel)) + session.execute(delete(AssetEvent)) + session.execute(delete(Job)) session.commit() @@ -121,7 +122,7 @@ def test_clean_unused(session, create_task_instance): session.add(trigger5) session.add(trigger6) session.commit() - assert session.query(Trigger).count() == 6 + assert session.scalar(select(func.count()).select_from(Trigger)) == 6 # Tie one to a fake TaskInstance that is not deferred, and one to one that is task_instance = create_task_instance( session=session, task_id="fake", state=State.DEFERRED, logical_date=timezone.utcnow() @@ -150,7 +151,7 @@ def test_clean_unused(session, create_task_instance): asset.add_trigger(trigger5, "test_asset_watcher2") session.add(asset) session.commit() - assert session.query(AssetModel).count() == 1 + assert session.scalar(select(func.count()).select_from(AssetModel)) == 1 # Create callback with trigger callback = TriggererCallback( @@ -162,7 +163,7 @@ def test_clean_unused(session, create_task_instance): # Run clear operation Trigger.clean_unused() - results = session.query(Trigger).all() + results = session.scalars(select(Trigger)).all() assert len(results) == 4 assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id} @@ -196,7 +197,10 @@ def test_submit_event(mock_callback_handle_event, session, create_task_instance) session.commit() # Check that the asset has 0 event prior to sending an event to the trigger - assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asset.id)) + == 0 + ) # Create event payload = "payload" @@ -210,8 +214,11 @@ def test_submit_event(mock_callback_handle_event, session, create_task_instance) assert task_instance.state == State.SCHEDULED assert task_instance.next_kwargs == {"event": payload, "cheesecake": True} # Check that the asset has received an event - assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1 - asset_event = session.query(AssetEvent).filter_by(asset_id=asset.id).first() + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asset.id)) + == 1 + ) + asset_event = session.scalar(select(AssetEvent).where(AssetEvent.asset_id == asset.id)) assert asset_event.extra == {"from_trigger": True, "payload": payload} # Check that the callback's handle_event was called @@ -233,7 +240,7 @@ def test_submit_failure(session, create_task_instance): # Call submit_event Trigger.submit_failure(trigger.id, session=session) # Check that the task instance is now scheduled to fail - updated_task_instance = session.query(TaskInstance).one() + updated_task_instance = session.scalar(select(TaskInstance)) assert updated_task_instance.state == State.SCHEDULED assert updated_task_instance.next_method == "__fail__" @@ -272,7 +279,7 @@ def get_xcoms(ti): # now for the real test # first check initial state - ti: TaskInstance = session.query(TaskInstance).one() + ti: TaskInstance = session.scalar(select(TaskInstance)) assert ti.state == "deferred" assert get_xcoms(ti) == [] @@ -285,7 +292,7 @@ def get_xcoms(ti): # commit changes made by submit event and expire all cache to read from db. session.flush() # Check that the task instance is now correct - ti = session.query(TaskInstance).one() + ti = session.scalar(select(TaskInstance)) assert ti.state == expected assert ti.next_kwargs is None assert ti.end_date == now @@ -370,26 +377,26 @@ def test_assign_unassigned(session, create_task_instance): session.add(ti_trigger_unassigned_to_triggerer) assert trigger_unassigned_to_triggerer.triggerer_id is None session.commit() - assert session.query(Trigger).count() == 4 + assert session.scalar(select(func.count()).select_from(Trigger)) == 4 Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30) session.expire_all() # Check that trigger on killed triggerer and unassigned trigger are assigned to new triggerer assert ( - session.query(Trigger).filter(Trigger.id == trigger_on_killed_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_on_killed_triggerer.id)).triggerer_id == new_triggerer.id ) assert ( - session.query(Trigger).filter(Trigger.id == trigger_unassigned_to_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_unassigned_to_triggerer.id)).triggerer_id == new_triggerer.id ) # Check that trigger on healthy triggerer still assigned to existing triggerer assert ( - session.query(Trigger).filter(Trigger.id == trigger_on_healthy_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_on_healthy_triggerer.id)).triggerer_id == healthy_triggerer.id ) # Check that trigger on unhealthy triggerer is assigned to new triggerer assert ( - session.query(Trigger).filter(Trigger.id == trigger_on_unhealthy_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_on_unhealthy_triggerer.id)).triggerer_id == new_triggerer.id ) @@ -453,7 +460,7 @@ def test_get_sorted_triggers_same_priority_weight(session, create_task_instance) ) session.add(trigger_callback) session.commit() - assert session.query(Trigger).count() == 5 + assert session.scalar(select(func.count()).select_from(Trigger)) == 5 # Create assets asset = AssetModel("test") asset.add_trigger(trigger_asset, "test_asset_watcher") @@ -534,7 +541,7 @@ def test_get_sorted_triggers_different_priority_weights(session, create_task_ins session.add(TI_new) session.commit() - assert session.query(Trigger).count() == 5 + assert session.scalar(select(func.count()).select_from(Trigger)) == 5 trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session) @@ -605,7 +612,7 @@ def test_get_sorted_triggers_dont_starve_for_ha(session, create_task_instance): asset_triggers.append(trigger) session.commit() - assert session.query(Trigger).count() == 60 + assert session.scalar(select(func.count()).select_from(Trigger)) == 60 # Mock max_trigger_to_select_per_loop to 5 for testing with patch.object(Trigger, "max_trigger_to_select_per_loop", 5): diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index 1bc7105cd8a6a..acf7ad752bfbd 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -23,6 +23,7 @@ from unittest.mock import MagicMock import pytest +from sqlalchemy import delete, func, select from airflow._shared.timezones import timezone from airflow.configuration import conf @@ -88,7 +89,7 @@ def func(*, dag_id, task_id, logical_date, run_after=None): def cleanup_database(): # This should also clear task instances by cascading. - session.query(DagRun).filter_by(id=run.id).delete() + session.execute(delete(DagRun).where(DagRun.id == run.id)) session.commit() request.addfinalizer(cleanup_database) @@ -384,7 +385,7 @@ def test_xcom_set(self, session, task_instance, key, value, expected_value): run_id=task_instance.run_id, session=session, ) - stored_xcoms = session.query(XComModel).all() + stored_xcoms = session.scalars(select(XComModel)).all() assert stored_xcoms[0].key == key assert isinstance(stored_xcoms[0].value, type(json.dumps(expected_value))) assert stored_xcoms[0].value == json.dumps(expected_value) @@ -398,7 +399,7 @@ def setup_for_xcom_set_again_replace(self, task_instance, push_simple_json_xcom) @pytest.mark.usefixtures("setup_for_xcom_set_again_replace") def test_xcom_set_again_replace(self, session, task_instance): - assert session.query(XComModel).one().value == json.dumps({"key1": "value1"}) + assert session.scalar(select(XComModel)).value == json.dumps({"key1": "value1"}) XComModel.set( key="xcom_1", value={"key2": "value2"}, @@ -407,7 +408,7 @@ def test_xcom_set_again_replace(self, session, task_instance): run_id=task_instance.run_id, session=session, ) - assert session.query(XComModel).one().value == json.dumps({"key2": "value2"}) + assert session.scalar(select(XComModel)).value == json.dumps({"key2": "value2"}) def test_xcom_set_invalid_key(self, session, task_instance): """Test that setting an XCom with an invalid key raises a ValueError.""" @@ -440,14 +441,14 @@ def setup_for_xcom_clear(self, task_instance, push_simple_json_xcom): @pytest.mark.usefixtures("setup_for_xcom_clear") @mock.patch("airflow.sdk.execution_time.xcom.XCom.purge") def test_xcom_clear(self, mock_purge, session, task_instance): - assert session.query(XComModel).count() == 1 + assert session.scalar(select(func.count()).select_from(XComModel)) == 1 XComModel.clear( dag_id=task_instance.dag_id, task_id=task_instance.task_id, run_id=task_instance.run_id, session=session, ) - assert session.query(XComModel).count() == 0 + assert session.scalar(select(func.count()).select_from(XComModel)) == 0 # purge will not be done when we clear, will be handled in task sdk assert mock_purge.call_count == 0 @@ -459,7 +460,7 @@ def test_xcom_clear_different_run(self, session, task_instance): run_id="different_run", session=session, ) - assert session.query(XComModel).count() == 1 + assert session.scalar(select(func.count()).select_from(XComModel)) == 1 class TestXComRoundTrip: diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index 30669ab05997f..cbe1a61d1c618 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -36,6 +36,7 @@ from pydantic import TypeAdapter from pydantic.v1.utils import deep_update from requests.adapters import Response +from sqlalchemy import delete, select from airflow import settings from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG @@ -98,8 +99,8 @@ def cleanup_tables(): class TestFileTaskLogHandler: def clean_up(self): with create_session() as session: - session.query(DagRun).delete() - session.query(TaskInstance).delete() + session.execute(delete(DagRun)) + session.execute(delete(TaskInstance)) def setup_method(self): settings.configure_logging() @@ -781,16 +782,14 @@ def test_jinja_id_in_template_for_history( ) TaskInstanceHistory.record_ti(ti, session=session) session.flush() - tih = ( - session.query(TaskInstanceHistory) - .filter_by( - dag_id=ti.dag_id, - task_id=ti.task_id, - run_id=ti.run_id, - map_index=ti.map_index, - try_number=ti.try_number, + tih = session.scalar( + select(TaskInstanceHistory).where( + TaskInstanceHistory.dag_id == ti.dag_id, + TaskInstanceHistory.task_id == ti.task_id, + TaskInstanceHistory.run_id == ti.run_id, + TaskInstanceHistory.map_index == ti.map_index, + TaskInstanceHistory.try_number == ti.try_number, ) - .one() ) fth = FileTaskHandler("") rendered_ti = fth._render_filename(ti, ti.try_number, session=session) diff --git a/airflow-core/tests/unit/utils/test_state.py b/airflow-core/tests/unit/utils/test_state.py index 463f943320462..88a1925842e8b 100644 --- a/airflow-core/tests/unit/utils/test_state.py +++ b/airflow-core/tests/unit/utils/test_state.py @@ -19,6 +19,7 @@ from datetime import timedelta import pytest +from sqlalchemy import select from airflow.models.dagrun import DagRun from airflow.sdk import DAG @@ -58,22 +59,18 @@ def test_dagrun_state_enum_escape(testing_dag_bundle): triggered_by=DagRunTriggeredByType.TEST, ) - query = session.query( - DagRun.dag_id, - DagRun.state, - DagRun.run_type, - ).filter( + stmt = select(DagRun.dag_id, DagRun.state, DagRun.run_type).where( DagRun.dag_id == dag.dag_id, # make sure enum value can be used in filter queries DagRun.state == DagRunState.QUEUED, ) - assert str(query.statement.compile(compile_kwargs={"literal_binds": True})) == ( + assert str(stmt.compile(compile_kwargs={"literal_binds": True})) == ( "SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n" "FROM dag_run \n" "WHERE dag_run.dag_id = 'test_dagrun_state_enum_escape' AND dag_run.state = 'queued'" ) - rows = query.all() + rows = session.execute(stmt).all() assert len(rows) == 1 assert rows[0].dag_id == dag.dag_id # make sure value in db is stored as `queued`, not `DagRunType.QUEUED`