diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d63a999f7ea34..897ee78b18ea6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -455,6 +455,7 @@ repos: ^airflow-core/tests/unit/utils/test_cli_util.py$| ^airflow-core/tests/unit/timetables/test_assets_timetable.py$| ^airflow-core/tests/unit/assets/test_manager.py$| + ^airflow-core/tests/unit/jobs/test_scheduler_job.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py$| ^airflow-core/tests/unit/ti_deps/deps/test_runnable_exec_date_dep.py$| ^airflow-core/tests/unit/models/test_dagwarning.py$| diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 9c0fb1943994c..4340693d99d42 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -35,7 +35,7 @@ import pytest import time_machine from pytest import param -from sqlalchemy import func, select, update +from sqlalchemy import delete, func, select, update from sqlalchemy.orm import joinedload from airflow import settings @@ -113,7 +113,6 @@ ) from tests_common.test_utils.mock_executor import MockExecutor from tests_common.test_utils.mock_operators import CustomOperator -from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4, SQLALCHEMY_V_2_0 from unit.listeners import dag_listener from unit.listeners.test_listeners import get_listener_manager from unit.models import TEST_DAGS_FOLDER @@ -1477,15 +1476,19 @@ def test_not_enough_pool_slots(self, caplog, dag_maker): ) assert ( - session.query(TaskInstance) - .filter(TaskInstance.dag_id == dag_id, TaskInstance.state == State.SCHEDULED) - .count() + session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where(TaskInstance.dag_id == dag_id, TaskInstance.state == State.SCHEDULED) + ) == 1 ) assert ( - session.query(TaskInstance) - .filter(TaskInstance.dag_id == dag_id, TaskInstance.state == State.QUEUED) - .count() + session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where(TaskInstance.dag_id == dag_id, TaskInstance.state == State.QUEUED) + ) == 1 ) @@ -1595,7 +1598,7 @@ def test_find_executable_task_instances_concurrency(self, dag_maker, active_stat assert queued_runs["run_3"] == 2 session.commit() - session.query(TaskInstance).all() + session.scalars(select(TaskInstance)).all() # now we still have max tis running so no more will be queued queued_tis = self.job_runner._executable_task_instances_to_queued(max_tis=32, session=session) @@ -2775,19 +2778,19 @@ def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_make assert orm_dag is not None for _ in range(20): self.job_runner._create_dag_runs([orm_dag], session) - drs = session.query(DagRun).all() + drs = session.scalars(select(DagRun)).all() assert len(drs) == 10 for dr in drs: dr.state = State.RUNNING session.merge(dr) session.commit() - assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10 + assert session.scalar(select(func.count(DagRun.state)).where(DagRun.state == State.RUNNING)) == 10 for _ in range(20): self.job_runner._create_dag_runs([orm_dag], session) - assert session.query(DagRun).count() == 10 - assert session.query(DagRun.state).filter(DagRun.state == State.RUNNING).count() == 10 - assert session.query(DagRun.state).filter(DagRun.state == State.QUEUED).count() == 0 + assert session.scalar(select(func.count()).select_from(DagRun)) == 10 + assert session.scalar(select(func.count(DagRun.state)).where(DagRun.state == State.RUNNING)) == 10 + assert session.scalar(select(func.count(DagRun.state)).where(DagRun.state == State.QUEUED)) == 0 assert orm_dag.next_dagrun_create_after is None def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, session): @@ -3133,7 +3136,11 @@ def test_dagrun_timeout_fails_run_and_update_next_dagrun(self, dag_maker): assert dag_maker.dag_model.next_dagrun_create_after == dr.logical_date + timedelta(days=1) # check that no running/queued runs yet assert ( - session.query(DagRun).filter(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])).count() + session.scalar( + select(func.count()) + .select_from(DagRun) + .where(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])) + ) == 0 ) @@ -3238,8 +3245,7 @@ def test_dagrun_timeout_callbacks_are_stored_in_database(self, dag_maker, sessio self.job_runner._do_scheduling(session) callback = ( - session.query(DbCallbackRequest) - .order_by(DbCallbackRequest.id.desc()) + session.scalars(select(DbCallbackRequest).order_by(DbCallbackRequest.id.desc())) .first() .get_callback_request() ) @@ -3419,11 +3425,9 @@ def test_do_not_schedule_removed_task(self, dag_maker, session): assert dr is not None # Verify the task instance was created - initial_tis = ( - session.query(TaskInstance) - .filter(TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy") - .all() - ) + initial_tis = session.scalars( + select(TaskInstance).where(TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy") + ).all() assert len(initial_tis) == 1 # Update the DAG to remove the task (simulate DAG file change) @@ -3448,15 +3452,13 @@ def test_do_not_schedule_removed_task(self, dag_maker, session): assert res == [] # Verify no new task instances were created for the removed task in the new dagrun - new_tis = ( - session.query(TaskInstance) - .filter( + new_tis = session.scalars( + select(TaskInstance).where( TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy", TaskInstance.run_id == "test_run_2", ) - .all() - ) + ).all() assert len(new_tis) == 0 @pytest.mark.parametrize( @@ -3486,7 +3488,9 @@ def my_task(): ... self.job_runner._do_scheduling(session) assert ( - session.query(DagRun).filter(DagRun.dag_id == dr.dag_id, DagRun.run_id == dr.run_id).one().state + session.scalars(select(DagRun).where(DagRun.dag_id == dr.dag_id, DagRun.run_id == dr.run_id)) + .one() + .state == run_state ) @@ -3532,7 +3536,7 @@ def test_scheduler_start_date(self, testing_dag_bundle): run_job(scheduler_job, execute_callable=self.job_runner._execute) # zero tasks ran - assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 0 + assert len(session.scalars(select(TaskInstance).where(TaskInstance.dag_id == dag_id)).all()) == 0 session.commit() assert self.null_exec.sorted_tasks == [] @@ -3551,7 +3555,7 @@ def test_scheduler_start_date(self, testing_dag_bundle): run_after=data_interval_end, ) # one task "ran" - assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1 + assert len(session.scalars(select(TaskInstance).where(TaskInstance.dag_id == dag_id)).all()) == 1 session.commit() scheduler_job = Job(executor=self.null_exec) @@ -3559,7 +3563,7 @@ def test_scheduler_start_date(self, testing_dag_bundle): run_job(scheduler_job, execute_callable=self.job_runner._execute) # still one task - assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1 + assert len(session.scalars(select(TaskInstance).where(TaskInstance.dag_id == dag_id)).all()) == 1 session.commit() assert self.null_exec.sorted_tasks == [] @@ -3590,9 +3594,12 @@ def test_scheduler_task_start_date_catchup_true(self, testing_dag_bundle): run_job(scheduler_job, execute_callable=self.job_runner._execute) session = settings.Session() - tiq = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id) - ti1s = tiq.filter(TaskInstance.task_id == "dummy1").all() - ti2s = tiq.filter(TaskInstance.task_id == "dummy2").all() + ti1s = session.scalars( + select(TaskInstance).where(TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy1") + ).all() + ti2s = session.scalars( + select(TaskInstance).where(TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy2") + ).all() # With catchup=True, future task start dates are respected assert len(ti1s) == 0, "Expected no instances for dummy1 (start date in future with catchup=True)" @@ -3626,9 +3633,12 @@ def test_scheduler_task_start_date_catchup_false(self, testing_dag_bundle): run_job(scheduler_job, execute_callable=self.job_runner._execute) session = settings.Session() - tiq = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id) - ti1s = tiq.filter(TaskInstance.task_id == "dummy1").all() - ti2s = tiq.filter(TaskInstance.task_id == "dummy2").all() + ti1s = session.scalars( + select(TaskInstance).where(TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy1") + ).all() + ti2s = session.scalars( + select(TaskInstance).where(TaskInstance.dag_id == dag_id, TaskInstance.task_id == "dummy2") + ).all() # With catchup=False, future task start dates are ignored assert len(ti1s) >= 1, "Expected instances for dummy1 (ignoring future start date with catchup=False)" @@ -3664,7 +3674,7 @@ def test_scheduler_multiprocessing(self): # zero tasks ran dag_id = "test_start_date_scheduling" session = settings.Session() - assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 0 + assert len(session.scalars(select(TaskInstance).where(TaskInstance.dag_id == dag_id)).all()) == 0 def test_scheduler_verify_pool_full(self, dag_maker, mock_executor): """ @@ -3869,32 +3879,29 @@ def test_scheduler_verify_priority_and_slots(self, dag_maker, mock_executor): # Only second and third assert len(task_instances_list) == 2 - ti0 = ( - session.query(TaskInstance) - .filter(TaskInstance.task_id == "test_scheduler_verify_priority_and_slots_t0") - .first() - ) + ti0 = session.scalars( + select(TaskInstance).where(TaskInstance.task_id == "test_scheduler_verify_priority_and_slots_t0") + ).first() assert ti0.state == State.SCHEDULED - ti1 = ( - session.query(TaskInstance) - .filter(TaskInstance.task_id == "test_scheduler_verify_priority_and_slots_t1") - .first() - ) + ti1 = session.scalars( + select(TaskInstance).where(TaskInstance.task_id == "test_scheduler_verify_priority_and_slots_t1") + ).first() assert ti1.state == State.QUEUED - ti2 = ( - session.query(TaskInstance) - .filter(TaskInstance.task_id == "test_scheduler_verify_priority_and_slots_t2") - .first() - ) + ti2 = session.scalars( + select(TaskInstance).where(TaskInstance.task_id == "test_scheduler_verify_priority_and_slots_t2") + ).first() assert ti2.state == State.QUEUED def test_verify_integrity_if_dag_not_changed(self, dag_maker, session): # CleanUp - session.query(SerializedDagModel).filter( - SerializedDagModel.dag_id == "test_verify_integrity_if_dag_not_changed" - ).delete(synchronize_session=False) + session.execute( + delete(SerializedDagModel).where( + SerializedDagModel.dag_id == "test_verify_integrity_if_dag_not_changed" + ), + execution_options={"synchronize_session": False}, + ) with dag_maker(dag_id="test_verify_integrity_if_dag_not_changed") as dag: BashOperator(task_id="dummy", bash_command="echo hi") @@ -3921,15 +3928,13 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker, session): mock_verify_integrity.assert_not_called() session.flush() - tis_count = ( - session.query(func.count(TaskInstance.task_id)) - .filter( + tis_count = session.scalar( + select(func.count(TaskInstance.task_id)).where( TaskInstance.dag_id == dr.dag_id, TaskInstance.logical_date == dr.logical_date, TaskInstance.task_id == dr.dag.tasks[0].task_id, TaskInstance.state == State.SCHEDULED, ) - .scalar() ) assert tis_count == 1 @@ -3943,9 +3948,12 @@ def test_verify_integrity_if_dag_not_changed(self, dag_maker, session): def test_verify_integrity_if_dag_changed(self, dag_maker): # CleanUp with create_session() as session: - session.query(SerializedDagModel).filter( - SerializedDagModel.dag_id == "test_verify_integrity_if_dag_changed" - ).delete(synchronize_session=False) + session.execute( + delete(SerializedDagModel).where( + SerializedDagModel.dag_id == "test_verify_integrity_if_dag_changed" + ), + execution_options={"synchronize_session": False}, + ) with dag_maker(dag_id="test_verify_integrity_if_dag_changed", serialized=False) as dag: BashOperator(task_id="dummy", bash_command="echo hi") @@ -3990,24 +3998,13 @@ def test_verify_integrity_if_dag_changed(self, dag_maker): assert dr.dag_versions[-1].id == dag_version_2.id assert len(self.job_runner.scheduler_dag_bag.get_dag_for_run(dr, session).tasks) == 2 - if SQLALCHEMY_V_1_4: - tis_count = ( - session.query(func.count(TaskInstance.task_id)) - .filter( - TaskInstance.dag_id == dr.dag_id, - TaskInstance.logical_date == dr.logical_date, - TaskInstance.state == State.SCHEDULED, - ) - .scalar() - ) - if SQLALCHEMY_V_2_0: - tis_count = session.scalar( - select(func.count(TaskInstance.task_id)).where( - TaskInstance.dag_id == dr.dag_id, - TaskInstance.logical_date == dr.logical_date, - TaskInstance.state == State.SCHEDULED, - ) + tis_count = session.scalar( + select(func.count(TaskInstance.task_id)).where( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.logical_date == dr.logical_date, + TaskInstance.state == State.SCHEDULED, ) + ) assert tis_count == 2 latest_dag_version = DagVersion.get_latest_version(dr.dag_id, session=session) @@ -4033,7 +4030,7 @@ def test_verify_integrity_not_called_for_versioned_bundles(self, dag_maker, sess dr.bundle_version = "1" session.merge(dr) session.commit() - drs = session.query(DagRun).options(joinedload(DagRun.task_instances)).all() + drs = session.scalars(select(DagRun).options(joinedload(DagRun.task_instances))).unique().all() dr = drs[0] assert dr.bundle_version == "1" dag_version_1 = DagVersion.get_latest_version(dr.dag_id, session=session) @@ -4090,14 +4087,12 @@ def do_schedule(session): do_schedule() with create_session() as session: - ti = ( - session.query(TaskInstance) - .filter( + ti = session.scalars( + select(TaskInstance).where( TaskInstance.dag_id == "test_retry_still_in_executor", TaskInstance.task_id == "test_retry_handling_op", ) - .first() - ) + ).first() assert ti is not None, "Task not created by scheduler" ti.task = dag_task1 @@ -4523,7 +4518,7 @@ def test_create_dag_runs(self, dag_maker): with create_session() as session: self.job_runner._create_dag_runs([dag_model], session) - dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one() + dr = session.scalars(select(DagRun).where(DagRun.dag_id == dag.dag_id)).one() assert dr.state == State.QUEUED assert dr.start_date is None assert dr.creating_job_id == scheduler_job.id @@ -4548,7 +4543,7 @@ def test_create_dag_runs_assets(self, session, dag_maker): data_interval=(DEFAULT_DATE + timedelta(days=10), DEFAULT_DATE + timedelta(days=11)), ) - asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() + asset1_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset1.uri)) event1 = AssetEvent( asset_id=asset1_id, @@ -4602,7 +4597,7 @@ def dict_from_obj(obj): return {k.key: obj.__dict__.get(k) for k in obj.__mapper__.column_attrs} # dag3 should be triggered since it only depends on asset1, and it's been queued - created_run = session.query(DagRun).filter(DagRun.dag_id == dag3.dag_id).one() + created_run = session.scalars(select(DagRun).where(DagRun.dag_id == dag3.dag_id)).one() assert created_run.state == State.QUEUED assert created_run.start_date is None @@ -4614,11 +4609,21 @@ def dict_from_obj(obj): assert created_run.data_interval_start is None assert created_run.data_interval_end is None # dag2 ADRQ record should still be there since the dag run was *not* triggered - assert session.query(AssetDagRunQueue).filter_by(target_dag_id=dag2.dag_id).one() is not None + assert ( + session.scalars( + select(AssetDagRunQueue).where(AssetDagRunQueue.target_dag_id == dag2.dag_id) + ).one() + is not None + ) # dag2 should not be triggered since it depends on both asset 1 and 2 - assert session.query(DagRun).filter(DagRun.dag_id == dag2.dag_id).one_or_none() is None + assert session.scalars(select(DagRun).where(DagRun.dag_id == dag2.dag_id)).one_or_none() is None # dag3 ADRQ record should be deleted since the dag run was triggered - assert session.query(AssetDagRunQueue).filter_by(target_dag_id=dag3.dag_id).one_or_none() is None + assert ( + session.scalars( + select(AssetDagRunQueue).where(AssetDagRunQueue.target_dag_id == dag3.dag_id) + ).one_or_none() + is None + ) assert created_run.creating_job_id == scheduler_job.id @@ -4645,7 +4650,7 @@ def test_create_dag_runs_asset_alias_with_asset_event_attached(self, session, da BashOperator(task_id="simulate-asset-alias-outlet", bash_command="echo 1") dr = dag_maker.create_dagrun(run_id="asset-alias-producer-run") - asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() + asset1_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset1.uri)) # Create an AssetEvent, which is associated with the Asset, and it is attached to the AssetAlias event = AssetEvent( @@ -4685,7 +4690,7 @@ def dict_from_obj(obj): """Get dict of column attrs from SqlAlchemy object.""" return {k.key: obj.__dict__.get(k) for k in obj.__mapper__.column_attrs} - created_run = session.query(DagRun).filter(DagRun.dag_id == consumer_dag.dag_id).one() + created_run = session.scalars(select(DagRun).where(DagRun.dag_id == consumer_dag.dag_id)).one() assert created_run.state == State.QUEUED assert created_run.start_date is None @@ -4714,7 +4719,9 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, BashOperator(task_id="task", bash_command="echo 1", outlets=asset) asset_manger = AssetManager() - asset_id = session.scalars(select(AssetModel.id).filter_by(uri=asset.uri, name=asset.name)).one() + asset_id = session.scalars( + select(AssetModel.id).where(AssetModel.uri == asset.uri, AssetModel.name == asset.name) + ).one() ase_q = select(AssetEvent).where(AssetEvent.asset_id == asset_id).order_by(AssetEvent.timestamp) adrq_q = select(AssetDagRunQueue).where( AssetDagRunQueue.asset_id == asset_id, AssetDagRunQueue.target_dag_id == "consumer" @@ -4770,7 +4777,7 @@ def test_start_dagruns(self, stats_timing, dag_maker, session): self.job_runner._create_dag_runs([dag_model], session) self.job_runner._start_queued_dagruns(session) - dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first() + dr = session.scalars(select(DagRun).where(DagRun.dag_id == dag.dag_id)).first() # Assert dr state is running assert dr.state == State.RUNNING @@ -5079,24 +5086,24 @@ def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_mak self.job_runner = SchedulerJobRunner(job=scheduler_job) session = settings.Session() - assert session.query(DagRun).count() == 0 + assert session.scalar(select(func.count()).select_from(DagRun)) == 0 query, _ = DagModel.dags_needing_dagruns(session) dag_models = query.all() self.job_runner._create_dag_runs(dag_models, session) - dr = session.query(DagRun).one() + dr = session.scalars(select(DagRun)).one() dr.state == DagRunState.QUEUED - assert session.query(DagRun).count() == 1 + assert session.scalar(select(func.count()).select_from(DagRun)) == 1 assert dag_maker.dag_model.next_dagrun_create_after is None session.flush() # dags_needing_dagruns query should not return any value query, _ = DagModel.dags_needing_dagruns(session) assert len(query.all()) == 0 self.job_runner._create_dag_runs(dag_models, session) - assert session.query(DagRun).count() == 1 + assert session.scalar(select(func.count()).select_from(DagRun)) == 1 assert dag_maker.dag_model.next_dagrun_create_after is None assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE # set dagrun to success - dr = session.query(DagRun).one() + dr = session.scalars(select(DagRun)).one() dr.state = DagRunState.SUCCESS ti = dr.get_task_instance("task", session) ti.state = TaskInstanceState.SUCCESS @@ -5112,7 +5119,11 @@ def test_more_runs_are_not_created_when_max_active_runs_is_reached(self, dag_mak assert dag_maker.dag_model.next_dagrun == DEFAULT_DATE + timedelta(days=1) # assert no dagruns is created yet assert ( - session.query(DagRun).filter(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])).count() + session.scalar( + select(func.count()) + .select_from(DagRun) + .where(DagRun.state.in_([DagRunState.RUNNING, DagRunState.QUEUED])) + ) == 0 ) @@ -5123,13 +5134,12 @@ def test_max_active_runs_creation_phasing(self, dag_maker, session): """ def complete_one_dagrun(): - ti = ( - session.query(TaskInstance) + ti = session.scalars( + select(TaskInstance) .join(TaskInstance.dag_run) - .filter(TaskInstance.state != State.SUCCESS) + .where(TaskInstance.state != State.SUCCESS) .order_by(DagRun.logical_date) - .first() - ) + ).first() if ti: ti.state = State.SUCCESS session.flush() @@ -5171,7 +5181,7 @@ def complete_one_dagrun(): expected_logical_dates = [datetime.datetime(2016, 1, d, tzinfo=timezone.utc) for d in range(1, 6)] dagrun_logical_dates = [ - dr.logical_date for dr in session.query(DagRun).order_by(DagRun.logical_date).all() + dr.logical_date for dr in session.scalars(select(DagRun).order_by(DagRun.logical_date)).all() ] assert dagrun_logical_dates == expected_logical_dates @@ -5263,12 +5273,10 @@ def test_max_active_runs_in_a_dag_doesnt_stop_running_dag_runs_in_other_dags(sel self.job_runner._start_queued_dagruns(session) session.flush() - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) - .scalar() + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) ) - running_count = session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + running_count = session.scalar(select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING)) assert dag1_running_count == 1 assert running_count == 11 @@ -5306,12 +5314,10 @@ def test_max_active_runs_in_a_dag_doesnt_prevent_backfill_from_running_catchup_t self.job_runner._start_queued_dagruns(session) session.flush() - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) - .scalar() + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) ) - running_count = session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + running_count = session.scalar(select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING)) assert dag1_running_count == 1 assert running_count == 11 @@ -5326,48 +5332,42 @@ def test_max_active_runs_in_a_dag_doesnt_prevent_backfill_from_running_catchup_t triggering_user_name="test_user", dag_run_conf={}, ) - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) - .scalar() + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) ) assert dag1_running_count == 1 - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) assert total_running_count == 11 # scheduler will now mark backfill runs as running self.job_runner._start_queued_dagruns(session) session.flush() - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, ) - .scalar() ) assert dag1_running_count == 4 - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) assert total_running_count == 14 # and doing it again does not change anything self.job_runner._start_queued_dagruns(session) session.flush() - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, ) - .scalar() ) assert dag1_running_count == 4 - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) assert total_running_count == 14 @@ -5407,12 +5407,10 @@ def test_max_active_runs_in_a_dag_doesnt_prevent_backfill_from_running_catchup_f self.job_runner._start_queued_dagruns(session) session.flush() - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) - .scalar() + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.dag_id == "test_dag1", DagRun.state == State.RUNNING) ) - running_count = session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + running_count = session.scalar(select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING)) assert dag1_running_count == 1 # With catchup=False, only the most recent interval is scheduled for each DAG assert ( @@ -5435,18 +5433,16 @@ def test_max_active_runs_in_a_dag_doesnt_prevent_backfill_from_running_catchup_f # scheduler will now mark backfill runs as running self.job_runner._start_queued_dagruns(session) session.flush() - dag1_running_count = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_running_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, ) - .scalar() ) # Even with catchup=False, backfill runs should start assert dag1_running_count == 4 - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) assert ( total_running_count == 5 @@ -5470,26 +5466,22 @@ def test_backfill_runs_are_started_with_lower_priority_catchup_true(self, dag_ma EmptyOperator(task_id="mytask") def _running_counts(): - dag1_non_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_non_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type != DagRunType.BACKFILL_JOB, ) - .scalar() ) - dag1_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) return dag1_non_b_running, dag1_b_running, total_running_count @@ -5529,7 +5521,7 @@ def _running_counts(): assert dag1_non_b_running == 0 assert dag1_b_running == 0 assert total_running == 0 - assert session.query(func.count(DagRun.id)).scalar() == 46 + assert session.scalar(select(func.count(DagRun.id))) == 46 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 # now let's run it once @@ -5586,26 +5578,22 @@ def test_backfill_runs_are_started_with_lower_priority_catchup_false(self, dag_m EmptyOperator(task_id="mytask") def _running_counts(): - dag1_non_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_non_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type != DagRunType.BACKFILL_JOB, ) - .scalar() ) - dag1_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) return dag1_non_b_running, dag1_b_running, total_running_count @@ -5649,7 +5637,7 @@ def _running_counts(): assert dag1_b_running == 0 assert total_running == 0 # Total 14 runs: 5 for dag1 + 3 for dag2 + 6 backfill runs (Jan 1-6 inclusive) - assert session.query(func.count(DagRun.id)).scalar() == 14 + assert session.scalar(select(func.count(DagRun.id))) == 14 # now let's run it once self.job_runner._start_queued_dagruns(session) @@ -5678,7 +5666,7 @@ def _running_counts(): assert total_running == 5 # Total runs remain the same - assert session.query(func.count(DagRun.id)).scalar() == 14 + assert session.scalar(select(func.count(DagRun.id))) == 14 def test_backfill_maxed_out_no_prevent_non_backfill_max_out(self, dag_maker): session = settings.Session() @@ -5693,26 +5681,22 @@ def test_backfill_maxed_out_no_prevent_non_backfill_max_out(self, dag_maker): EmptyOperator(task_id="mytask") def _running_counts(): - dag1_non_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_non_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type != DagRunType.BACKFILL_JOB, ) - .scalar() ) - dag1_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) return dag1_non_b_running, dag1_b_running, total_running_count @@ -5734,7 +5718,7 @@ def _running_counts(): assert dag1_non_b_running == 0 assert dag1_b_running == 0 assert total_running == 0 - assert session.query(func.count(DagRun.id)).scalar() == 6 + assert session.scalar(select(func.count(DagRun.id))) == 6 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 6 # scheduler will now mark backfill runs as running @@ -5842,26 +5826,22 @@ def test_backfill_runs_not_started_when_backfill_paused( EmptyOperator(task_id="mytask") def _running_counts(): - dag1_non_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_non_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type != DagRunType.BACKFILL_JOB, ) - .scalar() ) - dag1_b_running = ( - session.query(func.count(DagRun.id)) - .filter( + dag1_b_running = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag1_dag_id, DagRun.state == State.RUNNING, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) - total_running_count = ( - session.query(func.count(DagRun.id)).filter(DagRun.state == State.RUNNING).scalar() + total_running_count = session.scalar( + select(func.count(DagRun.id)).where(DagRun.state == State.RUNNING) ) return dag1_non_b_running, dag1_b_running, total_running_count @@ -5885,7 +5865,7 @@ def _running_counts(): assert dag1_non_b_running == 0 assert dag1_b_running == 0 assert total_running == 0 - assert session.query(func.count(DagRun.id)).scalar() == 6 + assert session.scalar(select(func.count(DagRun.id))) == 6 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 6 if pause_it: @@ -5933,14 +5913,12 @@ def test_backfill_runs_skipped_when_lock_held_by_another_scheduler(self, dag_mak dag_run_conf={}, ) - queued_count = ( - session.query(func.count(DagRun.id)) - .filter( + queued_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag_id, DagRun.state == State.QUEUED, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) assert queued_count == 5 @@ -5953,38 +5931,32 @@ def test_backfill_runs_skipped_when_lock_held_by_another_scheduler(self, dag_mak session.flush() # No runs should be started because we couldn't acquire the lock - running_count = ( - session.query(func.count(DagRun.id)) - .filter( + running_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag_id, DagRun.state == State.RUNNING, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) assert running_count == 0, f"Expected 0 running when lock not acquired, but got {running_count}. " # no locks now: job_runner._start_queued_dagruns(session) session.flush() - running_count = ( - session.query(func.count(DagRun.id)) - .filter( + running_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag_id, DagRun.state == State.RUNNING, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) assert running_count == backfill_max_active_runs - queued_count = ( - session.query(func.count(DagRun.id)) - .filter( + queued_count = session.scalar( + select(func.count(DagRun.id)).where( DagRun.dag_id == dag_id, DagRun.state == State.QUEUED, DagRun.run_type == DagRunType.BACKFILL_JOB, ) - .scalar() ) # 2 runs are still queued assert queued_count == 2 @@ -6020,7 +5992,7 @@ def test_start_queued_dagruns_do_follow_logical_date_order(self, dag_maker): self.job_runner._start_queued_dagruns(session) session.flush() dr = DagRun.find(run_id="dagrun_1") - assert len(session.query(DagRun).filter(DagRun.state == State.RUNNING).all()) == 1 + assert len(session.scalars(select(DagRun).where(DagRun.state == State.RUNNING)).all()) == 1 assert dr[0].state == State.RUNNING @@ -6144,7 +6116,14 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ ti.end_date = end_date self.job_runner._schedule_dag_run(dr, session) - assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1 + assert ( + session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where(TaskInstance.state == State.SCHEDULED) + ) + == 1 + ) session.refresh(ti) assert ti.state == State.SCHEDULED @@ -6190,7 +6169,14 @@ def test_dag_file_processor_process_task_instances_with_max_active_tis_per_dag( ti.end_date = end_date self.job_runner._schedule_dag_run(dr, session) - assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1 + assert ( + session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where(TaskInstance.state == State.SCHEDULED) + ) + == 1 + ) session.refresh(ti) assert ti.state == State.SCHEDULED @@ -6236,7 +6222,14 @@ def test_dag_file_processor_process_task_instances_with_max_active_tis_per_dagru ti.end_date = end_date self.job_runner._schedule_dag_run(dr, session) - assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1 + assert ( + session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where(TaskInstance.state == State.SCHEDULED) + ) + == 1 + ) session.refresh(ti) assert ti.state == State.SCHEDULED @@ -6289,7 +6282,14 @@ def test_dag_file_processor_process_task_instances_depends_on_past( ti.end_date = end_date self.job_runner._schedule_dag_run(dr, session) - assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 2 + assert ( + session.scalar( + select(func.count()) + .select_from(TaskInstance) + .where(TaskInstance.state == State.SCHEDULED) + ) + == 2 + ) session.refresh(tis[0]) session.refresh(tis[1]) @@ -6316,8 +6316,12 @@ def test_scheduler_job_add_new_task(self, dag_maker): self.job_runner._create_dag_runs([orm_dag], session) drs = ( - session.query(DagRun) - .options(joinedload(DagRun.task_instances).joinedload(TaskInstance.dag_version)) + session.scalars( + select(DagRun).options(joinedload(DagRun.task_instances).joinedload(TaskInstance.dag_version)) + ) + # The unique() method must be invoked on this Result, as it contains results that include + # joined eager loads against collections + .unique() .all() ) assert len(drs) == 1 @@ -6331,7 +6335,12 @@ def test_scheduler_job_add_new_task(self, dag_maker): session.commit() self.job_runner._schedule_dag_run(dr, session) session.expunge_all() - assert session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 2 + assert ( + session.scalar( + select(func.count()).select_from(TaskInstance).where(TaskInstance.state == State.SCHEDULED) + ) + == 2 + ) session.flush() drs = DagRun.find(dag_id=dag.dag_id, session=session) @@ -6574,7 +6583,7 @@ def test_task_instance_heartbeat_timeout_message(self, session, create_dagrun): dagbag = DagBag(dagfile) dag = dagbag.get_dag("example_branch_operator") scheduler_dag = sync_dag_to_db(dag, session=session) - session.query(Job).delete() + session.execute(delete(Job)) data_interval = infer_automated_data_interval(scheduler_dag.timetable, DEFAULT_LOGICAL_DATE) dag_run = create_dagrun( @@ -6746,7 +6755,7 @@ def test_should_mark_empty_task_as_success(self, testing_dag_bundle): self.job_runner._schedule_dag_run(dr, session) session.expunge_all() with create_session() as session: - tis = session.query(TaskInstance).all() + tis = session.scalars(select(TaskInstance)).all() dags = self.job_runner.scheduler_dag_bag._dags.values() assert [dag.dag_id for dag in dags] == ["test_only_empty_tasks"] @@ -6774,7 +6783,7 @@ def test_should_mark_empty_task_as_success(self, testing_dag_bundle): self.job_runner._schedule_dag_run(dr, session) session.expunge_all() with create_session() as session: - tis = session.query(TaskInstance).all() + tis = session.scalars(select(TaskInstance)).all() assert len(tis) == 6 assert { @@ -6840,9 +6849,11 @@ def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): # Check catchup worked correctly by ensuring logical_date is quite new # Our dag is a daily dag assert ( - session.query(DagRun.logical_date) - .filter(DagRun.logical_date != DEFAULT_DATE) # exclude the first run - .scalar() + session.scalar( + select(DagRun.logical_date).where( + DagRun.logical_date != DEFAULT_DATE + ) # exclude the first run + ) ) > (timezone.utcnow() - timedelta(days=2)) def test_update_dagrun_state_for_paused_dag(self, dag_maker, session): @@ -7071,7 +7082,7 @@ def test_misconfigured_dags_doesnt_crash_scheduler(self, mock_create, session, d assert "Failed creating DagRun for testdag1" in caplog.text def test_activate_referenced_assets_with_no_existing_warning(self, session, testing_dag_bundle): - dag_warnings = session.query(DagWarning).all() + dag_warnings = session.scalars(select(DagWarning)).all() assert dag_warnings == [] dag_id1 = "test_asset_dag1" @@ -7996,9 +8007,9 @@ def test_start_queued_dagruns_uses_latest_max_active_runs_from_dag_model(self, d self.job_runner._create_dag_runs([dag_model], session) # Verify SerializedDAG has max_active_runs=1 - dag_run_1 = ( - session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).order_by(DagRun.logical_date).first() - ) + dag_run_1 = session.scalars( + select(DagRun).where(DagRun.dag_id == dag.dag_id).order_by(DagRun.logical_date) + ).first() assert dag_run_1 is not None serialized_dag = self.job_runner.scheduler_dag_bag.get_dag_for_run(dag_run_1, session=session) assert serialized_dag is not None @@ -8028,15 +8039,15 @@ def test_start_queued_dagruns_uses_latest_max_active_runs_from_dag_model(self, d session.refresh(dag_run_2) # Verify we have 1 running and 1 queued - running_count = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.state == DagRunState.RUNNING) - .count() + running_count = session.scalar( + select(func.count()) + .select_from(DagRun) + .where(DagRun.dag_id == dag.dag_id, DagRun.state == DagRunState.RUNNING) ) - queued_count = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.state == DagRunState.QUEUED) - .count() + queued_count = session.scalar( + select(func.count()) + .select_from(DagRun) + .where(DagRun.dag_id == dag.dag_id, DagRun.state == DagRunState.QUEUED) ) assert running_count == 1 assert queued_count == 1 @@ -8061,10 +8072,10 @@ def test_start_queued_dagruns_uses_latest_max_active_runs_from_dag_model(self, d ) # Verify we now have 2 running dag runs - running_count = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.state == DagRunState.RUNNING) - .count() + running_count = session.scalar( + select(func.count()) + .select_from(DagRun) + .where(DagRun.dag_id == dag.dag_id, DagRun.state == DagRunState.RUNNING) ) assert running_count == 2