From 64167317c4db20f1cee6851d19a9ebb39f085da5 Mon Sep 17 00:00:00 2001 From: rich7420 Date: Mon, 15 Dec 2025 17:41:30 +0800 Subject: [PATCH 1/3] Refactor deprecated SQLA models/test_dagrun.py --- .pre-commit-config.yaml | 1 + airflow-core/tests/unit/models/test_dagrun.py | 98 +++++++++---------- 2 files changed, 48 insertions(+), 51 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81e24f2e8a3b3..e85ac1aca634e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -413,6 +413,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/models/test_dagrun.py$| ^task_sdk.*\.py$ pass_filenames: true - id: update-supported-versions diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index daeff282424b6..e03d252ba7ca1 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -27,7 +27,7 @@ import pendulum import pytest -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from sqlalchemy.orm import joinedload from airflow import settings @@ -155,10 +155,10 @@ def test_clear_task_instances_for_backfill_running_dagrun(self, dag_maker, sessi EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all() clear_task_instances(qry, session) session.flush() - dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() + dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now)) assert dr0.state == state assert dr0.clear_number < 1 @@ -170,10 +170,10 @@ def test_clear_task_instances_for_backfill_finished_dagrun(self, dag_maker, stat EmptyOperator(task_id="backfill_task_0") self.create_dag_run(dag, logical_date=now, is_backfill=True, state=state, session=session) - qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all() + qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all() clear_task_instances(qry, session) session.flush() - dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.logical_date == now).first() + dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.logical_date == now)) assert dr0.state == DagRunState.QUEUED assert dr0.clear_number == 1 @@ -721,7 +721,7 @@ def test_dagrun_set_state_end_date(self, dag_maker, session): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -729,14 +729,14 @@ def test_dagrun_set_state_end_date(self, dag_maker, session): session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is None dr.set_state(DagRunState.FAILED) session.merge(dr) session.commit() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -764,7 +764,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -772,7 +772,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): ti_op2.set_state(state=TaskInstanceState.RUNNING, session=session) dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr._state == DagRunState.RUNNING assert dr.end_date is None @@ -782,7 +782,7 @@ def test_dagrun_update_state_end_date(self, dag_maker, session): ti_op2.set_state(state=TaskInstanceState.FAILED, session=session) dr.update_state() - dr_database = session.query(DagRun).filter(DagRun.run_id == dr.run_id).one() + dr_database = session.scalar(select(DagRun).where(DagRun.run_id == dr.run_id)) assert dr_database.end_date is not None assert dr.end_date == dr_database.end_date @@ -1216,7 +1216,7 @@ def test_dag_run_dag_versions_method(self, dag_maker, session): EmptyOperator(task_id="empty") dag_run = dag_maker.create_dagrun() - dm = session.query(DagModel).options(joinedload(DagModel.dag_versions)).one() + dm = session.scalar(select(DagModel).options(joinedload(DagModel.dag_versions))) assert dag_run.dag_versions[0].id == dm.dag_versions[0].id def test_dag_run_version_number(self, dag_maker, session): @@ -1231,7 +1231,7 @@ def test_dag_run_version_number(self, dag_maker, session): tis[1].dag_version = dag_v session.merge(tis[1]) session.flush() - dag_run = session.query(DagRun).filter(DagRun.run_id == dag_run.run_id).one() + dag_run = session.scalar(select(DagRun).where(DagRun.run_id == dag_run.run_id)) # Check that dag_run.version_number returns the version number of # the latest task instance dag_version assert dag_run.version_number == dag_v.version_number @@ -1337,14 +1337,14 @@ def test_dagrun_success_deadline_prune(self, dag_maker, session): dag_run1_deadline = exists().where(Deadline.dagrun_id == dag_run1.id) dag_run2_deadline = exists().where(Deadline.dagrun_id == dag_run2.id) - assert session.query(dag_run1_deadline).scalar() - assert session.query(dag_run2_deadline).scalar() + assert session.scalar(select(dag_run1_deadline)) + assert session.scalar(select(dag_run2_deadline)) session.add(dag_run1) dag_run1.update_state() - assert not session.query(dag_run1_deadline).scalar() - assert session.query(dag_run2_deadline).scalar() + assert not session.scalar(select(dag_run1_deadline)) + assert session.scalar(select(dag_run2_deadline)) assert dag_run1.state == DagRunState.SUCCESS assert dag_run2.state == DagRunState.RUNNING @@ -1399,12 +1399,11 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): mapped = MockOperator.partial(task_id="task_2").expand(arg2=literal) dr = dag_maker.create_dagrun() - indices = ( - session.query(TI.map_index) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + indices = session.scalars( + select(TI.map_index) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [(0,), (1,), (2,), (3,)] @@ -1422,12 +1421,11 @@ def mynameis(arg): mynameis.expand(arg=literal) dr = dag_maker.create_dagrun() - indices = ( - session.query(TI.map_index) - .filter_by(task_id="mynameis", dag_id=dr.dag_id, run_id=dr.run_id) + indices = session.scalars( + select(TI.map_index) + .where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [(0,), (1,), (2,), (3,)] @@ -1483,12 +1481,11 @@ def task_2(arg2): ... dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) - indices = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + indices = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [ (0, TaskInstanceState.REMOVED), @@ -1786,8 +1783,8 @@ def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): dr = dag_maker.create_dagrun() query = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + select(TI.map_index, TI.state) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) @@ -1823,8 +1820,8 @@ def tg(x): dr = dag_maker.create_dagrun() query = ( - session.query(TI.task_id, TI.map_index, TI.state) - .filter_by(dag_id=dr.dag_id, run_id=dr.run_id) + select(TI.task_id, TI.map_index, TI.state) + .where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.task_id, TI.map_index) ) assert query.all() == [ @@ -1904,12 +1901,11 @@ def test_ti_scheduling_mapped_zero_length(dag_maker, session): # expanded against a zero-length XCom. assert decision.finished_tis == [ti1, ti2] - indices = ( - session.query(TI.map_index, TI.state) - .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + indices = session.execute( + select(TI.map_index, TI.state) + .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) - .all() - ) + ).all() assert indices == [(-1, TaskInstanceState.SKIPPED)] @@ -2577,7 +2573,7 @@ def printx(x): dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) ti = dr1.get_task_instances()[0] filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) - ti = session.query(TaskInstance).filter_by(**filter_kwargs).one() + ti = session.scalar(select(TaskInstance).filter_by(**filter_kwargs)) tr = TaskReschedule( ti_id=ti.id, @@ -2598,10 +2594,10 @@ def printx(x): XComModel.set(key="test", value="value", task_id=ti.task_id, dag_id=dag.dag_id, run_id=ti.run_id) session.commit() for table in [TaskInstanceNote, TaskReschedule, XComModel]: - assert session.query(table).count() == 1 + assert session.scalar(select(func.count()).select_from(table)) == 1 dr1.task_instance_scheduling_decisions(session) for table in [TaskInstanceNote, TaskReschedule, XComModel]: - assert session.query(table).count() == 0 + assert session.scalar(select(func.count()).select_from(table)) == 0 def test_dagrun_with_note(dag_maker, session): @@ -2619,14 +2615,14 @@ def the_task(): session.add(dr) session.commit() - dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one() + dr_note = session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) assert dr_note.content == "dag run with note" session.delete(dr) session.commit() - assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is None - assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == dr.id).one_or_none() is None + assert session.scalar(select(DagRun).where(DagRun.id == dr.id)) is None + assert session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == dr.id)) is None @pytest.mark.parametrize( @@ -2655,7 +2651,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == dag_run_state @@ -2694,7 +2690,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == dag_run_state @@ -2729,7 +2725,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == DagRunState.FAILED @@ -2765,7 +2761,7 @@ def mytask(): session.flush() dr.update_state() session.flush() - dr = session.query(DagRun).one() + dr = session.scalar(select(DagRun)) assert dr.state == DagRunState.FAILED From bdc4d6ab41ba99bee697e462d9690f7d74c02560 Mon Sep 17 00:00:00 2001 From: rich7420 Date: Mon, 15 Dec 2025 19:26:03 +0800 Subject: [PATCH 2/3] fix error --- airflow-core/tests/unit/models/test_dagrun.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index e03d252ba7ca1..e4bc4b3e651cd 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -1404,7 +1404,7 @@ def test_expand_mapped_task_instance_at_create(is_noop, dag_maker, session): .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ).all() - assert indices == [(0,), (1,), (2,), (3,)] + assert indices == [0, 1, 2, 3] @pytest.mark.parametrize("is_noop", [True, False]) @@ -1426,7 +1426,7 @@ def mynameis(arg): .where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ).all() - assert indices == [(0,), (1,), (2,), (3,)] + assert indices == [0, 1, 2, 3] def test_mapped_literal_verify_integrity(dag_maker, session): @@ -1788,12 +1788,12 @@ def test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session): .order_by(TI.map_index) ) - assert query.all() == [(-1, None)] + assert session.execute(query).all() == [(-1, None)] # Verify_integrity shouldn't change the result now that the TIs exist dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id dr.verify_integrity(dag_version_id=dag_version_id, session=session) - assert query.all() == [(-1, None)] + assert session.execute(query).all() == [(-1, None)] def test_mapped_task_group_expands_at_create(dag_maker, session): @@ -1824,7 +1824,7 @@ def tg(x): .where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.task_id, TI.map_index) ) - assert query.all() == [ + assert session.execute(query).all() == [ ("tg.t1", 0, None), ("tg.t1", 1, None), # ("tg.t2", 0, None), From 78ba8858f9b7537527e5c29d9a5f395d62324f50 Mon Sep 17 00:00:00 2001 From: rich7420 Date: Wed, 17 Dec 2025 09:28:10 +0800 Subject: [PATCH 3/3] change to where() --- airflow-core/tests/unit/models/test_dagrun.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index e4bc4b3e651cd..a09aceccc029f 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -1442,7 +1442,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1508,7 +1508,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1549,7 +1549,7 @@ def task_2(arg2): ... dr = dag_maker.create_dagrun() query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1658,7 +1658,7 @@ def task_2(arg2): ... dr.task_instance_scheduling_decisions(session=session) query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -1748,7 +1748,7 @@ def task_2(arg2): ... query = ( select(TI.map_index, TI.state) - .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id) + .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == dr.run_id) .order_by(TI.map_index) ) indices = session.execute(query).all() @@ -2572,8 +2572,14 @@ def printx(x): dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) ti = dr1.get_task_instances()[0] - filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index) - ti = session.scalar(select(TaskInstance).filter_by(**filter_kwargs)) + ti = session.scalar( + select(TaskInstance).where( + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, + ) + ) tr = TaskReschedule( ti_id=ti.id,