Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have ensemble fail early when problems on initial connection to ensemble evaluator #9139

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue()
self._connection: Optional[WebSocketClientProtocol] = None
self._receiver_task: Optional[asyncio.Task[None]] = None
self._connected: asyncio.Event = asyncio.Event()
self._connected: asyncio.Future[None] = asyncio.Future()
self._connection_timeout: float = 120.0
self._receiver_timeout: float = 60.0

async def __aenter__(self) -> "Monitor":
self._receiver_task = asyncio.create_task(self._receiver())
try:
await asyncio.wait_for(
self._connected.wait(), timeout=self._connection_timeout
)
await asyncio.wait_for(self._connected, timeout=self._connection_timeout)
except asyncio.TimeoutError as exc:
msg = "Couldn't establish connection with the ensemble evaluator!"
logger.error(msg)
Expand All @@ -64,7 +62,6 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None
self._receiver_task,
return_exceptions=True,
)

if self._connection:
await self._connection.close()

Expand Down Expand Up @@ -127,13 +124,16 @@ async def _receiver(self) -> None:
headers = Headers()
if self._ee_con_info.token:
headers["token"] = self._ee_con_info.token

await wait_for_evaluator(
base_url=self._ee_con_info.url,
token=self._ee_con_info.token,
cert=self._ee_con_info.cert,
timeout=5,
)
try:
await wait_for_evaluator(
base_url=self._ee_con_info.url,
token=self._ee_con_info.token,
cert=self._ee_con_info.cert,
timeout=5,
)
except Exception as e:
self._connected.set_exception(e)
return
async for conn in connect(
self._ee_con_info.client_uri,
ssl=tls,
Expand All @@ -147,13 +147,13 @@ async def _receiver(self) -> None:
):
try:
self._connection = conn
self._connected.set()
self._connected.set_result(None)
async for raw_msg in self._connection:
event = event_from_json(raw_msg)
await self._event_queue.put(event)
except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc:
self._connection = None
self._connected.clear()
self._connected = asyncio.Future()
logger.debug(
f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}"
)
22 changes: 22 additions & 0 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import logging
from http import HTTPStatus
from typing import NoReturn
from urllib.parse import urlparse

import pytest
from websockets import server
from websockets.exceptions import ConnectionClosedOK

import ert
import ert.ensemble_evaluator
from _ert.events import EEUserCancel, EEUserDone, event_from_json
from ert.ensemble_evaluator import Monitor
from ert.ensemble_evaluator.config import EvaluatorConnectionInfo
Expand Down Expand Up @@ -135,3 +138,22 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port):

set_when_done.set() # shuts down websocket server
await websocket_server_task


@pytest.mark.timeout(10)
async def test_that_monitor_will_raise_exception_if_wait_for_evaluator_fails(
monkeypatch,
):
async def mock_failing_wait_for_evaluator(*args, **kwargs) -> NoReturn:
raise ValueError()

monkeypatch.setattr(
ert.ensemble_evaluator.monitor,
"wait_for_evaluator",
mock_failing_wait_for_evaluator,
)
ee_con_info = EvaluatorConnectionInfo("")

with pytest.raises(ValueError):
async with Monitor(ee_con_info):
pass