From a07620423604e2f7440cdc26db3bc9a81d252f4d Mon Sep 17 00:00:00 2001 From: xjules <jparu@equinor.com> Date: Fri, 4 Oct 2024 12:58:10 +0200 Subject: [PATCH] Make sure that the last evaluator event is EETerminated This introduces EnsembleEvaluator._complete_batch event, which is set every time a batch is produced. When closing the ee server, we make sure that the _complete_batch is set and thus preventing an events to be lost. This behaviour is then tested in test_ensure_multi_level_events_in_order, which was updated correspondighly. --- src/ert/ensemble_evaluator/evaluator.py | 8 ++++++-- .../ensemble_evaluator/test_ensemble_evaluator.py | 9 +++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 4afe8e97971..0d21c54a72b 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -87,6 +87,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): ] = asyncio.Queue() self._max_batch_size: int = 500 self._batching_interval: int = 2 + self._complete_batch: asyncio.Event = asyncio.Event() async def _publisher(self) -> None: while True: @@ -140,6 +141,7 @@ def set_event_handler(event_types: Set[Type[Event]], func: Any) -> None: and asyncio.get_running_loop().time() - start_time < self._batching_interval ): + self._complete_batch.clear() try: event = await asyncio.wait_for(self._events.get(), timeout=0.1) function = event_handler[type(event)] @@ -147,6 +149,7 @@ def set_event_handler(event_types: Set[Type[Event]], func: Any) -> None: self._events.task_done() except asyncio.TimeoutError: continue + self._complete_batch.set() await self._batch_processing_queue.put(batch) async def _fm_handler( @@ -329,10 +332,11 @@ async def _server(self) -> None: logger.debug("Sending termination-message to clients...") - event = EETerminated(ensemble=self._ensemble.id_) - await self._events_to_send.put(event) await self._events.join() + await self._complete_batch.wait() await self._batch_processing_queue.join() + event = EETerminated(ensemble=self._ensemble.id_) + await self._events_to_send.put(event) await self._events_to_send.join() logger.debug("Async server exiting.") diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index 668527b6017..66df469f439 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -29,6 +29,7 @@ from ert.ensemble_evaluator.evaluator import detect_overspent_cpu from ert.ensemble_evaluator.state import ( ENSEMBLE_STATE_STARTED, + ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_UNKNOWN, FORWARD_MODEL_STATE_FAILURE, FORWARD_MODEL_STATE_FINISHED, @@ -97,6 +98,7 @@ async def mock_done_prematurely(message, *args, **kwargs): async def evaluator_to_use_fixture(make_ee_config): ensemble = TestEnsemble(0, 2, 2, id_="0") evaluator = EnsembleEvaluator(ensemble, make_ee_config()) + evaluator._batching_interval = 0.5 # batching can be faster for tests run_task = asyncio.create_task(evaluator.run_and_get_successful_realizations()) await evaluator._server_started.wait() yield evaluator @@ -469,11 +471,18 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use): # about realizations, the state of the ensemble up until that point # should be not final (i.e. not cancelled, stopped, failed). ensemble_state = snapshot_event.snapshot.get("status") + final_event_was_EETerminated = False async for event in monitor.track(): + if isinstance(event, EETerminated): + final_event_was_EETerminated = True + assert ensemble_state == ENSEMBLE_STATE_STOPPED if type(event) in [EESnapshot, EESnapshotUpdate]: + # if we get an snapshot event than this need to be valid + assert final_event_was_EETerminated == False if "reals" in event.snapshot: assert ensemble_state == ENSEMBLE_STATE_STARTED ensemble_state = event.snapshot.get("status", ensemble_state) + assert final_event_was_EETerminated == True @given(