Skip to content

Commit

Permalink
Remove websockets communication between scheduler and ensemble evaluator
Browse files Browse the repository at this point in the history
It is replaced by two message queues, which resides in LegacyEnsemble.
- scheduler_queue: responsible for providing CloudEvent (representing realization and driver events) for evaluator
- manifest_queue: responsible for providing CloudEvent (representing notification manifest checksum Event) for scheduler
  • Loading branch information
xjules committed Aug 15, 2024
1 parent bb5b4fc commit ca4ea6a
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 263 deletions.
12 changes: 6 additions & 6 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from ._wait_for_evaluator import wait_for_evaluator
from .config import EvaluatorServerConfig
from .identifiers import EVTYPE_ENSEMBLE_FAILED, EVTYPE_ENSEMBLE_STARTED
from .snapshot import (
ForwardModel,
RealizationSnapshot,
Snapshot,
SnapshotDict,
)
from .snapshot import ForwardModel, RealizationSnapshot, Snapshot, SnapshotDict
from .state import (
ENSEMBLE_STATE_CANCELLED,
ENSEMBLE_STATE_FAILED,
Expand Down Expand Up @@ -125,6 +120,9 @@ def __post_init__(self) -> None:
else:
self._status_tracker = _EnsembleStateTracker()

self.scheduler_queue: Any = None
self.manifest_queue: Any = None

@property
def active_reals(self) -> Sequence[Realization]:
return list(filter(lambda real: real.active, self.reals))
Expand Down Expand Up @@ -243,6 +241,8 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
self._scheduler = Scheduler(
driver,
self.active_reals,
self.manifest_queue,
self.scheduler_queue,
max_submit=self._queue_config.max_submit,
max_running=self._queue_config.max_running,
submit_sleep=self._queue_config.submit_sleep,
Expand Down
19 changes: 18 additions & 1 deletion src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import pickle
import traceback
from contextlib import asynccontextmanager, contextmanager
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):

self._events: asyncio.Queue[CloudEvent] = asyncio.Queue()
self._messages_to_send: asyncio.Queue[str] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()

self._result = None

Expand Down Expand Up @@ -311,9 +313,11 @@ async def forward_checksum(self, event: CloudEvent) -> None:
},
{event["run_path"]: event.data},
)
# currently clients still need to receive events via ws
await self._messages_to_send.put(
to_json(forward_event, data_marshaller=evaluator_marshaller).decode()
)
await self._manifest_queue.put(forward_event)

async def connection_handler(
self, websocket: WebSocketServerProtocol, path: str
Expand Down Expand Up @@ -417,6 +421,9 @@ async def _start_running(self) -> None:
]
# now we wait for the server to actually start
await self._server_started.wait()
# setup message queues
self._ensemble.scheduler_queue = self._events
self._ensemble.manifest_queue = self._manifest_queue
# let's run
self._ee_tasks.append(
asyncio.create_task(
Expand All @@ -433,7 +440,17 @@ async def _monitor_and_handle_tasks(self) -> None:
)
for task in done:
if task_exception := task.exception():
logger.error((f"Exception in evaluator task: {task_exception}"))
exc_traceback = "".join(
traceback.format_exception(
None, task_exception, task_exception.__traceback__
)
)
logger.error(
(
f"Exception in evaluator task {task.get_name()}: {task_exception}\n"
f"Traceback: {exc_traceback}"
)
)
raise task_exception
elif task.get_name() == "server_task":
return
Expand Down
7 changes: 3 additions & 4 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from lxml import etree

Expand Down Expand Up @@ -164,7 +163,7 @@ async def run(self, sem: asyncio.BoundedSemaphore, max_submit: int = 1) -> None:
break

if self.returncode.result() == 0:
if self._scheduler.wait_for_checksum():
if self._scheduler._manifest_queue:
await self._verify_checksum()
await self._handle_finished_forward_model()
break
Expand All @@ -191,7 +190,7 @@ async def _max_runtime_task(self) -> None:
}
)
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(timeout_event))
await self._scheduler._events.put(timeout_event)
logger.error(
f"Realization {self.iens} stopped due to MAX_RUNTIME={self.real.max_runtime} seconds"
)
Expand Down Expand Up @@ -307,7 +306,7 @@ async def _send(self, state: State) -> None:
"queue_event_type": status,
},
)
await self._scheduler._events.put(to_json(event))
await self._scheduler._events.put(event)


def log_info_from_exit_file(exit_file_path: Path) -> None:
Expand Down
134 changes: 21 additions & 113 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
import os
import ssl
import time
import traceback
from collections import defaultdict
Expand All @@ -21,12 +20,7 @@
Sequence,
)

from aiohttp import ClientError
from cloudevents.exceptions import DataUnmarshallerError
from cloudevents.http import from_json
from pydantic.dataclasses import dataclass
from websockets import ConnectionClosed, Headers
from websockets.client import connect

from _ert.async_utils import get_running_loop
from ert.constant_filenames import CERT_FILE
Expand All @@ -35,7 +29,6 @@
EVTYPE_ENSEMBLE_SUCCEEDED,
EVTYPE_FORWARD_MODEL_CHECKSUM,
)
from ert.serialization import evaluator_unmarshaller

from .driver import Driver
from .event import FinishedEvent
Expand All @@ -47,8 +40,6 @@

logger = logging.getLogger(__name__)

CLOSE_PUBLISHER_SENTINEL = object()


@dataclass
class _JobsJson:
Expand Down Expand Up @@ -82,6 +73,8 @@ def __init__(
self,
driver: Driver,
realizations: Optional[Sequence[Realization]] = None,
manifest_queue: Optional[asyncio.Queue[Any]] = None,
ee_queue: Optional[asyncio.Queue[Any]] = None,
*,
max_submit: int = 1,
max_running: int = 1,
Expand All @@ -92,6 +85,9 @@ def __init__(
ee_token: Optional[str] = None,
) -> None:
self.driver = driver
self._ee_queue = ee_queue
self._manifest_queue = manifest_queue

self._job_tasks: MutableMapping[int, asyncio.Task[None]] = {}

self.submit_sleep_state: Optional[SubmitSleeper] = None
Expand All @@ -116,30 +112,12 @@ def __init__(
)
self._max_submit = max_submit
self._max_running = max_running

self._ee_uri = ee_uri
self._ens_id = ens_id
self._ee_cert = ee_cert
self._ee_token = ee_token
self._publisher_done = asyncio.Event()
# this timeout makes sure we won't wait for the queue and the sentinel indefinitely
self._queue_timeout: float = 10.0
self._consumer_started = asyncio.Event()
self.checksum: Dict[str, Dict[str, Any]] = {}
self.checksum_listener: Optional[asyncio.Task[None]] = None

async def start_manifest_listener(self) -> Optional[asyncio.Task[None]]:
if self._ee_uri is None or "dispatch" not in self._ee_uri:
return None

self.checksum_listener = asyncio.create_task(
self._checksum_consumer(), name="consumer_task"
)
await self._consumer_started.wait()
return self.checksum_listener

def wait_for_checksum(self) -> bool:
return self._consumer_started.is_set()
self.checksum: Dict[str, Dict[str, Any]] = {}

def kill_all_jobs(self) -> None:
assert self._loop
Expand Down Expand Up @@ -209,79 +187,21 @@ def count_states(self) -> Dict[JobState, int]:
return counts

async def _checksum_consumer(self) -> None:
if not self._ee_uri:
if not self._manifest_queue:
return
tls: Optional[ssl.SSLContext] = None
if self._ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_cert)
headers = Headers()
if self._ee_token:
headers["token"] = self._ee_token
event = None
async for conn in connect(
self._ee_uri.replace("dispatch", "client"),
ssl=tls,
extra_headers=headers,
max_size=2**26,
max_queue=500,
open_timeout=5,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):
try:
self._consumer_started.set()
async for message in conn:
try:
event = from_json(
str(message), data_unmarshaller=evaluator_unmarshaller
)
if event["type"] == EVTYPE_FORWARD_MODEL_CHECKSUM:
self.checksum.update(event.data)
except DataUnmarshallerError:
logger.error(
"Scheduler checksum consumer received unknown message"
)
except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc:
self._consumer_started.clear()
logger.debug(
f"Scheduler connection to EnsembleEvaluator went down: {exc}"
)
while True:
event = await self._manifest_queue.get()
if event["type"] == EVTYPE_FORWARD_MODEL_CHECKSUM:
self.checksum.update(event.data)
self._manifest_queue.task_done()

async def _publisher(self) -> None:
if not self._ee_uri:
if not self._ee_queue:
return
tls: Optional[ssl.SSLContext] = None
if self._ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_cert)
headers = Headers()
if self._ee_token:
headers["token"] = self._ee_token
event = None
async for conn in connect(
self._ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):
try:
while True:
if event is None:
event = await self._events.get()
if event == CLOSE_PUBLISHER_SENTINEL:
self._publisher_done.set()
else:
await conn.send(event)
event = None
self._events.task_done()
except ConnectionClosed:
logger.debug("Connection to EnsembleEvalutor went down, reconnecting.")
continue
while True:
event = await self._events.get()
await self._ee_queue.put(event)
self._events.task_done()

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._jobs.values():
Expand Down Expand Up @@ -319,19 +239,9 @@ async def _monitor_and_handle_tasks(
raise task_exception

if not self.is_active():
if self._ee_uri is not None:
try:
await self._events.put(CLOSE_PUBLISHER_SENTINEL)
await asyncio.wait_for(
self._publisher_done.wait(), timeout=self._queue_timeout
)
await asyncio.wait_for(
self._events.join(), timeout=self._queue_timeout
)
except asyncio.TimeoutError:
logger.error(
f"{self._events.qsize()} items left unprocessed in the queue!"
)
if self._ee_queue:
# only join queue if there is a consumer
await self._events.join()
for task in self._job_tasks.values():
if task.cancelled():
continue
Expand All @@ -343,16 +253,14 @@ async def execute(
self,
min_required_realizations: int = 0,
) -> str:
listener_task = await self.start_manifest_listener()
scheduling_tasks = [
asyncio.create_task(self._publisher(), name="publisher_task"),
asyncio.create_task(
self._process_event_queue(), name="process_event_queue_task"
),
asyncio.create_task(self.driver.poll(), name="poll_task"),
asyncio.create_task(self._checksum_consumer(), name="consumer_task"),
]
if listener_task is not None:
scheduling_tasks.append(listener_task)

if min_required_realizations > 0:
scheduling_tasks.append(
Expand Down
Loading

0 comments on commit ca4ea6a

Please sign in to comment.