Skip to content

Commit

Permalink
Add event queue and cancel functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Jan 25, 2024
1 parent a620ea3 commit bbf5aa7
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 52 deletions.
10 changes: 5 additions & 5 deletions src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import contextlib
import logging
import os
import queue
import sys
import threading
import time
from typing import Any, TextIO

from ert.cli import (
Expand Down Expand Up @@ -89,12 +89,15 @@ def run_cli(args: Namespace, _: Any = None) -> None:
observations=ert_config.observations,
)

status_queue = queue.SimpleQueue()

try:
model = create_model(
ert_config,
storage,
args,
experiment.id,
status_queue,
)
except ValueError as e:
raise ErtCliError(e) from e
Expand Down Expand Up @@ -132,12 +135,9 @@ def run_cli(args: Namespace, _: Any = None) -> None:
else:
out = sys.stderr
monitor = Monitor(out=out, color_always=args.color_always)
monitor.start()
model.add_send_event_callback(monitor.on_event)
thread.start()
try:
while not monitor.done:
time.sleep(0.5)
monitor.monitor(status_queue)
except (SystemExit, KeyboardInterrupt):
print("\nKilling simulations...")
# tracker.request_termination()
Expand Down
34 changes: 27 additions & 7 deletions src/ert/cli/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from queue import SimpleQueue
from typing import TYPE_CHECKING
from uuid import UUID

