From a07620423604e2f7440cdc26db3bc9a81d252f4d Mon Sep 17 00:00:00 2001
From: xjules <jparu@equinor.com>
Date: Fri, 4 Oct 2024 12:58:10 +0200
Subject: [PATCH] Make sure that the last evaluator event is EETerminated

This introduces EnsembleEvaluator._complete_batch event, which is set
every time a batch is produced. When closing the ee server, we make sure
that the _complete_batch is set and thus preventing an events to be
lost.

This behaviour is then tested in
test_ensure_multi_level_events_in_order, which was updated
correspondighly.
---
 src/ert/ensemble_evaluator/evaluator.py                  | 8 ++++++--
 .../ensemble_evaluator/test_ensemble_evaluator.py        | 9 +++++++++
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py
index 4afe8e97971..0d21c54a72b 100644
--- a/src/ert/ensemble_evaluator/evaluator.py
+++ b/src/ert/ensemble_evaluator/evaluator.py
@@ -87,6 +87,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
         ] = asyncio.Queue()
         self._max_batch_size: int = 500
         self._batching_interval: int = 2
+        self._complete_batch: asyncio.Event = asyncio.Event()
 
     async def _publisher(self) -> None:
         while True:
@@ -140,6 +141,7 @@ def set_event_handler(event_types: Set[Type[Event]], func: Any) -> None:
                 and asyncio.get_running_loop().time() - start_time
                 < self._batching_interval
             ):
+                self._complete_batch.clear()
                 try:
                     event = await asyncio.wait_for(self._events.get(), timeout=0.1)
                     function = event_handler[type(event)]
@@ -147,6 +149,7 @@ def set_event_handler(event_types: Set[Type[Event]], func: Any) -> None:
                     self._events.task_done()
                 except asyncio.TimeoutError:
                     continue
+            self._complete_batch.set()
             await self._batch_processing_queue.put(batch)
 
     async def _fm_handler(
@@ -329,10 +332,11 @@ async def _server(self) -> None:
 
             logger.debug("Sending termination-message to clients...")
 
-            event = EETerminated(ensemble=self._ensemble.id_)
-            await self._events_to_send.put(event)
             await self._events.join()
+            await self._complete_batch.wait()
             await self._batch_processing_queue.join()
+            event = EETerminated(ensemble=self._ensemble.id_)
+            await self._events_to_send.put(event)
             await self._events_to_send.join()
         logger.debug("Async server exiting.")
 
diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py
index 668527b6017..66df469f439 100644
--- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py
+++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py
@@ -29,6 +29,7 @@
 from ert.ensemble_evaluator.evaluator import detect_overspent_cpu
 from ert.ensemble_evaluator.state import (
     ENSEMBLE_STATE_STARTED,
+    ENSEMBLE_STATE_STOPPED,
     ENSEMBLE_STATE_UNKNOWN,
     FORWARD_MODEL_STATE_FAILURE,
     FORWARD_MODEL_STATE_FINISHED,
@@ -97,6 +98,7 @@ async def mock_done_prematurely(message, *args, **kwargs):
 async def evaluator_to_use_fixture(make_ee_config):
     ensemble = TestEnsemble(0, 2, 2, id_="0")
     evaluator = EnsembleEvaluator(ensemble, make_ee_config())
+    evaluator._batching_interval = 0.5  # batching can be faster for tests
     run_task = asyncio.create_task(evaluator.run_and_get_successful_realizations())
     await evaluator._server_started.wait()
     yield evaluator
@@ -469,11 +471,18 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use):
         # about realizations, the state of the ensemble up until that point
         # should be not final (i.e. not cancelled, stopped, failed).
         ensemble_state = snapshot_event.snapshot.get("status")
+        final_event_was_EETerminated = False
         async for event in monitor.track():
+            if isinstance(event, EETerminated):
+                final_event_was_EETerminated = True
+                assert ensemble_state == ENSEMBLE_STATE_STOPPED
             if type(event) in [EESnapshot, EESnapshotUpdate]:
+                # if we get an snapshot event than this need to be valid
+                assert final_event_was_EETerminated == False
                 if "reals" in event.snapshot:
                     assert ensemble_state == ENSEMBLE_STATE_STARTED
                 ensemble_state = event.snapshot.get("status", ensemble_state)
+        assert final_event_was_EETerminated == True
 
 
 @given(