From dbe80c89b2a99d6ab737f2c4146bf8f918034f0f Mon Sep 17 00:00:00 2001 From: Bowrna Date: Sun, 5 Jun 2022 20:01:03 +0530 Subject: [PATCH] Fix xfail test in test_scheduler.py (#23731) --- tests/jobs/test_scheduler_job.py | 46 ++++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fd32e6dd7dd6d..5837623bdacf2 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -4099,42 +4099,42 @@ def test_catchup_works_correctly(self, dag_maker): ) > (timezone.utcnow() - timedelta(days=2)) -@pytest.mark.xfail(reason="Work out where this goes") -def test_task_with_upstream_skip_process_task_instances(): +@pytest.mark.need_serialized_dag +def test_schedule_dag_run_with_upstream_skip(dag_maker, session): """ - Test if _process_task_instances puts a task instance into SKIPPED state if any of its + Test if _schedule_dag_run puts a task instance into SKIPPED state if any of its upstream tasks are skipped according to TriggerRuleDep. """ - clear_db_runs() - with DAG( - dag_id='test_task_with_upstream_skip_dag', start_date=DEFAULT_DATE, schedule_interval=None - ) as dag: + with dag_maker( + dag_id='test_task_with_upstream_skip_process_task_instances', + start_date=DEFAULT_DATE, + session=session, + ): dummy1 = EmptyOperator(task_id='dummy1') dummy2 = EmptyOperator(task_id="dummy2") dummy3 = EmptyOperator(task_id="dummy3") [dummy1, dummy2] >> dummy3 - # dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - dr = dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) + dr = dag_maker.create_dagrun(state=State.RUNNING) assert dr is not None - with create_session() as session: - tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} - # Set dummy1 to skipped and dummy2 to success. dummy3 remains as none. - tis[dummy1.task_id].state = State.SKIPPED - tis[dummy2.task_id].state = State.SUCCESS - assert tis[dummy3.task_id].state == State.NONE + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + # Set dummy1 to skipped and dummy2 to success. dummy3 remains as none. + tis[dummy1.task_id].state = State.SKIPPED + tis[dummy2.task_id].state = State.SUCCESS + assert tis[dummy3.task_id].state == State.NONE + session.flush() # dag_runs = DagRun.find(dag_id='test_task_with_upstream_skip_dag') # dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - - with create_session() as session: - tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} - assert tis[dummy1.task_id].state == State.SKIPPED - assert tis[dummy2.task_id].state == State.SUCCESS - # dummy3 should be skipped because dummy1 is skipped. - assert tis[dummy3.task_id].state == State.SKIPPED + scheduler_job = SchedulerJob(subdir=os.devnull) + scheduler_job._schedule_dag_run(dr, session) + session.flush() + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + assert tis[dummy1.task_id].state == State.SKIPPED + assert tis[dummy2.task_id].state == State.SUCCESS + # dummy3 should be skipped because dummy1 is skipped. + assert tis[dummy3.task_id].state == State.SKIPPED class TestSchedulerJobQueriesCount: