From c2b469479b5e49918428db07cd5d430905c1fc9c Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Fri, 6 Sep 2019 21:21:24 -0600 Subject: [PATCH 1/6] d --- p2p/trio_service.py | 139 +++++++++++++++++++++++++++++++------------- spawn.py | 81 ++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 41 deletions(-) create mode 100644 spawn.py diff --git a/p2p/trio_service.py b/p2p/trio_service.py index afa1c2553b..d8edf12330 100644 --- a/p2p/trio_service.py +++ b/p2p/trio_service.py @@ -212,7 +212,70 @@ async def run(self) -> None: return _Service -class Manager(ManagerAPI): +class BaseManager(ManagerAPI): + + def __init__(self, service: ServiceAPI) -> None: + if hasattr(service, 'manager'): + raise LifecycleError("Service already has a manager.") + else: + service.manager = self + + self._service = service + + # events + self._started = trio.Event() + self._cancelled = trio.Event() + self._stopped = trio.Event() + + # locks + self._run_lock = trio.Lock() + + # errors + self._errors = [] + + # + # Event API mirror + # + @property + def is_started(self) -> bool: + return self._started.is_set() + + @property + def is_running(self) -> bool: + return self.is_started and not self.is_stopped + + @property + def is_cancelled(self) -> bool: + return self._cancelled.is_set() + + @property + def is_stopped(self) -> bool: + return self._stopped.is_set() + + # + # Wait API + # + async def wait_started(self) -> None: + await self._started.wait() + + async def wait_cancelled(self) -> None: + await self._cancelled.wait() + + async def wait_stopped(self) -> None: + await self._stopped.wait() + + # + # Tasks + # + def run_daemon_task(self, + async_fn: Callable[..., Awaitable[Any]], + *args: Any, + name: str = None) -> None: + + self.run_task(async_fn, *args, daemon=True, name=name) + + +class Manager(BaseManager): logger = logging.getLogger('p2p.trio_service.Manager') _service: ServiceAPI @@ -250,6 +313,13 @@ def __init__(self, service: ServiceAPI) -> None: # errors self._errors = [] + # + # Error Handling + # + @property + def did_error(self) -> bool: + return len(self._errors) > 0 + # # System Tasks # @@ -352,29 +422,6 @@ async def run(self) -> None: in self._errors )) - # - # Event API mirror - # - @property - def is_started(self) -> bool: - return self._started.is_set() - - @property - def is_running(self) -> bool: - return self.is_started and not self.is_stopped - - @property - def is_cancelled(self) -> bool: - return self._cancelled.is_set() - - @property - def is_stopped(self) -> bool: - return self._stopped.is_set() - - @property - def did_error(self) -> bool: - return len(self._errors) > 0 - # # Control API # @@ -387,18 +434,6 @@ async def stop(self) -> None: self.cancel() await self.wait_stopped() - # - # Wait API - # - async def wait_started(self) -> None: - await self._started.wait() - - async def wait_cancelled(self) -> None: - await self._cancelled.wait() - - async def wait_stopped(self) -> None: - await self._stopped.wait() - async def _run_and_manage_task(self, async_fn: Callable[..., Awaitable[Any]], *args: Any, @@ -448,12 +483,34 @@ def run_task(self, name=name, ) - def run_daemon_task(self, - async_fn: Callable[..., Awaitable[Any]], - *args: Any, - name: str = None) -> None: - self.run_task(async_fn, *args, daemon=True, name=name) +class ProcessManager(BaseManager): + @property + def did_error(self) -> bool: + ... + + def cancel(self) -> None: + ... + + async def stop(self) -> None: + ... + + @classmethod + async def run_service(cls, service: ServiceAPI) -> None: + ... + + @abstractmethod + async def run(self) -> None: + ... + + @trio_typing.takes_callable_and_args + @abstractmethod + async def run_task(self, + async_fn: Callable[..., Awaitable[Any]], + *args: Any, + daemon: bool = False, + name: str = None) -> None: + raise NotImplementedError("The ProcessManager cannot be used to run tasks") @asynccontextmanager diff --git a/spawn.py b/spawn.py new file mode 100644 index 0000000000..2277ea6dbf --- /dev/null +++ b/spawn.py @@ -0,0 +1,81 @@ +import os +import sys +from typing import ( + Callable, + TypeVar, +) + +import trio + + +TReturn = TypeVar('TReturn') + + +class Process: + def __init__(self, + target: Callable[..., Any], + args: Sequence[Any]) -> None: + self._target = target + self._args = args + + def _run(self, child_r, child_w, parent_pid) -> None: + self._target(*self._args) + + def run_process(self) -> None: + self._spawn() + + def _spawn(self): + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + original_pid = os.getpid() + fork_pid = os.fork() + + if fork_pid == 0: + parent_pid = original_pid + os.close(parent_r) + os.close(parent_w) + code = self._run(child_r, child_w, parent_pid) + sys.exit(code) + else: + child_pid = fork_pid + os.close(child_r) + os.close(child_w) + handle_parent(parent_r, parent_w, child_pid) + + +async def run_in_process(async_fn: Callable[..., TReturn], *args) -> TReturn: + proc = Process(async_fn, args) + proc.start() + await proc.run_process() + + +def test_spawning_process(): + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + original_pid = os.getpid() + fork_pid = os.fork() + + if fork_pid == 0: + parent_pid = original_pid + os.close(parent_r) + os.close(parent_w) + code = handle_child(child_r, child_w, parent_pid) + sys.exit(code) + else: + child_pid = fork_pid + os.close(child_r) + os.close(child_w) + handle_parent(parent_r, parent_w, child_pid) + + +def handle_parent(parent_r, parent_w, child_pid): + print('Parent', os.getpid(), parent_r, parent_w, "Child pid: ", child_pid) + + +def handle_child(child_r, child_w, parent_pid): + print('Child:', os.getpid(), child_r, child_w, "Parent pid: ", parent_pid) + return 0 + + +if __name__ == '__main__': + test_spawning_process() From 23051345084a818216c672544174aaaba94ff678 Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Sat, 7 Sep 2019 09:01:20 -0600 Subject: [PATCH 2/6] d --- p2p/trio_run_in_process.py | 176 +++++++++++++++++++++++++ p2p/trio_service.py | 46 ++++++- spawn.py | 81 ------------ tests-trio/test_trio_run_in_process.py | 37 ++++++ 4 files changed, 257 insertions(+), 83 deletions(-) create mode 100644 p2p/trio_run_in_process.py delete mode 100644 spawn.py create mode 100644 tests-trio/test_trio_run_in_process.py diff --git a/p2p/trio_run_in_process.py b/p2p/trio_run_in_process.py new file mode 100644 index 0000000000..c922e7b1e6 --- /dev/null +++ b/p2p/trio_run_in_process.py @@ -0,0 +1,176 @@ +import argparse +import io +import logging +import os +import struct +import sys +from typing import ( + Any, + Callable, + Sequence, + TypeVar, +) + +import cloudpickle +import trio + + +TReturn = TypeVar('TReturn') + + +logger = logging.getLogger('trio.multiprocessing') + + +def get_subprocess_command(child_r, child_w, parent_pid): + return ( + sys.executable, + '-m', 'p2p.trio_run_in_process', + '--parent-pid', str(parent_pid), + '--fd-read', str(child_r), + '--fd-write', str(child_w), + ) + + +async def coro_read_exactly(stream: trio.abc.ReceiveStream, num_bytes: int) -> bytes: + buffer = io.BytesIO() + bytes_remaining = num_bytes + while bytes_remaining > 0: + data = await stream.read(bytes_remaining) + if data == b'': + raise Exception("End of stream...") + buffer.write(data) + bytes_remaining -= len(data) + + return buffer.getvalue() + + +async def coro_receive_pickled_value(stream: trio.abc.ReceiveStream) -> Any: + len_bytes = await coro_read_exactly(stream, 4) + serialized_len = int.from_bytes(len_bytes, 'big') + serialized_result = await coro_read_exactly(stream, serialized_len) + return cloudpickle.loads(serialized_result) + + +def pickle_value(value: Any) -> bytes: + serialized_value = cloudpickle.dumps(value) + return struct.pack('>I', len(serialized_value)) + serialized_value + + +class Process: + def __init__(self, + async_fn: Callable[..., Any], + args: Sequence[Any]) -> None: + self._async_fn = async_fn + self._args = args + + async def run_process(self): + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + parent_pid = os.getpid() + + command = get_subprocess_command( + child_r, + child_w, + parent_pid, + ) + + async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: + async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: + async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: + async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: + proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) + async with proc: + await to_child.write(pickle_value((self._async_fn, self._args))) + await to_child.flush() + + if proc.returncode == 0: + result = await coro_receive_pickled_value(from_child) + return result + else: + error = await coro_receive_pickled_value(from_child) + raise error + + +async def run_in_process(async_fn: Callable[..., TReturn], *args) -> TReturn: + proc = Process(async_fn, args) + # TODO: signal handling + return await proc.run_process() + + +# +# CLI invocation for subprocesses +# +parser = argparse.ArgumentParser(description='trio-run-in-process') +parser.add_argument( + '--parent-pid', + type=int, + required=True, + help="The PID of the parent process", +) +parser.add_argument( + '--fd-read', + type=int, + required=True, + help=( + "The file descriptor that the child process can use to read data that " + "has been written by the parent process" + ) +) +parser.add_argument( + '--fd-write', + type=int, + required=True, + help=( + "The file descriptor that the child process can use for writing data " + "meant to be read by the parent process" + ), +) + + +def read_exactly(stream: io.BytesIO, num_bytes: int) -> bytes: + buffer = io.BytesIO() + bytes_remaining = num_bytes + while bytes_remaining > 0: + data = stream.read(bytes_remaining) + if data == b'': + raise Exception("End of stream...") + buffer.write(data) + bytes_remaining -= len(data) + + return buffer.getvalue() + + +def receive_pickled_value(stream: io.BytesIO) -> Any: + len_bytes = read_exactly(stream, 4) + serialized_len = int.from_bytes(len_bytes, 'big') + serialized_result = read_exactly(stream, serialized_len) + return cloudpickle.loads(serialized_result) + + +def _run_process(parent_pid: int, + fd_read: int, + fd_write: int) -> None: + with os.fdopen(sys.stdin.fileno(), 'rb', closefd=True) as stdin_binary: + async_fn, args = receive_pickled_value(stdin_binary) + + # TODO: signal handling + try: + result = trio.run(async_fn, *args) + except BaseException as err: + with os.fdopen(sys.stdout.fileno(), 'wb', closefd=True) as stdout_binary: + stdout_binary.write(pickle_value(err)) + sys.exit(1) + else: + logger.debug("Ran successfully: %r", result) + with os.fdopen(sys.stdout.fileno(), 'wb', closefd=True) as stdout_binary: + stdout_binary.write(pickle_value(result)) + sys.exit(0) + + +if __name__ == "__main__": + args = parser.parse_args() + _run_process( + parent_pid=args.parent_pid, + fd_read=args.fd_read, + fd_write=args.fd_write, + ) diff --git a/p2p/trio_service.py b/p2p/trio_service.py index d8edf12330..845c6ffea2 100644 --- a/p2p/trio_service.py +++ b/p2p/trio_service.py @@ -485,6 +485,8 @@ def run_task(self, class ProcessManager(BaseManager): + _run_lock: trio.Lock + @property def did_error(self) -> bool: ... @@ -499,9 +501,49 @@ async def stop(self) -> None: async def run_service(cls, service: ServiceAPI) -> None: ... - @abstractmethod async def run(self) -> None: - ... + + if self._run_lock.locked(): + raise LifecycleError( + "Cannot run a service with the run lock already engaged. Already started?" + ) + elif self.is_started: + raise LifecycleError("Cannot run a service which is already started.") + + async with self._run_lock: + async with trio.open_nursery() as system_nursery: + try: + async with trio.open_nursery() as task_nursery: + self._task_nursery = task_nursery + + system_nursery.start_soon( + self._handle_cancelled, + task_nursery, + ) + system_nursery.start_soon( + self._handle_stopped, + system_nursery, + ) + + task_nursery.start_soon(self._handle_run) + + self._started.set() + + # ***BLOCKING HERE*** + # The code flow will block here until the background tasks have + # completed or cancellation occurs. + finally: + # Mark as having stopped + self._stopped.set() + self.logger.debug('%s stopped', self) + + # If an error occured, re-raise it here + if self.did_error: + raise trio.MultiError(tuple( + exc_value.with_traceback(exc_tb) + for _, exc_value, exc_tb + in self._errors + )) @trio_typing.takes_callable_and_args @abstractmethod diff --git a/spawn.py b/spawn.py deleted file mode 100644 index 2277ea6dbf..0000000000 --- a/spawn.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import sys -from typing import ( - Callable, - TypeVar, -) - -import trio - - -TReturn = TypeVar('TReturn') - - -class Process: - def __init__(self, - target: Callable[..., Any], - args: Sequence[Any]) -> None: - self._target = target - self._args = args - - def _run(self, child_r, child_w, parent_pid) -> None: - self._target(*self._args) - - def run_process(self) -> None: - self._spawn() - - def _spawn(self): - parent_r, child_w = os.pipe() - child_r, parent_w = os.pipe() - original_pid = os.getpid() - fork_pid = os.fork() - - if fork_pid == 0: - parent_pid = original_pid - os.close(parent_r) - os.close(parent_w) - code = self._run(child_r, child_w, parent_pid) - sys.exit(code) - else: - child_pid = fork_pid - os.close(child_r) - os.close(child_w) - handle_parent(parent_r, parent_w, child_pid) - - -async def run_in_process(async_fn: Callable[..., TReturn], *args) -> TReturn: - proc = Process(async_fn, args) - proc.start() - await proc.run_process() - - -def test_spawning_process(): - parent_r, child_w = os.pipe() - child_r, parent_w = os.pipe() - original_pid = os.getpid() - fork_pid = os.fork() - - if fork_pid == 0: - parent_pid = original_pid - os.close(parent_r) - os.close(parent_w) - code = handle_child(child_r, child_w, parent_pid) - sys.exit(code) - else: - child_pid = fork_pid - os.close(child_r) - os.close(child_w) - handle_parent(parent_r, parent_w, child_pid) - - -def handle_parent(parent_r, parent_w, child_pid): - print('Parent', os.getpid(), parent_r, parent_w, "Child pid: ", child_pid) - - -def handle_child(child_r, child_w, parent_pid): - print('Child:', os.getpid(), child_r, child_w, "Parent pid: ", parent_pid) - return 0 - - -if __name__ == '__main__': - test_spawning_process() diff --git a/tests-trio/test_trio_run_in_process.py b/tests-trio/test_trio_run_in_process.py new file mode 100644 index 0000000000..c696b3be99 --- /dev/null +++ b/tests-trio/test_trio_run_in_process.py @@ -0,0 +1,37 @@ +import tempfile + +import pytest + +import trio + +from p2p.trio_run_in_process import run_in_process + + +@pytest.mark.trio +async def test_run_in_process(): + async def touch_file(path: trio.Path): + await path.touch() + + with tempfile.TemporaryDirectory() as base_dir: + path = trio.Path(base_dir) / 'test.txt' + assert not await path.exists() + await run_in_process(touch_file, path) + assert await path.exists() + + +@pytest.mark.trio +async def test_run_in_process_with_result(): + async def return7(): + return 7 + + result = await run_in_process(return7) + assert result == 7 + + +@pytest.mark.trio +async def test_run_in_process_with_error(): + async def raise_err(): + raise ValueError("Some err") + + with pytest.raises(ValueError, match="Some err"): + await run_in_process(raise_err) From 2e1915dfbbdbe3c528624b21e7f09ca391f882fa Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Sun, 8 Sep 2019 17:16:56 -0600 Subject: [PATCH 3/6] convert back to class based --- p2p/trio_run_in_process.py | 208 ++++++++++++++++++++----- tests-trio/test_trio_run_in_process.py | 18 ++- 2 files changed, 185 insertions(+), 41 deletions(-) diff --git a/p2p/trio_run_in_process.py b/p2p/trio_run_in_process.py index c922e7b1e6..c7eb9526c4 100644 --- a/p2p/trio_run_in_process.py +++ b/p2p/trio_run_in_process.py @@ -4,9 +4,15 @@ import os import struct import sys + +from async_generator import asynccontextmanager +import trio_typing from typing import ( Any, + AsyncIterator, + Awaitable, Callable, + Optional, Sequence, TypeVar, ) @@ -31,6 +37,11 @@ def get_subprocess_command(child_r, child_w, parent_pid): ) +def pickle_value(value: Any) -> bytes: + serialized_value = cloudpickle.dumps(value) + return struct.pack('>I', len(serialized_value)) + serialized_value + + async def coro_read_exactly(stream: trio.abc.ReceiveStream, num_bytes: int) -> bytes: buffer = io.BytesIO() bytes_remaining = num_bytes @@ -51,50 +62,171 @@ async def coro_receive_pickled_value(stream: trio.abc.ReceiveStream) -> Any: return cloudpickle.loads(serialized_result) -def pickle_value(value: Any) -> bytes: - serialized_value = cloudpickle.dumps(value) - return struct.pack('>I', len(serialized_value)) + serialized_value +class empty: + pass + +class Process(Awaitable[TReturn]): + returncode: Optional[int] = None -class Process: - def __init__(self, - async_fn: Callable[..., Any], - args: Sequence[Any]) -> None: + _pid: Optional[int] = None + _result: Optional[TReturn] = empty + _returncode: Optional[int] = None + _error: Optional[BaseException] = None + + def __init__(self, async_fn: Callable[..., TReturn], args: Sequence[TReturn]) -> None: self._async_fn = async_fn self._args = args - async def run_process(self): - parent_r, child_w = os.pipe() - child_r, parent_w = os.pipe() - parent_pid = os.getpid() - - command = get_subprocess_command( - child_r, - child_w, - parent_pid, - ) - - async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: - async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: - async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: - async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: - proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) - async with proc: - await to_child.write(pickle_value((self._async_fn, self._args))) - await to_child.flush() - - if proc.returncode == 0: - result = await coro_receive_pickled_value(from_child) - return result - else: - error = await coro_receive_pickled_value(from_child) - raise error - - -async def run_in_process(async_fn: Callable[..., TReturn], *args) -> TReturn: + self._has_pid = trio.Event() + self._has_returncode = trio.Event() + self._has_result = trio.Event() + self._has_error = trio.Event() + + def __await__(self) -> TReturn: + return self.run().__await__() + + # + # PID + # + @property + def pid(self) -> int: + if self._pid is None: + raise AttributeError("No PID set for process") + return self._pid + + @pid.setter + def pid(self, value: int) -> None: + self._pid = value + self._has_pid.set() + + async def wait_pid(self) -> int: + await self._has_pid.wait() + return self.pid + + # + # Result + # + @property + def result(self) -> int: + if self._result is empty: + raise AttributeError("No result set") + return self._result + + @result.setter + def result(self, value: int) -> None: + self._result = value + self._has_result.set() + + async def wait_result(self) -> int: + await self._has_result.wait() + return self.result + + # + # Return Code + # + @property + def returncode(self) -> int: + if self._returncode is None: + raise AttributeError("No returncode set") + return self._returncode + + @returncode.setter + def returncode(self, value: int) -> None: + self._returncode = value + self._has_returncode.set() + + async def wait_returncode(self) -> int: + await self._has_returncode.wait() + return self.returncode + + # + # Error + # + @property + def error(self) -> int: + if self._error is None: + raise AttributeError("No error set") + return self._error + + @error.setter + def error(self, value: int) -> None: + self._error = value + self._has_error.set() + + async def wait_error(self) -> int: + await self._has_error.wait() + return self.error + + async def wait(self) -> TReturn: + """ + Block until the process has exited. + """ + await self._has_returncode.wait() + if self.returncode == 0: + return await self.wait_result() + else: + raise await self.wait_error() + + def poll(self) -> Optional[int]: + """ + Check if the process has finished. Returns `None` if the re + """ + return self.returncode + + +async def _monitor_sub_proc(proc: Process) -> None: + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + parent_pid = os.getpid() + + command = get_subprocess_command( + child_r, + child_w, + parent_pid, + ) + + async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: + async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: + async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: + async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: + sub_proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) + async with sub_proc: + # set the process ID + proc.pid = sub_proc.pid + + # pass the child process the serialized `async_fn` + # and `args` over stdin. + await to_child.write(pickle_value((proc._async_fn, proc._args))) + await to_child.flush() + + proc.returncode = sub_proc.returncode + + if proc.returncode == 0: + proc.result = await coro_receive_pickled_value(from_child) + else: + proc.error = await coro_receive_pickled_value(from_child) + + +@asynccontextmanager +@trio_typing.takes_callable_and_args +async def open_in_process(async_fn: Callable[..., TReturn], *args: Any) -> AsyncIterator[Process]: proc = Process(async_fn, args) - # TODO: signal handling - return await proc.run_process() + + async with trio.open_nursery() as nursery: + nursery.start_soon(_monitor_sub_proc, proc) + + await proc.wait_pid() + + yield proc + + await proc.wait() + + +@trio_typing.takes_callable_and_args +async def run_in_process(async_fn: Callable[..., TReturn], *args: Any) -> TReturn: + async with open_in_process(async_fn, *args) as proc: + return await proc.wait() # diff --git a/tests-trio/test_trio_run_in_process.py b/tests-trio/test_trio_run_in_process.py index c696b3be99..77ff2d4cae 100644 --- a/tests-trio/test_trio_run_in_process.py +++ b/tests-trio/test_trio_run_in_process.py @@ -24,7 +24,8 @@ async def test_run_in_process_with_result(): async def return7(): return 7 - result = await run_in_process(return7) + with trio.fail_after(5): + result = await run_in_process(return7) assert result == 7 @@ -33,5 +34,16 @@ async def test_run_in_process_with_error(): async def raise_err(): raise ValueError("Some err") - with pytest.raises(ValueError, match="Some err"): - await run_in_process(raise_err) + with trio.fail_after(5): + with pytest.raises(ValueError, match="Some err"): + await run_in_process(raise_err) + + +@pytest.mark.trio +async def test_run_in_process_handles_keyboard_interrupt(): + async def raise_err(): + raise ValueError("Some err") + + with trio.fail_after(5): + with pytest.raises(ValueError, match="Some err"): + await run_in_process(raise_err) From b0e456e54f602e819997fd7077f7c2191a54182f Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Sun, 8 Sep 2019 17:36:48 -0600 Subject: [PATCH 4/6] d --- p2p/trio_run_in_process.py | 46 +++++++++++++--------- tests-trio/test_trio_run_in_process.py | 53 +++++++++++++++++++------- 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/p2p/trio_run_in_process.py b/p2p/trio_run_in_process.py index c7eb9526c4..e5681fd708 100644 --- a/p2p/trio_run_in_process.py +++ b/p2p/trio_run_in_process.py @@ -2,6 +2,7 @@ import io import logging import os +import signal import struct import sys @@ -186,26 +187,35 @@ async def _monitor_sub_proc(proc: Process) -> None: parent_pid, ) - async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: - async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: - async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: - async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: - sub_proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) - async with sub_proc: - # set the process ID - proc.pid = sub_proc.pid + async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: + async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: + sub_proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) + async with sub_proc: + # set the process ID + proc.pid = sub_proc.pid - # pass the child process the serialized `async_fn` - # and `args` over stdin. - await to_child.write(pickle_value((proc._async_fn, proc._args))) - await to_child.flush() + # pass the child process the serialized `async_fn` + # and `args` over stdin. + async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: + await to_child.write(pickle_value((proc._async_fn, proc._args))) + await to_child.flush() - proc.returncode = sub_proc.returncode + proc.returncode = sub_proc.returncode - if proc.returncode == 0: - proc.result = await coro_receive_pickled_value(from_child) - else: - proc.error = await coro_receive_pickled_value(from_child) + async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: + if proc.returncode == 0: + proc.result = await coro_receive_pickled_value(from_child) + else: + proc.error = await coro_receive_pickled_value(from_child) + + +async def _monitor_for_sigterm(proc: Process) -> None: + await trio.sleep_forever() + # with trio.open_signal_receiver(signal.SIGTERM) as signal_aiter: + # async for signum in signal_aiter: + # if signum == signal.SIGTERM: + # logger.info('GOT SIGTERM') + # os.kill(proc.pid, signal.SIGTERM) @asynccontextmanager @@ -214,6 +224,7 @@ async def open_in_process(async_fn: Callable[..., TReturn], *args: Any) -> Async proc = Process(async_fn, args) async with trio.open_nursery() as nursery: + nursery.start_soon(_monitor_for_sigterm, proc) nursery.start_soon(_monitor_sub_proc, proc) await proc.wait_pid() @@ -221,6 +232,7 @@ async def open_in_process(async_fn: Callable[..., TReturn], *args: Any) -> Async yield proc await proc.wait() + nursery.cancel_scope.cancel() @trio_typing.takes_callable_and_args diff --git a/tests-trio/test_trio_run_in_process.py b/tests-trio/test_trio_run_in_process.py index 77ff2d4cae..e9a03be059 100644 --- a/tests-trio/test_trio_run_in_process.py +++ b/tests-trio/test_trio_run_in_process.py @@ -1,10 +1,12 @@ +import os +import signal import tempfile import pytest import trio -from p2p.trio_run_in_process import run_in_process +from p2p.trio_run_in_process import run_in_process, open_in_process @pytest.mark.trio @@ -12,11 +14,12 @@ async def test_run_in_process(): async def touch_file(path: trio.Path): await path.touch() - with tempfile.TemporaryDirectory() as base_dir: - path = trio.Path(base_dir) / 'test.txt' - assert not await path.exists() - await run_in_process(touch_file, path) - assert await path.exists() + with trio.fail_after(2): + with tempfile.TemporaryDirectory() as base_dir: + path = trio.Path(base_dir) / 'test.txt' + assert not await path.exists() + await run_in_process(touch_file, path) + assert await path.exists() @pytest.mark.trio @@ -24,7 +27,7 @@ async def test_run_in_process_with_result(): async def return7(): return 7 - with trio.fail_after(5): + with trio.fail_after(2): result = await run_in_process(return7) assert result == 7 @@ -34,16 +37,40 @@ async def test_run_in_process_with_error(): async def raise_err(): raise ValueError("Some err") - with trio.fail_after(5): + with trio.fail_after(2): with pytest.raises(ValueError, match="Some err"): await run_in_process(raise_err) @pytest.mark.trio async def test_run_in_process_handles_keyboard_interrupt(): - async def raise_err(): - raise ValueError("Some err") + async def monitor_for_interrupt(path): + import trio + try: + await trio.sleep_forever() + except KeyboardInterrupt: + await path.touch() + else: + assert False - with trio.fail_after(5): - with pytest.raises(ValueError, match="Some err"): - await run_in_process(raise_err) + async def wrap_and_get_interrupted(path): + try: + await run_in_process(monitor_for_interrupt, path) + except KeyboardInterrupt: + pass + else: + assert False + + with trio.fail_after(2): + with tempfile.TemporaryDirectory() as base_dir: + # TODO + path = trio.Path(base_dir) / 'test.txt' + assert not await path.exists() + async with open_in_process(wrap_and_get_interrupted, path) as proc: + print('killing') + os.kill(proc.pid, signal.SIGTERM) + print('killed') + assert await path.exists() + print('exited1') + print('exited2') + print('exited3') From 42a7d9e73bacb1e36aaa635f30c51fd75a2ceb96 Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Sun, 8 Sep 2019 21:05:40 -0600 Subject: [PATCH 5/6] try to handle keyboard interrupt --- p2p/trio_run_in_process.py | 51 ++++++++++++++++++-------- tests-trio/test_trio_run_in_process.py | 33 +++++++++-------- 2 files changed, 53 insertions(+), 31 deletions(-) diff --git a/p2p/trio_run_in_process.py b/p2p/trio_run_in_process.py index e5681fd708..a8d5df3022 100644 --- a/p2p/trio_run_in_process.py +++ b/p2p/trio_run_in_process.py @@ -159,13 +159,14 @@ async def wait_error(self) -> int: await self._has_error.wait() return self.error - async def wait(self) -> TReturn: + async def wait(self) -> None: """ Block until the process has exited. """ - await self._has_returncode.wait() + await self.wait_returncode() + if self.returncode == 0: - return await self.wait_result() + await self.wait_result() else: raise await self.wait_error() @@ -175,6 +176,15 @@ def poll(self) -> Optional[int]: """ return self.returncode + def kill(self) -> None: + self.send_signal(signal.SIGKILL) + + def terminate(self) -> None: + self.send_signal(signal.SIGTERM) + + def send_signal(self, sig: int) -> None: + os.kill(self.pid, sig) + async def _monitor_sub_proc(proc: Process) -> None: parent_r, child_w = os.pipe() @@ -190,32 +200,40 @@ async def _monitor_sub_proc(proc: Process) -> None: async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: sub_proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) + logger.debug('starting subprocess to run %s', proc) async with sub_proc: # set the process ID proc.pid = sub_proc.pid + logger.debug('subprocess for %s started. pid=%d', proc, proc.pid) + logger.debug('writing execution data for %s over stdin', proc) # pass the child process the serialized `async_fn` # and `args` over stdin. async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: await to_child.write(pickle_value((proc._async_fn, proc._args))) await to_child.flush() + logger.debug('waiting for process %s finish', proc) proc.returncode = sub_proc.returncode + logger.debug('process %s finished: returncode=%d', proc, proc.returncode) async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: if proc.returncode == 0: + logger.debug('setting result for process %s', proc) proc.result = await coro_receive_pickled_value(from_child) else: - proc.error = await coro_receive_pickled_value(from_child) + with trio.move_on_after(2) as scope: + logger.debug('setting error for process %s', proc) + proc.error = await coro_receive_pickled_value(from_child) + if scope.cancelled_caught: + logger.debug('process %s exited due unknown reason.', proc) + proc.error = SystemExit(proc.returncode) -async def _monitor_for_sigterm(proc: Process) -> None: - await trio.sleep_forever() - # with trio.open_signal_receiver(signal.SIGTERM) as signal_aiter: - # async for signum in signal_aiter: - # if signum == signal.SIGTERM: - # logger.info('GOT SIGTERM') - # os.kill(proc.pid, signal.SIGTERM) +async def _monitory_signals(proc: Process, signal_aiter: AsyncIterator[int]) -> None: + async for signum in signal_aiter: + logger.info('GOT SIGNAL: %s', signum) + proc.send_signal(signum) @asynccontextmanager @@ -224,21 +242,24 @@ async def open_in_process(async_fn: Callable[..., TReturn], *args: Any) -> Async proc = Process(async_fn, args) async with trio.open_nursery() as nursery: - nursery.start_soon(_monitor_for_sigterm, proc) nursery.start_soon(_monitor_sub_proc, proc) await proc.wait_pid() - yield proc + with trio.open_signal_receiver(signal.SIGTERM) as signal_aiter: + nursery.start_soon(_monitory_signals, proc, signal_aiter) + + yield proc + await proc.wait() - await proc.wait() nursery.cancel_scope.cancel() @trio_typing.takes_callable_and_args async def run_in_process(async_fn: Callable[..., TReturn], *args: Any) -> TReturn: async with open_in_process(async_fn, *args) as proc: - return await proc.wait() + await proc.wait() + return proc.result # diff --git a/tests-trio/test_trio_run_in_process.py b/tests-trio/test_trio_run_in_process.py index e9a03be059..f7e521a772 100644 --- a/tests-trio/test_trio_run_in_process.py +++ b/tests-trio/test_trio_run_in_process.py @@ -1,5 +1,3 @@ -import os -import signal import tempfile import pytest @@ -10,7 +8,7 @@ @pytest.mark.trio -async def test_run_in_process(): +async def test_run_in_process_touch_file(): async def touch_file(path: trio.Path): await path.touch() @@ -42,6 +40,19 @@ async def raise_err(): await run_in_process(raise_err) +@pytest.mark.trio +async def test_open_in_proc_can_terminate(): + async def do_sleep_forever(): + import trio + await trio.sleep_forever() + + with trio.fail_after(2): + async with open_in_process(do_sleep_forever) as proc: + proc.terminate() + assert proc.returncode == 1 + + +@pytest.mark.skip @pytest.mark.trio async def test_run_in_process_handles_keyboard_interrupt(): async def monitor_for_interrupt(path): @@ -53,24 +64,14 @@ async def monitor_for_interrupt(path): else: assert False - async def wrap_and_get_interrupted(path): - try: - await run_in_process(monitor_for_interrupt, path) - except KeyboardInterrupt: - pass - else: - assert False - with trio.fail_after(2): with tempfile.TemporaryDirectory() as base_dir: # TODO path = trio.Path(base_dir) / 'test.txt' assert not await path.exists() - async with open_in_process(wrap_and_get_interrupted, path) as proc: + async with open_in_process(monitor_for_interrupt, path) as proc: print('killing') - os.kill(proc.pid, signal.SIGTERM) + proc.terminate() print('killed') + print('finished') assert await path.exists() - print('exited1') - print('exited2') - print('exited3') From 99d193ecca43cad059a269175885a7be8d5a909a Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Mon, 9 Sep 2019 13:07:45 -0600 Subject: [PATCH 6/6] move away from stdin/stdout communication --- p2p/trio_run_in_process.py | 452 +++++++++++++++++++------ tests-trio/test_trio_run_in_process.py | 98 ++++-- 2 files changed, 421 insertions(+), 129 deletions(-) diff --git a/p2p/trio_run_in_process.py b/p2p/trio_run_in_process.py index a8d5df3022..da75887d50 100644 --- a/p2p/trio_run_in_process.py +++ b/p2p/trio_run_in_process.py @@ -1,31 +1,34 @@ import argparse +import enum import io import logging import os import signal import struct +# import subprocess import sys - -from async_generator import asynccontextmanager -import trio_typing from typing import ( Any, AsyncIterator, Awaitable, + BinaryIO, Callable, Optional, Sequence, TypeVar, ) +from async_generator import asynccontextmanager import cloudpickle +from eth_utils.toolz import sliding_window import trio +import trio_typing TReturn = TypeVar('TReturn') -logger = logging.getLogger('trio.multiprocessing') +logger = logging.getLogger('trio-run-in-process') def get_subprocess_command(child_r, child_w, parent_pid): @@ -38,42 +41,48 @@ def get_subprocess_command(child_r, child_w, parent_pid): ) -def pickle_value(value: Any) -> bytes: - serialized_value = cloudpickle.dumps(value) - return struct.pack('>I', len(serialized_value)) + serialized_value +class State(enum.Enum): + """ + Child process lifecycle + """ + INITIALIZING = b'\x00' + INITIALIZED = b'\x01' + WAIT_EXEC_DATA = b'\x02' + BOOTING = b'\x03' + STARTED = b'\x04' + EXECUTING = b'\x05' + STOPPING = b'\x06' + FINISHED = b'\x07' + def as_int(self) -> int: + return self.value[0] -async def coro_read_exactly(stream: trio.abc.ReceiveStream, num_bytes: int) -> bytes: - buffer = io.BytesIO() - bytes_remaining = num_bytes - while bytes_remaining > 0: - data = await stream.read(bytes_remaining) - if data == b'': - raise Exception("End of stream...") - buffer.write(data) - bytes_remaining -= len(data) + def is_next(self, other: 'State') -> bool: + return other.as_int() == self.as_int() + 1 - return buffer.getvalue() + def is_on_or_after(self, other: 'State') -> bool: + return self.value[0] >= other.value[0] - -async def coro_receive_pickled_value(stream: trio.abc.ReceiveStream) -> Any: - len_bytes = await coro_read_exactly(stream, 4) - serialized_len = int.from_bytes(len_bytes, 'big') - serialized_result = await coro_read_exactly(stream, serialized_len) - return cloudpickle.loads(serialized_result) + def is_before(self, other: 'State') -> bool: + return self.value[0] < other.value[0] class empty: pass +class ProcessException(Exception): + pass + + class Process(Awaitable[TReturn]): returncode: Optional[int] = None _pid: Optional[int] = None - _result: Optional[TReturn] = empty _returncode: Optional[int] = None + _return_value: Optional[TReturn] = empty _error: Optional[BaseException] = None + _state: State = State.INITIALIZING def __init__(self, async_fn: Callable[..., TReturn], args: Sequence[TReturn]) -> None: self._async_fn = async_fn @@ -81,12 +90,46 @@ def __init__(self, async_fn: Callable[..., TReturn], args: Sequence[TReturn]) -> self._has_pid = trio.Event() self._has_returncode = trio.Event() - self._has_result = trio.Event() + self._has_return_value = trio.Event() self._has_error = trio.Event() + self._state_changed = trio.Event() def __await__(self) -> TReturn: return self.run().__await__() + # + # State + # + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, value: State) -> State: + if not self._state.is_next(value): + raise Exception(f"Invalid state transition: {self.state} -> {value}") + self._state = value + self._state_changed.set() + self._state_changed = trio.Event() + + async def wait_for_state(self, state: State) -> None: + """ + Block until the process as reached the + """ + if self.state.is_on_or_after(state): + return + + for _ in range(len(State)): + await self._state_changed.wait() + if self.state.is_on_or_after(state): + break + else: + raise Exception( + f"This code path should not be reachable since there are a " + f"finite number of state transitions. Current state is " + f"{self.state}" + ) + # # PID # @@ -106,22 +149,22 @@ async def wait_pid(self) -> int: return self.pid # - # Result + # Return Value # @property - def result(self) -> int: - if self._result is empty: - raise AttributeError("No result set") - return self._result + def return_value(self) -> int: + if self._return_value is empty: + raise AttributeError("No return_value set") + return self._return_value - @result.setter - def result(self, value: int) -> None: - self._result = value - self._has_result.set() + @return_value.setter + def return_value(self, value: int) -> None: + self._return_value = value + self._has_return_value.set() - async def wait_result(self) -> int: - await self._has_result.wait() - return self.result + async def wait_return_value(self) -> int: + await self._has_return_value.wait() + return self.return_value # # Return Code @@ -159,6 +202,35 @@ async def wait_error(self) -> int: await self._has_error.wait() return self.error + # + # Result + # + @property + def result(self) -> TReturn: + if self._error is None and self._return_value is empty: + raise AttributeError("Process not done") + elif self._error is not None: + raise self._error + elif self._return_value is not empty: + return self._return_value + else: + raise Exception("Code path should be unreachable") + + async def wait_result(self) -> TReturn: + """ + Block until the process has exited, either returning the return value + if execution was successful, or raising an exception if it failed + """ + await self.wait_returncode() + + if self.returncode == 0: + return await self.wait_return_value() + else: + raise await self.wait_error() + + # + # Lifecycle management APIs + # async def wait(self) -> None: """ Block until the process has exited. @@ -166,9 +238,9 @@ async def wait(self) -> None: await self.wait_returncode() if self.returncode == 0: - await self.wait_result() + await self.wait_return_value() else: - raise await self.wait_error() + await self.wait_error() def poll(self) -> Optional[int]: """ @@ -178,6 +250,8 @@ def poll(self) -> Optional[int]: def kill(self) -> None: self.send_signal(signal.SIGKILL) + self.status = State.FINISHED + self.error = ProcessException("Process terminated with SIGKILL") def terminate(self) -> None: self.send_signal(signal.SIGTERM) @@ -186,73 +260,178 @@ def send_signal(self, sig: int) -> None: os.kill(self.pid, sig) -async def _monitor_sub_proc(proc: Process) -> None: - parent_r, child_w = os.pipe() - child_r, parent_w = os.pipe() - parent_pid = os.getpid() +def pickle_value(value: Any) -> bytes: + serialized_value = cloudpickle.dumps(value) + return struct.pack('>I', len(serialized_value)) + serialized_value - command = get_subprocess_command( - child_r, - child_w, - parent_pid, - ) - async with await trio.open_file(child_w, 'wb', closefd=False) as to_parent: - async with await trio.open_file(child_r, 'rb', closefd=False) as from_parent: - sub_proc = await trio.open_process(command, stdin=from_parent, stdout=to_parent) - logger.debug('starting subprocess to run %s', proc) - async with sub_proc: - # set the process ID - proc.pid = sub_proc.pid - logger.debug('subprocess for %s started. pid=%d', proc, proc.pid) - - logger.debug('writing execution data for %s over stdin', proc) - # pass the child process the serialized `async_fn` - # and `args` over stdin. - async with await trio.open_file(parent_w, 'wb', closefd=True) as to_child: - await to_child.write(pickle_value((proc._async_fn, proc._args))) - await to_child.flush() - logger.debug('waiting for process %s finish', proc) +async def coro_read_exactly(stream: trio.abc.ReceiveStream, num_bytes: int) -> bytes: + buffer = io.BytesIO() + bytes_remaining = num_bytes + while bytes_remaining > 0: + data = await stream.receive_some(bytes_remaining) + if data == b'': + raise Exception("End of stream...") + buffer.write(data) + bytes_remaining -= len(data) + + return buffer.getvalue() + + +async def coro_receive_pickled_value(stream: trio.abc.ReceiveStream) -> Any: + logger.info('waiting for pickled length') + len_bytes = await coro_read_exactly(stream, 4) + serialized_len = int.from_bytes(len_bytes, 'big') + logger.info('got pickled length: %s', serialized_len) + logger.info('waiting for pickled payload') + serialized_result = await coro_read_exactly(stream, serialized_len) + logger.info('got pickled payload') + return cloudpickle.loads(serialized_result) + + +async def _monitor_sub_proc(proc: Process, sub_proc: trio.Process, parent_w: int) -> None: + logger.debug('starting subprocess to run %s', proc) + async with sub_proc: + # set the process ID + proc.pid = sub_proc.pid + logger.debug('subprocess for %s started. pid=%d', proc, proc.pid) + + # we write the execution data immediately without waiting for the + # `WAIT_EXEC_DATA` state to ensure that the child process doesn't have + # to wait for that data due to the round trip times between processes. + logger.debug('writing execution data for %s over stdin', proc) + # pass the child process the serialized `async_fn` and `args` + async with trio.hazmat.FdStream(parent_w) as to_child: + await to_child.send_all(pickle_value((proc._async_fn, proc._args))) + + # this wait ensures that we + with trio.fail_after(5): + await proc.wait_for_state(State.WAIT_EXEC_DATA) + + with trio.fail_after(5): + await proc.wait_for_state(State.EXECUTING) + logger.debug('waiting for process %s finish', proc) proc.returncode = sub_proc.returncode logger.debug('process %s finished: returncode=%d', proc, proc.returncode) - async with await trio.open_file(parent_r, 'rb', closefd=True) as from_child: - if proc.returncode == 0: - logger.debug('setting result for process %s', proc) - proc.result = await coro_receive_pickled_value(from_child) - else: - with trio.move_on_after(2) as scope: - logger.debug('setting error for process %s', proc) - proc.error = await coro_receive_pickled_value(from_child) - if scope.cancelled_caught: - logger.debug('process %s exited due unknown reason.', proc) - proc.error = SystemExit(proc.returncode) - -async def _monitory_signals(proc: Process, signal_aiter: AsyncIterator[int]) -> None: +async def _relay_signals(proc: Process, signal_aiter: AsyncIterator[int]) -> None: async for signum in signal_aiter: - logger.info('GOT SIGNAL: %s', signum) + if proc.state.is_before(State.STARTED): + # If the process has not reached the state where the child process + # can properly handle the signal, give it a moment to reach the + # `STARTED` stage. + with trio.fail_after(1): + await proc.wait_for_state(State.STARTED) + logger.debug('relaying signal %s to child process %s', signum, proc) proc.send_signal(signum) +async def _monitor_state(proc: Process, from_child: trio.hazmat.FdStream) -> None: + for current_state, next_state in sliding_window(2, State): + if proc.state is not current_state: + raise Exception( + f"Invalid state. proc in state {proc.state} but expected state {current_state}" + ) + + child_state_as_byte = await coro_read_exactly(from_child, 1) + + try: + child_state = State(child_state_as_byte) + except TypeError: + raise Exception(f"Invalid state. child sent state: {child_state_as_byte.hex()}") + + if child_state is not next_state: + raise Exception( + f"Invalid state. child sent state {child_state_as_byte.hex()} " + f"but expected state {next_state}" + ) + + proc.state = child_state + + if proc.state is not State.FINISHED: + raise Exception(f"Invalid final state: {proc.state}") + + result = await coro_receive_pickled_value(from_child) + + # The `returncode` should already be set but we do a quick wait to ensure + # that it will be set when we access it below. + with trio.fail_after(5): + await proc.wait_returncode() + + if proc.returncode == 0: + proc.return_value = result + else: + proc.error = result + + +RELAY_SIGNALS = (signal.SIGINT, signal.SIGTERM, signal.SIGHUP) + + @asynccontextmanager @trio_typing.takes_callable_and_args async def open_in_process(async_fn: Callable[..., TReturn], *args: Any) -> AsyncIterator[Process]: proc = Process(async_fn, args) - async with trio.open_nursery() as nursery: - nursery.start_soon(_monitor_sub_proc, proc) - - await proc.wait_pid() + parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() + parent_pid = os.getpid() - with trio.open_signal_receiver(signal.SIGTERM) as signal_aiter: - nursery.start_soon(_monitory_signals, proc, signal_aiter) + command = get_subprocess_command( + child_r, + child_w, + parent_pid, + ) - yield proc - await proc.wait() + sub_proc = await trio.open_process( + command, + # stdin=subprocess.PIPE, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + pass_fds=(child_r, child_w), + ) - nursery.cancel_scope.cancel() + async with trio.open_nursery() as nursery: + nursery.start_soon(_monitor_sub_proc, proc, sub_proc, parent_w) + + async with trio.hazmat.FdStream(parent_r) as from_child: + with trio.open_signal_receiver(*RELAY_SIGNALS) as signal_aiter: + # Monitor the child stream for incoming updates to the state of + # the child process. + nursery.start_soon(_monitor_state, proc, from_child) + + # Relay any appropriate signals to the child process. + nursery.start_soon(_relay_signals, proc, signal_aiter) + + await proc.wait_pid() + + # Wait until the child process has reached the STARTED + # state before yielding the context. This ensures that any + # calls to things like `terminate` or `kill` will be handled + # properly in the child process. + # + # The timeout ensures that if something is fundamentally wrong + # with the subprocess we don't hang indefinitely. + with trio.fail_after(5): + await proc.wait_for_state(State.STARTED) + + try: + yield proc + except KeyboardInterrupt as err: + # If a keyboard interrupt is encountered relay it to the + # child process and then give it a moment to cleanup before + # re-raising + try: + proc.send_signal(signal.SIGINT) + with trio.move_on_after(1): + await proc.wait() + finally: + raise err + + await proc.wait() + + nursery.cancel_scope.cancel() @trio_typing.takes_callable_and_args @@ -312,24 +491,93 @@ def receive_pickled_value(stream: io.BytesIO) -> Any: return cloudpickle.loads(serialized_result) +def update_state(to_parent: BinaryIO, state: State) -> None: + to_parent.write(state.value) + to_parent.flush() + + +def update_state_finished(to_parent: BinaryIO, finished_payload: bytes) -> None: + payload = State.FINISHED.value + finished_payload + to_parent.write(payload) + to_parent.flush() + + +SHUTDOWN_SIGNALS = {signal.SIGTERM} + + +async def _do_monitor_signals(signal_aiter: AsyncIterator[int]): + async for signum in signal_aiter: + raise SystemExit(signum) + + +@trio_typing.takes_callable_and_args +async def _do_async_fn(async_fn: Callable[..., TReturn], + args: Sequence[Any], + to_parent: trio.hazmat.FdStream) -> TReturn: + with trio.open_signal_receiver(*SHUTDOWN_SIGNALS) as signal_aiter: + # state: STARTED + update_state(to_parent, State.STARTED) + + async with trio.open_nursery() as nursery: + nursery.start_soon(_do_monitor_signals, signal_aiter) + + # state: EXECUTING + update_state(to_parent, State.EXECUTING) + + result = await async_fn(*args) + + # state: STOPPING + update_state(to_parent, State.STOPPING) + + nursery.cancel_scope.cancel() + return result + + def _run_process(parent_pid: int, fd_read: int, fd_write: int) -> None: - with os.fdopen(sys.stdin.fileno(), 'rb', closefd=True) as stdin_binary: - async_fn, args = receive_pickled_value(stdin_binary) - - # TODO: signal handling - try: - result = trio.run(async_fn, *args) - except BaseException as err: - with os.fdopen(sys.stdout.fileno(), 'wb', closefd=True) as stdout_binary: - stdout_binary.write(pickle_value(err)) - sys.exit(1) - else: - logger.debug("Ran successfully: %r", result) - with os.fdopen(sys.stdout.fileno(), 'wb', closefd=True) as stdout_binary: - stdout_binary.write(pickle_value(result)) - sys.exit(0) + """ + Run the child process + """ + # state: INITIALIZING + with os.fdopen(fd_write, 'wb', closefd=True) as to_parent: + # state: INITIALIZED + update_state(to_parent, State.INITIALIZED) + with os.fdopen(fd_read, 'rb', closefd=True) as from_parent: + # state: WAIT_EXEC_DATA + update_state(to_parent, State.WAIT_EXEC_DATA) + async_fn, args = receive_pickled_value(from_parent) + + # state: BOOTING + update_state(to_parent, State.BOOTING) + + try: + try: + result = trio.run( + _do_async_fn, + async_fn, + args, + to_parent, + ) + except BaseException as err: + # state: STOPPING + update_state(to_parent, State.STOPPING) + finished_payload = pickle_value(err) + raise + except KeyboardInterrupt: + code = 2 + except SystemExit as err: + code = err.args[0] + except BaseException: + code = 1 + else: + # state: STOPPING (set from within _do_async_fn) + finished_payload = pickle_value(result) + code = 0 + finally: + # state: FINISHED + update_state_finished(to_parent, finished_payload) + sys.exit(code) if __name__ == "__main__": diff --git a/tests-trio/test_trio_run_in_process.py b/tests-trio/test_trio_run_in_process.py index f7e521a772..6b6a66aac5 100644 --- a/tests-trio/test_trio_run_in_process.py +++ b/tests-trio/test_trio_run_in_process.py @@ -1,23 +1,29 @@ +import pickle +import signal import tempfile import pytest import trio -from p2p.trio_run_in_process import run_in_process, open_in_process +from p2p.trio_run_in_process import run_in_process, open_in_process, ProcessException + + +@pytest.fixture +def touch_path(): + with tempfile.TemporaryDirectory() as base_dir: + yield trio.Path(base_dir) / 'touch.txt' @pytest.mark.trio -async def test_run_in_process_touch_file(): +async def test_run_in_process_touch_file(touch_path): async def touch_file(path: trio.Path): await path.touch() with trio.fail_after(2): - with tempfile.TemporaryDirectory() as base_dir: - path = trio.Path(base_dir) / 'test.txt' - assert not await path.exists() - await run_in_process(touch_file, path) - assert await path.exists() + assert not await touch_path.exists() + await run_in_process(touch_file, touch_path) + assert await touch_path.exists() @pytest.mark.trio @@ -41,7 +47,7 @@ async def raise_err(): @pytest.mark.trio -async def test_open_in_proc_can_terminate(): +async def test_open_in_proc_termination_while_running(): async def do_sleep_forever(): import trio await trio.sleep_forever() @@ -49,29 +55,67 @@ async def do_sleep_forever(): with trio.fail_after(2): async with open_in_process(do_sleep_forever) as proc: proc.terminate() + assert proc.returncode == 15 + + +@pytest.mark.trio +async def test_open_in_proc_kill_while_running(): + async def do_sleep_forever(): + import trio + await trio.sleep_forever() + + with trio.fail_after(2): + async with open_in_process(do_sleep_forever) as proc: + proc.kill() + assert proc.returncode == -9 + assert isinstance(proc.error, ProcessException) + + +@pytest.mark.trio +async def test_open_proc_interrupt_while_running(): + async def monitor_for_interrupt(): + import trio + await trio.sleep_forever() + + with trio.fail_after(2): + async with open_in_process(monitor_for_interrupt) as proc: + proc.send_signal(signal.SIGINT) + assert proc.returncode == 2 + + +@pytest.mark.trio +async def test_open_proc_invalid_function_call(): + async def takes_no_args(): + pass + + with trio.fail_after(2): + async with open_in_process(takes_no_args, 1, 2, 3) as proc: + pass assert proc.returncode == 1 + assert isinstance(proc.error, TypeError) -@pytest.mark.skip @pytest.mark.trio -async def test_run_in_process_handles_keyboard_interrupt(): - async def monitor_for_interrupt(path): +async def test_open_proc_unpickleable_params(touch_path): + async def takes_open_file(f): + pass + + with trio.fail_after(2): + with pytest.raises(pickle.PickleError): + with open(touch_path, 'w') as touch_file: + async with open_in_process(takes_open_file, touch_file): + # this code block shouldn't get executed + assert False + + +@pytest.mark.trio +async def test_open_proc_outer_KeyboardInterrupt(): + async def sleep_forever(): import trio - try: - await trio.sleep_forever() - except KeyboardInterrupt: - await path.touch() - else: - assert False + await trio.sleep_forever() with trio.fail_after(2): - with tempfile.TemporaryDirectory() as base_dir: - # TODO - path = trio.Path(base_dir) / 'test.txt' - assert not await path.exists() - async with open_in_process(monitor_for_interrupt, path) as proc: - print('killing') - proc.terminate() - print('killed') - print('finished') - assert await path.exists() + with pytest.raises(KeyboardInterrupt): + async with open_in_process(sleep_forever) as proc: + raise KeyboardInterrupt + assert proc.returncode == 2