Skip to content

Commit

Permalink
Make sure that the last evaluator event is EETerminated
Browse files Browse the repository at this point in the history
This introduces EnsembleEvaluator._complete_batch event, which is set
every time a batch is produced. When closing the ee server, we make sure
that the _complete_batch is set and thus preventing an events to be
lost.

This behaviour is then tested in
test_ensure_multi_level_events_in_order, which was updated
correspondighly.
  • Loading branch information
xjules committed Oct 8, 2024
1 parent c258c76 commit a076204
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
] = asyncio.Queue()
self._max_batch_size: int = 500
self._batching_interval: int = 2
self._complete_batch: asyncio.Event = asyncio.Event()

async def _publisher(self) -> None:
while True:
Expand Down Expand Up @@ -140,13 +141,15 @@ def set_event_handler(event_types: Set[Type[Event]], func: Any) -> None:
and asyncio.get_running_loop().time() - start_time
< self._batching_interval
):
self._complete_batch.clear()
try:
event = await asyncio.wait_for(self._events.get(), timeout=0.1)
function = event_handler[type(event)]
batch.append((function, event))
self._events.task_done()
except asyncio.TimeoutError:
continue
self._complete_batch.set()
await self._batch_processing_queue.put(batch)

async def _fm_handler(
Expand Down Expand Up @@ -329,10 +332,11 @@ async def _server(self) -> None:

logger.debug("Sending termination-message to clients...")

event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events.join()
await self._complete_batch.wait()
await self._batch_processing_queue.join()
event = EETerminated(ensemble=self._ensemble.id_)
await self._events_to_send.put(event)
await self._events_to_send.join()
logger.debug("Async server exiting.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ert.ensemble_evaluator.evaluator import detect_overspent_cpu
from ert.ensemble_evaluator.state import (
ENSEMBLE_STATE_STARTED,
ENSEMBLE_STATE_STOPPED,
ENSEMBLE_STATE_UNKNOWN,
FORWARD_MODEL_STATE_FAILURE,
FORWARD_MODEL_STATE_FINISHED,
Expand Down Expand Up @@ -97,6 +98,7 @@ async def mock_done_prematurely(message, *args, **kwargs):
async def evaluator_to_use_fixture(make_ee_config):
ensemble = TestEnsemble(0, 2, 2, id_="0")
evaluator = EnsembleEvaluator(ensemble, make_ee_config())
evaluator._batching_interval = 0.5 # batching can be faster for tests
run_task = asyncio.create_task(evaluator.run_and_get_successful_realizations())
await evaluator._server_started.wait()
yield evaluator
Expand Down Expand Up @@ -469,11 +471,18 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use):
# about realizations, the state of the ensemble up until that point
# should be not final (i.e. not cancelled, stopped, failed).
ensemble_state = snapshot_event.snapshot.get("status")
final_event_was_EETerminated = False
async for event in monitor.track():
if isinstance(event, EETerminated):
final_event_was_EETerminated = True
assert ensemble_state == ENSEMBLE_STATE_STOPPED
if type(event) in [EESnapshot, EESnapshotUpdate]:
# if we get an snapshot event than this need to be valid
assert final_event_was_EETerminated == False
if "reals" in event.snapshot:
assert ensemble_state == ENSEMBLE_STATE_STARTED
ensemble_state = event.snapshot.get("status", ensemble_state)
assert final_event_was_EETerminated == True


@given(
Expand Down

0 comments on commit a076204

Please sign in to comment.