diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py index 5080602a23af0..0b916f10fd353 100644 --- a/airflow/executors/dask_executor.py +++ b/airflow/executors/dask_executor.py @@ -100,7 +100,7 @@ def _process_future(self, future: Future) -> None: self.futures.pop(future) def sync(self) -> None: - if not self.futures: + if self.futures is None: raise AirflowException(NOT_STARTED_MESSAGE) # make a copy so futures can be popped during iteration for future in self.futures.copy(): @@ -109,14 +109,14 @@ def sync(self) -> None: def end(self) -> None: if not self.client: raise AirflowException(NOT_STARTED_MESSAGE) - if not self.futures: + if self.futures is None: raise AirflowException(NOT_STARTED_MESSAGE) self.client.cancel(list(self.futures.keys())) for future in as_completed(self.futures.copy()): self._process_future(future) def terminate(self): - if not self.futures: + if self.futures is None: raise AirflowException(NOT_STARTED_MESSAGE) self.client.cancel(self.futures.keys()) self.end() diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 176a3ab87dadd..35dd8970e5bdf 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -97,11 +97,22 @@ def signal_handler(signum, frame): try: self.task_runner.start() - heartbeat_time_limit = conf.getint('scheduler', - 'scheduler_zombie_task_threshold') + heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold') + while True: - # Monitor the task to see if it's done - return_code = self.task_runner.return_code() + # Monitor the task to see if it's done. Wait in a syscall + # (`os.wait`) for as long as possible so we notice the + # subprocess finishing as quick as we can + max_wait_time = max( + 0, # Make sure this value is never negative, + min( + (heartbeat_time_limit - + (timezone.utcnow() - self.latest_heartbeat).total_seconds() * 0.75), + self.heartrate, + ) + ) + + return_code = self.task_runner.return_code(timeout=max_wait_time) if return_code is not None: self.log.info("Task exited with return code %s", return_code) return diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index daeb2675cbf7a..2ecbae0a71972 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -108,7 +108,10 @@ def terminate(self): if self.process is None: return - if self.process.is_running(): + # Reap the child process - it may already be finished + _ = self.return_code(timeout=0) + + if self.process and self.process.is_running(): rcs = reap_process_group(self.process.pid, self.log) self._rc = rcs.get(self.process.pid) diff --git a/tests/dags/test_heartbeat_failed_fast.py b/tests/dags/test_heartbeat_failed_fast.py index 27e720e460456..5c74d49a578d6 100644 --- a/tests/dags/test_heartbeat_failed_fast.py +++ b/tests/dags/test_heartbeat_failed_fast.py @@ -30,5 +30,5 @@ dag = DAG(dag_id='test_heartbeat_failed_fast', default_args=args) task = BashOperator( task_id='test_heartbeat_failed_fast_op', - bash_command='sleep 5', + bash_command='sleep 7', dag=dag) diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 104d3e233b66b..e113c0b30d2a6 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -132,12 +132,10 @@ def test_localtaskjob_heartbeat(self, mock_pid): mock_pid.return_value = 2 self.assertRaises(AirflowException, job1.heartbeat_callback) - @patch('os.getpid') - def test_heartbeat_failed_fast(self, mock_getpid): + def test_heartbeat_failed_fast(self): """ Test that task heartbeat will sleep when it fails fast """ - mock_getpid.return_value = 1 self.mock_base_job_sleep.side_effect = time.sleep with create_session() as session: