diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 939039b0827fb..ced7a9ed668d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -425,6 +425,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/models/test_dagrun.py$| ^airflow-core/tests/unit/utils/test_db_cleanup.py$| ^dev/airflow_perf/scheduler_dag_execution_timing.py$| ^providers/openlineage/.*\.py$| diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index daeff282424b6..a09aceccc029f 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -27,7 +27,7 @@ import pendulum import pytest -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from sqlalchemy.orm import joinedload from airflow import settings @@ -155,10 +155,10 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all() clear_task_instances(qry, session) session.flush() - dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() + dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now)) assert dr0.state == state assert dr0.clear_number < 1 @@ -170,10 +170,10 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all() clear_task_instances(qry, session) session.flush() - dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() + dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now)) assert dr0.state == DagRunState.QUEUED assert dr0.clear_number == 1 @@ -721,7 +721,7 @@ def test_dagrun_set_state_end_date(self, dag_maker, session): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -729,14 +729,14 @@ def test_dagrun_set_state_end_date(self, dag_maker, session): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is None dr.set_state(DagRunState.FAILED) session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -764,7 +764,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -772,7 +772,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): ti_op2.set_state(state=TaskInstanceState.RUNNING, session=session) dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr._state == DagRunState.RUNNING assert dr.end_date is None @@ -782,7 +782,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): ti_op2.set_state(state=TaskInstanceState.FAILED, session=session) dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -1216,7 +1216,7 @@ def test_dag_run_dag_versions_method(self, dag_maker, session): EmptyOperator(task_id="empty") dag_run = dag_maker.create_dagrun() - dm = session.query(DagModel).options(joinedload(DagModel.dag_versions)).one() + dm = session.scalar(select(DagModel).options(joinedload(DagModel.dag_versions))) assert dag_run.dag_versions[0].id == dm.dag_versions[0].id def test_dag_run_version_number(self, dag_maker, session): @@ -1231,7 +1231,7 @@ def test_dag_run_version_number(self, dag_maker, session): tis[1].dag_version = dag_v session.merge(tis[1]) session.flush() - dag_run = session.query(DagRun).filter(DagRun.run_id == dag_run.run_id).one() + dag_run = session.scalar(select(DagRun).where(DagRun.run_id == dag_run.run_id)) # Check that dag_run.version_number returns the version number of # the latest task instance dag_version assert dag_run.version_number == dag_v.version_number @@ -1337,14 +1337,14 @@ def test_dagrun_success_deadline_prune(self, dag_maker, session): dag_run1_deadline = exists().where(Deadline.dagrun_id == dag_run1.id) dag_run2_deadline = exists().where(Deadline.dagrun_id == dag_run2.id) - assert session.query(dag_run1_deadline).scalar() - assert session.query(dag_run2_deadline).scalar() + assert session.scalar(select(dag_run1_deadline)) + assert session.scalar(select(dag_run2_deadline)) session.add(dag_run1) dag_run1.update_state() - assert not session.query(dag_run1_deadline).scalar() - assert session.query(dag_run2_deadline).scalar() + assert not session.scalar(select(dag_run1_deadline)) + assert session.scalar(select(dag_run2_deadline)) assert dag_run1.state == DagRunState.SUCCESS assert dag_run2.state == DagRunState.RUNNING @@ -1399,13 +1399,12 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) dr = dag_maker.create_dagrun() - indices = ( - session.query(TI.map_index) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + indices = session.scalars( + select(TI.map_index) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) - assert indices == [(0,), (1,), (2,), (3,)] + ).all() + assert indices == [0, 1, 2, 3] @pytest.mark.parametrize("is_noop", [True, False]) @@ -1422,13 +1421,12 @@ def mynameis(arg): mynameis.expand(arg=literal) dr = dag_maker.create_dagrun() - indices = ( - session.query(TI.map_index) - .filter_by(task_id="mynameis", dag_id=dr.dag_id, run_id=dr.run_id) + indices = session.scalars( + select(TI.map_index) + .where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) - assert indices == [(0,), (1,), (2,), (3,)] + ).all() + assert indices == [0, 1, 2, 3] def test_mapped_literal_verify_integrity(dag_maker, session): @@ -1444,7 +1442,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1483,12 +1481,11 @@ def task_2(arg2): ... dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) - indices = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + indices = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [ (0, TaskInstanceState.REMOVED), @@ -1511,7 +1508,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1552,7 +1549,7 @@ def task_2(arg2): ... dr = dag_maker.create_dagrun() query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1661,7 +1658,7 @@ def task_2(arg2): ... dr.task_instance_scheduling_decisions(session=session) query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1751,7 +1748,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1786,17 +1783,17 @@ def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): dr = dag_maker.create_dagrun() query = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + select(TI.map_index, TI.state) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) - assert query.all() == [(-1, None)] + assert session.execute(query).all() == [(-1, None)] # Verify_integrity shouldn't change the result now that the TIs exist dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) - assert query.all() == [(-1, None)] + assert session.execute(query).all() == [(-1, None)] def test_mapped_task_group_expands_at_create(dag_maker, session): @@ -1823,11 +1820,11 @@ def tg(x): dr = dag_maker.create_dagrun() query = ( - session.query(TI.task_id, TI.map_index, TI.state) - .filter_by(dag_id=dr.dag_id, run_id=dr.run_id) + select(TI.task_id, TI.map_index, TI.state) + .where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.task_id, TI.map_index) ) - assert query.all() == [ + assert session.execute(query).all() == [ ("tg.t1", 0, None), ("tg.t1", 1, None), # ("tg.t2", 0, None), @@ -1904,12 +1901,11 @@ def test_ti_scheduling_mapped_zero_length(dag_maker, session): # expanded against a zero-length XCom. assert decision.finished_tis == [ti1, ti2] - indices = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + indices = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [(-1, TaskInstanceState.SKIPPED)] @@ -2576,8 +2572,14 @@ def printx(x): dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) ti = dr1.get_task_instances()[0] - filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) - ti = session.query(TaskInstance).filter_by(**filter_kwargs).one() + ti = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, + ) + ) tr = TaskReschedule( ti_id=ti.id, @@ -2598,10 +2600,10 @@ def printx(x): XComModel.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id) session.commit() for table in [TaskInstanceNote, TaskReschedule, XComModel]: - assert session.query(table).count() == 1 + assert session.scalar(select(func.count()).select_from(table)) == 1 dr1.task_instance_scheduling_decisions(session) for table in [TaskInstanceNote, TaskReschedule, XComModel]: - assert session.query(table).count() == 0 + assert session.scalar(select(func.count()).select_from(table)) == 0 def test_dagrun_with_note(dag_maker, session): @@ -2619,14 +2621,14 @@ def the_task(): session.add(dr) session.commit() - dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one() + dr_note = session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) assert dr_note.content == "dag run with note" session.delete(dr) session.commit() - assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is None - assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one_or_none() is None + assert session.scalar(select(DagRun).where(DagRun.id == dr.id)) is None + assert session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) is None @pytest.mark.parametrize( @@ -2655,7 +2657,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == dag_run_state @@ -2694,7 +2696,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == dag_run_state @@ -2729,7 +2731,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == DagRunState.FAILED @@ -2765,7 +2767,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == DagRunState.FAILED