Expand Down Expand Up @@ -55,6 +56,7 @@ def create_model(
storage: StorageAccessor,
args: Namespace,
experiment_id: UUID,
status_queue: SimpleQueue,
) -> BaseRunModel:
logger = logging.getLogger(__name__)
logger.info(
Expand All @@ -75,28 +77,36 @@ def create_model(
)

if args.mode == TEST_RUN_MODE:
return _setup_single_test_run(config, storage, args, experiment_id)
return _setup_single_test_run(
config, storage, args, experiment_id, status_queue
)
elif args.mode == ENSEMBLE_EXPERIMENT_MODE:
return _setup_ensemble_experiment(config, storage, args, experiment_id)
return _setup_ensemble_experiment(
config, storage, args, experiment_id, status_queue
)
elif args.mode == ENSEMBLE_SMOOTHER_MODE:
return _setup_ensemble_smoother(
config, storage, args, experiment_id, update_settings
config, storage, args, experiment_id, update_settings, status_queue
)
elif args.mode == ES_MDA_MODE:
return _setup_multiple_data_assimilation(
config, storage, args, experiment_id, update_settings
config, storage, args, experiment_id, update_settings, status_queue
)
elif args.mode == ITERATIVE_ENSEMBLE_SMOOTHER_MODE:
return _setup_iterative_ensemble_smoother(
config, storage, args, experiment_id, update_settings
config, storage, args, experiment_id, update_settings, status_queue
)

else:
raise NotImplementedError(f"Run type not supported {args.mode}")


def _setup_single_test_run(
config: ErtConfig, storage: StorageAccessor, args: Namespace, experiment_id: UUID
config: ErtConfig,
storage: StorageAccessor,
args: Namespace,
experiment_id: UUID,
status_queue: SimpleQueue,
) -> SingleTestRun:
return SingleTestRun(
SingleTestRunArguments(
Expand All @@ -113,7 +123,11 @@ def _setup_single_test_run(


def _setup_ensemble_experiment(
config: ErtConfig, storage: StorageAccessor, args: Namespace, experiment_id: UUID
config: ErtConfig,
storage: StorageAccessor,
args: Namespace,
experiment_id: UUID,
status_queue: SimpleQueue,
) -> EnsembleExperiment:
min_realizations_count = config.analysis_config.minimum_required_realizations
active_realizations = _realizations(args, config.model_config.num_realizations)
Expand All @@ -140,6 +154,7 @@ def _setup_ensemble_experiment(
storage,
config.queue_config,
experiment_id,
status_queue,
)


Expand All @@ -149,6 +164,7 @@ def _setup_ensemble_smoother(
args: Namespace,
experiment_id: UUID,
update_settings: UpdateSettings,
status_queue: SimpleQueue,
) -> EnsembleSmoother:
return EnsembleSmoother(
ESRunArguments(
Expand All @@ -168,6 +184,7 @@ def _setup_ensemble_smoother(
experiment_id,
es_settings=config.analysis_config.es_module,
update_settings=update_settings,
status_queue=status_queue,
)


Expand All @@ -177,6 +194,7 @@ def _setup_multiple_data_assimilation(
args: Namespace,
experiment_id: UUID,
update_settings: UpdateSettings,
status_queue: SimpleQueue,
) -> MultipleDataAssimilation:
# Because the configuration of the CLI is different from the gui, we
# have a different way to get the restart information.
Expand Down Expand Up @@ -207,6 +225,7 @@ def _setup_multiple_data_assimilation(
prior_ensemble,
es_settings=config.analysis_config.es_module,
update_settings=update_settings,
status_queue=status_queue,
)


Expand All @@ -216,6 +235,7 @@ def _setup_iterative_ensemble_smoother(
args: Namespace,
id_: UUID,
update_settings: UpdateSettings,
status_queue: SimpleQueue,
) -> IteratedEnsembleSmoother:
return IteratedEnsembleSmoother(
SIESRunArguments(
Expand Down
36 changes: 18 additions & 18 deletions src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import sys
from datetime import datetime, timedelta
from queue import SimpleQueue
from typing import Dict, Optional, TextIO, Tuple

from tqdm import tqdm
Expand All @@ -18,7 +19,6 @@
FORWARD_MODEL_STATE_FAILURE,
REAL_STATE_TO_COLOR,
)
from ert.run_models.base_run_model import StatusEvents
from ert.shared.status.utils import format_running_time

Color = Tuple[int, int, int]
Expand Down Expand Up @@ -60,25 +60,25 @@ def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None
self.dot = ""
self.done = False

def start(self) -> None:
self._start_time = datetime.now()

def on_event(
def monitor(
self,
event: StatusEvents,
event_queue: SimpleQueue,
) -> None:
if isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
self._snapshots[event.iteration] = event.snapshot
self._progress = event.progress
elif isinstance(event, SnapshotUpdateEvent):
if event.partial_snapshot is not None:
self._snapshots[event.iteration].merge_event(event.partial_snapshot)
self._print_progress(event)
if isinstance(event, EndEvent):
self._print_result(event.failed, event.failed_msg)
self._print_job_errors()
self.done = True
self._start_time = datetime.now()
while True:
event = event_queue.get()
if isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
self._snapshots[event.iteration] = event.snapshot
self._progress = event.progress
elif isinstance(event, SnapshotUpdateEvent):
if event.partial_snapshot is not None:
self._snapshots[event.iteration].merge_event(event.partial_snapshot)
self._print_progress(event)
if isinstance(event, EndEvent):
self._print_result(event.failed, event.failed_msg)
self._print_job_errors()
return

def _print_job_errors(self) -> None:
failed_jobs: Dict[Optional[str], int] = {}
Expand Down
56 changes: 56 additions & 0 deletions src/ert/gui/simulation/queue_emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from queue import SimpleQueue

from qtpy.QtCore import QObject, Signal, Slot

from ert.ensemble_evaluator import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent
from ert.gui.model.snapshot import SnapshotModel

logger = logging.getLogger(__name__)


class QueueEmitter(QObject):
"""A worker that emits items put on a queue to qt subscribers."""

new_event = Signal(object)
done = Signal()

def __init__(
self,
event_queue: SimpleQueue,
parent=None,
):
super().__init__(parent)
logger.debug("init QueueEmitter")
self._event_queue = event_queue
self._stopped = False

@Slot()
def consume_and_emit(self):
logger.debug("tracking...")
while True:
event = self._event_queue.get()
if self._stopped:
logger.debug("stopped")
break

# pre-rendering in this thread to avoid work in main rendering thread
if isinstance(event, FullSnapshotEvent) and event.snapshot:
SnapshotModel.prerender(event.snapshot)
elif isinstance(event, SnapshotUpdateEvent) and event.partial_snapshot:
SnapshotModel.prerender(event.partial_snapshot)

logger.debug(f"emit {event}")
self.new_event.emit(event)

if isinstance(event, EndEvent):
logger.debug("got end event")
break

self.done.emit()
logger.debug("tracking done.")

@Slot()
def stop(self):
logger.debug("stopping...")
self._stopped = True
27 changes: 19 additions & 8 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from queue import SimpleQueue
from threading import Thread
from typing import Optional

from PyQt5.QtWidgets import QAbstractItemView
from qtpy.QtCore import QModelIndex, QSize, Qt, QTimer, Signal, Slot
from qtpy.QtCore import QModelIndex, QSize, Qt, QThread, QTimer, Signal, Slot
from qtpy.QtGui import QMovie
from qtpy.QtWidgets import (
QDialog,
Expand Down Expand Up @@ -42,6 +43,7 @@
)
from ert.shared.status.utils import format_running_time

from .queue_emitter import QueueEmitter
from .view import LegendView, ProgressView, RealizationWidget, UpdateWidget

_TOTAL_PROGRESS_TEMPLATE = "Total progress {total_progress}% — {phase_name}"
Expand All @@ -55,6 +57,7 @@ def __init__(
self,
config_file: str,
run_model: BaseRunModel,
event_queue: SimpleQueue,
notifier: ErtNotifier,
parent=None,
):
Expand All @@ -66,6 +69,7 @@ def __init__(

self._snapshot_model = SnapshotModel(self)
self._run_model = run_model
self._event_queue = event_queue
self._notifier = notifier

self._isDetailedDialog = False
Expand Down Expand Up @@ -174,7 +178,6 @@ def __init__(
self._setSimpleDialog()
self.finished.connect(self._on_finished)

self._run_model.add_send_event_callback(self.on_run_model_event.emit)
self.on_run_model_event.connect(self._on_event)

def _current_tab_changed(self, index: int) -> None:
Expand Down Expand Up @@ -274,8 +277,20 @@ def startSimulation(self):
args=(evaluator_server_config,),
)

simulation_thread.start()
worker = QueueEmitter(self._event_queue)
worker_thread = QThread()
self._worker = worker
self._worker_thread = worker_thread

worker.done.connect(worker_thread.quit)
worker.new_event.connect(self._on_event)
worker.moveToThread(worker_thread)
self.simulation_done.connect(worker.stop)
worker_thread.started.connect(worker.consume_and_emit)

self._ticker.start(1000)
self._worker_thread.start()
simulation_thread.start()
self._notifier.set_is_simulation_running(True)

def killJobs(self):
Expand All @@ -287,9 +302,7 @@ def killJobs(self):
if kill_job == QMessageBox.Yes:
# Normally this slot would be invoked by the signal/slot system,
# but the worker is busy tracking the evaluation.
# self._tracker.request_termination()
# self._worker_thread.quit()
# self._worker_thread.wait()
self._run_model.cancel()
self._on_finished()
self.finished.emit(-1)
return kill_job
Expand Down Expand Up @@ -329,7 +342,6 @@ def _on_event(self, event: object):
self._show_done_button()
elif isinstance(event, FullSnapshotEvent):
if event.snapshot is not None:
SnapshotModel.prerender(event.snapshot)
self._snapshot_model._add_snapshot(event.snapshot, event.iteration)
self._progress_view.setIndeterminate(event.indeterminate)
progress = int(event.progress * 100)
Expand All @@ -343,7 +355,6 @@ def _on_event(self, event: object):

elif isinstance(event, SnapshotUpdateEvent):
if event.partial_snapshot is not None:
SnapshotModel.prerender(event.partial_snapshot)
self._snapshot_model._add_partial_snapshot(
event.partial_snapshot, event.iteration
)
Expand Down
Loading

0 comments on commit bbf5aa7

Please sign in to comment.