From a76a0df969cd6e354768811886754eec6ed0d6ea Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Thu, 29 Jul 2021 11:05:02 -0600 Subject: [PATCH] Fix race condition with dagrun callbacks (#16741) Instead of immediately sending callbacks to be processed, wait until after we commit so the dagrun.end_date is guaranteed to be there when the callback runs. (cherry picked from commit fb3031acf51f95384154143553aac1a40e568ebf) --- airflow/jobs/scheduler_job.py | 18 ++++-- tests/dag_processing/test_processor.py | 20 ++++--- tests/jobs/test_scheduler_job.py | 80 ++++++++++++++++++++++---- 3 files changed, 94 insertions(+), 24 deletions(-) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 7a37b254219d1..18ec981acf6f7 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -888,6 +888,7 @@ def _do_scheduling(self, session) -> int: # Bulk fetch the currently active dag runs for the dags we are # examining, rather than making one query per DagRun + callback_tuples = [] for dag_run in dag_runs: # Use try_except to not stop the Scheduler when a Serialized DAG is not found # This takes care of Dynamic DAGs especially @@ -896,13 +897,18 @@ def _do_scheduling(self, session) -> int: # But this would take care of the scenario when the Scheduler is restarted after DagRun is # created and the DAG is deleted / renamed try: - self._schedule_dag_run(dag_run, session) + callback_to_run = self._schedule_dag_run(dag_run, session) + callback_tuples.append((dag_run, callback_to_run)) except SerializedDagNotFound: self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id) continue guard.commit() + # Send the callbacks after we commit to ensure the context is up to date when it gets run + for dag_run, callback_to_run in callback_tuples: + self._send_dag_callbacks_to_processor(dag_run, callback_to_run) + # Without this, the session has an invalid view of the DB session.expunge_all() # END: schedule TIs @@ -1064,12 +1070,12 @@ def _schedule_dag_run( self, dag_run: DagRun, session: Session, - ) -> int: + ) -> Optional[DagCallbackRequest]: """ Make scheduling decisions about an individual dag run :param dag_run: The DagRun to schedule - :return: Number of tasks scheduled + :return: Callback that needs to be executed """ dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) @@ -1116,13 +1122,13 @@ def _schedule_dag_run( # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) - self._send_dag_callbacks_to_processor(dag_run, callback_to_run) - # This will do one query per dag run. We "could" build up a complex # query to update all the TIs across all the execution dates and dag # IDs in a single query, but it turns out that can be _very very slow_ # see #11147/commit ee90807ac for more details - return dag_run.schedule_tis(schedulable_tis, session) + dag_run.schedule_tis(schedulable_tis, session) + + return callback_to_run @provide_session def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py index feb34974c0824..b7f8e7d48b755 100644 --- a/tests/dag_processing/test_processor.py +++ b/tests/dag_processing/test_processor.py @@ -115,6 +115,10 @@ def setUpClass(cls): non_serialized_dagbag.sync_to_db() cls.dagbag = DagBag(read_dags_from_db=True) + @staticmethod + def assert_scheduled_ti_count(session, count): + assert count == session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() + def test_dag_file_processor_sla_miss_callback(self): """ Test that the dag file processor calls the sla miss callback @@ -387,8 +391,8 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ ti.start_date = start_date ti.end_date = end_date - count = self.scheduler_job._schedule_dag_run(dr, session) - assert count == 1 + self.scheduler_job._schedule_dag_run(dr, session) + self.assert_scheduled_ti_count(session, 1) session.refresh(ti) assert ti.state == State.SCHEDULED @@ -444,8 +448,8 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( ti.start_date = start_date ti.end_date = end_date - count = self.scheduler_job._schedule_dag_run(dr, session) - assert count == 1 + self.scheduler_job._schedule_dag_run(dr, session) + self.assert_scheduled_ti_count(session, 1) session.refresh(ti) assert ti.state == State.SCHEDULED @@ -504,8 +508,8 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, ti.start_date = start_date ti.end_date = end_date - count = self.scheduler_job._schedule_dag_run(dr, session) - assert count == 2 + self.scheduler_job._schedule_dag_run(dr, session) + self.assert_scheduled_ti_count(session, 2) session.refresh(tis[0]) session.refresh(tis[1]) @@ -547,9 +551,9 @@ def test_scheduler_job_add_new_task(self): BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') SerializedDagModel.write_dag(dag=dag) - scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session) + self.scheduler_job._schedule_dag_run(dr, session) + self.assert_scheduled_ti_count(session, 2) session.flush() - assert scheduled_tis == 2 drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 0ee6f5f456018..5de365ecd50af 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -1710,10 +1710,11 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): ti = dr.get_task_instance('dummy') ti.set_state(state, session) - self.scheduler_job._schedule_dag_run(dr, session) + with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): + self.scheduler_job._do_scheduling(session) expected_callback = DagCallbackRequest( - full_filepath=dr.dag.fileloc, + full_filepath=dag.fileloc, dag_id=dr.dag_id, is_failure_callback=bool(state == State.FAILED), execution_date=dr.execution_date, @@ -1729,6 +1730,64 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): session.rollback() session.close() + def test_dagrun_callbacks_commited_before_sent(self): + """ + Tests that before any callbacks are sent to the processor, the session is committed. This ensures + that the dagrun details are up to date when the callbacks are run. + """ + dag = DAG(dag_id='test_dagrun_callbacks_commited_before_sent', start_date=DEFAULT_DATE) + DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.Mock() + self.scheduler_job._send_dag_callbacks_to_processor = mock.Mock() + self.scheduler_job._schedule_dag_run = mock.Mock() + + # Sync DAG into DB + with mock.patch.object(settings, "STORE_DAG_CODE", False): + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + self.scheduler_job.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + # Create DagRun + self.scheduler_job._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + ti = dr.get_task_instance('dummy') + ti.set_state(State.SUCCESS, session) + + with mock.patch.object(settings, "USE_JOB_SCHEDULE", False), mock.patch( + "airflow.jobs.scheduler_job.prohibit_commit" + ) as mock_gaurd: + mock_gaurd.return_value.__enter__.return_value.commit.side_effect = session.commit + + def mock_schedule_dag_run(*args, **kwargs): + mock_gaurd.reset_mock() + return None + + def mock_send_dag_callbacks_to_processor(*args, **kwargs): + mock_gaurd.return_value.__enter__.return_value.commit.assert_called_once() + + self.scheduler_job._send_dag_callbacks_to_processor.side_effect = ( + mock_send_dag_callbacks_to_processor + ) + self.scheduler_job._schedule_dag_run.side_effect = mock_schedule_dag_run + + self.scheduler_job._do_scheduling(session) + + # Verify dag failure callback request is sent to file processor + self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once() + # and mock_send_dag_callbacks_to_processor has asserted the callback was sent after a commit + + session.rollback() + session.close() + @parameterized.expand([(State.SUCCESS,), (State.FAILED,)]) def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state): """ @@ -1765,10 +1824,15 @@ def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, sta ti = dr.get_task_instance('test_task') ti.set_state(state, session) - self.scheduler_job._schedule_dag_run(dr, session) + with mock.patch.object(settings, "USE_JOB_SCHEDULE", False): + self.scheduler_job._do_scheduling(session) # Verify Callback is not set (i.e is None) when no callbacks are set on DAG - self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None) + self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once() + call_args = self.scheduler_job._send_dag_callbacks_to_processor.call_args[0] + assert call_args[0].dag_id == dr.dag_id + assert call_args[0].execution_date == dr.execution_date + assert call_args[1] is None session.rollback() session.close() @@ -2411,12 +2475,10 @@ def test_verify_integrity_if_dag_not_changed(self): # Verify that DagRun.verify_integrity is not called with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity: - scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session) + self.scheduler_job._schedule_dag_run(dr, session) mock_verify_integrity.assert_not_called() session.flush() - assert scheduled_tis == 1 - tis_count = ( session.query(func.count(TaskInstance.task_id)) .filter( @@ -2474,11 +2536,9 @@ def test_verify_integrity_if_dag_changed(self): dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) assert dag_version_2 != dag_version_1 - scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session) + self.scheduler_job._schedule_dag_run(dr, session) session.flush() - assert scheduled_tis == 2 - drs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(drs) == 1 dr = drs[0]