diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index a339e20cf31..f7a7c538a23 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -2,6 +2,7 @@ import contextlib import logging import os +import queue import sys import threading from typing import Any, TextIO @@ -19,7 +20,7 @@ from ert.cli.workflow import execute_workflow from ert.config import ErtConfig, QueueSystem from ert.enkf_main import EnKFMain -from ert.ensemble_evaluator import EvaluatorServerConfig, EvaluatorTracker +from ert.ensemble_evaluator import EvaluatorServerConfig from ert.namespace import Namespace from ert.shared.feature_toggling import FeatureToggling from ert.storage import open_storage @@ -82,11 +83,13 @@ def run_cli(args: Namespace, _: Any = None) -> None: execute_workflow(ert, storage, args.name) return + status_queue = queue.SimpleQueue() try: model = create_model( ert_config, storage, args, + status_queue, ) except ValueError as e: raise ErtCliError(e) from e @@ -112,11 +115,6 @@ def run_cli(args: Namespace, _: Any = None) -> None: target=model.start_simulations_thread, args=(evaluator_server_config,), ) - thread.start() - - tracker = EvaluatorTracker( - model, ee_con_info=evaluator_server_config.get_connection_info() - ) with contextlib.ExitStack() as exit_stack: out: TextIO @@ -127,13 +125,12 @@ def run_cli(args: Namespace, _: Any = None) -> None: else: out = sys.stderr monitor = Monitor(out=out, color_always=args.color_always) - + thread.start() try: - monitor.monitor(tracker.track()) + monitor.monitor(status_queue) except (SystemExit, KeyboardInterrupt, OSError): - # _base_service.py translates CTRL-c to OSError print("\nKilling simulations...") - tracker.request_termination() + model.cancel() thread.join() storage.close() diff --git a/src/ert/cli/model_factory.py b/src/ert/cli/model_factory.py index 6e24c6340f1..a34782922a3 100644 --- a/src/ert/cli/model_factory.py +++ b/src/ert/cli/model_factory.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from queue import SimpleQueue from typing import TYPE_CHECKING import numpy as np @@ -53,6 +54,7 @@ def create_model( config: ErtConfig, storage: StorageAccessor, args: Namespace, + status_queue: SimpleQueue, ) -> BaseRunModel: logger = logging.getLogger(__name__) logger.info( @@ -73,16 +75,20 @@ def create_model( ) if args.mode == TEST_RUN_MODE: - return _setup_single_test_run(config, storage, args) + return _setup_single_test_run(config, storage, args, status_queue) elif args.mode == ENSEMBLE_EXPERIMENT_MODE: - return _setup_ensemble_experiment(config, storage, args) + return _setup_ensemble_experiment(config, storage, args, status_queue) elif args.mode == ENSEMBLE_SMOOTHER_MODE: - return _setup_ensemble_smoother(config, storage, args, update_settings) + return _setup_ensemble_smoother( + config, storage, args, update_settings, status_queue + ) elif args.mode == ES_MDA_MODE: - return _setup_multiple_data_assimilation(config, storage, args, update_settings) + return _setup_multiple_data_assimilation( + config, storage, args, update_settings, status_queue + ) elif args.mode == ITERATIVE_ENSEMBLE_SMOOTHER_MODE: return _setup_iterative_ensemble_smoother( - config, storage, args, update_settings + config, storage, args, update_settings, status_queue ) else: @@ -90,7 +96,10 @@ def create_model( def _setup_single_test_run( - config: ErtConfig, storage: StorageAccessor, args: Namespace + config: ErtConfig, + storage: StorageAccessor, + args: Namespace, + status_queue: SimpleQueue, ) -> SingleTestRun: return SingleTestRun( SingleTestRunArguments( @@ -102,11 +111,15 @@ def _setup_single_test_run( ), config, storage, + status_queue, ) def _setup_ensemble_experiment( - config: ErtConfig, storage: StorageAccessor, args: Namespace + config: ErtConfig, + storage: StorageAccessor, + args: Namespace, + status_queue: SimpleQueue, ) -> EnsembleExperiment: min_realizations_count = config.analysis_config.minimum_required_realizations active_realizations = _realizations(args, config.model_config.num_realizations) @@ -132,6 +145,7 @@ def _setup_ensemble_experiment( config, storage, config.queue_config, + status_queue, ) @@ -140,6 +154,7 @@ def _setup_ensemble_smoother( storage: StorageAccessor, args: Namespace, update_settings: UpdateSettings, + status_queue: SimpleQueue, ) -> EnsembleSmoother: return EnsembleSmoother( ESRunArguments( @@ -158,6 +173,7 @@ def _setup_ensemble_smoother( config.queue_config, es_settings=config.analysis_config.es_module, update_settings=update_settings, + status_queue=status_queue, ) @@ -166,6 +182,7 @@ def _setup_multiple_data_assimilation( storage: StorageAccessor, args: Namespace, update_settings: UpdateSettings, + status_queue: SimpleQueue, ) -> MultipleDataAssimilation: # Because the configuration of the CLI is different from the gui, we # have a different way to get the restart information. @@ -195,6 +212,7 @@ def _setup_multiple_data_assimilation( prior_ensemble, es_settings=config.analysis_config.es_module, update_settings=update_settings, + status_queue=status_queue, ) @@ -203,6 +221,7 @@ def _setup_iterative_ensemble_smoother( storage: StorageAccessor, args: Namespace, update_settings: UpdateSettings, + status_queue: SimpleQueue, ) -> IteratedEnsembleSmoother: return IteratedEnsembleSmoother( SIESRunArguments( @@ -223,6 +242,7 @@ def _setup_iterative_ensemble_smoother( config.queue_config, config.analysis_config.ies_module, update_settings=update_settings, + status_queue=status_queue, ) diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index f051870079e..e5afeb82689 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- import sys from datetime import datetime, timedelta -from typing import Dict, Iterator, Optional, TextIO, Tuple, Union +from queue import SimpleQueue +from typing import Dict, Optional, TextIO, Tuple from tqdm import tqdm @@ -57,13 +58,15 @@ def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None # The dot adds no value without color, so remove it. self.dot = "" + self.done = False def monitor( self, - events: Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]], + event_queue: SimpleQueue, ) -> None: self._start_time = datetime.now() - for event in events: + while True: + event = event_queue.get() if isinstance(event, FullSnapshotEvent): if event.snapshot is not None: self._snapshots[event.iteration] = event.snapshot diff --git a/src/ert/ensemble_evaluator/__init__.py b/src/ert/ensemble_evaluator/__init__.py index ea093217a94..46b711f38c1 100644 --- a/src/ert/ensemble_evaluator/__init__.py +++ b/src/ert/ensemble_evaluator/__init__.py @@ -7,7 +7,6 @@ ) from .config import EvaluatorServerConfig from .evaluator import EnsembleEvaluator -from .evaluator_tracker import EvaluatorTracker from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent from .monitor import Monitor from .snapshot import PartialSnapshot, Snapshot @@ -18,7 +17,6 @@ "EnsembleBuilder", "EnsembleEvaluator", "EvaluatorServerConfig", - "EvaluatorTracker", "ForwardModel", "FullSnapshotEvent", "Monitor", diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index bdd22697d59..724089bb563 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -390,9 +390,10 @@ def _run_server(self, loop: asyncio.AbstractEventLoop) -> None: loop.run_until_complete(self.evaluator_server()) logger.debug("Server thread exiting.") - def _start_running(self) -> None: + def start_running(self) -> None: self._ws_thread.start() self._ensemble.evaluate(self._config) + logger.debug("Started evaluator, joining until shutdown") def _stop(self) -> None: if not self._done.done(): @@ -418,11 +419,11 @@ def _signal_cancel(self) -> None: logger.debug("Stopping current ensemble") self._loop.call_soon_threadsafe(self._stop) - def run_and_get_successful_realizations(self) -> List[int]: - self._start_running() - logger.debug("Started evaluator, joining until shutdown") + def join(self) -> None: self._ws_thread.join() logger.debug("Evaluator is done") + + def get_successful_realizations(self) -> List[int]: return self._ensemble.get_successful_realizations() @staticmethod diff --git a/src/ert/ensemble_evaluator/evaluator_tracker.py b/src/ert/ensemble_evaluator/evaluator_tracker.py deleted file mode 100644 index 5dbd239420d..00000000000 --- a/src/ert/ensemble_evaluator/evaluator_tracker.py +++ /dev/null @@ -1,240 +0,0 @@ -import asyncio -import contextlib -import copy -import logging -import queue -import threading -import time -from typing import TYPE_CHECKING, Dict, Iterator, Union - -from aiohttp import ClientError -from websockets.exceptions import ConnectionClosedError - -from ert.async_utils import get_event_loop, new_event_loop -from ert.ensemble_evaluator.identifiers import ( - EVTYPE_EE_SNAPSHOT, - EVTYPE_EE_SNAPSHOT_UPDATE, - EVTYPE_EE_TERMINATED, - STATUS, -) -from ert.ensemble_evaluator.state import ( - ENSEMBLE_STATE_CANCELLED, - ENSEMBLE_STATE_FAILED, - ENSEMBLE_STATE_STOPPED, - REALIZATION_STATE_FAILED, - REALIZATION_STATE_FINISHED, -) - -from ._wait_for_evaluator import wait_for_evaluator -from .evaluator_connection_info import EvaluatorConnectionInfo -from .event import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent -from .monitor import Monitor -from .snapshot import PartialSnapshot, Snapshot - -if TYPE_CHECKING: - from cloudevents.http.event import CloudEvent - - from ert.run_models import BaseRunModel - - -class OutOfOrderSnapshotUpdateException(ValueError): - pass - - -class EvaluatorTracker: - DONE = "done" - - def __init__( - self, - model: "BaseRunModel", - ee_con_info: EvaluatorConnectionInfo, - next_ensemble_evaluator_wait_time: int = 5, - ): - self._model = model - self._ee_con_info = ee_con_info - self._next_ensemble_evaluator_wait_time = next_ensemble_evaluator_wait_time - self._work_queue: "queue.Queue[Union[str, CloudEvent]]" = queue.Queue() - self._drainer_thread = threading.Thread( - target=self._drain_monitor, - name="DrainerThread", - ) - self._drainer_thread.start() - self._iter_snapshot: Dict[int, Snapshot] = {} - - def _drain_monitor(self) -> None: - asyncio.set_event_loop(new_event_loop()) - drainer_logger = logging.getLogger("ert.ensemble_evaluator.drainer") - while not self._model.isFinished(): - try: - drainer_logger.debug("connecting to new monitor...") - with Monitor(self._ee_con_info) as monitor: - drainer_logger.debug("connected") - for event in monitor.track(): - if event["type"] in ( - EVTYPE_EE_SNAPSHOT, - EVTYPE_EE_SNAPSHOT_UPDATE, - ): - self._work_queue.put(event) - if event.data.get(STATUS) in [ - ENSEMBLE_STATE_STOPPED, - ENSEMBLE_STATE_FAILED, - ]: - drainer_logger.debug( - "observed evaluation stopped event, signal done" - ) - monitor.signal_done() - if event.data.get(STATUS) == ENSEMBLE_STATE_CANCELLED: - drainer_logger.debug( - "observed evaluation cancelled event, exit drainer" - ) - # Allow track() to emit an EndEvent. - self._work_queue.put(EvaluatorTracker.DONE) - return - elif event["type"] == EVTYPE_EE_TERMINATED: - drainer_logger.debug("got terminator event") - # This sleep needs to be there. Refer to issue #1250: `Authority - # on information about evaluations/experiments` - time.sleep(self._next_ensemble_evaluator_wait_time) - except (ConnectionRefusedError, ClientError) as e: - if not self._model.isFinished(): - drainer_logger.debug(f"connection refused: {e}") - except ConnectionClosedError as e: - # The monitor connection closed unexpectedly - drainer_logger.debug(f"connection closed error: {e}") - except BaseException: - drainer_logger.exception("unexpected error: ") - # We really don't know what happened... shut down - # the thread and get out of here. The monitor has - # been stopped by the ctx-mgr - self._work_queue.put(EvaluatorTracker.DONE) - self._work_queue.join() - return - drainer_logger.debug( - "observed that model was finished, waiting tasks completion..." - ) - # The model has finished, we indicate this by sending a DONE - self._work_queue.put(EvaluatorTracker.DONE) - self._work_queue.join() - drainer_logger.debug("tasks complete") - - def track( - self, - ) -> Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]]: - while True: - event = self._work_queue.get() - if isinstance(event, str): - with contextlib.suppress(GeneratorExit): - # consumers may exit at this point, make sure the last - # task is marked as done - if event == EvaluatorTracker.DONE: - yield EndEvent( - failed=self._model.hasRunFailed(), - failed_msg=self._model.getFailMessage(), - ) - self._work_queue.task_done() - break - if event["type"] == EVTYPE_EE_SNAPSHOT: - iter_ = event.data["iter"] - snapshot = Snapshot(event.data) - self._iter_snapshot[iter_] = snapshot - yield FullSnapshotEvent( - phase_name=self._model.getPhaseName(), - current_phase=self._model.currentPhase(), - total_phases=self._model.phaseCount(), - indeterminate=self._model.isIndeterminate(), - progress=self._progress(), - iteration=iter_, - snapshot=copy.deepcopy(snapshot), - ) - elif event["type"] == EVTYPE_EE_SNAPSHOT_UPDATE: - iter_ = event.data["iter"] - if iter_ not in self._iter_snapshot: - raise OutOfOrderSnapshotUpdateException( - f"got {EVTYPE_EE_SNAPSHOT_UPDATE} without having stored " - f"snapshot for iter {iter_}" - ) - partial = PartialSnapshot(self._iter_snapshot[iter_]).from_cloudevent( - event - ) - self._iter_snapshot[iter_].merge_event(partial) - yield SnapshotUpdateEvent( - phase_name=self._model.getPhaseName(), - current_phase=self._model.currentPhase(), - total_phases=self._model.phaseCount(), - indeterminate=self._model.isIndeterminate(), - progress=self._progress(), - iteration=iter_, - partial_snapshot=partial, - ) - self._work_queue.task_done() - - def is_finished(self) -> bool: - return not self._drainer_thread.is_alive() - - def _progress(self) -> float: - """Fraction of completed iterations over total iterations""" - - if self.is_finished(): - return 1.0 - elif not self._iter_snapshot: - return 0.0 - else: - # Calculate completed realizations - current_iter = max(list(self._iter_snapshot.keys())) - done_reals = 0 - all_reals = self._iter_snapshot[current_iter].reals - if not all_reals: - # Empty ensemble or all realizations deactivated - return 1.0 - for real in all_reals.values(): - if real.status in [ - REALIZATION_STATE_FINISHED, - REALIZATION_STATE_FAILED, - ]: - done_reals += 1 - real_progress = float(done_reals) / len(all_reals) - - return ( - (current_iter + real_progress) / self._model.phaseCount() - if self._model.phaseCount() != 1 - else real_progress - ) - - def _clear_work_queue(self) -> None: - with contextlib.suppress(queue.Empty): - while True: - self._work_queue.get_nowait() - self._work_queue.task_done() - - def request_termination(self) -> None: - logger = logging.getLogger("ert.ensemble_evaluator.tracker") - # There might be some situations where the - # evaluation is finished or the evaluation - # is yet to start when calling this function. - # In these cases the monitor is not started - # - # To avoid waiting too long we exit if we are not - # able to connect to the monitor after 2 tries - # - # See issue: https://github.com/equinor/ert/issues/1250 - # - try: - logger.debug("requesting termination...") - get_event_loop().run_until_complete( - wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) - ) - logger.debug("requested termination") - except ClientError as e: - logger.warning(f"{__name__} - exception {e}") - return - - with Monitor(self._ee_con_info) as monitor: - monitor.signal_cancel() - while self._drainer_thread.is_alive(): - self._clear_work_queue() - time.sleep(1) diff --git a/src/ert/gui/simulation/tracker_worker.py b/src/ert/gui/simulation/queue_emitter.py similarity index 68% rename from src/ert/gui/simulation/tracker_worker.py rename to src/ert/gui/simulation/queue_emitter.py index 00a3f91dac0..6d1a996dce8 100644 --- a/src/ert/gui/simulation/tracker_worker.py +++ b/src/ert/gui/simulation/queue_emitter.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Iterator, Union +from queue import SimpleQueue from qtpy.QtCore import QObject, Signal, Slot @@ -9,40 +9,39 @@ logger = logging.getLogger(__name__) -class TrackerWorker(QObject): - """A worker that consumes events produced by a tracker and emits them to qt - subscribers.""" +class QueueEmitter(QObject): + """A worker that emits items put on a queue to qt subscribers.""" - consumed_event = Signal(object) + new_event = Signal(object) done = Signal() def __init__( self, - event_generator_factory: Callable[ - [], Iterator[Union[FullSnapshotEvent, SnapshotUpdateEvent, EndEvent]] - ], + event_queue: SimpleQueue, parent=None, ): super().__init__(parent) - logger.debug("init trackerworker") - self._tracker = event_generator_factory + logger.debug("init QueueEmitter") + self._event_queue = event_queue self._stopped = False @Slot() def consume_and_emit(self): logger.debug("tracking...") - for event in self._tracker(): + while True: + event = self._event_queue.get() if self._stopped: logger.debug("stopped") break + # pre-rendering in this thread to avoid work in main rendering thread if isinstance(event, FullSnapshotEvent) and event.snapshot: SnapshotModel.prerender(event.snapshot) elif isinstance(event, SnapshotUpdateEvent) and event.partial_snapshot: SnapshotModel.prerender(event.partial_snapshot) logger.debug(f"emit {event}") - self.consumed_event.emit(event) + self.new_event.emit(event) if isinstance(event, EndEvent): logger.debug("got end event") diff --git a/src/ert/gui/simulation/run_dialog.py b/src/ert/gui/simulation/run_dialog.py index 723e588a3dd..2e506164c9c 100644 --- a/src/ert/gui/simulation/run_dialog.py +++ b/src/ert/gui/simulation/run_dialog.py @@ -1,4 +1,5 @@ import logging +from queue import SimpleQueue from threading import Thread from typing import Optional @@ -23,7 +24,6 @@ from ert.ensemble_evaluator import ( EndEvent, EvaluatorServerConfig, - EvaluatorTracker, FullSnapshotEvent, SnapshotUpdateEvent, ) @@ -43,7 +43,7 @@ ) from ert.shared.status.utils import format_running_time -from .tracker_worker import TrackerWorker +from .queue_emitter import QueueEmitter from .view import LegendView, ProgressView, RealizationWidget, UpdateWidget _TOTAL_PROGRESS_TEMPLATE = "Total progress {total_progress}% — {phase_name}" @@ -57,6 +57,7 @@ def __init__( self, config_file: str, run_model: BaseRunModel, + event_queue: SimpleQueue, notifier: ErtNotifier, parent=None, ): @@ -68,6 +69,7 @@ def __init__( self._snapshot_model = SnapshotModel(self) self._run_model = run_model + self._event_queue = event_queue self._notifier = notifier self._isDetailedDialog = False @@ -176,7 +178,6 @@ def __init__( self._setSimpleDialog() self.finished.connect(self._on_finished) - self._run_model.add_send_event_callback(self.on_run_model_event.emit) self.on_run_model_event.connect(self._on_event) def _current_tab_changed(self, index: int) -> None: @@ -269,37 +270,27 @@ def startSimulation(self): port_range = range(49152, 51819) evaluator_server_config = EvaluatorServerConfig(custom_port_range=port_range) - def run(): - self._run_model.startSimulations( - evaluator_server_config=evaluator_server_config, - ) - simulation_thread = Thread( - name="ert_gui_simulation_thread", target=run, daemon=True - ) - simulation_thread.start() - - self._ticker.start(1000) - - self._tracker = EvaluatorTracker( - self._run_model, - ee_con_info=evaluator_server_config.get_connection_info(), + name="ert_gui_simulation_thread", + target=self._run_model.start_simulations_thread, + daemon=True, + args=(evaluator_server_config,), ) - worker = TrackerWorker(self._tracker.track) + worker = QueueEmitter(self._event_queue) worker_thread = QThread() + self._worker = worker + self._worker_thread = worker_thread + worker.done.connect(worker_thread.quit) - worker.consumed_event.connect(self._on_event) + worker.new_event.connect(self._on_event) worker.moveToThread(worker_thread) self.simulation_done.connect(worker.stop) - self._worker = worker - self._worker_thread = worker_thread worker_thread.started.connect(worker.consume_and_emit) - # _worker_thread is finished once everything has stopped. We wait to - # show the done button to this point to avoid destroying the QThread - # while it is running (which would sigabrt) - self._worker_thread.finished.connect(self._show_done_button) + + self._ticker.start(1000) self._worker_thread.start() + simulation_thread.start() self._notifier.set_is_simulation_running(True) def killJobs(self): @@ -311,9 +302,7 @@ def killJobs(self): if kill_job == QMessageBox.Yes: # Normally this slot would be invoked by the signal/slot system, # but the worker is busy tracking the evaluation. - self._tracker.request_termination() - self._worker_thread.quit() - self._worker_thread.wait() + self._run_model.cancel() self._on_finished() self.finished.emit(-1) return kill_job @@ -349,9 +338,8 @@ def _on_ticker(self): def _on_event(self, event: object): if isinstance(event, EndEvent): self.simulation_done.emit(event.failed, event.failed_msg) - self._worker.stop() self._ticker.stop() - + self._show_done_button() elif isinstance(event, FullSnapshotEvent): if event.snapshot is not None: self._snapshot_model._add_snapshot(event.snapshot, event.iteration) diff --git a/src/ert/gui/simulation/simulation_panel.py b/src/ert/gui/simulation/simulation_panel.py index a544b5af6d2..ef40afd722c 100644 --- a/src/ert/gui/simulation/simulation_panel.py +++ b/src/ert/gui/simulation/simulation_panel.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from queue import SimpleQueue from typing import Any, Dict from qtpy.QtCore import QSize, Qt @@ -181,11 +182,13 @@ def runSimulation(self): abort = False QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor) config = self.facade.config + event_queue = SimpleQueue() try: model = create_model( config, self._notifier.storage, args, + event_queue, ) except ValueError as e: @@ -245,7 +248,7 @@ def runSimulation(self): QApplication.restoreOverrideCursor() dialog = RunDialog( - self._config_file, model, self._notifier, self.parent() + self._config_file, model, event_queue, self._notifier, self.parent() ) self.run_button.setEnabled(False) self.run_button.setText(EXPERIMENT_IS_RUNNING_BUTTON_MESSAGE) diff --git a/src/ert/gui/tools/run_analysis/run_analysis_tool.py b/src/ert/gui/tools/run_analysis/run_analysis_tool.py index 4bf823d2bf9..3f0b1f932e4 100644 --- a/src/ert/gui/tools/run_analysis/run_analysis_tool.py +++ b/src/ert/gui/tools/run_analysis/run_analysis_tool.py @@ -63,7 +63,7 @@ def run(self): update_settings, config.analysis_config.es_module, rng, - self.smoother_event_callback, + self.send_smoother_event, log_path=config.analysis_config.log_path, ) except ErtAnalysisError as e: @@ -73,7 +73,7 @@ def run(self): self.finished.emit(error, self._source_fs.name) - def smoother_event_callback(self, event: AnalysisEvent) -> None: + def send_smoother_event(self, event: AnalysisEvent) -> None: if isinstance(event, AnalysisStatusEvent): self.progress_update.emit(RunModelStatusEvent(iteration=0, msg=event.msg)) elif isinstance(event, AnalysisTimeEvent): diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index b42636dc449..c86f06c0d45 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import logging import os import shutil @@ -7,9 +8,9 @@ import uuid from contextlib import contextmanager from pathlib import Path +from queue import SimpleQueue from typing import ( TYPE_CHECKING, - Callable, Dict, Generator, List, @@ -20,6 +21,9 @@ ) import numpy as np +from aiohttp import ClientError +from cloudevents.http import CloudEvent +from websockets.exceptions import ConnectionClosedError from ert.analysis import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent from ert.cli import MODULE_MODE @@ -30,8 +34,28 @@ EnsembleBuilder, EnsembleEvaluator, EvaluatorServerConfig, + Monitor, RealizationBuilder, ) +from ert.ensemble_evaluator.event import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, +) +from ert.ensemble_evaluator.identifiers import ( + EVTYPE_EE_SNAPSHOT, + EVTYPE_EE_SNAPSHOT_UPDATE, + EVTYPE_EE_TERMINATED, + STATUS, +) +from ert.ensemble_evaluator.snapshot import PartialSnapshot, Snapshot +from ert.ensemble_evaluator.state import ( + ENSEMBLE_STATE_CANCELLED, + ENSEMBLE_STATE_FAILED, + ENSEMBLE_STATE_STOPPED, + REALIZATION_STATE_FAILED, + REALIZATION_STATE_FINISHED, +) from ert.libres_facade import LibresFacade from ert.run_context import RunContext from ert.runpaths import Runpaths @@ -40,6 +64,8 @@ from .event import ( RunModelStatusEvent, RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, ) event_logger = logging.getLogger("ert.event_log") @@ -48,6 +74,23 @@ from ert.config import QueueConfig from ert.run_models.run_arguments import RunArgumentsType +StatusEvents = Union[ + FullSnapshotEvent, + SnapshotUpdateEvent, + EndEvent, + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, + RunModelStatusEvent, + RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, +] + + +class OutOfOrderSnapshotUpdateException(ValueError): + pass + class ErtRunError(Exception): pass @@ -87,6 +130,7 @@ def __init__( config: ErtConfig, storage: StorageAccessor, queue_config: QueueConfig, + status_queue: SimpleQueue, phase_count: int = 1, ): """ @@ -134,16 +178,14 @@ def __init__( filename=str(config.runpath_file), substitute=self.substitution_list.substitute_real_iter, ) - self._send_event_callback: Optional[Callable[[object], None]] = None - - def add_send_event_callback(self, func: Callable[[object], None]) -> None: - self._send_event_callback = func + self._iter_snapshot: Dict[int, Snapshot] = {} + self._status_queue = status_queue + self._end_queue = SimpleQueue() - def send_event(self, event: object) -> None: - if self._send_event_callback: - self._send_event_callback(event) + def send_event(self, event: StatusEvents) -> None: + self._status_queue.put(event) - def smoother_event_callback(self, iteration: int, event: AnalysisEvent) -> None: + def send_smoother_event(self, iteration: int, event: AnalysisEvent) -> None: if isinstance(event, AnalysisStatusEvent): self.send_event(RunModelStatusEvent(iteration=iteration, msg=event.msg)) elif isinstance(event, AnalysisTimeEvent): @@ -167,6 +209,9 @@ def simulation_arguments(self) -> RunArgumentsType: def _ensemble_size(self) -> int: return len(self._initial_realizations_mask) + def cancel(self) -> None: + self._end_queue.put("END") + def reset(self) -> None: self._failed = False self._error_messages = [] @@ -315,6 +360,7 @@ def reraise_exception(self, exctype: Type[Exception]) -> None: def _simulationEnded(self) -> None: self._clean_env_context() self._job_stop_time = int(time.time()) + self.send_end_event() def setPhase( self, phase: int, phase_name: str, indeterminate: Optional[bool] = None @@ -364,18 +410,159 @@ def checkHaveSufficientRealizations( f"MIN_REALIZATIONS to allow (more) failures in your experiments." ) + def _progress(self) -> float: + """Fraction of completed iterations over total iterations""" + + if self.isFinished(): + return 1.0 + elif not self._iter_snapshot: + return 0.0 + else: + # Calculate completed realizations + current_iter = max(list(self._iter_snapshot.keys())) + done_reals = 0 + all_reals = self._iter_snapshot[current_iter].reals + if not all_reals: + # Empty ensemble or all realizations deactivated + return 1.0 + for real in all_reals.values(): + if real.status in [ + REALIZATION_STATE_FINISHED, + REALIZATION_STATE_FAILED, + ]: + done_reals += 1 + real_progress = float(done_reals) / len(all_reals) + + return ( + (current_iter + real_progress) / self.phaseCount() + if self.phaseCount() != 1 + else real_progress + ) + + def send_end_event(self) -> None: + self.send_event( + EndEvent( + failed=self.hasRunFailed(), + failed_msg=self.getFailMessage(), + ) + ) + + def send_snapshot_event(self, event: CloudEvent) -> None: + if event["type"] == EVTYPE_EE_SNAPSHOT: + iter_ = event.data["iter"] + snapshot = Snapshot(event.data) + self._iter_snapshot[iter_] = snapshot + self.send_event( + FullSnapshotEvent( + phase_name=self.getPhaseName(), + current_phase=self.currentPhase(), + total_phases=self.phaseCount(), + indeterminate=self.isIndeterminate(), + progress=self._progress(), + iteration=iter_, + snapshot=copy.deepcopy(snapshot), + ) + ) + elif event["type"] == EVTYPE_EE_SNAPSHOT_UPDATE: + iter_ = event.data["iter"] + if iter_ not in self._iter_snapshot: + raise OutOfOrderSnapshotUpdateException( + f"got {EVTYPE_EE_SNAPSHOT_UPDATE} without having stored " + f"snapshot for iter {iter_}" + ) + partial = PartialSnapshot(self._iter_snapshot[iter_]).from_cloudevent(event) + self._iter_snapshot[iter_].merge_event(partial) + self.send_event( + SnapshotUpdateEvent( + phase_name=self.getPhaseName(), + current_phase=self.currentPhase(), + total_phases=self.phaseCount(), + indeterminate=self.isIndeterminate(), + progress=self._progress(), + iteration=iter_, + partial_snapshot=partial, + ) + ) + def run_ensemble_evaluator( self, run_context: RunContext, ee_config: EvaluatorServerConfig ) -> List[int]: + if not self._end_queue.empty(): + event_logger.debug("Run model canceled - pre evaluation") + self._end_queue.get() + return [] ensemble = self._build_ensemble(run_context) - - successful_realizations = EnsembleEvaluator( + evaluator = EnsembleEvaluator( ensemble, ee_config, run_context.iteration, - ).run_and_get_successful_realizations() - - return successful_realizations + ) + evaluator.start_running() + should_exit = False + while not should_exit: + try: + event_logger.debug("connecting to new monitor...") + with Monitor(ee_config.get_connection_info()) as monitor: + event_logger.debug("connected") + for event in monitor.track(): + if event["type"] in ( + EVTYPE_EE_SNAPSHOT, + EVTYPE_EE_SNAPSHOT_UPDATE, + ): + self.send_snapshot_event(event) + if event.data.get(STATUS) in [ + ENSEMBLE_STATE_STOPPED, + ENSEMBLE_STATE_FAILED, + ]: + event_logger.debug( + "observed evaluation stopped event, signal done" + ) + monitor.signal_done() + should_exit = True + if event.data.get(STATUS) == ENSEMBLE_STATE_CANCELLED: + event_logger.debug( + "observed evaluation cancelled event, exit drainer" + ) + # Allow track() to emit an EndEvent. + return [] + elif event["type"] == EVTYPE_EE_TERMINATED: + event_logger.debug("got terminator event") + + if not self._end_queue.empty(): + event_logger.debug("Run model canceled - during evaluation") + self._end_queue.get() + monitor.signal_cancel() + event_logger.debug( + "Run model canceled - during evaluation - cancel sent" + ) + # This sleep needs to be there. Refer to issue #1250: `Authority + # on information about evaluations/experiments` + # time.sleep(self._next_ensemble_evaluator_wait_time) + except (ConnectionRefusedError, ClientError) as e: + if not self.isFinished(): + event_logger.debug(f"connection refused: {e}") + except ConnectionClosedError as e: + # The monitor connection closed unexpectedly + event_logger.debug(f"connection closed error: {e}") + except BaseException: + event_logger.exception("unexpected error: ") + # We really don't know what happened... shut down + # the thread and get out of here. The monitor has + # been stopped by the ctx-mgr + return [] + + event_logger.debug( + "observed that model was finished, waiting tasks completion..." + ) + # The model has finished, we indicate this by sending a DONE + event_logger.debug("tasks complete") + + evaluator.join() + if not self._end_queue.empty(): + event_logger.debug("Run model canceled - post evaluation") + self._end_queue.get() + return [] + return evaluator.get_successful_realizations() def _build_ensemble( self, diff --git a/src/ert/run_models/ensemble_experiment.py b/src/ert/run_models/ensemble_experiment.py index 31d23964b5a..10bb59f3cc1 100644 --- a/src/ert/run_models/ensemble_experiment.py +++ b/src/ert/run_models/ensemble_experiment.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from queue import SimpleQueue from typing import TYPE_CHECKING, Union import numpy as np @@ -32,12 +33,14 @@ def __init__( config: ErtConfig, storage: StorageAccessor, queue_config: QueueConfig, + status_queue: SimpleQueue, ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, ) def runSimulations__( diff --git a/src/ert/run_models/ensemble_smoother.py b/src/ert/run_models/ensemble_smoother.py index f5fd431d145..f2b8b3cbf95 100644 --- a/src/ert/run_models/ensemble_smoother.py +++ b/src/ert/run_models/ensemble_smoother.py @@ -2,6 +2,7 @@ import functools import logging +from queue import SimpleQueue from typing import TYPE_CHECKING import numpy as np @@ -35,12 +36,14 @@ def __init__( queue_config: QueueConfig, es_settings: ESSettings, update_settings: UpdateSettings, + status_queue: SimpleQueue, ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, phase_count=2, ) self.es_settings = es_settings @@ -130,7 +133,7 @@ def run_experiment( es_settings=self.es_settings, updatestep=self.ert.update_configuration, rng=self.rng, - progress_callback=functools.partial(self.smoother_event_callback, 0), + progress_callback=functools.partial(self.send_smoother_event, 0), log_path=self.ert_config.analysis_config.log_path, ) diff --git a/src/ert/run_models/iterated_ensemble_smoother.py b/src/ert/run_models/iterated_ensemble_smoother.py index b37e2e2da5d..bbabda0cb5d 100644 --- a/src/ert/run_models/iterated_ensemble_smoother.py +++ b/src/ert/run_models/iterated_ensemble_smoother.py @@ -2,6 +2,7 @@ import functools import logging +from queue import SimpleQueue from typing import TYPE_CHECKING import numpy as np @@ -40,12 +41,14 @@ def __init__( queue_config: QueueConfig, analysis_config: IESSettings, update_settings: UpdateSettings, + status_queue: SimpleQueue, ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, phase_count=2, ) self.support_restart = False @@ -96,7 +99,7 @@ def analyzeStep( initial_mask=initial_mask, rng=self.rng, progress_callback=functools.partial( - self.smoother_event_callback, iteration + self.send_smoother_event, iteration ), log_path=self.ert_config.analysis_config.log_path, ) diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index cab0dc9ff27..9c0911605c1 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -2,6 +2,7 @@ import functools import logging +from queue import SimpleQueue from typing import TYPE_CHECKING, List, Optional import numpy as np @@ -42,12 +43,14 @@ def __init__( prior_ensemble: Optional[EnsembleAccessor], es_settings: ESSettings, update_settings: UpdateSettings, + status_queue: SimpleQueue, ): super().__init__( simulation_arguments, config, storage, queue_config, + status_queue, phase_count=2, ) self.weights = MultipleDataAssimilation.default_weights @@ -208,7 +211,7 @@ def update( global_scaling=weight, rng=self.rng, progress_callback=functools.partial( - self.smoother_event_callback, prior_context.iteration + self.send_smoother_event, prior_context.iteration ), log_path=self.ert_config.analysis_config.log_path, ) diff --git a/src/ert/run_models/single_test_run.py b/src/ert/run_models/single_test_run.py index fcde4b5b29f..13ea2299b3c 100644 --- a/src/ert/run_models/single_test_run.py +++ b/src/ert/run_models/single_test_run.py @@ -18,9 +18,12 @@ def __init__( simulation_arguments: SingleTestRunArguments, config: ErtConfig, storage: StorageAccessor, + status_queue, ): local_queue_config = config.queue_config.create_local_copy() - super().__init__(simulation_arguments, config, storage, local_queue_config) + super().__init__( + simulation_arguments, config, storage, local_queue_config, status_queue + ) @staticmethod def checkHaveSufficientRealizations( diff --git a/tests/integration_tests/cli/test_integration_cli.py b/tests/integration_tests/cli/test_integration_cli.py index 4394d40a4f7..6d95a9f8dcd 100644 --- a/tests/integration_tests/cli/test_integration_cli.py +++ b/tests/integration_tests/cli/test_integration_cli.py @@ -3,17 +3,15 @@ import fileinput import os import shutil -import threading from argparse import ArgumentParser from pathlib import Path from textwrap import dedent -from unittest.mock import Mock, call +from unittest.mock import Mock import numpy as np import pandas as pd import pytest -import ert.shared from ert import LibresFacade from ert.__main__ import ert_parser from ert.cli import ( @@ -30,17 +28,6 @@ from ert.storage import open_storage -@pytest.fixture(name="mock_cli_run") -def fixture_mock_cli_run(monkeypatch): - mocked_monitor = Mock() - mocked_thread_start = Mock() - mocked_thread_join = Mock() - monkeypatch.setattr(threading.Thread, "start", mocked_thread_start) - monkeypatch.setattr(threading.Thread, "join", mocked_thread_join) - monkeypatch.setattr(ert.cli.monitor.Monitor, "monitor", mocked_monitor) - yield mocked_monitor, mocked_thread_join, mocked_thread_start - - @pytest.mark.scheduler @pytest.mark.integration_test def test_ensemble_evaluator(tmpdir, source_root, try_queue_and_scheduler, monkeypatch): @@ -190,9 +177,7 @@ def test_ensemble_evaluator_disable_monitoring( @pytest.mark.scheduler @pytest.mark.integration_test -def test_cli_test_run( - tmpdir, source_root, mock_cli_run, try_queue_and_scheduler, monkeypatch -): +def test_cli_test_run(tmpdir, source_root, try_queue_and_scheduler, monkeypatch): shutil.copytree( os.path.join(source_root, "test-data", "poly_example"), os.path.join(str(tmpdir), "poly_example"), @@ -203,11 +188,6 @@ def test_cli_test_run( parsed = ert_parser(parser, [TEST_RUN_MODE, "poly_example/poly.ert"]) run_cli(parsed) - monitor_mock, thread_join_mock, thread_start_mock = mock_cli_run - monitor_mock.assert_called_once() - thread_join_mock.assert_called_once() - thread_start_mock.assert_has_calls([[call(), call()]]) - @pytest.mark.scheduler @pytest.mark.integration_test diff --git a/tests/integration_tests/status/test_tracking_integration.py b/tests/integration_tests/status/test_tracking_integration.py index 596d6a65f9a..f3865a29b09 100644 --- a/tests/integration_tests/status/test_tracking_integration.py +++ b/tests/integration_tests/status/test_tracking_integration.py @@ -7,6 +7,7 @@ from argparse import ArgumentParser from datetime import datetime from pathlib import Path +from queue import SimpleQueue from textwrap import dedent import pytest @@ -17,7 +18,6 @@ from ert.cli import ENSEMBLE_EXPERIMENT_MODE, ENSEMBLE_SMOOTHER_MODE, TEST_RUN_MODE from ert.cli.model_factory import create_model from ert.config import ErtConfig -from ert.ensemble_evaluator import EvaluatorTracker from ert.ensemble_evaluator.config import EvaluatorServerConfig from ert.ensemble_evaluator.event import ( EndEvent, @@ -33,6 +33,19 @@ from ert.shared.feature_toggling import FeatureToggling +class Events: + def __init__(self): + self.events = [] + self.environment = [] + + def __iter__(self): + yield from self.events + + def put(self, event): + self.events.append(event) + self.environment.append(os.environ.copy()) + + def check_expression(original, path_expression, expected, msg_start): assert isinstance(original, dict), f"{msg_start}data is not a dict" jsonpath_expr = parse(path_expression) @@ -185,10 +198,12 @@ def test_tracking( ert_config = ErtConfig.from_file(parsed.config) os.chdir(ert_config.config_path) + queue = SimpleQueue() model = create_model( ert_config, storage, parsed, + queue, ) evaluator_server_config = EvaluatorServerConfig( @@ -205,14 +220,12 @@ def test_tracking( ) thread.start() - tracker = EvaluatorTracker( - model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) - snapshots = {} - for event in tracker.track(): + thread.join() + + while not queue.empty(): + event = queue.get() if isinstance(event, FullSnapshotEvent): snapshots[event.iteration] = event.snapshot if ( @@ -223,7 +236,7 @@ def test_tracking( if isinstance(event, EndEvent): pass - assert tracker._progress() == progress + # assert tracker._progress() == progress assert len(snapshots) == num_iters for snapshot in snapshots.values(): @@ -248,7 +261,6 @@ def test_tracking( expected, f"Snapshot {i} did not match:\n", ) - thread.join() FeatureToggling.reset() @@ -302,10 +314,12 @@ def test_setting_env_context_during_run( ert_config = ErtConfig.from_file(parsed.config) os.chdir(ert_config.config_path) + queue = Events() model = create_model( ert_config, storage, parsed, + queue, ) evaluator_server_config = EvaluatorServerConfig( @@ -321,22 +335,16 @@ def test_setting_env_context_during_run( args=(evaluator_server_config,), ) thread.start() - - tracker = EvaluatorTracker( - model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) + thread.join() expected = ["_ERT_SIMULATION_MODE", "_ERT_EXPERIMENT_ID", "_ERT_ENSEMBLE_ID"] - for event in tracker.track(): + for event, environment in zip(queue.events, queue.environment): if isinstance(event, (FullSnapshotEvent, SnapshotUpdateEvent)): - assert model._context_env_keys == expected for key in expected: - assert key in os.environ - assert os.environ.get("_ERT_SIMULATION_MODE") == mode + assert key in environment + assert environment.get("_ERT_SIMULATION_MODE") == mode if isinstance(event, EndEvent): pass - thread.join() # Check environment is clean after the model run ends. assert not model._context_env_keys @@ -389,11 +397,12 @@ def test_tracking_missing_ecl( ert_config = ErtConfig.from_file(parsed.config) os.chdir(ert_config.config_path) - + events = Events() model = create_model( ert_config, storage, parsed, + events, ) evaluator_server_config = EvaluatorServerConfig( @@ -410,15 +419,10 @@ def test_tracking_missing_ecl( ) with caplog.at_level(logging.ERROR): thread.start() - - tracker = EvaluatorTracker( - model, - ee_con_info=evaluator_server_config.get_connection_info(), - ) - + thread.join() failures = [] - for event in tracker.track(): + for event in events: if isinstance(event, EndEvent): failures.append(event) assert ( @@ -441,5 +445,4 @@ def test_tracking_missing_ecl( "iter-0/ECLIPSE_CASE" ) in failures[0].failed_msg - thread.join() FeatureToggling.reset() diff --git a/tests/performance_tests/test_snapshot.py b/tests/performance_tests/test_snapshot.py index 4ed27cf8a2a..3e0293620e7 100644 --- a/tests/performance_tests/test_snapshot.py +++ b/tests/performance_tests/test_snapshot.py @@ -17,8 +17,6 @@ from ..unit_tests.gui.conftest import ( # noqa: F401 active_realizations_fixture, large_snapshot, - mock_tracker, - runmodel, ) from ..unit_tests.gui.simulation.test_run_dialog import test_large_snapshot @@ -45,18 +43,14 @@ def test_snapshot_handling_of_forward_model_events( def test_gui_snapshot( benchmark, - runmodel, # noqa: F811 large_snapshot, # noqa: F811 qtbot, - mock_tracker, # noqa: F811 ): infinite_timeout = 100000 benchmark( test_large_snapshot, - runmodel, large_snapshot, qtbot, - mock_tracker, timeout_per_iter=infinite_timeout, ) diff --git a/tests/unit_tests/cli/test_model_factory.py b/tests/unit_tests/cli/test_model_factory.py index 375afa4102b..5e820c07ca0 100644 --- a/tests/unit_tests/cli/test_model_factory.py +++ b/tests/unit_tests/cli/test_model_factory.py @@ -45,13 +45,8 @@ def test_custom_realizations(poly_case): args = Namespace(realizations="0-4,7,8") ensemble_size = facade.get_ensemble_size() active_mask = [False] * ensemble_size - active_mask[0] = True - active_mask[1] = True - active_mask[2] = True - active_mask[3] = True - active_mask[4] = True - active_mask[7] = True - active_mask[8] = True + active_mask[0:5] = [True] * 5 + active_mask[7:9] = [True] * 2 assert model_factory._realizations(args, ensemble_size).tolist() == active_mask @@ -60,6 +55,7 @@ def test_setup_single_test_run(poly_case, storage): poly_case, storage, Namespace(current_case="default", target_case=None, random_seed=None), + MagicMock(), ) assert isinstance(model, SingleTestRun) assert model.simulation_arguments.current_case == "default" @@ -79,6 +75,7 @@ def test_setup_ensemble_experiment(poly_case, storage): poly_case, storage, args, + MagicMock(), ) assert isinstance(model, EnsembleExperiment) @@ -94,7 +91,7 @@ def test_setup_ensemble_smoother(poly_case, storage): ) model = model_factory._setup_ensemble_smoother( - poly_case, storage, args, MagicMock() + poly_case, storage, args, MagicMock(), MagicMock() ) assert isinstance(model, EnsembleSmoother) assert model.simulation_arguments.current_case == "default" @@ -115,7 +112,7 @@ def test_setup_multiple_data_assimilation(poly_case, storage): ) model = model_factory._setup_multiple_data_assimilation( - poly_case, storage, args, MagicMock() + poly_case, storage, args, MagicMock(), MagicMock() ) assert isinstance(model, MultipleDataAssimilation) assert model.simulation_arguments.weights == "6,4,2" @@ -137,7 +134,7 @@ def test_setup_iterative_ensemble_smoother(poly_case, storage): ) model = model_factory._setup_iterative_ensemble_smoother( - poly_case, storage, args, MagicMock() + poly_case, storage, args, MagicMock(), MagicMock() ) assert isinstance(model, IteratedEnsembleSmoother) assert model.simulation_arguments.current_case == "default" diff --git a/tests/unit_tests/cli/test_model_hook_order.py b/tests/unit_tests/cli/test_model_hook_order.py index 6cf2fa1aab8..1849020ee87 100644 --- a/tests/unit_tests/cli/test_model_hook_order.py +++ b/tests/unit_tests/cli/test_model_hook_order.py @@ -64,6 +64,7 @@ def test_hook_call_order_ensemble_smoother(monkeypatch): MagicMock(), MagicMock(), MagicMock(), + MagicMock(), ) test_class.ert = ert_mock test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) @@ -110,6 +111,7 @@ def test_hook_call_order_es_mda(monkeypatch): prior_ensemble=None, es_settings=MagicMock(), update_settings=MagicMock(), + status_queue=MagicMock(), ) ert_mock.runWorkflows = MagicMock() test_class.ert = ert_mock @@ -150,6 +152,7 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): MagicMock(), MagicMock(), MagicMock(), + MagicMock(), ) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) test_class.ert = ert_mock diff --git a/tests/unit_tests/cli/test_run_context.py b/tests/unit_tests/cli/test_run_context.py index be3eaacc4f5..3b9d513fb42 100644 --- a/tests/unit_tests/cli/test_run_context.py +++ b/tests/unit_tests/cli/test_run_context.py @@ -45,6 +45,7 @@ def test_that_all_iterations_gets_correct_name_and_iteration_number( prior_ensemble=None, es_settings=MagicMock(), update_settings=MagicMock(), + status_queue=MagicMock(), ) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) test_class.run_experiment(MagicMock()) diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index 3c5efb70ec5..b8233fcd354 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -16,7 +16,7 @@ def test_new_monitor_can_pick_up_where_we_left_off(evaluator): - evaluator._start_running() + evaluator.start_running() token = evaluator._config.token cert = evaluator._config.cert url = evaluator._config.url @@ -121,7 +121,7 @@ def test_new_monitor_can_pick_up_where_we_left_off(evaluator): def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_evaluator( evaluator, ): - evaluator._start_running() + evaluator.start_running() conn_info = evaluator._config.get_connection_info() with Monitor(conn_info) as monitor: events = monitor.track() @@ -216,7 +216,7 @@ def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_evaluat def test_ensure_multi_level_events_in_order(evaluator): - evaluator._start_running() + evaluator.start_running() config_info = evaluator._config.get_connection_info() with Monitor(config_info) as monitor: events = monitor.track() @@ -277,7 +277,7 @@ def exploding_handler(events): evaluator._dispatcher.set_event_handler({"EXPLODING"}, exploding_handler) - evaluator._start_running() + evaluator.start_running() config_info = evaluator._config.get_connection_info() with Monitor(config_info) as monitor: diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index 387a3559fb7..34d8b18dd5f 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -29,7 +29,7 @@ def test_run_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypatch): generate_cert=False, ) evaluator = EnsembleEvaluator(ensemble, config, 0) - evaluator._start_running() + evaluator.start_running() with Monitor(config) as monitor: for e in monitor.track(): if e["type"] in ( @@ -68,7 +68,7 @@ def test_run_and_cancel_legacy_ensemble( evaluator = EnsembleEvaluator(ensemble, config, 0) - evaluator._start_running() + evaluator.start_running() with Monitor(config) as mon: cancel = True with contextlib.suppress( @@ -108,7 +108,7 @@ def test_run_legacy_ensemble_with_bare_exception( with patch.object(JobQueue, "add_realization") as faulty_queue: faulty_queue.side_effect = RuntimeError() - evaluator._start_running() + evaluator.start_running() with Monitor(config) as monitor: for e in monitor.track(): if e.data is not None and e.data.get(identifiers.STATUS) in [ diff --git a/tests/unit_tests/gui/conftest.py b/tests/unit_tests/gui/conftest.py index 3b597e85b4f..d6f52d70148 100644 --- a/tests/unit_tests/gui/conftest.py +++ b/tests/unit_tests/gui/conftest.py @@ -5,7 +5,6 @@ import os.path import shutil import stat -import time from datetime import datetime as dt from textwrap import dedent from typing import List, Type, TypeVar @@ -337,45 +336,6 @@ def active_realizations_fixture() -> Mock: return active_reals -@pytest.fixture -def runmodel(active_realizations) -> Mock: - brm = Mock() - brm.get_runtime = Mock(return_value=100) - brm.hasRunFailed = Mock(return_value=False) - brm.getFailMessage = Mock(return_value="") - brm.support_restart = True - brm._simulation_arguments = {"active_realizations": active_realizations} - brm.has_failed_realizations = lambda: False - return brm - - -class MockTracker: - def __init__(self, events) -> None: - self._events = events - self._is_running = True - - def track(self): - for event in self._events: - if not self._is_running: - break - time.sleep(0.1) - yield event - - def reset(self): - pass - - def request_termination(self): - self._is_running = False - - -@pytest.fixture -def mock_tracker(): - def _make_mock_tracker(events): - return MockTracker(events) - - return _make_mock_tracker - - def load_results_manually(qtbot, gui, case_name="default"): def handle_load_results_dialog(): dialog = wait_for_child(gui, qtbot, ClosableDialog) diff --git a/tests/unit_tests/gui/simulation/test_run_dialog.py b/tests/unit_tests/gui/simulation/test_run_dialog.py index 2878df44c5c..f78ebf67e48 100644 --- a/tests/unit_tests/gui/simulation/test_run_dialog.py +++ b/tests/unit_tests/gui/simulation/test_run_dialog.py @@ -1,6 +1,7 @@ import os from pathlib import Path -from unittest.mock import Mock, patch +from queue import SimpleQueue +from unittest.mock import MagicMock, Mock import pytest from pytestqt.qtbot import QtBot @@ -25,15 +26,15 @@ from tests.unit_tests.gui.simulation.test_run_path_dialog import handle_run_path_dialog -def test_success(runmodel, qtbot: QtBot, mock_tracker): +def test_success(qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker([EndEvent(failed=False, failed_msg="")]) - widget.startSimulation() + widget.startSimulation() + queue.put(EndEvent(failed=False, failed_msg="")) with qtbot.waitExposed(widget, timeout=30000): qtbot.waitUntil(lambda: widget._total_progress_bar.value() == 100) @@ -41,15 +42,15 @@ def test_success(runmodel, qtbot: QtBot, mock_tracker): assert widget.done_button.text() == "Done" -def test_kill_simulations(runmodel, qtbot: QtBot, mock_tracker): +def test_kill_simulations(qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker([EndEvent(failed=False, failed_msg="")]) - widget.startSimulation() + widget.startSimulation() + queue.put(EndEvent(failed=False, failed_msg="")) with qtbot.waitSignal(widget.finished, timeout=30000): @@ -78,16 +79,15 @@ def handle_dialog(): widget.killJobs() -def test_large_snapshot( - runmodel, large_snapshot, qtbot: QtBot, mock_tracker, timeout_per_iter=5000 -): +def test_large_snapshot(large_snapshot, qtbot: QtBot, timeout_per_iter=5000): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - iter_0 = FullSnapshotEvent( + events = [ + FullSnapshotEvent( snapshot=large_snapshot, phase_name="Foo", current_phase=0, @@ -95,8 +95,8 @@ def test_large_snapshot( progress=0.5, iteration=0, indeterminate=False, - ) - iter_1 = FullSnapshotEvent( + ), + FullSnapshotEvent( snapshot=large_snapshot, phase_name="Foo", current_phase=0, @@ -104,11 +104,13 @@ def test_large_snapshot( progress=0.5, iteration=1, indeterminate=False, - ) - tracker.return_value = mock_tracker( - [iter_0, iter_1, EndEvent(failed=False, failed_msg="")] - ) - widget.startSimulation() + ), + EndEvent(failed=False, failed_msg=""), + ] + + widget.startSimulation() + for event in events: + queue.put(event) with qtbot.waitExposed(widget, timeout=timeout_per_iter * 6): qtbot.waitUntil( @@ -317,15 +319,16 @@ def test_large_snapshot( ), ], ) -def test_run_dialog(events, tab_widget_count, runmodel, qtbot: QtBot, mock_tracker): +def test_run_dialog(events, tab_widget_count, qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("mock.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker(events) - widget.startSimulation() + widget.startSimulation() + for event in events: + queue.put(event) with qtbot.waitExposed(widget, timeout=30000): qtbot.mouseClick(widget.show_details_button, Qt.LeftButton) @@ -477,17 +480,16 @@ def handle_dialog(): ), ], ) -def test_run_dialog_memory_usage_showing( - events, tab_widget_count, runmodel, qtbot: QtBot, mock_tracker -): +def test_run_dialog_memory_usage_showing(events, tab_widget_count, qtbot: QtBot): notifier = Mock() - widget = RunDialog("poly.ert", runmodel, notifier) + queue = SimpleQueue() + widget = RunDialog("poly.ert", MagicMock(), queue, notifier) widget.show() qtbot.addWidget(widget) - with patch("ert.gui.simulation.run_dialog.EvaluatorTracker") as tracker: - tracker.return_value = mock_tracker(events) - widget.startSimulation() + widget.startSimulation() + for event in events: + queue.put(event) with qtbot.waitExposed(widget, timeout=30000): qtbot.mouseClick(widget.show_details_button, Qt.LeftButton) diff --git a/tests/unit_tests/run_models/test_ensemble_experiment.py b/tests/unit_tests/run_models/test_ensemble_experiment.py index a16154e6b18..78e9141b748 100644 --- a/tests/unit_tests/run_models/test_ensemble_experiment.py +++ b/tests/unit_tests/run_models/test_ensemble_experiment.py @@ -46,7 +46,7 @@ def get_run_path_mock(realizations, iteration=None): EnsembleExperiment.validate = MagicMock() ensemble_experiment = EnsembleExperiment( - simulation_arguments, MagicMock(), None, None + simulation_arguments, MagicMock(), None, None, MagicMock() ) ensemble_experiment.run_paths.get_paths = get_run_path_mock ensemble_experiment.facade = MagicMock(