From 4fad2f9d38fceca71cdcda9d085e3fcd19df058a Mon Sep 17 00:00:00 2001 From: Sondre Sortland Date: Tue, 12 Mar 2024 15:02:45 +0100 Subject: [PATCH] Remove tracker and tracker_worker --- src/ert/cli/main.py | 20 +- src/ert/cli/model_factory.py | 43 ++- src/ert/cli/monitor.py | 12 +- src/ert/ensemble_evaluator/__init__.py | 2 - src/ert/ensemble_evaluator/evaluator.py | 8 +- .../ensemble_evaluator/evaluator_tracker.py | 240 ------------ .../{tracker_worker.py => queue_emitter.py} | 23 +- src/ert/gui/simulation/run_dialog.py | 38 +- src/ert/gui/simulation/simulation_panel.py | 5 +- .../tools/run_analysis/run_analysis_tool.py | 4 +- src/ert/run_models/base_run_model.py | 215 ++++++++++- src/ert/run_models/ensemble_experiment.py | 5 +- src/ert/run_models/ensemble_smoother.py | 7 +- src/ert/run_models/evaluate_ensemble.py | 6 + .../run_models/iterated_ensemble_smoother.py | 7 +- .../run_models/multiple_data_assimilation.py | 7 +- src/ert/run_models/single_test_run.py | 9 +- .../status/test_tracking_integration.py | 69 ++-- tests/performance_tests/test_snapshot.py | 6 - tests/unit_tests/cli/test_model_factory.py | 18 +- tests/unit_tests/cli/test_model_hook_order.py | 3 + tests/unit_tests/cli/test_run_context.py | 1 + .../test_ensemble_evaluator.py | 10 +- .../test_ensemble_legacy.py | 6 +- .../test_evaluator_tracker.py | 356 ------------------ .../gui/simulation/test_run_dialog.py | 74 ++-- .../run_models/test_ensemble_experiment.py | 2 +- 27 files changed, 415 insertions(+), 781 deletions(-) delete mode 100644 src/ert/ensemble_evaluator/evaluator_tracker.py rename src/ert/gui/simulation/{tracker_worker.py => queue_emitter.py} (68%) delete mode 100644 tests/unit_tests/ensemble_evaluator/test_evaluator_tracker.py diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 97a98386c1d..a5a64b313a6 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -1,7 +1,10 @@ #!/usr/bin/env python +from __future__ import annotations + import contextlib import logging import os +import queue import sys from typing import Any, TextIO @@ -19,8 +22,9 @@ from ert.cli.workflow import execute_workflow from ert.config import ErtConfig, QueueSystem from ert.enkf_main import EnKFMain -from ert.ensemble_evaluator import EvaluatorServerConfig, EvaluatorTracker +from ert.ensemble_evaluator import EvaluatorServerConfig from ert.namespace import Namespace +from ert.run_models.base_run_model import StatusEvents from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config @@ -76,11 +80,13 @@ def run_cli(args: Namespace, _: Any = None) -> None: execute_workflow(ert, storage, args.name) return + status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue() try: model = create_model( ert_config, storage, args, + status_queue, ) except ValueError as e: raise ErtCliError(e) from e @@ -106,11 +112,6 @@ def run_cli(args: Namespace, _: Any = None) -> None: target=model.start_simulations_thread, args=(evaluator_server_config,), ) - thread.start() - - tracker = EvaluatorTracker( - model, ee_con_info=evaluator_server_config.get_connection_info() - ) with contextlib.ExitStack() as exit_stack: out: TextIO @@ -121,13 +122,12 @@ def run_cli(args: Namespace, _: Any = None) -> None: else: out = sys.stderr monitor = Monitor(out=out, color_always=args.color_always) - + thread.start() try: - monitor.monitor(tracker.track()) + monitor.monitor(status_queue) except (SystemExit, KeyboardInterrupt, OSError): - # _base_service.py translates CTRL-c to OSError print("\nKilling simulations...") - tracker.request_termination() + model.cancel() thread.join() storage.close() diff --git a/src/ert/cli/model_factory.py b/src/ert/cli/model_factory.py index b6a0c51ddcb..648dcab197f 100644 --- a/src/ert/cli/model_factory.py +++ b/src/ert/cli/model_factory.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from queue import SimpleQueue from typing import TYPE_CHECKING, Tuple import numpy as np @@ -38,6 +39,7 @@ import numpy.typing as npt from ert.namespace import Namespace + from ert.run_models.base_run_model import StatusEvents from ert.storage import Storage @@ -45,6 +47,7 @@ def create_model( config: ErtConfig, storage: Storage, args: Namespace, + status_queue: SimpleQueue[StatusEvents], ) -> BaseRunModel: logger = logging.getLogger(__name__) logger.info( @@ -57,18 +60,22 @@ def create_model( update_settings = config.analysis_config.observation_settings if args.mode == TEST_RUN_MODE: - return _setup_single_test_run(config, storage, args) + return _setup_single_test_run(config, storage, args, status_queue) elif args.mode == ENSEMBLE_EXPERIMENT_MODE: - return _setup_ensemble_experiment(config, storage, args) + return _setup_ensemble_experiment(config, storage, args, status_queue) elif args.mode == EVALUATE_ENSEMBLE_MODE: - return _setup_evaluate_ensemble(config, storage, args) + return _setup_evaluate_ensemble(config, storage, args, status_queue) elif args.mode == ENSEMBLE_SMOOTHER_MODE: - return _setup_ensemble_smoother(config, storage, args, update_settings) + return _setup_ensemble_smoother( + config, storage, args, update_settings, status_queue + ) elif args.mode == ES_MDA_MODE: - return _setup_multiple_data_assimilation(config, storage, args, update_settings) + return _setup_multiple_data_assimilation( + config, storage, args, update_settings, status_queue + ) elif args.mode == ITERATIVE_ENSEMBLE_SMOOTHER_MODE: return _setup_iterative_ensemble_smoother( - config, storage, args, update_settings + config, storage, args, update_settings, status_queue ) else: @@ -76,7 +83,10 @@ def create_model( def _setup_single_test_run( - config: ErtConfig, storage: Storage, args: Namespace + config: ErtConfig, + storage: Storage, + args: Namespace, + status_queue: SimpleQueue[StatusEvents], ) -> SingleTestRun: return SingleTestRun( SingleTestRunArguments( @@ -89,11 +99,15 @@ def _setup_single_test_run( ), config, storage, + status_queue, ) def _setup_ensemble_experiment( - config: ErtConfig, storage: Storage, args: Namespace + config: ErtConfig, + storage: Storage, + args: Namespace, + status_queue: SimpleQueue[StatusEvents], ) -> EnsembleExperiment: min_realizations_count = config.analysis_config.minimum_required_realizations active_realizations = _realizations(args, config.model_config.num_realizations) @@ -124,11 +138,15 @@ def _setup_ensemble_experiment( config, storage, config.queue_config, + status_queue=status_queue, ) def _setup_evaluate_ensemble( - config: ErtConfig, storage: Storage, args: Namespace + config: ErtConfig, + storage: Storage, + args: Namespace, + status_queue: SimpleQueue[StatusEvents], ) -> EvaluateEnsemble: min_realizations_count = config.analysis_config.minimum_required_realizations active_realizations = _realizations(args, config.model_config.num_realizations) @@ -154,6 +172,7 @@ def _setup_evaluate_ensemble( config, storage, config.queue_config, + status_queue=status_queue, ) @@ -162,6 +181,7 @@ def _setup_ensemble_smoother( storage: Storage, args: Namespace, update_settings: UpdateSettings, + status_queue: SimpleQueue[StatusEvents], ) -> EnsembleSmoother: return EnsembleSmoother( ESRunArguments( @@ -181,6 +201,7 @@ def _setup_ensemble_smoother( config.queue_config, es_settings=config.analysis_config.es_module, update_settings=update_settings, + status_queue=status_queue, ) @@ -208,6 +229,7 @@ def _setup_multiple_data_assimilation( storage: Storage, args: Namespace, update_settings: UpdateSettings, + status_queue: SimpleQueue[StatusEvents], ) -> MultipleDataAssimilation: restart_run, prior_ensemble = _determine_restart_info(args) @@ -231,6 +253,7 @@ def _setup_multiple_data_assimilation( config.queue_config, es_settings=config.analysis_config.es_module, update_settings=update_settings, + status_queue=status_queue, ) @@ -239,6 +262,7 @@ def _setup_iterative_ensemble_smoother( storage: Storage, args: Namespace, update_settings: UpdateSettings, + status_queue: SimpleQueue[StatusEvents], ) -> IteratedEnsembleSmoother: return IteratedEnsembleSmoother( SIESRunArguments( @@ -260,6 +284,7 @@ def _setup_iterative_ensemble_smoother( config.queue_config, config.analysis_config.ies_module, update_settings=update_settings, + status_queue=status_queue, ) diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index f051870079e..e256fde0e7e 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -1,7 +1,10 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import sys from datetime import datetime, timedelta -from typing import Dict, Iterator, Optional, TextIO, Tuple, Union +from queue import SimpleQueue +from typing import Dict, Optional, TextIO, Tuple from tqdm import tqdm @@ -18,6 +21,7 @@ FORWARD_MODEL_STATE_FAILURE, REAL_STATE_TO_COLOR, ) +from ert.run_models.base_run_model import StatusEvents from ert.shared.status.utils import format_running_time Color = Tuple[int, int, int] @@ -57,13 +61,15 @@ def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None # The dot adds no value without color, so remove it. self.dot = "" + self.done = False def monitor( self, - events: Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]], + event_queue: SimpleQueue[StatusEvents], ) -> None: self._start_time = datetime.now() - for event in events: + while True: + event = event_queue.get() if isinstance(event, FullSnapshotEvent): if event.snapshot is not None: self._snapshots[event.iteration] = event.snapshot diff --git a/src/ert/ensemble_evaluator/__init__.py b/src/ert/ensemble_evaluator/__init__.py index ea093217a94..46b711f38c1 100644 --- a/src/ert/ensemble_evaluator/__init__.py +++ b/src/ert/ensemble_evaluator/__init__.py @@ -7,7 +7,6 @@ ) from .config import EvaluatorServerConfig from .evaluator import EnsembleEvaluator -from .evaluator_tracker import EvaluatorTracker from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent from .monitor import Monitor from .snapshot import PartialSnapshot, Snapshot @@ -18,7 +17,6 @@ "EnsembleBuilder", "EnsembleEvaluator", "EvaluatorServerConfig", - "EvaluatorTracker", "ForwardModel", "FullSnapshotEvent", "Monitor", diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index a156ef946d0..b321221a46e 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -391,7 +391,7 @@ def _run_server(self, loop: asyncio.AbstractEventLoop) -> None: loop.run_until_complete(self.evaluator_server()) logger.debug("Server thread exiting.") - def _start_running(self) -> None: + def start_running(self) -> None: self._ws_thread.start() self._ensemble.evaluate(self._config) @@ -419,11 +419,11 @@ def _signal_cancel(self) -> None: logger.debug("Stopping current ensemble") self._loop.call_soon_threadsafe(self._stop) - def run_and_get_successful_realizations(self) -> List[int]: - self._start_running() - logger.debug("Started evaluator, joining until shutdown") + def join(self) -> None: self._ws_thread.join() logger.debug("Evaluator is done") + + def get_successful_realizations(self) -> List[int]: return self._ensemble.get_successful_realizations() @staticmethod diff --git a/src/ert/ensemble_evaluator/evaluator_tracker.py b/src/ert/ensemble_evaluator/evaluator_tracker.py deleted file mode 100644 index 56369724efb..00000000000 --- a/src/ert/ensemble_evaluator/evaluator_tracker.py +++ /dev/null @@ -1,240 +0,0 @@ -import asyncio -import contextlib -import copy -import logging -import queue -import time -from typing import TYPE_CHECKING, Dict, Iterator, Union - -from aiohttp import ClientError -from websockets.exceptions import ConnectionClosedError - -from _ert.async_utils import get_running_loop, new_event_loop -from _ert.threading import ErtThread -from ert.ensemble_evaluator.identifiers import ( - EVTYPE_EE_SNAPSHOT, - EVTYPE_EE_SNAPSHOT_UPDATE, - EVTYPE_EE_TERMINATED, - STATUS, -) -from ert.ensemble_evaluator.state import ( - ENSEMBLE_STATE_CANCELLED, - ENSEMBLE_STATE_FAILED, - ENSEMBLE_STATE_STOPPED, - REALIZATION_STATE_FAILED, - REALIZATION_STATE_FINISHED, -) - -from ._wait_for_evaluator import wait_for_evaluator -from .evaluator_connection_info import EvaluatorConnectionInfo -from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent -from .monitor import Monitor -from .snapshot import PartialSnapshot, Snapshot - -if TYPE_CHECKING: - from cloudevents.http.event import CloudEvent - - from ert.run_models import BaseRunModel - - -class OutOfOrderSnapshotUpdateException(ValueError): - pass - - -class EvaluatorTracker: - DONE = "done" - - def __init__( - self, - model: "BaseRunModel", - ee_con_info: EvaluatorConnectionInfo, - next_ensemble_evaluator_wait_time: int = 5, - ): - self._model = model - self._ee_con_info = ee_con_info - self._next_ensemble_evaluator_wait_time = next_ensemble_evaluator_wait_time - self._work_queue: "queue.Queue[Union[str, CloudEvent]]" = queue.Queue() - self._drainer_thread = ErtThread( - target=self._drain_monitor, - name="DrainerThread", - ) - self._drainer_thread.start() - self._iter_snapshot: Dict[int, Snapshot] = {} - - def _drain_monitor(self) -> None: - asyncio.set_event_loop(new_event_loop()) - drainer_logger = logging.getLogger("ert.ensemble_evaluator.drainer") - while not self._model.isFinished(): - try: - drainer_logger.debug("connecting to new monitor...") - with Monitor(self._ee_con_info) as monitor: - drainer_logger.debug("connected") - for event in monitor.track(): - if event["type"] in ( - EVTYPE_EE_SNAPSHOT, - EVTYPE_EE_SNAPSHOT_UPDATE, - ): - self._work_queue.put(event) - if event.data.get(STATUS) in [ - ENSEMBLE_STATE_STOPPED, - ENSEMBLE_STATE_FAILED, - ]: - drainer_logger.debug( - "observed evaluation stopped event, signal done" - ) - monitor.signal_done() - if event.data.get(STATUS) == ENSEMBLE_STATE_CANCELLED: - drainer_logger.debug( - "observed evaluation cancelled event, exit drainer" - ) - # Allow track() to emit an EndEvent. - self._work_queue.put(EvaluatorTracker.DONE) - return - elif event["type"] == EVTYPE_EE_TERMINATED: - drainer_logger.debug("got terminator event") - # This sleep needs to be there. Refer to issue #1250: `Authority - # on information about evaluations/experiments` - time.sleep(self._next_ensemble_evaluator_wait_time) - except (ConnectionRefusedError, ClientError) as e: - if not self._model.isFinished(): - drainer_logger.debug(f"connection refused: {e}") - except ConnectionClosedError as e: - # The monitor connection closed unexpectedly - drainer_logger.debug(f"connection closed error: {e}") - except BaseException: - drainer_logger.exception("unexpected error: ") - # We really don't know what happened... shut down - # the thread and get out of here. The monitor has - # been stopped by the ctx-mgr - self._work_queue.put(EvaluatorTracker.DONE) - self._work_queue.join() - return - drainer_logger.debug( - "observed that model was finished, waiting tasks completion..." - ) - # The model has finished, we indicate this by sending a DONE - self._work_queue.put(EvaluatorTracker.DONE) - self._work_queue.join() - drainer_logger.debug("tasks complete") - - def track( - self, - ) -> Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]]: - while True: - event = self._work_queue.get() - if isinstance(event, str): - with contextlib.suppress(GeneratorExit): - # consumers may exit at this point, make sure the last - # task is marked as done - if event == EvaluatorTracker.DONE: - yield EndEvent( - failed=self._model.hasRunFailed(), - failed_msg=self._model.getFailMessage(), - ) - self._work_queue.task_done() - break - if event["type"] == EVTYPE_EE_SNAPSHOT: - iter_ = event.data["iter"] - snapshot = Snapshot(event.data) - self._iter_snapshot[iter_] = snapshot - yield FullSnapshotEvent( - phase_name=self._model.getPhaseName(), - current_phase=self._model.currentPhase(), - total_phases=self._model.phaseCount(), - indeterminate=self._model.isIndeterminate(), - progress=self._progress(), - iteration=iter_, - snapshot=copy.deepcopy(snapshot), - ) - elif event["type"] == EVTYPE_EE_SNAPSHOT_UPDATE: - iter_ = event.data["iter"] - if iter_ not in self._iter_snapshot: - raise OutOfOrderSnapshotUpdateException( - f"got {EVTYPE_EE_SNAPSHOT_UPDATE} without having stored " - f"snapshot for iter {iter_}" - ) - partial = PartialSnapshot(self._iter_snapshot[iter_]).from_cloudevent( - event - ) - self._iter_snapshot[iter_].merge_event(partial) - yield SnapshotUpdateEvent( - phase_name=self._model.getPhaseName(), - current_phase=self._model.currentPhase(), - total_phases=self._model.phaseCount(), - indeterminate=self._model.isIndeterminate(), - progress=self._progress(), - iteration=iter_, - partial_snapshot=partial, - ) - self._work_queue.task_done() - - def is_finished(self) -> bool: - return not self._drainer_thread.is_alive() - - def _progress(self) -> float: - """Fraction of completed iterations over total iterations""" - - if self.is_finished(): - return 1.0 - elif not self._iter_snapshot: - return 0.0 - else: - # Calculate completed realizations - current_iter = max(list(self._iter_snapshot.keys())) - done_reals = 0 - all_reals = self._iter_snapshot[current_iter].reals - if not all_reals: - # Empty ensemble or all realizations deactivated - return 1.0 - for real in all_reals.values(): - if real.status in [ - REALIZATION_STATE_FINISHED, - REALIZATION_STATE_FAILED, - ]: - done_reals += 1 - real_progress = float(done_reals) / len(all_reals) - - return ( - (current_iter + real_progress) / self._model.phaseCount() - if self._model.phaseCount() != 1 - else real_progress - ) - - def _clear_work_queue(self) -> None: - with contextlib.suppress(queue.Empty): - while True: - self._work_queue.get_nowait() - self._work_queue.task_done() - - def request_termination(self) -> None: - logger = logging.getLogger("ert.ensemble_evaluator.tracker") - # There might be some situations where the - # evaluation is finished or the evaluation - # is yet to start when calling this function. - # In these cases the monitor is not started - # - # To avoid waiting too long we exit if we are not - # able to connect to the monitor after 2 tries - # - # See issue: https://github.com/equinor/ert/issues/1250 - # - try: - logger.debug("requesting termination...") - get_running_loop().run_until_complete( - wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) - ) - logger.debug("requested termination") - except ClientError as e: - logger.warning(f"{__name__} - exception {e}") - return - - with Monitor(self._ee_con_info) as monitor: - monitor.signal_cancel() - while self._drainer_thread.is_alive(): - self._clear_work_queue() - time.sleep(1) diff --git a/src/ert/gui/simulation/tracker_worker.py b/src/ert/gui/simulation/queue_emitter.py similarity index 68% rename from src/ert/gui/simulation/tracker_worker.py rename to src/ert/gui/simulation/queue_emitter.py index 00a3f91dac0..6d1a996dce8 100644 --- a/src/ert/gui/simulation/tracker_worker.py +++ b/src/ert/gui/simulation/queue_emitter.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Iterator, Union +from queue import SimpleQueue from qtpy.QtCore import QObject, Signal, Slot @@ -9,40 +9,39 @@ logger = logging.getLogger(__name__) -class TrackerWorker(QObject): - """A worker that consumes events produced by a tracker and emits them to qt - subscribers.""" +class QueueEmitter(QObject): + """A worker that emits items put on a queue to qt subscribers.""" - consumed_event = Signal(object) + new_event = Signal(object) done = Signal() def __init__( self, - event_generator_factory: Callable[ - [], Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]] - ], + event_queue: SimpleQueue, parent=None, ): super().__init__(parent) - logger.debug("init trackerworker") - self._tracker = event_generator_factory + logger.debug("init QueueEmitter") + self._event_queue = event_queue self._stopped = False @Slot() def consume_and_emit(self): logger.debug("tracking...") - for event in self._tracker(): + while True: + event = self._event_queue.get() if self._stopped: logger.debug("stopped") break + # pre-rendering in this thread to avoid work in main rendering thread if isinstance(event, FullSnapshotEvent) and event.snapshot: SnapshotModel.prerender(event.snapshot) elif isinstance(event, SnapshotUpdateEvent) and event.partial_snapshot: SnapshotModel.prerender(event.partial_snapshot) logger.debug(f"emit {event}") - self.consumed_event.emit(event) + self.new_event.emit(event) if isinstance(event, EndEvent): logger.debug("got end event") diff --git a/src/ert/gui/simulation/run_dialog.py b/src/ert/gui/simulation/run_dialog.py index 27d26f15255..59e4b76434f 100644 --- a/src/ert/gui/simulation/run_dialog.py +++ b/src/ert/gui/simulation/run_dialog.py @@ -1,4 +1,5 @@ import logging +from queue import SimpleQueue from typing import Optional from PyQt5.QtWidgets import QAbstractItemView @@ -23,7 +24,6 @@ from ert.ensemble_evaluator import ( EndEvent, EvaluatorServerConfig, - EvaluatorTracker, FullSnapshotEvent, SnapshotUpdateEvent, ) @@ -44,7 +44,7 @@ from ert.run_models.event import RunModelErrorEvent from ert.shared.status.utils import format_running_time -from .tracker_worker import TrackerWorker +from .queue_emitter import QueueEmitter from .view import LegendView, ProgressView, RealizationWidget, UpdateWidget _TOTAL_PROGRESS_TEMPLATE = "Total progress {total_progress}% — {phase_name}" @@ -58,6 +58,7 @@ def __init__( self, config_file: str, run_model: BaseRunModel, + event_queue: SimpleQueue, notifier: ErtNotifier, parent=None, ): @@ -69,6 +70,7 @@ def __init__( self._snapshot_model = SnapshotModel(self) self._run_model = run_model + self._event_queue = event_queue self._notifier = notifier self._isDetailedDialog = False @@ -181,7 +183,6 @@ def __init__( self._setSimpleDialog() self.finished.connect(self._on_finished) - self._run_model.add_send_event_callback(self.on_run_model_event.emit) self.on_run_model_event.connect(self._on_event) def _current_tab_changed(self, index: int) -> None: @@ -280,29 +281,21 @@ def run(): simulation_thread = ErtThread( name="ert_gui_simulation_thread", target=run, daemon=True ) - simulation_thread.start() - - self._ticker.start(1000) - self._tracker = EvaluatorTracker( - self._run_model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) - - worker = TrackerWorker(self._tracker.track) + worker = QueueEmitter(self._event_queue) worker_thread = QThread() + self._worker = worker + self._worker_thread = worker_thread + worker.done.connect(worker_thread.quit) - worker.consumed_event.connect(self._on_event) + worker.new_event.connect(self._on_event) worker.moveToThread(worker_thread) self.simulation_done.connect(worker.stop) - self._worker = worker - self._worker_thread = worker_thread worker_thread.started.connect(worker.consume_and_emit) - # _worker_thread is finished once everything has stopped. We wait to - # show the done button to this point to avoid destroying the QThread - # while it is running (which would sigabrt) - self._worker_thread.finished.connect(self._show_done_button) + + self._ticker.start(1000) self._worker_thread.start() + simulation_thread.start() self._notifier.set_is_simulation_running(True) def killJobs(self): @@ -314,9 +307,7 @@ def killJobs(self): if kill_job == QMessageBox.Yes: # Normally this slot would be invoked by the signal/slot system, # but the worker is busy tracking the evaluation. - self._tracker.request_termination() - self._worker_thread.quit() - self._worker_thread.wait() + self._run_model.cancel() self._on_finished() self.finished.emit(-1) return kill_job @@ -352,9 +343,8 @@ def _on_ticker(self): def _on_event(self, event: object): if isinstance(event, EndEvent): self.simulation_done.emit(event.failed, event.failed_msg) - self._worker.stop() self._ticker.stop() - + self._show_done_button() elif isinstance(event, FullSnapshotEvent): if event.snapshot is not None: self._snapshot_model._add_snapshot(event.snapshot, event.iteration) diff --git a/src/ert/gui/simulation/simulation_panel.py b/src/ert/gui/simulation/simulation_panel.py index f2520a69b19..49b4677f8a9 100644 --- a/src/ert/gui/simulation/simulation_panel.py +++ b/src/ert/gui/simulation/simulation_panel.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from queue import SimpleQueue from typing import Any, Dict from qtpy.QtCore import QSize, Qt @@ -195,11 +196,13 @@ def runSimulation(self): abort = False QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor) config = self.facade.config + event_queue = SimpleQueue() try: model = create_model( config, self._notifier.storage, args, + event_queue, ) except ValueError as e: @@ -275,7 +278,7 @@ def runSimulation(self): if not abort: dialog = RunDialog( - self._config_file, model, self._notifier, self.parent() + self._config_file, model, event_queue, self._notifier, self.parent() ) self.run_button.setEnabled(False) self.run_button.setText(EXPERIMENT_IS_RUNNING_BUTTON_MESSAGE) diff --git a/src/ert/gui/tools/run_analysis/run_analysis_tool.py b/src/ert/gui/tools/run_analysis/run_analysis_tool.py index 8353950a7a9..88bb4c2b8f7 100644 --- a/src/ert/gui/tools/run_analysis/run_analysis_tool.py +++ b/src/ert/gui/tools/run_analysis/run_analysis_tool.py @@ -58,7 +58,7 @@ def run(self): update_settings, config.analysis_config.es_module, rng, - self.smoother_event_callback, + self.send_smoother_event, log_path=config.analysis_config.log_path, ) except ErtAnalysisError as e: @@ -68,7 +68,7 @@ def run(self): self.finished.emit(error, self._source_ensemble.name) - def smoother_event_callback(self, event: AnalysisEvent) -> None: + def send_smoother_event(self, event: AnalysisEvent) -> None: if isinstance(event, AnalysisStatusEvent): self.progress_update.emit(RunModelStatusEvent(iteration=0, msg=event.msg)) elif isinstance(event, AnalysisTimeEvent): diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 1d46353b4d1..5ee53c5f36c 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import logging import os import shutil @@ -7,9 +8,9 @@ import uuid from contextlib import contextmanager from pathlib import Path +from queue import SimpleQueue from typing import ( TYPE_CHECKING, - Callable, Dict, Generator, List, @@ -20,6 +21,9 @@ ) import numpy as np +from aiohttp import ClientError +from cloudevents.http import CloudEvent +from websockets.exceptions import ConnectionClosedError from ert.analysis import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent from ert.analysis.event import AnalysisErrorEvent @@ -31,8 +35,28 @@ EnsembleBuilder, EnsembleEvaluator, EvaluatorServerConfig, + Monitor, RealizationBuilder, ) +from ert.ensemble_evaluator.event import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, +) +from ert.ensemble_evaluator.identifiers import ( + EVTYPE_EE_SNAPSHOT, + EVTYPE_EE_SNAPSHOT_UPDATE, + EVTYPE_EE_TERMINATED, + STATUS, +) +from ert.ensemble_evaluator.snapshot import PartialSnapshot, Snapshot +from ert.ensemble_evaluator.state import ( + ENSEMBLE_STATE_CANCELLED, + ENSEMBLE_STATE_FAILED, + ENSEMBLE_STATE_STOPPED, + REALIZATION_STATE_FAILED, + REALIZATION_STATE_FINISHED, +) from ert.libres_facade import LibresFacade from ert.run_context import RunContext from ert.runpaths import Runpaths @@ -42,6 +66,8 @@ RunModelErrorEvent, RunModelStatusEvent, RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, ) event_logger = logging.getLogger("ert.event_log") @@ -50,6 +76,24 @@ from ert.config import QueueConfig from ert.run_models.run_arguments import RunArgumentsType +StatusEvents = Union[ + FullSnapshotEvent, + SnapshotUpdateEvent, + EndEvent, + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, + RunModelErrorEvent, + RunModelStatusEvent, + RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, +] + + +class OutOfOrderSnapshotUpdateException(ValueError): + pass + class ErtRunError(Exception): pass @@ -89,6 +133,7 @@ def __init__( config: ErtConfig, storage: Storage, queue_config: QueueConfig, + status_queue: SimpleQueue[StatusEvents], phase_count: int = 1, ): """ @@ -140,16 +185,15 @@ def __init__( current_ensemble = self.simulation_arguments.current_ensemble if current_ensemble is not None: self.run_paths.set_ert_ensemble(current_ensemble) - self._send_event_callback: Optional[Callable[[object], None]] = None - def add_send_event_callback(self, func: Callable[[object], None]) -> None: - self._send_event_callback = func + self._iter_snapshot: Dict[int, Snapshot] = {} + self._status_queue = status_queue + self._end_queue: SimpleQueue[str] = SimpleQueue() - def send_event(self, event: object) -> None: - if self._send_event_callback: - self._send_event_callback(event) + def send_event(self, event: StatusEvents) -> None: + self._status_queue.put(event) - def smoother_event_callback(self, iteration: int, event: AnalysisEvent) -> None: + def send_smoother_event(self, iteration: int, event: AnalysisEvent) -> None: if isinstance(event, AnalysisStatusEvent): self.send_event(RunModelStatusEvent(iteration=iteration, msg=event.msg)) elif isinstance(event, AnalysisTimeEvent): @@ -181,6 +225,9 @@ def simulation_arguments(self) -> RunArgumentsType: def _ensemble_size(self) -> int: return len(self._initial_realizations_mask) + def cancel(self) -> None: + self._end_queue.put("END") + def reset(self) -> None: self._failed = False self._error_messages = [] @@ -324,6 +371,7 @@ def reraise_exception(self, exctype: Type[Exception]) -> None: def _simulationEnded(self) -> None: self._clean_env_context() self._job_stop_time = int(time.time()) + self.send_end_event() def setPhase( self, phase: int, phase_name: str, indeterminate: Optional[bool] = None @@ -373,18 +421,159 @@ def checkHaveSufficientRealizations( f"MIN_REALIZATIONS to allow (more) failures in your experiments." ) + def _progress(self) -> float: + """Fraction of completed iterations over total iterations""" + + if self.isFinished(): + return 1.0 + elif not self._iter_snapshot: + return 0.0 + else: + # Calculate completed realizations + current_iter = max(list(self._iter_snapshot.keys())) + done_reals = 0 + all_reals = self._iter_snapshot[current_iter].reals + if not all_reals: + # Empty ensemble or all realizations deactivated + return 1.0 + for real in all_reals.values(): + if real.status in [ + REALIZATION_STATE_FINISHED, + REALIZATION_STATE_FAILED, + ]: + done_reals += 1 + real_progress = float(done_reals) / len(all_reals) + + return ( + (current_iter + real_progress) / self.phaseCount() + if self.phaseCount() != 1 + else real_progress + ) + + def send_end_event(self) -> None: + self.send_event( + EndEvent( + failed=self.hasRunFailed(), + failed_msg=self.getFailMessage(), + ) + ) + + def send_snapshot_event(self, event: CloudEvent) -> None: + if event["type"] == EVTYPE_EE_SNAPSHOT: + iter_ = event.data["iter"] + snapshot = Snapshot(event.data) + self._iter_snapshot[iter_] = snapshot + self.send_event( + FullSnapshotEvent( + phase_name=self.getPhaseName(), + current_phase=self.currentPhase(), + total_phases=self.phaseCount(), + indeterminate=self.isIndeterminate(), + progress=self._progress(), + iteration=iter_, + snapshot=copy.deepcopy(snapshot), + ) + ) + elif event["type"] == EVTYPE_EE_SNAPSHOT_UPDATE: + iter_ = event.data["iter"] + if iter_ not in self._iter_snapshot: + raise OutOfOrderSnapshotUpdateException( + f"got {EVTYPE_EE_SNAPSHOT_UPDATE} without having stored " + f"snapshot for iter {iter_}" + ) + partial = PartialSnapshot(self._iter_snapshot[iter_]).from_cloudevent(event) + self._iter_snapshot[iter_].merge_event(partial) + self.send_event( + SnapshotUpdateEvent( + phase_name=self.getPhaseName(), + current_phase=self.currentPhase(), + total_phases=self.phaseCount(), + indeterminate=self.isIndeterminate(), + progress=self._progress(), + iteration=iter_, + partial_snapshot=partial, + ) + ) + def run_ensemble_evaluator( self, run_context: RunContext, ee_config: EvaluatorServerConfig ) -> List[int]: + if not self._end_queue.empty(): + event_logger.debug("Run model canceled - pre evaluation") + self._end_queue.get() + return [] ensemble = self._build_ensemble(run_context) - - successful_realizations = EnsembleEvaluator( + evaluator = EnsembleEvaluator( ensemble, ee_config, run_context.iteration, - ).run_and_get_successful_realizations() - - return successful_realizations + ) + evaluator.start_running() + should_exit = False + while not should_exit: + try: + event_logger.debug("connecting to new monitor...") + with Monitor(ee_config.get_connection_info()) as monitor: + event_logger.debug("connected") + for event in monitor.track(): + if event["type"] in ( + EVTYPE_EE_SNAPSHOT, + EVTYPE_EE_SNAPSHOT_UPDATE, + ): + self.send_snapshot_event(event) + if event.data.get(STATUS) in [ + ENSEMBLE_STATE_STOPPED, + ENSEMBLE_STATE_FAILED, + ]: + event_logger.debug( + "observed evaluation stopped event, signal done" + ) + monitor.signal_done() + should_exit = True + if event.data.get(STATUS) == ENSEMBLE_STATE_CANCELLED: + event_logger.debug( + "observed evaluation cancelled event, exit drainer" + ) + # Allow track() to emit an EndEvent. + return [] + elif event["type"] == EVTYPE_EE_TERMINATED: + event_logger.debug("got terminator event") + + if not self._end_queue.empty(): + event_logger.debug("Run model canceled - during evaluation") + self._end_queue.get() + monitor.signal_cancel() + event_logger.debug( + "Run model canceled - during evaluation - cancel sent" + ) + # This sleep needs to be there. Refer to issue #1250: `Authority + # on information about evaluations/experiments` + # time.sleep(self._next_ensemble_evaluator_wait_time) + except (ConnectionRefusedError, ClientError) as e: + if not self.isFinished(): + event_logger.debug(f"connection refused: {e}") + except ConnectionClosedError as e: + # The monitor connection closed unexpectedly + event_logger.debug(f"connection closed error: {e}") + except BaseException: + event_logger.exception("unexpected error: ") + # We really don't know what happened... shut down + # the thread and get out of here. The monitor has + # been stopped by the ctx-mgr + return [] + + event_logger.debug( + "observed that model was finished, waiting tasks completion..." + ) + # The model has finished, we indicate this by sending a DONE + event_logger.debug("tasks complete") + + evaluator.join() + if not self._end_queue.empty(): + event_logger.debug("Run model canceled - post evaluation") + self._end_queue.get() + return [] + return evaluator.get_successful_realizations() def _build_ensemble( self, diff --git a/src/ert/run_models/ensemble_experiment.py b/src/ert/run_models/ensemble_experiment.py index cad5063f44a..1bcc4f2ed0c 100644 --- a/src/ert/run_models/ensemble_experiment.py +++ b/src/ert/run_models/ensemble_experiment.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from queue import SimpleQueue from typing import TYPE_CHECKING, Union import numpy as np @@ -10,7 +11,7 @@ from ert.run_context import RunContext from ert.storage import Storage -from .base_run_model import BaseRunModel +from .base_run_model import BaseRunModel, StatusEvents if TYPE_CHECKING: from ert.config import ErtConfig, QueueConfig @@ -38,12 +39,14 @@ def __init__( config: ErtConfig, storage: Storage, queue_config: QueueConfig, + status_queue: SimpleQueue[StatusEvents], ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, ) def run_experiment( diff --git a/src/ert/run_models/ensemble_smoother.py b/src/ert/run_models/ensemble_smoother.py index 076a6949ee6..9f943d8a1a2 100644 --- a/src/ert/run_models/ensemble_smoother.py +++ b/src/ert/run_models/ensemble_smoother.py @@ -2,6 +2,7 @@ import functools import logging +from queue import SimpleQueue from typing import TYPE_CHECKING import numpy as np @@ -16,7 +17,7 @@ from ..config.analysis_config import UpdateSettings from ..config.analysis_module import ESSettings -from .base_run_model import BaseRunModel, ErtRunError +from .base_run_model import BaseRunModel, ErtRunError, StatusEvents from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent if TYPE_CHECKING: @@ -35,12 +36,14 @@ def __init__( queue_config: QueueConfig, es_settings: ESSettings, update_settings: UpdateSettings, + status_queue: SimpleQueue[StatusEvents], ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, phase_count=2, ) self.es_settings = es_settings @@ -137,7 +140,7 @@ def run_experiment( parameters=prior_context.ensemble.experiment.update_parameters, observations=prior_context.ensemble.experiment.observations.keys(), rng=self.rng, - progress_callback=functools.partial(self.smoother_event_callback, 0), + progress_callback=functools.partial(self.send_smoother_event, 0), log_path=self.ert_config.analysis_config.log_path, ) diff --git a/src/ert/run_models/evaluate_ensemble.py b/src/ert/run_models/evaluate_ensemble.py index acf6a6d3a7e..547d7a9538f 100644 --- a/src/ert/run_models/evaluate_ensemble.py +++ b/src/ert/run_models/evaluate_ensemble.py @@ -12,8 +12,12 @@ from . import BaseRunModel if TYPE_CHECKING: + from queue import SimpleQueue + from ert.config import ErtConfig, QueueConfig + from .base_run_model import StatusEvents + # pylint: disable=too-many-arguments class EvaluateEnsemble(BaseRunModel): @@ -32,12 +36,14 @@ def __init__( config: ErtConfig, storage: Storage, queue_config: QueueConfig, + status_queue: SimpleQueue[StatusEvents], ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, ) def run_experiment( diff --git a/src/ert/run_models/iterated_ensemble_smoother.py b/src/ert/run_models/iterated_ensemble_smoother.py index f9a962d26fe..b4ec5bf9039 100644 --- a/src/ert/run_models/iterated_ensemble_smoother.py +++ b/src/ert/run_models/iterated_ensemble_smoother.py @@ -2,6 +2,7 @@ import functools import logging +from queue import SimpleQueue from typing import TYPE_CHECKING import numpy as np @@ -17,7 +18,7 @@ from ..config.analysis_config import UpdateSettings from ..config.analysis_module import IESSettings -from .base_run_model import BaseRunModel, ErtRunError +from .base_run_model import BaseRunModel, ErtRunError, StatusEvents from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent if TYPE_CHECKING: @@ -40,12 +41,14 @@ def __init__( queue_config: QueueConfig, analysis_config: IESSettings, update_settings: UpdateSettings, + status_queue: SimpleQueue[StatusEvents], ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, phase_count=2, ) self.support_restart = False @@ -97,7 +100,7 @@ def analyzeStep( initial_mask=initial_mask, rng=self.rng, progress_callback=functools.partial( - self.smoother_event_callback, iteration + self.send_smoother_event, iteration ), log_path=self.ert_config.analysis_config.log_path, ) diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index 28e8dec1100..580385665e2 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -2,6 +2,7 @@ import functools import logging +from queue import SimpleQueue from typing import TYPE_CHECKING, List import numpy as np @@ -16,7 +17,7 @@ from ..config.analysis_config import UpdateSettings from ..config.analysis_module import ESSettings -from .base_run_model import BaseRunModel, ErtRunError +from .base_run_model import BaseRunModel, ErtRunError, StatusEvents from .event import RunModelStatusEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent if TYPE_CHECKING: @@ -41,12 +42,14 @@ def __init__( queue_config: QueueConfig, es_settings: ESSettings, update_settings: UpdateSettings, + status_queue: SimpleQueue[StatusEvents], ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, phase_count=2, ) self.weights = MultipleDataAssimilation.default_weights @@ -212,7 +215,7 @@ def update( global_scaling=weight, rng=self.rng, progress_callback=functools.partial( - self.smoother_event_callback, prior_context.iteration + self.send_smoother_event, prior_context.iteration ), log_path=self.ert_config.analysis_config.log_path, ) diff --git a/src/ert/run_models/single_test_run.py b/src/ert/run_models/single_test_run.py index 8b8c8ef677c..3991db8eca8 100644 --- a/src/ert/run_models/single_test_run.py +++ b/src/ert/run_models/single_test_run.py @@ -8,9 +8,13 @@ from ert.run_models import EnsembleExperiment, ErtRunError if TYPE_CHECKING: + from queue import SimpleQueue + from ert.run_models.run_arguments import SingleTestRunArguments from ert.storage import Storage + from .base_run_model import StatusEvents + class SingleTestRun(EnsembleExperiment): def __init__( @@ -18,9 +22,12 @@ def __init__( simulation_arguments: SingleTestRunArguments, config: ErtConfig, storage: Storage, + status_queue: SimpleQueue[StatusEvents], ): local_queue_config = config.queue_config.create_local_copy() - super().__init__(simulation_arguments, config, storage, local_queue_config) + super().__init__( + simulation_arguments, config, storage, local_queue_config, status_queue + ) @staticmethod def checkHaveSufficientRealizations( diff --git a/tests/integration_tests/status/test_tracking_integration.py b/tests/integration_tests/status/test_tracking_integration.py index 99b8b5d88d2..26dc574bbec 100644 --- a/tests/integration_tests/status/test_tracking_integration.py +++ b/tests/integration_tests/status/test_tracking_integration.py @@ -16,7 +16,6 @@ from ert.cli import ENSEMBLE_EXPERIMENT_MODE, ENSEMBLE_SMOOTHER_MODE, TEST_RUN_MODE from ert.cli.model_factory import create_model from ert.config import ErtConfig -from ert.ensemble_evaluator import EvaluatorTracker from ert.ensemble_evaluator.config import EvaluatorServerConfig from ert.ensemble_evaluator.event import ( EndEvent, @@ -31,6 +30,19 @@ ) +class Events: + def __init__(self): + self.events = [] + self.environment = [] + + def __iter__(self): + yield from self.events + + def put(self, event): + self.events.append(event) + self.environment.append(os.environ.copy()) + + def check_expression(original, path_expression, expected, msg_start): assert isinstance(original, dict), f"{msg_start}data is not a dict" jsonpath_expr = parse(path_expression) @@ -171,10 +183,12 @@ def test_tracking( ert_config = ErtConfig.from_file(parsed.config) os.chdir(ert_config.config_path) + queue = Events() model = create_model( ert_config, storage, parsed, + queue, ) evaluator_server_config = EvaluatorServerConfig( @@ -191,14 +205,11 @@ def test_tracking( ) thread.start() - tracker = EvaluatorTracker( - model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) - snapshots = {} - for event in tracker.track(): + thread.join() + + for event in queue: if isinstance(event, FullSnapshotEvent): snapshots[event.iteration] = event.snapshot if ( @@ -209,8 +220,6 @@ def test_tracking( if isinstance(event, EndEvent): pass - assert tracker._progress() == progress - assert len(snapshots) == num_iters for snapshot in snapshots.values(): successful_reals = list( @@ -234,7 +243,6 @@ def test_tracking( expected, f"Snapshot {i} did not match:\n", ) - thread.join() @pytest.mark.integration_test @@ -275,18 +283,19 @@ def test_setting_env_context_during_run( ert_config = ErtConfig.from_file(parsed.config) os.chdir(ert_config.config_path) - model = create_model( - ert_config, - storage, - parsed, - ) - evaluator_server_config = EvaluatorServerConfig( custom_port_range=range(1024, 65535), custom_host="127.0.0.1", use_token=False, generate_cert=False, ) + queue = Events() + model = create_model( + ert_config, + storage, + parsed, + queue, + ) thread = ErtThread( name="ert_cli_simulation_thread", @@ -294,22 +303,16 @@ def test_setting_env_context_during_run( args=(evaluator_server_config,), ) thread.start() - - tracker = EvaluatorTracker( - model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) + thread.join() expected = ["_ERT_SIMULATION_MODE", "_ERT_EXPERIMENT_ID", "_ERT_ENSEMBLE_ID"] - for event in tracker.track(): + for event, environment in zip(queue.events, queue.environment): if isinstance(event, (FullSnapshotEvent, SnapshotUpdateEvent)): - assert model._context_env_keys == expected for key in expected: - assert key in os.environ - assert os.environ.get("_ERT_SIMULATION_MODE") == mode + assert key in environment + assert environment.get("_ERT_SIMULATION_MODE") == mode if isinstance(event, EndEvent): pass - thread.join() # Check environment is clean after the model run ends. assert not model._context_env_keys @@ -358,11 +361,12 @@ def test_tracking_missing_ecl(tmpdir, caplog, storage): ert_config = ErtConfig.from_file(parsed.config) os.chdir(ert_config.config_path) - + events = Events() model = create_model( ert_config, storage, parsed, + events, ) evaluator_server_config = EvaluatorServerConfig( @@ -379,15 +383,10 @@ def test_tracking_missing_ecl(tmpdir, caplog, storage): ) with caplog.at_level(logging.ERROR): thread.start() - - tracker = EvaluatorTracker( - model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) - + thread.join() failures = [] - for event in tracker.track(): + for event in events: if isinstance(event, EndEvent): failures.append(event) assert ( @@ -409,5 +408,3 @@ def test_tracking_missing_ecl(tmpdir, caplog, storage): f"{Path().absolute()}/simulations/realization-0/" "iter-0/ECLIPSE_CASE" ) in failures[0].failed_msg - - thread.join() diff --git a/tests/performance_tests/test_snapshot.py b/tests/performance_tests/test_snapshot.py index 28a54f21a62..91f079ed995 100644 --- a/tests/performance_tests/test_snapshot.py +++ b/tests/performance_tests/test_snapshot.py @@ -17,8 +17,6 @@ from ..unit_tests.gui.conftest import ( # noqa: F401 active_realizations_fixture, large_snapshot, - mock_tracker, - runmodel, ) from ..unit_tests.gui.simulation.test_run_dialog import test_large_snapshot @@ -45,18 +43,14 @@ def test_snapshot_handling_of_forward_model_events( def test_gui_snapshot( benchmark, - runmodel, # noqa: F811 large_snapshot, # noqa: F811 qtbot, - mock_tracker, # noqa: F811 ): infinite_timeout = 100000 benchmark( test_large_snapshot, - runmodel, large_snapshot, qtbot, - mock_tracker, timeout_per_iter=infinite_timeout, ) diff --git a/tests/unit_tests/cli/test_model_factory.py b/tests/unit_tests/cli/test_model_factory.py index 322ddc75c24..cef4ac82754 100644 --- a/tests/unit_tests/cli/test_model_factory.py +++ b/tests/unit_tests/cli/test_model_factory.py @@ -42,13 +42,8 @@ def test_custom_realizations(poly_case): args = Namespace(realizations="0-4,7,8") ensemble_size = facade.get_ensemble_size() active_mask = [False] * ensemble_size - active_mask[0] = True - active_mask[1] = True - active_mask[2] = True - active_mask[3] = True - active_mask[4] = True - active_mask[7] = True - active_mask[8] = True + active_mask[0:5] = [True] * 5 + active_mask[7:9] = [True] * 2 assert model_factory._realizations(args, ensemble_size).tolist() == active_mask @@ -62,6 +57,7 @@ def test_setup_single_test_run(poly_case, storage): random_seed=None, experiment_name=None, ), + MagicMock(), ) assert isinstance(model, SingleTestRun) assert model.simulation_arguments.current_ensemble == "current-ensemble" @@ -80,6 +76,7 @@ def test_setup_single_test_run_with_ensemble(poly_case, storage): random_seed=None, experiment_name=None, ), + MagicMock(), ) assert isinstance(model, SingleTestRun) assert model.simulation_arguments.current_ensemble == "current-ensemble" @@ -100,6 +97,7 @@ def test_setup_ensemble_experiment(poly_case, storage): poly_case, storage, args, + MagicMock(), ) assert isinstance(model, EnsembleExperiment) @@ -116,7 +114,7 @@ def test_setup_ensemble_smoother(poly_case, storage): ) model = model_factory._setup_ensemble_smoother( - poly_case, storage, args, MagicMock() + poly_case, storage, args, MagicMock(), MagicMock() ) assert isinstance(model, EnsembleSmoother) assert model.simulation_arguments.current_ensemble == "default" @@ -138,7 +136,7 @@ def test_setup_multiple_data_assimilation(poly_case, storage): ) model = model_factory._setup_multiple_data_assimilation( - poly_case, storage, args, MagicMock() + poly_case, storage, args, MagicMock(), MagicMock() ) assert isinstance(model, MultipleDataAssimilation) assert model.simulation_arguments.weights == "6,4,2" @@ -161,7 +159,7 @@ def test_setup_iterative_ensemble_smoother(poly_case, storage): ) model = model_factory._setup_iterative_ensemble_smoother( - poly_case, storage, args, MagicMock() + poly_case, storage, args, MagicMock(), MagicMock() ) assert isinstance(model, IteratedEnsembleSmoother) assert model.simulation_arguments.current_ensemble == "default" diff --git a/tests/unit_tests/cli/test_model_hook_order.py b/tests/unit_tests/cli/test_model_hook_order.py index be26ed7e5c2..2cd11e337c1 100644 --- a/tests/unit_tests/cli/test_model_hook_order.py +++ b/tests/unit_tests/cli/test_model_hook_order.py @@ -65,6 +65,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch): MagicMock(), MagicMock(), MagicMock(), + MagicMock(), ) test_class.ert = ert_mock test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) @@ -111,6 +112,7 @@ def test_hook_call_order_es_mda(monkeypatch): MagicMock(), es_settings=MagicMock(), update_settings=MagicMock(), + status_queue=MagicMock(), ) ert_mock.runWorkflows = MagicMock() test_class.ert = ert_mock @@ -152,6 +154,7 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): MagicMock(), MagicMock(), MagicMock(), + MagicMock(), ) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) test_class.ert = ert_mock diff --git a/tests/unit_tests/cli/test_run_context.py b/tests/unit_tests/cli/test_run_context.py index f7660130f91..b1011e10901 100644 --- a/tests/unit_tests/cli/test_run_context.py +++ b/tests/unit_tests/cli/test_run_context.py @@ -45,6 +45,7 @@ def test_that_all_iterations_gets_correct_name_and_iteration_number( MagicMock(), es_settings=MagicMock(), update_settings=MagicMock(), + status_queue=MagicMock(), ) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) test_class.run_experiment(MagicMock()) diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index baeba164d95..db53dcebe2d 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -16,7 +16,7 @@ def test_restarted_jobs_do_not_have_error_msgs(evaluator): - evaluator._start_running() + evaluator.start_running() token = evaluator._config.token cert = evaluator._config.cert url = evaluator._config.url @@ -89,7 +89,7 @@ def test_restarted_jobs_do_not_have_error_msgs(evaluator): def test_new_monitor_can_pick_up_where_we_left_off(evaluator): - evaluator._start_running() + evaluator.start_running() token = evaluator._config.token cert = evaluator._config.cert url = evaluator._config.url @@ -193,7 +193,7 @@ def test_new_monitor_can_pick_up_where_we_left_off(evaluator): def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_evaluator( evaluator, ): - evaluator._start_running() + evaluator.start_running() conn_info = evaluator._config.get_connection_info() with Monitor(conn_info) as monitor: events = monitor.track() @@ -288,7 +288,7 @@ def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_evaluat def test_ensure_multi_level_events_in_order(evaluator): - evaluator._start_running() + evaluator.start_running() config_info = evaluator._config.get_connection_info() with Monitor(config_info) as monitor: events = monitor.track() @@ -349,7 +349,7 @@ def exploding_handler(events): evaluator._dispatcher.set_event_handler({"EXPLODING"}, exploding_handler) - evaluator._start_running() + evaluator.start_running() config_info = evaluator._config.get_connection_info() with Monitor(config_info) as monitor: diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index d1b59e2bbae..e9f060b40df 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -29,7 +29,7 @@ def test_run_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypatch): generate_cert=False, ) evaluator = EnsembleEvaluator(ensemble, config, 0) - evaluator._start_running() + evaluator.start_running() with Monitor(config) as monitor: for e in monitor.track(): if e["type"] in ( @@ -66,7 +66,7 @@ def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypat evaluator = EnsembleEvaluator(ensemble, config, 0) - evaluator._start_running() + evaluator.start_running() with Monitor(config) as mon: cancel = True with contextlib.suppress( @@ -106,7 +106,7 @@ def test_run_legacy_ensemble_with_bare_exception( with patch.object(JobQueue, "add_realization") as faulty_queue: faulty_queue.side_effect = RuntimeError() - evaluator._start_running() + evaluator.start_running() with Monitor(config) as monitor: for e in monitor.track(): if e.data is not None and e.data.get(identifiers.STATUS) in [ diff --git a/tests/unit_tests/ensemble_evaluator/test_evaluator_tracker.py b/tests/unit_tests/ensemble_evaluator/test_evaluator_tracker.py deleted file mode 100644 index 677e1684b25..00000000000 --- a/tests/unit_tests/ensemble_evaluator/test_evaluator_tracker.py +++ /dev/null @@ -1,356 +0,0 @@ -import math -from typing import Any, List, Optional, Tuple -from unittest.mock import MagicMock, patch - -import pytest -from cloudevents.http.event import CloudEvent - -import ert.ensemble_evaluator.identifiers as ids -from ert.ensemble_evaluator import EvaluatorTracker, state -from ert.ensemble_evaluator.config import EvaluatorServerConfig -from ert.ensemble_evaluator.event import EndEvent, SnapshotUpdateEvent -from ert.ensemble_evaluator.snapshot import PartialSnapshot, SnapshotBuilder -from ert.run_models import BaseRunModel - - -def build_snapshot(real_list: Optional[List[str]] = None): - if real_list is None: - # passing ["0"] is required - real_list = ["0"] - return SnapshotBuilder().build(real_list, state.REALIZATION_STATE_UNKNOWN) - - -def build_partial(real_list: Optional[List[str]] = None): - if real_list is None: - real_list = ["0"] - return PartialSnapshot(build_snapshot(real_list)) - - -@pytest.fixture -def make_mock_ee_monitor(): - def _mock_ee_monitor(events): - def _track(): - while True: - try: - event = events.pop(0) - yield event - except IndexError: - return - - return MagicMock(track=MagicMock(side_effect=_track)) - - return _mock_ee_monitor - - -@pytest.mark.timeout(60) -@pytest.mark.parametrize( - "run_model, monitor_events,brm_mutations,expected_progress", - [ - pytest.param( - BaseRunModel, - [ - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 2, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 2, - }, - ), - ], - [("_phase_count", 1)], - [0, 0.5], - id="ensemble_experiment_50", - ), - pytest.param( - BaseRunModel, - [ - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 0, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 0, - }, - ), - ], - [("_phase_count", 1)], - [0, 0.5], - id="ensemble_experiment_50", - ), - pytest.param( - BaseRunModel, - [ - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 0, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 0, - }, - ), - ], - [("_phase_count", 2)], - [0, 0.25], - id="ensemble_smoother_25", - ), - pytest.param( - BaseRunModel, - [ - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 0, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .update_realization( - "1", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 0, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 1, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 1, - }, - ), - ], - [("_phase_count", 2)], - [ - 0, - 0.5, - 0.5, - 0.75, - ], - id="ensemble_smoother_75", - ), - pytest.param( - BaseRunModel, - [ - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 0, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .update_realization( - "1", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 0, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 1, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .update_realization( - "1", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 1, - }, - ), - ], - [("_phase_count", 2)], - [ - 0, - 0.5, - 0.5, - 1.0, - ], - id="ensemble_smoother_100", - ), - pytest.param( - BaseRunModel, - [ - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 1, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .update_realization( - "1", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 1, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT}, - data={ - **(build_snapshot(["0", "1"]).to_dict()), - "iter": 2, - }, - ), - CloudEvent( - {"source": "/", "type": ids.EVTYPE_EE_SNAPSHOT_UPDATE}, - data={ - **( - build_partial(["0", "1"]) - .update_realization( - "0", status=state.REALIZATION_STATE_FINISHED - ) - .update_realization( - "1", status=state.REALIZATION_STATE_FINISHED - ) - .to_dict() - ), - "iter": 2, - }, - ), - ], - [("_phase_count", 3)], - [ - 0.3333, - 0.6666, - 0.6666, - 1.0, - ], - id="ensemble_smoother_100", - ), - ], -) -def test_tracking_progress( - run_model: BaseRunModel, - monitor_events: List[CloudEvent], - brm_mutations: List[Tuple[str, Any]], - expected_progress: float, - make_mock_ee_monitor, -): - """Tests progress by providing a list of CloudEvent and a list of - arguments to apply to setattr(brm) where brm is an actual BaseRunModel - instance. - - The CloudEvent are provided to the tracker via mocking an Ensemble - Evaluator Monitor. - - PartialSnapshots allow realizations to progress, while iterating "iter" in - CloudEvents allows phases to progress. Such progress should happen - when events are yielded by the tracker. This combined progress is tested. - - The final update event and end event is also tested.""" - arg_mock = MagicMock() - arg_mock.random_seed = None - run_model.validate = MagicMock() - brm = run_model(arg_mock, MagicMock(), None, None, None) - ee_config = EvaluatorServerConfig( - custom_port_range=range(1024, 65535), - custom_host="127.0.0.1", - use_token=False, - generate_cert=False, - ) - with patch("ert.ensemble_evaluator.evaluator_tracker.Monitor") as mock_ee: - mock_ee.return_value.__enter__.return_value = make_mock_ee_monitor( - monitor_events.copy() - ) - tracker = EvaluatorTracker( - brm, ee_config.get_connection_info(), next_ensemble_evaluator_wait_time=0.1 - ) - for attr, val in brm_mutations: - setattr(brm, attr, val) - tracker_gen = tracker.track() - update_event = None - for i in range(len(monitor_events)): - update_event = next(tracker_gen) - assert math.isclose( - update_event.progress, expected_progress[i], rel_tol=0.0001 - ) - assert isinstance(update_event, SnapshotUpdateEvent) - brm._phase = brm._phase_count - assert isinstance(next(tracker_gen), EndEvent) diff --git a/tests/unit_tests/gui/simulation/test_run_dialog.py b/tests/unit_tests/gui/simulation/test_run_dialog.py index d5f232b2d65..59dd12bc729 100644 --- a/tests/unit_tests/gui/simulation/test_run_dialog.py +++ b/tests/unit_tests/gui/simulation/test_run_dialog.py @@ -1,6 +1,7 @@ import os from pathlib import Path -from unittest.mock import Mock, patch +from queue import SimpleQueue +from unittest.mock import MagicMock, Mock import pytest from pytestqt.qtbot import QtBot @@ -24,15 +25,15 @@ from ert.services import StorageService -def test_success(runmodel, qtbot: QtBot, mock_tracker): +def test_success(qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker([EndEvent(failed=False, failed_msg="")]) - widget.startSimulation() + widget.startSimulation() + queue.put(EndEvent(failed=False, failed_msg="")) with qtbot.waitExposed(widget, timeout=30000): qtbot.waitUntil(lambda: widget._total_progress_bar.value() == 100) @@ -40,15 +41,15 @@ def test_success(runmodel, qtbot: QtBot, mock_tracker): assert widget.done_button.text() == "Done" -def test_kill_simulations(runmodel, qtbot: QtBot, mock_tracker): +def test_kill_simulations(qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker([EndEvent(failed=False, failed_msg="")]) - widget.startSimulation() + widget.startSimulation() + queue.put(EndEvent(failed=False, failed_msg="")) with qtbot.waitSignal(widget.finished, timeout=30000): @@ -77,16 +78,15 @@ def handle_dialog(): widget.killJobs() -def test_large_snapshot( - runmodel, large_snapshot, qtbot: QtBot, mock_tracker, timeout_per_iter=5000 -): +def test_large_snapshot(large_snapshot, qtbot: QtBot, timeout_per_iter=5000): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - iter_0 = FullSnapshotEvent( + events = [ + FullSnapshotEvent( snapshot=large_snapshot, phase_name="Foo", current_phase=0, @@ -94,8 +94,8 @@ def test_large_snapshot( progress=0.5, iteration=0, indeterminate=False, - ) - iter_1 = FullSnapshotEvent( + ), + FullSnapshotEvent( snapshot=large_snapshot, phase_name="Foo", current_phase=0, @@ -103,11 +103,13 @@ def test_large_snapshot( progress=0.5, iteration=1, indeterminate=False, - ) - tracker.return_value = mock_tracker( - [iter_0, iter_1, EndEvent(failed=False, failed_msg="")] - ) - widget.startSimulation() + ), + EndEvent(failed=False, failed_msg=""), + ] + + widget.startSimulation() + for event in events: + queue.put(event) with qtbot.waitExposed(widget, timeout=timeout_per_iter * 6): qtbot.waitUntil( @@ -317,15 +319,16 @@ def test_large_snapshot( ), ], ) -def test_run_dialog(events, tab_widget_count, runmodel, qtbot: QtBot, mock_tracker): +def test_run_dialog(events, tab_widget_count, qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker(events) - widget.startSimulation() + widget.startSimulation() + for event in events: + queue.put(event) with qtbot.waitExposed(widget, timeout=30000): qtbot.mouseClick(widget.show_details_button, Qt.LeftButton) @@ -459,17 +462,16 @@ def test_that_run_dialog_can_be_closed_while_file_plot_is_open(qtbot: QtBot, sto ), ], ) -def test_run_dialog_memory_usage_showing( - events, tab_widget_count, runmodel, qtbot: QtBot, mock_tracker -): +def test_run_dialog_memory_usage_showing(events, tab_widget_count, qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("poly.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker(events) - widget.startSimulation() + widget.startSimulation() + for event in events: + queue.put(event) with qtbot.waitExposed(widget, timeout=30000): qtbot.mouseClick(widget.show_details_button, Qt.LeftButton) diff --git a/tests/unit_tests/run_models/test_ensemble_experiment.py b/tests/unit_tests/run_models/test_ensemble_experiment.py index d972929fe01..73710c15ff5 100644 --- a/tests/unit_tests/run_models/test_ensemble_experiment.py +++ b/tests/unit_tests/run_models/test_ensemble_experiment.py @@ -47,7 +47,7 @@ def get_run_path_mock(realizations, iteration=None): EnsembleExperiment.validate = MagicMock() ensemble_experiment = EnsembleExperiment( - simulation_arguments, MagicMock(), None, None + simulation_arguments, MagicMock(), None, None, MagicMock() ) ensemble_experiment.run_paths.get_paths = get_run_path_mock ensemble_experiment.facade = MagicMock(