diff --git a/distributed/process.py b/distributed/process.py index debbf025cc6..0663229ccc5 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -1,13 +1,16 @@ from __future__ import annotations import asyncio +import inspect import logging import multiprocessing import os import re import threading import weakref +from collections.abc import Callable from queue import Queue as PyQueue +from typing import TypeVar from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -47,6 +50,9 @@ class _ProcessState: exitcode = None +_T_async_process = TypeVar("_T_async_process", bound="AsyncProcess") + + class AsyncProcess: """ A coroutine-compatible multiprocessing.Process-alike. @@ -126,9 +132,9 @@ def stop_thread(q): self._thread_finalizer = weakref.finalize(self, stop_thread, q=self._watch_q) self._thread_finalizer.atexit = False - def _on_exit(self, exitcode): + def _on_exit(self, exitcode: int) -> None: # Called from the event loop when the child process exited - self._process = None + self._process = None # type: ignore[assignment] if self._exit_callback is not None: self._exit_callback(self) self._exit_future.set_result(exitcode) @@ -311,14 +317,19 @@ def close(self): self._process = None self._closed = True - def set_exit_callback(self, func): + def set_exit_callback( + self: _T_async_process, func: Callable[[_T_async_process], None] + ) -> None: """ Set a function to be called by the event loop when the process exits. The function is called with the AsyncProcess as sole argument. - The function may be a coroutine function. + The function may not be a coroutine function. """ # XXX should this be a property instead? + assert not inspect.iscoroutinefunction( + func + ), "exit callback may not be a coroutine function" assert callable(func), "exit callback should be callable" assert ( self._state.pid is None