Skip to content

Commit 45d5048

Browse files
Communication: Use zmq.Poller() rather than waiting infinitely (#789)
* Communication: Use zmq.Poller() rather than waiting infinitely * Format black * fix spelling * Add test --------- Co-authored-by: pyiron-runner <pyiron@mpie.de>
1 parent ca65c40 commit 45d5048

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

executorlib/standalone/interactive/communication.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,46 @@
11
import logging
22
import sys
33
from socket import gethostname
4-
from typing import Optional
4+
from typing import Any, Optional
55

66
import cloudpickle
77
import zmq
88

99

10+
class ExecutorlibSocketError(RuntimeError):
11+
pass
12+
13+
1014
class SocketInterface:
1115
"""
1216
The SocketInterface is an abstraction layer on top of the zero message queue.
1317
1418
Args:
1519
spawner (executorlib.shared.spawner.BaseSpawner): Interface for starting the parallel process
1620
log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects.
21+
time_out_ms (int): Time out for waiting for a message on socket in milliseconds.
1722
"""
1823

19-
def __init__(self, spawner=None, log_obj_size=False):
24+
def __init__(
25+
self, spawner=None, log_obj_size: bool = False, time_out_ms: int = 1000
26+
):
2027
"""
2128
Initialize the SocketInterface.
2229
2330
Args:
2431
spawner (executorlib.shared.spawner.BaseSpawner): Interface for starting the parallel process
32+
log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects.
33+
time_out_ms (int): Time out for waiting for a message on socket in milliseconds.
2534
"""
2635
self._context = zmq.Context()
2736
self._socket = self._context.socket(zmq.PAIR)
37+
self._poller = zmq.Poller()
38+
self._poller.register(self._socket, zmq.POLLIN)
2839
self._process = None
40+
self._time_out_ms = time_out_ms
41+
self._logger: Optional[logging.Logger] = None
2942
if log_obj_size:
3043
self._logger = logging.getLogger("executorlib")
31-
else:
32-
self._logger = None
3344
self._spawner = spawner
3445

3546
def send_dict(self, input_dict: dict):
@@ -52,7 +63,12 @@ def receive_dict(self) -> dict:
5263
Returns:
5364
dict: dictionary with response received from the connected client
5465
"""
55-
data = self._socket.recv()
66+
response_lst: list[tuple[Any, int]] = []
67+
while len(response_lst) == 0:
68+
response_lst = self._poller.poll(self._time_out_ms)
69+
if not self._spawner.poll():
70+
raise ExecutorlibSocketError()
71+
data = self._socket.recv(zmq.NOBLOCK)
5672
if self._logger is not None:
5773
self._logger.warning(
5874
"Received dictionary of size: " + str(sys.getsizeof(data))

tests/test_standalone_interactive_communication.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
interface_send,
1313
interface_receive,
1414
SocketInterface,
15+
ExecutorlibSocketError,
1516
)
1617
from executorlib.standalone.serialize import cloudpickle_register
1718
from executorlib.standalone.interactive.spawner import MpiExecSpawner
@@ -114,6 +115,35 @@ def test_interface_serial_with_debug(self):
114115
)
115116
interface.shutdown(wait=True)
116117

118+
def test_interface_serial_with_stopped_process(self):
119+
cloudpickle_register(ind=1)
120+
task_dict = {"fn": calc, "args": (), "kwargs": {"i": 2}}
121+
interface = SocketInterface(
122+
spawner=MpiExecSpawner(cwd=None, cores=1, openmpi_oversubscribe=False),
123+
log_obj_size=True,
124+
)
125+
interface.bootup(
126+
command_lst=[
127+
sys.executable,
128+
os.path.abspath(
129+
os.path.join(
130+
__file__,
131+
"..",
132+
"..",
133+
"executorlib",
134+
"backend",
135+
"interactive_serial.py",
136+
)
137+
),
138+
"--zmqport",
139+
str(interface.bind_to_random_port()),
140+
]
141+
)
142+
interface.send_dict(input_dict=task_dict)
143+
interface._spawner._process.terminate()
144+
with self.assertRaises(ExecutorlibSocketError):
145+
interface.receive_dict()
146+
117147

118148
class TestZMQ(unittest.TestCase):
119149
def test_interface_receive(self):

0 commit comments

Comments
 (0)