From 9a89ae3b0d33da721d942ba604bcd17684be473e Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Wed, 27 Nov 2024 13:59:35 +0100 Subject: [PATCH] Have fm_runner's event reporter shutdown gracefully This commit fixes the issue where the logs would be spammed with errors related to the websocket client being forcefully shut down before closing the connection. It also fixes the issue where the fm_runner was not killing the running forward models when sigterm was signaled --- src/_ert/forward_model_runner/cli.py | 20 ++++++++++++++----- .../forward_model_runner/reporting/event.py | 17 +++++++++------- .../forward_model_runner/test_job_dispatch.py | 9 ++++++--- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/_ert/forward_model_runner/cli.py b/src/_ert/forward_model_runner/cli.py index a41b0ca4b16..0398ca68b73 100644 --- a/src/_ert/forward_model_runner/cli.py +++ b/src/_ert/forward_model_runner/cli.py @@ -137,7 +137,7 @@ def main(args): ) job_runner = ForwardModelRunner(jobs_data) - + signal.signal(signal.SIGTERM, lambda _, __: _stop_reporters_and_sigkill(reporters)) for job_status in job_runner.run(parsed_args.job): logger.info(f"Job status: {job_status}") for reporter in reporters: @@ -147,9 +147,19 @@ def main(args): print( f"job_dispatch failed due to {oserror}. Stopping and cleaning up." ) - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGKILL) + _stop_reporters_and_sigkill(reporters) if isinstance(job_status, Finish) and not job_status.success(): - pgid = os.getpgid(os.getpid()) - os.killpg(pgid, signal.SIGKILL) + _stop_reporters_and_sigkill(reporters) + + +def _stop_reporters_and_sigkill(reporters): + _stop_reporters(reporters) + pgid = os.getpgid(os.getpid()) + os.killpg(pgid, signal.SIGKILL) + + +def _stop_reporters(reporters: typing.Iterable[reporting.Reporter]) -> None: + for reporter in reporters: + if isinstance(reporter, reporting.Event): + reporter.stop() diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 8bf13dee238..f4f140232e1 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -83,6 +83,15 @@ def __init__(self, evaluator_url, token=None, cert_path=None): # seconds to timeout the reporter the thread after Finish() was received self._reporter_timeout = 60 + def stop(self) -> None: + self._event_queue.put(Event._sentinel) + with self._timestamp_lock: + self._timeout_timestamp = datetime.now() + timedelta( + seconds=self._reporter_timeout + ) + if self._event_publisher_thread.is_alive(): + self._event_publisher_thread.join() + def _event_publisher(self): logger.debug("Publishing event.") with Client( @@ -178,13 +187,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): self._dump_event(event) def _finished_handler(self, _): - self._event_queue.put(Event._sentinel) - with self._timestamp_lock: - self._timeout_timestamp = datetime.now() + timedelta( - seconds=self._reporter_timeout - ) - if self._event_publisher_thread.is_alive(): - self._event_publisher_thread.join() + self.stop() def _checksum_handler(self, msg: Checksum): fm_checksum = ForwardModelStepChecksum( diff --git a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py index 0befe45c5a9..345f1f82675 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py +++ b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py @@ -306,9 +306,7 @@ def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, ca } ) - with _mock_ws_thread("localhost", unused_tcp_port, []): - thread = ErtThread(target=main, args=[["script.py", str(tmp_path)]]) - thread.start() + def create_jobs_file_after_lock(): _wait_until( lambda: f"Could not find file {JOBS_FILE}, retrying" in caplog.text, 2, @@ -316,6 +314,11 @@ def test_retry_of_jobs_json_file_read(unused_tcp_port, tmp_path, monkeypatch, ca ) (tmp_path / JOBS_FILE).write_text(jobs_json) lock.release() + + with _mock_ws_thread("localhost", unused_tcp_port, []): + thread = ErtThread(target=create_jobs_file_after_lock) + thread.start() + main(args=["script.py", str(tmp_path)]) thread.join()