From 1fd35faeda4100996306c38df82c320802428897 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Sat, 22 Aug 2020 22:11:21 +0200 Subject: [PATCH 1/9] Add types in manager.py --- .github/workflows/main.yml | 2 + jupyter_client/manager.py | 200 +++++++++++++++++++++++++------------ setup.py | 1 + 3 files changed, 138 insertions(+), 65 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 056381cc9..ac653328b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,6 +31,8 @@ jobs: pip install --upgrade setuptools pip pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' pip freeze + - name: Check types + run: mypy jupyter_client/manager.py - name: Run the tests run: py.test --cov jupyter_client -v jupyter_client - name: Code coverage diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index cf9fb5db1..856b6d490 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -11,13 +11,15 @@ import sys import time import warnings +from subprocess import Popen +import typing from enum import Enum import zmq from .localinterfaces import is_local_ip, local_ips -from traitlets import ( +from traitlets import ( # type: ignore Any, Float, Instance, Unicode, List, Bool, Type, DottedObjectName, default, observe, observe_compat ) @@ -25,6 +27,7 @@ from jupyter_client import ( launch_kernel, kernelspec, + KernelClient, ) from .connect import ConnectionFileMixin from .managerabc import ( @@ -55,39 +58,45 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._shutdown_status = _ShutdownStatus.Unset - _created_context = Bool(False) + _created_context: Bool = Bool(False) # The PyZMQ Context to use for communication with the kernel. - context = Instance(zmq.Context) - def _context_default(self): + context: Instance = Instance(zmq.Context) + def _context_default(self) -> zmq.Context: self._created_context = True return zmq.Context() # the class to create with our `client` method - client_class = DottedObjectName('jupyter_client.blocking.BlockingKernelClient') - client_factory = Type(klass='jupyter_client.KernelClient') - def _client_factory_default(self): + client_class: DottedObjectName = DottedObjectName('jupyter_client.blocking.BlockingKernelClient') + client_factory: Type = Type(klass='jupyter_client.KernelClient') + def _client_factory_default(self) -> Type: return import_item(self.client_class) @observe('client_class') - def _client_class_changed(self, change): + def _client_class_changed( + self, + change: typing.Dict[str, DottedObjectName] + ) -> None: self.client_factory = import_item(str(change['new'])) # The kernel process with which the KernelManager is communicating. # generally a Popen instance - kernel = Any() + kernel: Any = Any() - kernel_spec_manager = Instance(kernelspec.KernelSpecManager) + kernel_spec_manager: Instance = Instance(kernelspec.KernelSpecManager) - def _kernel_spec_manager_default(self): + def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager: return kernelspec.KernelSpecManager(data_dir=self.data_dir) @observe('kernel_spec_manager') @observe_compat - def _kernel_spec_manager_changed(self, change): + def _kernel_spec_manager_changed( + self, + change: typing.Dict[str, Instance] + ) -> None: self._kernel_spec = None - shutdown_wait_time = Float( + shutdown_wait_time: Float = Float( 5.0, config=True, help="Time to wait for a kernel to terminate before killing it, " "in seconds. When a shutdown request is initiated, the kernel " @@ -98,23 +107,26 @@ def _kernel_spec_manager_changed(self, change): "and kill may be equivalent on windows.", ) - kernel_name = Unicode(kernelspec.NATIVE_KERNEL_NAME) + kernel_name: Unicode = Unicode(kernelspec.NATIVE_KERNEL_NAME) @observe('kernel_name') - def _kernel_name_changed(self, change): + def _kernel_name_changed( + self, + change: typing.Dict[str, Unicode] + ) -> None: self._kernel_spec = None if change['new'] == 'python': self.kernel_name = kernelspec.NATIVE_KERNEL_NAME - _kernel_spec = None + _kernel_spec: typing.Optional[kernelspec.KernelSpec] = None @property - def kernel_spec(self): + def kernel_spec(self) -> typing.Optional[kernelspec.KernelSpec]: if self._kernel_spec is None and self.kernel_name != '': self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name) return self._kernel_spec - kernel_cmd = List(Unicode(), config=True, + kernel_cmd: List = List(Unicode(), config=True, help="""DEPRECATED: Use kernel_name instead. The Popen Command to launch the kernel. @@ -132,29 +144,29 @@ def _kernel_cmd_changed(self, name, old, new): warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to " "start different kernels.") - cache_ports = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance') + cache_ports: Bool = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance') @default('cache_ports') - def _default_cache_ports(self): + def _default_cache_ports(self) -> bool: return self.transport == 'tcp' @property - def ipykernel(self): + def ipykernel(self) -> bool: return self.kernel_name in {'python', 'python2', 'python3'} # Protected traits - _launch_args = Any() - _control_socket = Any() + _launch_args: Any = Any() + _control_socket: Any = Any() - _restarter = Any() + _restarter: Any = Any() - autorestart = Bool(True, config=True, + autorestart: Bool = Bool(True, config=True, help="""Should we autorestart the kernel if it dies.""" ) - shutting_down = False + shutting_down: bool = False - def __del__(self): + def __del__(self) -> None: self._close_control_socket() self.cleanup_connection_file() @@ -162,19 +174,27 @@ def __del__(self): # Kernel restarter #-------------------------------------------------------------------------- - def start_restarter(self): + def start_restarter(self) -> None: pass - def stop_restarter(self): + def stop_restarter(self) -> None: pass - def add_restart_callback(self, callback, event='restart'): + def add_restart_callback( + self, + callback: typing.Callable, + event: str = 'restart' + ) -> None: """register a callback to be called when a kernel is restarted""" if self._restarter is None: return self._restarter.add_callback(callback, event) - def remove_restart_callback(self, callback, event='restart'): + def remove_restart_callback( + self, + callback: typing.Callable, + event: str ='restart' + ) -> None: """unregister a callback to be called when a kernel is restarted""" if self._restarter is None: return @@ -184,7 +204,7 @@ def remove_restart_callback(self, callback, event='restart'): # create a Client connected to our Kernel #-------------------------------------------------------------------------- - def client(self, **kwargs): + def client(self, **kwargs) -> KernelClient: """Create a client configured to connect to our kernel""" kw = {} kw.update(self.get_connection_info(session=True)) @@ -201,12 +221,16 @@ def client(self, **kwargs): # Kernel management #-------------------------------------------------------------------------- - def format_kernel_cmd(self, extra_arguments=None): + def format_kernel_cmd( + self, + extra_arguments: typing.Optional[typing.List[str]] = None + ) -> typing.List[str]: """replace templated args (e.g. {connection_file})""" extra_arguments = extra_arguments or [] if self.kernel_cmd: cmd = self.kernel_cmd + extra_arguments else: + assert self.kernel_spec is not None cmd = self.kernel_spec.argv + extra_arguments if cmd and cmd[0] in {'python', @@ -239,9 +263,13 @@ def from_ns(match): """Get the key out of ns if it's there, otherwise no change.""" return ns.get(match.group(1), match.group()) - return [ pat.sub(from_ns, arg) for arg in cmd ] + return [pat.sub(from_ns, arg) for arg in cmd] - def _launch_kernel(self, kernel_cmd, **kw): + def _launch_kernel( + self, + kernel_cmd: typing.List[str], + **kw + ) -> typing.Union[Popen, typing.Coroutine[typing.Any, typing.Any, Popen]]: """actually launch the kernel override in a subclass to launch kernel subprocesses differently @@ -250,18 +278,18 @@ def _launch_kernel(self, kernel_cmd, **kw): # Control socket used for polite kernel shutdown - def _connect_control_socket(self): + def _connect_control_socket(self) -> None: if self._control_socket is None: self._control_socket = self._create_connected_socket('control') self._control_socket.linger = 100 - def _close_control_socket(self): + def _close_control_socket(self) -> None: if self._control_socket is None: return self._control_socket.close() self._control_socket = None - def pre_start_kernel(self, **kw): + def pre_start_kernel(self, **kw) -> typing.Tuple[typing.List[str], typing.Dict[str, typing.Any]]: """Prepares a kernel for startup in a separate process. If random ports (port=0) are being used, this method must be called @@ -297,12 +325,17 @@ def pre_start_kernel(self, **kw): if not self.kernel_cmd: # If kernel_cmd has been set manually, don't refer to a kernel spec. # Environment variables from kernel spec are added to os.environ. + assert self.kernel_spec is not None env.update(self._get_env_substitutions(self.kernel_spec.env, env)) kw['env'] = env return kernel_cmd, kw - def _get_env_substitutions(self, templated_env, substitution_values): + def _get_env_substitutions( + self, + templated_env: typing.Optional[typing.Dict[str, str]], + substitution_values: typing.Dict[str, str] + ) -> typing.Optional[typing.Dict[str, str]]: """ Walks env entries in templated_env and applies possible substitutions from current env (represented by substitution_values). Returns the substituted list of env entries. @@ -318,7 +351,7 @@ def _get_env_substitutions(self, templated_env, substitution_values): substituted_env.update({k: Template(v).safe_substitute(substitution_values)}) return substituted_env - def post_start_kernel(self, **kw): + def post_start_kernel(self, **kw) -> None: self.start_restarter() self._connect_control_socket() @@ -341,7 +374,10 @@ def start_kernel(self, **kw): self.kernel = self._launch_kernel(kernel_cmd, **kw) self.post_start_kernel(**kw) - def request_shutdown(self, restart=False): + def request_shutdown( + self, + restart: bool = False + ) -> None: """Send a shutdown request via control channel """ content = dict(restart=restart) @@ -350,7 +386,11 @@ def request_shutdown(self, restart=False): self._connect_control_socket() self.session.send(self._control_socket, msg) - def finish_shutdown(self, waittime=None, pollinterval=0.1): + def finish_shutdown( + self, + waittime: typing.Optional[float] = None, + pollinterval: float = 0.1 + ) -> None: """Wait for kernel shutdown, then kill process if it doesn't shutdown. This does not send shutdown requests - use :meth:`request_shutdown` @@ -398,7 +438,10 @@ def poll_or_sleep_to_kernel_gone(): self._shutdown_status = _ShutdownStatus.SigkillRequest self._kill_kernel() - def cleanup_resources(self, restart=False): + def cleanup_resources( + self, + restart: bool = False + ) -> None: """Clean up resources when the kernel is shut down""" if not restart: self.cleanup_connection_file() @@ -410,13 +453,20 @@ def cleanup_resources(self, restart=False): if self._created_context and not restart: self.context.destroy(linger=100) - def cleanup(self, connection_file=True): + def cleanup( + self, + connection_file: bool = True + ) -> None: """Clean up resources when the kernel is shut down""" warnings.warn("Method cleanup(connection_file=True) is deprecated, use cleanup_resources(restart=False).", FutureWarning) self.cleanup_resources(restart=not connection_file) - def shutdown_kernel(self, now=False, restart=False): + def shutdown_kernel( + self, + now: bool = False, + restart: bool = False + ): """Attempts to stop the kernel process cleanly. This attempts to shutdown the kernels cleanly by: @@ -470,7 +520,12 @@ def shutdown_kernel(self, now=False, restart=False): else: self.cleanup_resources(restart=restart) - def restart_kernel(self, now=False, newports=False, **kw): + def restart_kernel( + self, + now: bool = False, + newports: bool = False, + **kw + ) -> None: """Restarts a kernel with the arguments that were used to launch it. Parameters @@ -510,11 +565,11 @@ def restart_kernel(self, now=False, newports=False, **kw): self.start_kernel(**self._launch_args) @property - def has_kernel(self): + def has_kernel(self) -> bool: """Has a kernel been started that we are managing.""" return self.kernel is not None - def _send_kernel_sigterm(self): + def _send_kernel_sigterm(self) -> None: """similar to _kill_kernel, but with sigterm (not sigkill), but do not block""" if self.has_kernel: # Signal the kernel to terminate (sends SIGTERM on Unix and @@ -534,7 +589,7 @@ def _send_kernel_sigterm(self): # In Windows, we will get an Access Denied error if the process # has already terminated. Ignore it. if sys.platform == "win32": - if e.winerror != 5: + if e.winerror != 5: # type: ignore raise # On Unix, we may get an ESRCH error if the process has already # terminated. Ignore it. @@ -544,7 +599,7 @@ def _send_kernel_sigterm(self): if e.errno != ESRCH: raise - def _kill_kernel(self): + def _kill_kernel(self) -> None: """Kill the running kernel. This is a private method, callers should use shutdown_kernel(now=True). @@ -554,14 +609,14 @@ def _kill_kernel(self): # TerminateProcess() on Win32). try: if hasattr(signal, 'SIGKILL'): - self.signal_kernel(signal.SIGKILL) + self.signal_kernel(signal.SIGKILL) # type: ignore else: self.kernel.kill() except OSError as e: # In Windows, we will get an Access Denied error if the process # has already terminated. Ignore it. if sys.platform == 'win32': - if e.winerror != 5: + if e.winerror != 5: # type: ignore raise # On Unix, we may get an ESRCH error if the process has already # terminated. Ignore it. @@ -574,13 +629,14 @@ def _kill_kernel(self): self.kernel.wait() self.kernel = None - def interrupt_kernel(self): + def interrupt_kernel(self) -> None: """Interrupts the kernel by sending it a signal. Unlike ``signal_kernel``, this operation is well supported on all platforms. """ if self.has_kernel: + assert self.kernel_spec is not None interrupt_mode = self.kernel_spec.interrupt_mode if interrupt_mode == 'signal': if sys.platform == 'win32': @@ -596,7 +652,10 @@ def interrupt_kernel(self): else: raise RuntimeError("Cannot interrupt kernel. No kernel is running!") - def signal_kernel(self, signum): + def signal_kernel( + self, + signum: int + ) -> None: """Sends a signal to the process group of the kernel (this usually includes the kernel and any subprocesses spawned by the kernel). @@ -607,8 +666,8 @@ def signal_kernel(self, signum): if self.has_kernel: if hasattr(os, "getpgid") and hasattr(os, "killpg"): try: - pgid = os.getpgid(self.kernel.pid) - os.killpg(pgid, signum) + pgid = os.getpgid(self.kernel.pid) # type: ignore + os.killpg(pgid, signum) # type: ignore return except OSError: pass @@ -616,7 +675,7 @@ def signal_kernel(self, signum): else: raise RuntimeError("Cannot signal kernel. No kernel is running!") - def is_alive(self): + def is_alive(self) -> bool: """Is the kernel process still running?""" if self.has_kernel: if self.kernel.poll() is None: @@ -631,16 +690,15 @@ def is_alive(self): class AsyncKernelManager(KernelManager): """Manages kernels in an asynchronous manner """ - client_class = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') - client_factory = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') + client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') + client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') async def _launch_kernel(self, kernel_cmd, **kw): """actually launch the kernel override in a subclass to launch kernel subprocesses differently """ - res = launch_kernel(kernel_cmd, **kw) - return res + return launch_kernel(kernel_cmd, **kw) async def start_kernel(self, **kw): """Starts a kernel in a separate process in an asynchronous manner. @@ -851,6 +909,7 @@ async def interrupt_kernel(self): platforms. """ if self.has_kernel: + assert self.kernel_spec is not None interrupt_mode = self.kernel_spec.interrupt_mode if interrupt_mode == 'signal': if sys.platform == 'win32': @@ -897,7 +956,10 @@ async def is_alive(self): # we don't have a kernel return False - async def _async_wait(self, pollinterval=0.1): + async def _async_wait( + self, + pollinterval: float = 0.1 + ) -> None: # Use busy loop at 100ms intervals, polling until the process is # not alive. If we find the process is no longer alive, complete # its cleanup via the blocking wait(). Callers are responsible for @@ -909,7 +971,11 @@ async def _async_wait(self, pollinterval=0.1): KernelManagerABC.register(KernelManager) -def start_new_kernel(startup_timeout=60, kernel_name='python', **kwargs): +def start_new_kernel( + startup_timeout: float =60, + kernel_name: str = 'python', + **kwargs + ) -> typing.Tuple[KernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = KernelManager(kernel_name=kernel_name) km.start_kernel(**kwargs) @@ -925,7 +991,11 @@ def start_new_kernel(startup_timeout=60, kernel_name='python', **kwargs): return km, kc -async def start_new_async_kernel(startup_timeout=60, kernel_name='python', **kwargs): +async def start_new_async_kernel( + startup_timeout: float = 60, + kernel_name: str = 'python', + **kwargs + ) -> typing.Tuple[KernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = AsyncKernelManager(kernel_name=kernel_name) await km.start_kernel(**kwargs) @@ -942,7 +1012,7 @@ async def start_new_async_kernel(startup_timeout=60, kernel_name='python', **kwa @contextmanager -def run_kernel(**kwargs): +def run_kernel(**kwargs) -> typing.Iterator[KernelClient]: """Context manager to create a kernel in a subprocess. The kernel is shut down when the context exits. diff --git a/setup.py b/setup.py index 1952004b0..9b4c8f4f2 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ def run(self): 'pytest-asyncio', 'pytest-timeout', 'pytest', + 'mypy', ], 'doc': open('docs/requirements.txt').read().splitlines(), }, From 0f819f7f1780e8c820630771434338e80d251389 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Tue, 9 Mar 2021 17:24:38 +0100 Subject: [PATCH 2/9] Refactor BlockingKernelManager/AsyncKernelManager --- jupyter_client/__init__.py | 2 +- jupyter_client/client.py | 16 +- jupyter_client/consoleapp.py | 6 +- jupyter_client/kernelapp.py | 8 +- jupyter_client/manager.py | 472 +++++------------- jupyter_client/tests/test_kernelmanager.py | 10 +- .../tests/test_multikernelmanager.py | 2 +- jupyter_client/tests/test_public_api.py | 4 +- jupyter_client/util.py | 14 + 9 files changed, 160 insertions(+), 374 deletions(-) create mode 100644 jupyter_client/util.py diff --git a/jupyter_client/__init__.py b/jupyter_client/__init__.py index f72c516d3..122010421 100644 --- a/jupyter_client/__init__.py +++ b/jupyter_client/__init__.py @@ -4,7 +4,7 @@ from .connect import * from .launcher import * from .client import KernelClient -from .manager import KernelManager, AsyncKernelManager, run_kernel +from .manager import KernelManager, BlockingKernelManager, AsyncKernelManager, run_kernel from .blocking import BlockingKernelClient from .asynchronous import AsyncKernelClient from .multikernelmanager import MultiKernelManager, AsyncMultiKernelManager diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 760ac5266..7c8028680 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -3,11 +3,13 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import typing as t + from jupyter_client.channels import major_protocol_version import zmq -from traitlets import ( +from traitlets import ( # type: ignore Any, Instance, Type, ) @@ -19,11 +21,13 @@ # some utilities to validate message structure, these might get moved elsewhere # if they prove to have more generic utility -def validate_string_dict(dct): +def validate_string_dict( + dct: t.Dict[str, str] +) -> None: """Validate that the input is a dict with string keys and values. Raises ValueError if not.""" - for k,v in dct.items(): + for k, v in dct.items(): if not isinstance(k, str): raise ValueError('key %r in dict must be a string' % k) if not isinstance(v, str): @@ -49,7 +53,7 @@ class KernelClient(ConnectionFileMixin): # The PyZMQ Context to use for communication with the kernel. context = Instance(zmq.Context) - def _context_default(self): + def _context_default(self) -> zmq.Context: return zmq.Context() # The classes to use for the various channels @@ -67,13 +71,13 @@ def _context_default(self): _control_channel = Any() # flag for whether execute requests should be allowed to call raw_input: - allow_stdin = True + allow_stdin: bool = True #-------------------------------------------------------------------------- # Channel proxy methods #-------------------------------------------------------------------------- - def get_shell_msg(self, *args, **kwargs): + def get_shell_msg(self, *args, **kwargs) -> None: """Get a message from the shell channel""" return self.shell_channel.get_msg(*args, **kwargs) diff --git a/jupyter_client/consoleapp.py b/jupyter_client/consoleapp.py index 42ce2fb77..75dbc985a 100644 --- a/jupyter_client/consoleapp.py +++ b/jupyter_client/consoleapp.py @@ -24,7 +24,7 @@ from .blocking import BlockingKernelClient from .restarter import KernelRestarter -from . import KernelManager, tunnel_to_kernel, find_connection_file, connect +from . import BlockingKernelManager, tunnel_to_kernel, find_connection_file, connect from .kernelspec import NoSuchKernel from .session import Session @@ -86,7 +86,7 @@ # Classes #----------------------------------------------------------------------------- -classes = [KernelManager, KernelRestarter, Session] +classes = [BlockingKernelManager, KernelRestarter, Session] class JupyterConsoleApp(ConnectionFileMixin): name = 'jupyter-console-mixin' @@ -112,7 +112,7 @@ class JupyterConsoleApp(ConnectionFileMixin): flags = Dict(flags) aliases = Dict(aliases) kernel_manager_class = Type( - default_value=KernelManager, + default_value=BlockingKernelManager, config=True, help='The kernel manager class to use.' ) diff --git a/jupyter_client/kernelapp.py b/jupyter_client/kernelapp.py index 33607049c..787cba9b1 100644 --- a/jupyter_client/kernelapp.py +++ b/jupyter_client/kernelapp.py @@ -8,7 +8,7 @@ from . import __version__ from .kernelspec import KernelSpecManager, NATIVE_KERNEL_NAME -from .manager import KernelManager +from .manager import BlockingKernelManager class KernelApp(JupyterApp): """Launch a kernel by name in a local subprocess. @@ -16,7 +16,7 @@ class KernelApp(JupyterApp): version = __version__ description = "Run a kernel locally in a subprocess" - classes = [KernelManager, KernelSpecManager] + classes = [BlockingKernelManager, KernelSpecManager] aliases = { 'kernel': 'KernelApp.kernel_name', @@ -33,8 +33,8 @@ def initialize(self, argv=None): cf_basename = 'kernel-%s.json' % uuid.uuid4() self.config.setdefault('KernelManager', {}).setdefault('connection_file', os.path.join(self.runtime_dir, cf_basename)) - self.km = KernelManager(kernel_name=self.kernel_name, - config=self.config) + self.km = BlockingKernelManager(kernel_name=self.kernel_name, + config=self.config) self.loop = IOLoop.current() self.loop.add_callback(self._record_started) diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 856b6d490..1b619a333 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -12,7 +12,7 @@ import time import warnings from subprocess import Popen -import typing +import typing as t from enum import Enum @@ -33,6 +33,7 @@ from .managerabc import ( KernelManagerABC ) +from .util import run_sync class _ShutdownStatus(Enum): """ @@ -62,6 +63,8 @@ def __init__(self, *args, **kwargs): # The PyZMQ Context to use for communication with the kernel. context: Instance = Instance(zmq.Context) + + @default('context') def _context_default(self) -> zmq.Context: self._created_context = True return zmq.Context() @@ -69,14 +72,16 @@ def _context_default(self) -> zmq.Context: # the class to create with our `client` method client_class: DottedObjectName = DottedObjectName('jupyter_client.blocking.BlockingKernelClient') client_factory: Type = Type(klass='jupyter_client.KernelClient') + + @default('client_factory') def _client_factory_default(self) -> Type: return import_item(self.client_class) @observe('client_class') def _client_class_changed( self, - change: typing.Dict[str, DottedObjectName] - ) -> None: + change: t.Dict[str, DottedObjectName] + ) -> None: self.client_factory = import_item(str(change['new'])) # The kernel process with which the KernelManager is communicating. @@ -85,6 +90,7 @@ def _client_class_changed( kernel_spec_manager: Instance = Instance(kernelspec.KernelSpecManager) + @default('kernel_spec_manager') def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager: return kernelspec.KernelSpecManager(data_dir=self.data_dir) @@ -92,8 +98,8 @@ def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager: @observe_compat def _kernel_spec_manager_changed( self, - change: typing.Dict[str, Instance] - ) -> None: + change: t.Dict[str, Instance] + ) -> None: self._kernel_spec = None shutdown_wait_time: Float = Float( @@ -112,16 +118,16 @@ def _kernel_spec_manager_changed( @observe('kernel_name') def _kernel_name_changed( self, - change: typing.Dict[str, Unicode] - ) -> None: + change: t.Dict[str, Unicode] + ) -> None: self._kernel_spec = None if change['new'] == 'python': self.kernel_name = kernelspec.NATIVE_KERNEL_NAME - _kernel_spec: typing.Optional[kernelspec.KernelSpec] = None + _kernel_spec: t.Optional[kernelspec.KernelSpec] = None @property - def kernel_spec(self) -> typing.Optional[kernelspec.KernelSpec]: + def kernel_spec(self) -> t.Optional[kernelspec.KernelSpec]: if self._kernel_spec is None and self.kernel_name != '': self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name) return self._kernel_spec @@ -182,9 +188,9 @@ def stop_restarter(self) -> None: def add_restart_callback( self, - callback: typing.Callable, + callback: t.Callable, event: str = 'restart' - ) -> None: + ) -> None: """register a callback to be called when a kernel is restarted""" if self._restarter is None: return @@ -192,7 +198,7 @@ def add_restart_callback( def remove_restart_callback( self, - callback: typing.Callable, + callback: t.Callable, event: str ='restart' ) -> None: """unregister a callback to be called when a kernel is restarted""" @@ -223,10 +229,11 @@ def client(self, **kwargs) -> KernelClient: def format_kernel_cmd( self, - extra_arguments: typing.Optional[typing.List[str]] = None - ) -> typing.List[str]: + extra_arguments: t.Optional[t.List[str]] = None + ) -> t.List[str]: """replace templated args (e.g. {connection_file})""" extra_arguments = extra_arguments or [] + self.log.info(str(self.kernel_spec)) if self.kernel_cmd: cmd = self.kernel_cmd + extra_arguments else: @@ -265,17 +272,19 @@ def from_ns(match): return [pat.sub(from_ns, arg) for arg in cmd] - def _launch_kernel( + async def _async__launch_kernel( self, - kernel_cmd: typing.List[str], + kernel_cmd: t.List[str], **kw - ) -> typing.Union[Popen, typing.Coroutine[typing.Any, typing.Any, Popen]]: + ) -> Popen: """actually launch the kernel override in a subclass to launch kernel subprocesses differently """ return launch_kernel(kernel_cmd, **kw) + _launch_kernel = _async__launch_kernel + # Control socket used for polite kernel shutdown def _connect_control_socket(self) -> None: @@ -289,7 +298,7 @@ def _close_control_socket(self) -> None: self._control_socket.close() self._control_socket = None - def pre_start_kernel(self, **kw) -> typing.Tuple[typing.List[str], typing.Dict[str, typing.Any]]: + def pre_start_kernel(self, **kw) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]: """Prepares a kernel for startup in a separate process. If random ports (port=0) are being used, this method must be called @@ -333,9 +342,9 @@ def pre_start_kernel(self, **kw) -> typing.Tuple[typing.List[str], typing.Dict[s def _get_env_substitutions( self, - templated_env: typing.Optional[typing.Dict[str, str]], - substitution_values: typing.Dict[str, str] - ) -> typing.Optional[typing.Dict[str, str]]: + templated_env: t.Optional[t.Dict[str, str]], + substitution_values: t.Dict[str, str] + ) -> t.Optional[t.Dict[str, str]]: """ Walks env entries in templated_env and applies possible substitutions from current env (represented by substitution_values). Returns the substituted list of env entries. @@ -355,7 +364,7 @@ def post_start_kernel(self, **kw) -> None: self.start_restarter() self._connect_control_socket() - def start_kernel(self, **kw): + async def _async_start_kernel(self, **kw): """Starts a kernel on this host in a separate process. If random ports (port=0) are being used, this method must be called @@ -371,13 +380,15 @@ def start_kernel(self, **kw): # launch the kernel subprocess self.log.debug("Starting kernel: %s", kernel_cmd) - self.kernel = self._launch_kernel(kernel_cmd, **kw) + self.kernel = await self._async__launch_kernel(kernel_cmd, **kw) self.post_start_kernel(**kw) + start_kernel = _async_start_kernel + def request_shutdown( self, restart: bool = False - ) -> None: + ) -> None: """Send a shutdown request via control channel """ content = dict(restart=restart) @@ -386,11 +397,11 @@ def request_shutdown( self._connect_control_socket() self.session.send(self._control_socket, msg) - def finish_shutdown( + async def _async_finish_shutdown( self, - waittime: typing.Optional[float] = None, + waittime: t.Optional[float] = None, pollinterval: float = 0.1 - ) -> None: + ) -> None: """Wait for kernel shutdown, then kill process if it doesn't shutdown. This does not send shutdown requests - use :meth:`request_shutdown` @@ -399,49 +410,35 @@ def finish_shutdown( if waittime is None: waittime = max(self.shutdown_wait_time, 0) self._shutdown_status = _ShutdownStatus.ShutdownRequest - - def poll_or_sleep_to_kernel_gone(): - """ - Poll until the kernel is not responding, - then wait (the subprocess), until process gone. - - After this function the kernel is either: - - still responding; or - - subprocess has been culled. - """ - if self.is_alive(): - time.sleep(pollinterval) - else: - # If there's still a proc, wait and clear - if self.has_kernel: - self.kernel.wait() - self.kernel = None - return True - - # wait 50% of the shutdown timeout... - for i in range(int(waittime / 2 / pollinterval)): - if poll_or_sleep_to_kernel_gone(): - break - else: - # if we've exited the loop normally (no break) - # send sigterm and wait the other 50%. + try: + await asyncio.wait_for( + self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 + ) + except asyncio.TimeoutError: self.log.debug("Kernel is taking too long to finish, terminating") self._shutdown_status = _ShutdownStatus.SigtermRequest - self._send_kernel_sigterm() - for i in range(int(waittime / 2 / pollinterval)): - if poll_or_sleep_to_kernel_gone(): - break - else: - # OK, we've waited long enough. - if self.has_kernel: - self.log.debug("Kernel is taking too long to finish, killing") - self._shutdown_status = _ShutdownStatus.SigkillRequest - self._kill_kernel() + await self._async__send_kernel_sigterm() + + try: + await asyncio.wait_for( + self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 + ) + except asyncio.TimeoutError: + self.log.debug("Kernel is taking too long to finish, killing") + self._shutdown_status = _ShutdownStatus.SigkillRequest + await self._async__kill_kernel() + else: + # Process is no longer alive, wait and clear + if self.kernel is not None: + self.kernel.wait() + self.kernel = None + + finish_shutdown = _async_finish_shutdown def cleanup_resources( self, restart: bool = False - ) -> None: + ) -> None: """Clean up resources when the kernel is shut down""" if not restart: self.cleanup_connection_file() @@ -456,13 +453,13 @@ def cleanup_resources( def cleanup( self, connection_file: bool = True - ) -> None: + ) -> None: """Clean up resources when the kernel is shut down""" warnings.warn("Method cleanup(connection_file=True) is deprecated, use cleanup_resources(restart=False).", FutureWarning) self.cleanup_resources(restart=not connection_file) - def shutdown_kernel( + async def _async_shutdown_kernel( self, now: bool = False, restart: bool = False @@ -488,16 +485,16 @@ def shutdown_kernel( # Stop monitoring for restarting while we shutdown. self.stop_restarter() - self.interrupt_kernel() + await self._async_interrupt_kernel() if now: - self._kill_kernel() + await self._async__kill_kernel() else: self.request_shutdown(restart=restart) # Don't send any additional kernel kill messages immediately, to give # the kernel a chance to properly execute shutdown actions. Wait for at # most 1s, checking every 0.1s. - self.finish_shutdown() + await self._async_finish_shutdown() # In 6.1.5, a new method, cleanup_resources(), was introduced to address # a leak issue (https://github.com/jupyter/jupyter_client/pull/548) and @@ -520,12 +517,14 @@ def shutdown_kernel( else: self.cleanup_resources(restart=restart) - def restart_kernel( + shutdown_kernel = _async_shutdown_kernel + + async def _async_restart_kernel( self, now: bool = False, newports: bool = False, **kw - ) -> None: + ) -> None: """Restarts a kernel with the arguments that were used to launch it. Parameters @@ -555,21 +554,23 @@ def restart_kernel( "No previous call to 'start_kernel'.") else: # Stop currently running kernel. - self.shutdown_kernel(now=now, restart=True) + await self._async_shutdown_kernel(now=now, restart=True) if newports: self.cleanup_random_ports() # Start new kernel. self._launch_args.update(kw) - self.start_kernel(**self._launch_args) + await self._async_start_kernel(**self._launch_args) + + restart_kernel = _async_restart_kernel @property def has_kernel(self) -> bool: """Has a kernel been started that we are managing.""" return self.kernel is not None - def _send_kernel_sigterm(self) -> None: + async def _async__send_kernel_sigterm(self) -> None: """similar to _kill_kernel, but with sigterm (not sigkill), but do not block""" if self.has_kernel: # Signal the kernel to terminate (sends SIGTERM on Unix and @@ -579,7 +580,7 @@ def _send_kernel_sigterm(self) -> None: if hasattr(self.kernel, "terminate"): self.kernel.terminate() elif hasattr(signal, "SIGTERM"): - self.signal_kernel(signal.SIGTERM) + await self._async_signal_kernel(signal.SIGTERM) else: self.log.debug( "Cannot set term signal to kernel, no" @@ -599,271 +600,9 @@ def _send_kernel_sigterm(self) -> None: if e.errno != ESRCH: raise - def _kill_kernel(self) -> None: - """Kill the running kernel. - - This is a private method, callers should use shutdown_kernel(now=True). - """ - if self.has_kernel: - # Signal the kernel to terminate (sends SIGKILL on Unix and calls - # TerminateProcess() on Win32). - try: - if hasattr(signal, 'SIGKILL'): - self.signal_kernel(signal.SIGKILL) # type: ignore - else: - self.kernel.kill() - except OSError as e: - # In Windows, we will get an Access Denied error if the process - # has already terminated. Ignore it. - if sys.platform == 'win32': - if e.winerror != 5: # type: ignore - raise - # On Unix, we may get an ESRCH error if the process has already - # terminated. Ignore it. - else: - from errno import ESRCH - if e.errno != ESRCH: - raise - - # Block until the kernel terminates. - self.kernel.wait() - self.kernel = None - - def interrupt_kernel(self) -> None: - """Interrupts the kernel by sending it a signal. - - Unlike ``signal_kernel``, this operation is well supported on all - platforms. - """ - if self.has_kernel: - assert self.kernel_spec is not None - interrupt_mode = self.kernel_spec.interrupt_mode - if interrupt_mode == 'signal': - if sys.platform == 'win32': - from .win_interrupt import send_interrupt - send_interrupt(self.kernel.win32_interrupt_event) - else: - self.signal_kernel(signal.SIGINT) - - elif interrupt_mode == 'message': - msg = self.session.msg("interrupt_request", content={}) - self._connect_control_socket() - self.session.send(self._control_socket, msg) - else: - raise RuntimeError("Cannot interrupt kernel. No kernel is running!") - - def signal_kernel( - self, - signum: int - ) -> None: - """Sends a signal to the process group of the kernel (this - usually includes the kernel and any subprocesses spawned by - the kernel). - - Note that since only SIGTERM is supported on Windows, this function is - only useful on Unix systems. - """ - if self.has_kernel: - if hasattr(os, "getpgid") and hasattr(os, "killpg"): - try: - pgid = os.getpgid(self.kernel.pid) # type: ignore - os.killpg(pgid, signum) # type: ignore - return - except OSError: - pass - self.kernel.send_signal(signum) - else: - raise RuntimeError("Cannot signal kernel. No kernel is running!") - - def is_alive(self) -> bool: - """Is the kernel process still running?""" - if self.has_kernel: - if self.kernel.poll() is None: - return True - else: - return False - else: - # we don't have a kernel - return False - - -class AsyncKernelManager(KernelManager): - """Manages kernels in an asynchronous manner """ - - client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') - client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') - - async def _launch_kernel(self, kernel_cmd, **kw): - """actually launch the kernel - - override in a subclass to launch kernel subprocesses differently - """ - return launch_kernel(kernel_cmd, **kw) - - async def start_kernel(self, **kw): - """Starts a kernel in a separate process in an asynchronous manner. - - If random ports (port=0) are being used, this method must be called - before the channels are created. - - Parameters - ---------- - `**kw` : optional - keyword arguments that are passed down to build the kernel_cmd - and launching the kernel (e.g. Popen kwargs). - """ - kernel_cmd, kw = self.pre_start_kernel(**kw) - - # launch the kernel subprocess - self.log.debug("Starting kernel (async): %s", kernel_cmd) - self.kernel = await self._launch_kernel(kernel_cmd, **kw) - self.post_start_kernel(**kw) - - async def finish_shutdown(self, waittime=None, pollinterval=0.1): - """Wait for kernel shutdown, then kill process if it doesn't shutdown. - - This does not send shutdown requests - use :meth:`request_shutdown` - first. - """ - if waittime is None: - waittime = max(self.shutdown_wait_time, 0) - self._shutdown_status = _ShutdownStatus.ShutdownRequest - try: - await asyncio.wait_for( - self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 - ) - except asyncio.TimeoutError: - self.log.debug("Kernel is taking too long to finish, terminating") - self._shutdown_status = _ShutdownStatus.SigtermRequest - await self._send_kernel_sigterm() - - try: - await asyncio.wait_for( - self._async_wait(pollinterval=pollinterval), timeout=waittime / 2 - ) - except asyncio.TimeoutError: - self.log.debug("Kernel is taking too long to finish, killing") - self._shutdown_status = _ShutdownStatus.SigkillRequest - await self._kill_kernel() - else: - # Process is no longer alive, wait and clear - if self.kernel is not None: - self.kernel.wait() - self.kernel = None - - async def shutdown_kernel(self, now=False, restart=False): - """Attempts to stop the kernel process cleanly. - - This attempts to shutdown the kernels cleanly by: - - 1. Sending it a shutdown message over the shell channel. - 2. If that fails, the kernel is shutdown forcibly by sending it - a signal. - - Parameters - ---------- - now : bool - Should the kernel be forcible killed *now*. This skips the - first, nice shutdown attempt. - restart: bool - Will this kernel be restarted after it is shutdown. When this - is True, connection files will not be cleaned up. - """ - self.shutting_down = True # Used by restarter to prevent race condition - # Stop monitoring for restarting while we shutdown. - self.stop_restarter() - - await self.interrupt_kernel() - - if now: - await self._kill_kernel() - else: - self.request_shutdown(restart=restart) - # Don't send any additional kernel kill messages immediately, to give - # the kernel a chance to properly execute shutdown actions. Wait for at - # most 1s, checking every 0.1s. - await self.finish_shutdown() - - # See comment in KernelManager.shutdown_kernel(). - overrides_cleanup = type(self).cleanup is not AsyncKernelManager.cleanup - overrides_cleanup_resources = type(self).cleanup_resources is not AsyncKernelManager.cleanup_resources - - if overrides_cleanup and not overrides_cleanup_resources: - self.cleanup(connection_file=not restart) - else: - self.cleanup_resources(restart=restart) - - async def restart_kernel(self, now=False, newports=False, **kw): - """Restarts a kernel with the arguments that were used to launch it. - - Parameters - ---------- - now : bool, optional - If True, the kernel is forcefully restarted *immediately*, without - having a chance to do any cleanup action. Otherwise the kernel is - given 1s to clean up before a forceful restart is issued. - - In all cases the kernel is restarted, the only difference is whether - it is given a chance to perform a clean shutdown or not. - - newports : bool, optional - If the old kernel was launched with random ports, this flag decides - whether the same ports and connection file will be used again. - If False, the same ports and connection file are used. This is - the default. If True, new random port numbers are chosen and a - new connection file is written. It is still possible that the newly - chosen random port numbers happen to be the same as the old ones. - - `**kw` : optional - Any options specified here will overwrite those used to launch the - kernel. - """ - if self._launch_args is None: - raise RuntimeError("Cannot restart the kernel. " - "No previous call to 'start_kernel'.") - else: - # Stop currently running kernel. - await self.shutdown_kernel(now=now, restart=True) - - if newports: - self.cleanup_random_ports() - - # Start new kernel. - self._launch_args.update(kw) - await self.start_kernel(**self._launch_args) - return None - - async def _send_kernel_sigterm(self): - """similar to _kill_kernel, but with sigterm (not sigkill), but do not block""" - if self.has_kernel: - # Signal the kernel to terminate (sends SIGTERM on Unix and - # if the kernel is a subprocess and we are on windows; this is - # equivalent to kill - try: - if hasattr(self.kernel, "terminate"): - self.kernel.terminate() - elif hasattr(signal, "SIGTERM"): - await self.signal_kernel(signal.SIGTERM) - else: - self.log.debug( - "Cannot set term signal to kernel, no" - " `.terminate()` method and no values for SIGTERM" - ) - except OSError as e: - # In Windows, we will get an Access Denied error if the process - # has already terminated. Ignore it. - if sys.platform == "win32": - if e.winerror != 5: - raise - # On Unix, we may get an ESRCH error if the process has already - # terminated. Ignore it. - else: - from errno import ESRCH - - if e.errno != ESRCH: - raise + _send_kernel_sigterm = _async__send_kernel_sigterm - async def _kill_kernel(self): + async def _async__kill_kernel(self) -> None: """Kill the running kernel. This is a private method, callers should use shutdown_kernel(now=True). @@ -873,14 +612,14 @@ async def _kill_kernel(self): # TerminateProcess() on Win32). try: if hasattr(signal, 'SIGKILL'): - await self.signal_kernel(signal.SIGKILL) + await self._async_signal_kernel(signal.SIGKILL) # type: ignore else: self.kernel.kill() except OSError as e: # In Windows, we will get an Access Denied error if the process # has already terminated. Ignore it. if sys.platform == 'win32': - if e.winerror != 5: + if e.winerror != 5: # type: ignore raise # On Unix, we may get an ESRCH error if the process has already # terminated. Ignore it. @@ -902,7 +641,9 @@ async def _kill_kernel(self): self.kernel.wait() self.kernel = None - async def interrupt_kernel(self): + _kill_kernel = _async__kill_kernel + + async def _async_interrupt_kernel(self) -> None: """Interrupts the kernel by sending it a signal. Unlike ``signal_kernel``, this operation is well supported on all @@ -916,7 +657,7 @@ async def interrupt_kernel(self): from .win_interrupt import send_interrupt send_interrupt(self.kernel.win32_interrupt_event) else: - await self.signal_kernel(signal.SIGINT) + await self._async_signal_kernel(signal.SIGINT) elif interrupt_mode == 'message': msg = self.session.msg("interrupt_request", content={}) @@ -925,7 +666,12 @@ async def interrupt_kernel(self): else: raise RuntimeError("Cannot interrupt kernel. No kernel is running!") - async def signal_kernel(self, signum): + interrupt_kernel = _async_interrupt_kernel + + async def _async_signal_kernel( + self, + signum: int + ) -> None: """Sends a signal to the process group of the kernel (this usually includes the kernel and any subprocesses spawned by the kernel). @@ -936,8 +682,8 @@ async def signal_kernel(self, signum): if self.has_kernel: if hasattr(os, "getpgid") and hasattr(os, "killpg"): try: - pgid = os.getpgid(self.kernel.pid) - os.killpg(pgid, signum) + pgid = os.getpgid(self.kernel.pid) # type: ignore + os.killpg(pgid, signum) # type: ignore return except OSError: pass @@ -945,7 +691,9 @@ async def signal_kernel(self, signum): else: raise RuntimeError("Cannot signal kernel. No kernel is running!") - async def is_alive(self): + signal_kernel = _async_signal_kernel + + async def _async_is_alive(self) -> bool: """Is the kernel process still running?""" if self.has_kernel: if self.kernel.poll() is None: @@ -956,6 +704,8 @@ async def is_alive(self): # we don't have a kernel return False + is_alive = _async_is_alive + async def _async_wait( self, pollinterval: float = 0.1 @@ -964,10 +714,28 @@ async def _async_wait( # not alive. If we find the process is no longer alive, complete # its cleanup via the blocking wait(). Callers are responsible for # issuing calls to wait() using a timeout (see _kill_kernel()). - while await self.is_alive(): + while await self._async_is_alive(): await asyncio.sleep(pollinterval) +class BlockingKernelManager(KernelManager): + _launch_kernel = run_sync(KernelManager._launch_kernel) + start_kernel = run_sync(KernelManager.start_kernel) + finish_shutdown = run_sync(KernelManager.finish_shutdown) + shutdown_kernel = run_sync(KernelManager.shutdown_kernel) + restart_kernel = run_sync(KernelManager.restart_kernel) + _send_kernel_sigterm = run_sync(KernelManager._send_kernel_sigterm) + _kill_kernel = run_sync(KernelManager._kill_kernel) + interrupt_kernel = run_sync(KernelManager.interrupt_kernel) + signal_kernel = run_sync(KernelManager.signal_kernel) + is_alive = run_sync(KernelManager.is_alive) + +class AsyncKernelManager(KernelManager): + # the class to create with our `client` method + client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') + client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') + + KernelManagerABC.register(KernelManager) @@ -975,9 +743,9 @@ def start_new_kernel( startup_timeout: float =60, kernel_name: str = 'python', **kwargs - ) -> typing.Tuple[KernelManager, KernelClient]: +) -> t.Tuple[BlockingKernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" - km = KernelManager(kernel_name=kernel_name) + km = BlockingKernelManager(kernel_name=kernel_name) km.start_kernel(**kwargs) kc = km.client() kc.start_channels() @@ -995,7 +763,7 @@ async def start_new_async_kernel( startup_timeout: float = 60, kernel_name: str = 'python', **kwargs - ) -> typing.Tuple[KernelManager, KernelClient]: +) -> t.Tuple[AsyncKernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = AsyncKernelManager(kernel_name=kernel_name) await km.start_kernel(**kwargs) @@ -1012,7 +780,7 @@ async def start_new_async_kernel( @contextmanager -def run_kernel(**kwargs) -> typing.Iterator[KernelClient]: +def run_kernel(**kwargs) -> t.Iterator[KernelClient]: """Context manager to create a kernel in a subprocess. The kernel is shut down when the context exits. diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index 7f5ea13dc..b70346443 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -17,7 +17,7 @@ from async_generator import async_generator, yield_ from traitlets.config.loader import Config from jupyter_core import paths -from jupyter_client import KernelManager, AsyncKernelManager +from jupyter_client import BlockingKernelManager, AsyncKernelManager from subprocess import PIPE from ..manager import start_new_kernel, start_new_async_kernel @@ -101,7 +101,7 @@ def start_kernel(): @pytest.fixture def km(config): - km = KernelManager(config=config) + km = BlockingKernelManager(config=config) return km @pytest.fixture @@ -198,7 +198,7 @@ def test_lifecycle(self, km): km.restart_kernel(now=True) assert km.is_alive() km.interrupt_kernel() - assert isinstance(km, KernelManager) + assert isinstance(km, BlockingKernelManager) km.shutdown_kernel(now=True) assert km.context.closed @@ -287,7 +287,7 @@ def test_cleanup_context(self, km): def test_no_cleanup_shared_context(self, zmq_context): """kernel manager does not terminate shared context""" - km = KernelManager(context=zmq_context) + km = BlockingKernelManager(context=zmq_context) assert km.context == zmq_context assert km.context is not None @@ -400,7 +400,7 @@ def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs): return kc def _run_signaltest_lifecycle(self, config=None): - km = KernelManager(config=config, kernel_name='signaltest') + km = BlockingKernelManager(config=config, kernel_name='signaltest') kc = self._prepare_kernel(km, stdout=PIPE, stderr=PIPE) def execute(cmd): diff --git a/jupyter_client/tests/test_multikernelmanager.py b/jupyter_client/tests/test_multikernelmanager.py index ff13e8282..e104a266f 100644 --- a/jupyter_client/tests/test_multikernelmanager.py +++ b/jupyter_client/tests/test_multikernelmanager.py @@ -256,7 +256,7 @@ async def _run_lifecycle(km, test_kid=None): assert kid in km.list_kernel_ids() await km.interrupt_kernel(kid) k = km.get_kernel(kid) - assert isinstance(k, KernelManager) + assert isinstance(k, AsyncKernelManager) await km.shutdown_kernel(kid, now=True) assert kid not in km, f'{kid} not in {km}' diff --git a/jupyter_client/tests/test_public_api.py b/jupyter_client/tests/test_public_api.py index ab3883d66..d77679578 100644 --- a/jupyter_client/tests/test_public_api.py +++ b/jupyter_client/tests/test_public_api.py @@ -9,12 +9,12 @@ def test_kms(): - for base in ("", "Multi"): + for base in ("", "Blocking", "Async", "Multi"): KM = base + "KernelManager" assert KM in dir(jupyter_client) def test_kcs(): - for base in ("", "Blocking"): + for base in ("", "Blocking", "Async"): KM = base + "KernelClient" assert KM in dir(jupyter_client) diff --git a/jupyter_client/util.py b/jupyter_client/util.py new file mode 100644 index 000000000..2e1f2516e --- /dev/null +++ b/jupyter_client/util.py @@ -0,0 +1,14 @@ +import concurrent.futures +import asyncio + +def asyncio_run(task): + loop = asyncio.new_event_loop() + return loop.run_until_complete(task) + +def run_sync(coro): + def wrapped(*args, **kwargs): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio_run, coro(*args, **kwargs)) + return future.result() + wrapped.__doc__ = coro.__doc__ + return wrapped From ec520ec176c51a1d30ddac9799db9cf68e5d1559 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 12 Mar 2021 09:54:17 +0100 Subject: [PATCH 3/9] Change BlockingKernelManager back to KernelManager --- jupyter_client/__init__.py | 2 +- jupyter_client/consoleapp.py | 6 +-- jupyter_client/kernelapp.py | 8 ++-- jupyter_client/manager.py | 47 +++++++++++----------- jupyter_client/tests/test_kernelmanager.py | 10 ++--- jupyter_client/tests/test_public_api.py | 2 +- 6 files changed, 37 insertions(+), 38 deletions(-) diff --git a/jupyter_client/__init__.py b/jupyter_client/__init__.py index 122010421..f72c516d3 100644 --- a/jupyter_client/__init__.py +++ b/jupyter_client/__init__.py @@ -4,7 +4,7 @@ from .connect import * from .launcher import * from .client import KernelClient -from .manager import KernelManager, BlockingKernelManager, AsyncKernelManager, run_kernel +from .manager import KernelManager, AsyncKernelManager, run_kernel from .blocking import BlockingKernelClient from .asynchronous import AsyncKernelClient from .multikernelmanager import MultiKernelManager, AsyncMultiKernelManager diff --git a/jupyter_client/consoleapp.py b/jupyter_client/consoleapp.py index 75dbc985a..42ce2fb77 100644 --- a/jupyter_client/consoleapp.py +++ b/jupyter_client/consoleapp.py @@ -24,7 +24,7 @@ from .blocking import BlockingKernelClient from .restarter import KernelRestarter -from . import BlockingKernelManager, tunnel_to_kernel, find_connection_file, connect +from . import KernelManager, tunnel_to_kernel, find_connection_file, connect from .kernelspec import NoSuchKernel from .session import Session @@ -86,7 +86,7 @@ # Classes #----------------------------------------------------------------------------- -classes = [BlockingKernelManager, KernelRestarter, Session] +classes = [KernelManager, KernelRestarter, Session] class JupyterConsoleApp(ConnectionFileMixin): name = 'jupyter-console-mixin' @@ -112,7 +112,7 @@ class JupyterConsoleApp(ConnectionFileMixin): flags = Dict(flags) aliases = Dict(aliases) kernel_manager_class = Type( - default_value=BlockingKernelManager, + default_value=KernelManager, config=True, help='The kernel manager class to use.' ) diff --git a/jupyter_client/kernelapp.py b/jupyter_client/kernelapp.py index 787cba9b1..33607049c 100644 --- a/jupyter_client/kernelapp.py +++ b/jupyter_client/kernelapp.py @@ -8,7 +8,7 @@ from . import __version__ from .kernelspec import KernelSpecManager, NATIVE_KERNEL_NAME -from .manager import BlockingKernelManager +from .manager import KernelManager class KernelApp(JupyterApp): """Launch a kernel by name in a local subprocess. @@ -16,7 +16,7 @@ class KernelApp(JupyterApp): version = __version__ description = "Run a kernel locally in a subprocess" - classes = [BlockingKernelManager, KernelSpecManager] + classes = [KernelManager, KernelSpecManager] aliases = { 'kernel': 'KernelApp.kernel_name', @@ -33,8 +33,8 @@ def initialize(self, argv=None): cf_basename = 'kernel-%s.json' % uuid.uuid4() self.config.setdefault('KernelManager', {}).setdefault('connection_file', os.path.join(self.runtime_dir, cf_basename)) - self.km = BlockingKernelManager(kernel_name=self.kernel_name, - config=self.config) + self.km = KernelManager(kernel_name=self.kernel_name, + config=self.config) self.loop = IOLoop.current() self.loop.add_callback(self._record_started) diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 1b619a333..ec055f25c 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -283,7 +283,7 @@ async def _async__launch_kernel( """ return launch_kernel(kernel_cmd, **kw) - _launch_kernel = _async__launch_kernel + _launch_kernel = run_sync(_async__launch_kernel) # Control socket used for polite kernel shutdown @@ -383,7 +383,7 @@ async def _async_start_kernel(self, **kw): self.kernel = await self._async__launch_kernel(kernel_cmd, **kw) self.post_start_kernel(**kw) - start_kernel = _async_start_kernel + start_kernel = run_sync(_async_start_kernel) def request_shutdown( self, @@ -433,7 +433,7 @@ async def _async_finish_shutdown( self.kernel.wait() self.kernel = None - finish_shutdown = _async_finish_shutdown + finish_shutdown = run_sync(_async_finish_shutdown) def cleanup_resources( self, @@ -517,7 +517,7 @@ async def _async_shutdown_kernel( else: self.cleanup_resources(restart=restart) - shutdown_kernel = _async_shutdown_kernel + shutdown_kernel = run_sync(_async_shutdown_kernel) async def _async_restart_kernel( self, @@ -563,7 +563,7 @@ async def _async_restart_kernel( self._launch_args.update(kw) await self._async_start_kernel(**self._launch_args) - restart_kernel = _async_restart_kernel + restart_kernel = run_sync(_async_restart_kernel) @property def has_kernel(self) -> bool: @@ -600,7 +600,7 @@ async def _async__send_kernel_sigterm(self) -> None: if e.errno != ESRCH: raise - _send_kernel_sigterm = _async__send_kernel_sigterm + _send_kernel_sigterm = run_sync(_async__send_kernel_sigterm) async def _async__kill_kernel(self) -> None: """Kill the running kernel. @@ -641,7 +641,7 @@ async def _async__kill_kernel(self) -> None: self.kernel.wait() self.kernel = None - _kill_kernel = _async__kill_kernel + _kill_kernel = run_sync(_async__kill_kernel) async def _async_interrupt_kernel(self) -> None: """Interrupts the kernel by sending it a signal. @@ -666,7 +666,7 @@ async def _async_interrupt_kernel(self) -> None: else: raise RuntimeError("Cannot interrupt kernel. No kernel is running!") - interrupt_kernel = _async_interrupt_kernel + interrupt_kernel = run_sync(_async_interrupt_kernel) async def _async_signal_kernel( self, @@ -691,7 +691,7 @@ async def _async_signal_kernel( else: raise RuntimeError("Cannot signal kernel. No kernel is running!") - signal_kernel = _async_signal_kernel + signal_kernel = run_sync(_async_signal_kernel) async def _async_is_alive(self) -> bool: """Is the kernel process still running?""" @@ -704,7 +704,7 @@ async def _async_is_alive(self) -> bool: # we don't have a kernel return False - is_alive = _async_is_alive + is_alive = run_sync(_async_is_alive) async def _async_wait( self, @@ -718,23 +718,22 @@ async def _async_wait( await asyncio.sleep(pollinterval) -class BlockingKernelManager(KernelManager): - _launch_kernel = run_sync(KernelManager._launch_kernel) - start_kernel = run_sync(KernelManager.start_kernel) - finish_shutdown = run_sync(KernelManager.finish_shutdown) - shutdown_kernel = run_sync(KernelManager.shutdown_kernel) - restart_kernel = run_sync(KernelManager.restart_kernel) - _send_kernel_sigterm = run_sync(KernelManager._send_kernel_sigterm) - _kill_kernel = run_sync(KernelManager._kill_kernel) - interrupt_kernel = run_sync(KernelManager.interrupt_kernel) - signal_kernel = run_sync(KernelManager.signal_kernel) - is_alive = run_sync(KernelManager.is_alive) - class AsyncKernelManager(KernelManager): # the class to create with our `client` method client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') + _launch_kernel = KernelManager._async__launch_kernel + start_kernel = KernelManager._async_start_kernel + finish_shutdown = KernelManager._async_finish_shutdown + shutdown_kernel = KernelManager._async_shutdown_kernel + restart_kernel = KernelManager._async_restart_kernel + _send_kernel_sigterm = KernelManager._async__send_kernel_sigterm + _kill_kernel = KernelManager._async__kill_kernel + interrupt_kernel = KernelManager._async_interrupt_kernel + signal_kernel = KernelManager._async_signal_kernel + is_alive = KernelManager._async_is_alive + KernelManagerABC.register(KernelManager) @@ -743,9 +742,9 @@ def start_new_kernel( startup_timeout: float =60, kernel_name: str = 'python', **kwargs -) -> t.Tuple[BlockingKernelManager, KernelClient]: +) -> t.Tuple[KernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" - km = BlockingKernelManager(kernel_name=kernel_name) + km = KernelManager(kernel_name=kernel_name) km.start_kernel(**kwargs) kc = km.client() kc.start_channels() diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index b70346443..7f5ea13dc 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -17,7 +17,7 @@ from async_generator import async_generator, yield_ from traitlets.config.loader import Config from jupyter_core import paths -from jupyter_client import BlockingKernelManager, AsyncKernelManager +from jupyter_client import KernelManager, AsyncKernelManager from subprocess import PIPE from ..manager import start_new_kernel, start_new_async_kernel @@ -101,7 +101,7 @@ def start_kernel(): @pytest.fixture def km(config): - km = BlockingKernelManager(config=config) + km = KernelManager(config=config) return km @pytest.fixture @@ -198,7 +198,7 @@ def test_lifecycle(self, km): km.restart_kernel(now=True) assert km.is_alive() km.interrupt_kernel() - assert isinstance(km, BlockingKernelManager) + assert isinstance(km, KernelManager) km.shutdown_kernel(now=True) assert km.context.closed @@ -287,7 +287,7 @@ def test_cleanup_context(self, km): def test_no_cleanup_shared_context(self, zmq_context): """kernel manager does not terminate shared context""" - km = BlockingKernelManager(context=zmq_context) + km = KernelManager(context=zmq_context) assert km.context == zmq_context assert km.context is not None @@ -400,7 +400,7 @@ def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs): return kc def _run_signaltest_lifecycle(self, config=None): - km = BlockingKernelManager(config=config, kernel_name='signaltest') + km = KernelManager(config=config, kernel_name='signaltest') kc = self._prepare_kernel(km, stdout=PIPE, stderr=PIPE) def execute(cmd): diff --git a/jupyter_client/tests/test_public_api.py b/jupyter_client/tests/test_public_api.py index d77679578..5ebf2f3d3 100644 --- a/jupyter_client/tests/test_public_api.py +++ b/jupyter_client/tests/test_public_api.py @@ -9,7 +9,7 @@ def test_kms(): - for base in ("", "Blocking", "Async", "Multi"): + for base in ("", "Async", "Multi"): KM = base + "KernelManager" assert KM in dir(jupyter_client) From fda3ebaf8217372e82dc3ecb4f67082e912b6282 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 12 Mar 2021 19:32:01 +0100 Subject: [PATCH 4/9] Refactor BlockingKernelClient/AsyncKernelClient --- jupyter_client/asynchronous/channels.py | 82 ----- jupyter_client/asynchronous/client.py | 391 ++------------------- jupyter_client/blocking/channels.py | 88 ----- jupyter_client/blocking/client.py | 330 ++--------------- jupyter_client/channels.py | 78 ++++ jupyter_client/client.py | 374 ++++++++++++++++++-- jupyter_client/manager.py | 2 +- jupyter_client/tests/test_kernelapp.py | 1 + jupyter_client/tests/test_kernelmanager.py | 8 +- jupyter_client/util.py | 17 +- setup.py | 1 + 11 files changed, 498 insertions(+), 874 deletions(-) delete mode 100644 jupyter_client/asynchronous/channels.py delete mode 100644 jupyter_client/blocking/channels.py diff --git a/jupyter_client/asynchronous/channels.py b/jupyter_client/asynchronous/channels.py deleted file mode 100644 index b6f49bd36..000000000 --- a/jupyter_client/asynchronous/channels.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Async channels""" - -# Copyright (c) Jupyter Development Team. -# Distributed under the terms of the Modified BSD License. - -from queue import Queue, Empty - - -class ZMQSocketChannel(object): - """A ZMQ socket in an async API""" - session = None - socket = None - stream = None - _exiting = False - proxy_methods = [] - - def __init__(self, socket, session, loop=None): - """Create a channel. - - Parameters - ---------- - socket : :class:`zmq.asyncio.Socket` - The ZMQ socket to use. - session : :class:`session.Session` - The session to use. - loop - Unused here, for other implementations - """ - super().__init__() - - self.socket = socket - self.session = session - - async def _recv(self, **kwargs): - msg = await self.socket.recv_multipart(**kwargs) - ident,smsg = self.session.feed_identities(msg) - return self.session.deserialize(smsg) - - async def get_msg(self, timeout=None): - """ Gets a message if there is one that is ready. """ - if timeout is not None: - timeout *= 1000 # seconds to ms - ready = await self.socket.poll(timeout) - - if ready: - return await self._recv() - else: - raise Empty - - async def get_msgs(self): - """ Get all messages that are currently ready. """ - msgs = [] - while True: - try: - msgs.append(await self.get_msg()) - except Empty: - break - return msgs - - async def msg_ready(self): - """ Is there a message that has been received? """ - return bool(await self.socket.poll(timeout=0)) - - def close(self): - if self.socket is not None: - try: - self.socket.close(linger=0) - except Exception: - pass - self.socket = None - stop = close - - def is_alive(self): - return (self.socket is not None) - - def send(self, msg): - """Pass a message to the ZMQ socket to send - """ - self.session.send(self.socket, msg) - - def start(self): - pass diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 1a21e3ac8..7a2632550 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -2,59 +2,10 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from functools import partial -from getpass import getpass -from queue import Empty -import sys -import time - -import zmq -import zmq.asyncio -import asyncio - from traitlets import (Type, Instance) -from jupyter_client.channels import HBChannel -from jupyter_client.client import KernelClient -from .channels import ZMQSocketChannel - - -def reqrep(meth, channel='shell'): - def wrapped(self, *args, **kwargs): - reply = kwargs.pop('reply', False) - timeout = kwargs.pop('timeout', None) - msg_id = meth(self, *args, **kwargs) - if not reply: - return msg_id - - return self._recv_reply(msg_id, timeout=timeout, channel=channel) - - if not meth.__doc__: - # python -OO removes docstrings, - # so don't bother building the wrapped docstring - return wrapped +from jupyter_client.channels import HBChannel, ZMQSocketChannel +from jupyter_client.client import KernelClient, reqrep - basedoc, _ = meth.__doc__.split('Returns\n', 1) - parts = [basedoc.strip()] - if 'Parameters' not in basedoc: - parts.append(""" - Parameters - ---------- - """) - parts.append(""" - reply: bool (default: False) - Whether to wait for and return reply - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - Returns - ------- - msg_id: str - The msg_id of the request sent, if reply=False (default) - reply: dict - The reply message for this request, if reply=True - """) - wrapped.__doc__ = '\n'.join(parts) - return wrapped class AsyncKernelClient(KernelClient): """A KernelClient with async APIs @@ -63,98 +14,28 @@ class AsyncKernelClient(KernelClient): raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds. """ - # The PyZMQ Context to use for communication with the kernel. - context = Instance(zmq.asyncio.Context) - def _context_default(self): - return zmq.asyncio.Context() - #-------------------------------------------------------------------------- # Channel proxy methods #-------------------------------------------------------------------------- - async def get_shell_msg(self, *args, **kwargs): - """Get a message from the shell channel""" - return await self.shell_channel.get_msg(*args, **kwargs) - - async def get_iopub_msg(self, *args, **kwargs): - """Get a message from the iopub channel""" - return await self.iopub_channel.get_msg(*args, **kwargs) - - async def get_stdin_msg(self, *args, **kwargs): - """Get a message from the stdin channel""" - return await self.stdin_channel.get_msg(*args, **kwargs) - - async def get_control_msg(self, *args, **kwargs): - """Get a message from the control channel""" - return await self.control_channel.get_msg(*args, **kwargs) - - @property - def hb_channel(self): - """Get the hb channel object for this kernel.""" - if self._hb_channel is None: - url = self._make_url('hb') - self.log.debug("connecting heartbeat channel to %s", url) - loop = asyncio.new_event_loop() - self._hb_channel = self.hb_channel_class( - self.context, self.session, url, loop - ) - return self._hb_channel - - async def wait_for_ready(self, timeout=None): - """Waits for a response when a client is blocked - - - Sets future time for timeout - - Blocks on shell channel until a message is received - - Exit if the kernel has died - - If client times out before receiving a message from the kernel, send RuntimeError - - Flush the IOPub channel - """ - if timeout is None: - abs_timeout = float('inf') - else: - abs_timeout = time.time() + timeout - - from ..manager import KernelManager - if not isinstance(self.parent, KernelManager): - # This Client was not created by a KernelManager, - # so wait for kernel to become responsive to heartbeats - # before checking for kernel_info reply - while not self.is_alive(): - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) - await asyncio.sleep(0.2) - - # Wait for kernel info reply on shell channel - while True: - self.kernel_info() - try: - msg = await self.shell_channel.get_msg(timeout=1) - except Empty: - pass - else: - if msg['msg_type'] == 'kernel_info_reply': - # Checking that IOPub is connected. If it is not connected, start over. - try: - await self.iopub_channel.get_msg(timeout=0.2) - except Empty: - pass - else: - self._handle_kernel_info_reply(msg) - break - - if not await self.is_alive(): - raise RuntimeError('Kernel died before replying to kernel_info') - - # Check if current time is ready check time plus timeout - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond in %d seconds" % timeout) - - # Flush IOPub channel - while True: - try: - msg = await self.iopub_channel.get_msg(timeout=0.2) - except Empty: - break + get_shell_msg = KernelClient._async_get_shell_msg + get_iopub_msg = KernelClient._async_get_iopub_msg + get_stdin_msg = KernelClient._async_get_stdin_msg + get_control_msg = KernelClient._async_get_control_msg + + #@property + #def hb_channel(self): + # """Get the hb channel object for this kernel.""" + # if self._hb_channel is None: + # url = self._make_url('hb') + # self.log.debug("connecting heartbeat channel to %s", url) + # loop = asyncio.new_event_loop() + # self._hb_channel = self.hb_channel_class( + # self.context, self.session, url, loop + # ) + # return self._hb_channel + + wait_for_ready = KernelClient._async_wait_for_ready # The classes to use for the various channels shell_channel_class = Type(ZMQSocketChannel) @@ -164,232 +45,24 @@ async def wait_for_ready(self, timeout=None): control_channel_class = Type(ZMQSocketChannel) - async def _recv_reply(self, msg_id, timeout=None, channel='shell'): - """Receive and return the reply for a given request""" - if timeout is not None: - deadline = time.monotonic() + timeout - while True: - if timeout is not None: - timeout = max(0, deadline - time.monotonic()) - try: - if channel == 'control': - reply = await self.get_control_msg(timeout=timeout) - else: - reply = await self.get_shell_msg(timeout=timeout) - except Empty as e: - raise TimeoutError("Timeout waiting for reply") from e - if reply['parent_header'].get('msg_id') != msg_id: - # not my reply, someone may have forgotten to retrieve theirs - continue - return reply + _recv_reply = KernelClient._async__recv_reply # replies come on the shell channel - execute = reqrep(KernelClient.execute) - history = reqrep(KernelClient.history) - complete = reqrep(KernelClient.complete) - inspect = reqrep(KernelClient.inspect) - kernel_info = reqrep(KernelClient.kernel_info) - comm_info = reqrep(KernelClient.comm_info) + execute = reqrep(KernelClient._async_execute) + history = reqrep(KernelClient._async_history) + complete = reqrep(KernelClient._async_complete) + inspect = reqrep(KernelClient._async_inspect) + kernel_info = reqrep(KernelClient._async_kernel_info) + comm_info = reqrep(KernelClient._async_comm_info) # replies come on the control channel - shutdown = reqrep(KernelClient.shutdown, channel='control') - - - def _stdin_hook_default(self, msg): - """Handle an input request""" - content = msg['content'] - if content.get('password', False): - prompt = getpass - else: - prompt = input - - try: - raw_data = prompt(content["prompt"]) - except EOFError: - # turn EOFError into EOF character - raw_data = '\x04' - except KeyboardInterrupt: - sys.stdout.write('\n') - return - - # only send stdin reply if there *was not* another request - # or execution finished while we were reading. - if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): - self.input(raw_data) - - def _output_hook_default(self, msg): - """Default hook for redisplaying plain-text output""" - msg_type = msg['header']['msg_type'] - content = msg['content'] - if msg_type == 'stream': - stream = getattr(sys, content['name']) - stream.write(content['text']) - elif msg_type in ('display_data', 'execute_result'): - sys.stdout.write(content['data'].get('text/plain', '')) - elif msg_type == 'error': - print('\n'.join(content['traceback']), file=sys.stderr) - - def _output_hook_kernel(self, session, socket, parent_header, msg): - """Output hook when running inside an IPython kernel - - adds rich output support. - """ - msg_type = msg['header']['msg_type'] - if msg_type in ('display_data', 'execute_result', 'error'): - session.send(socket, msg_type, msg['content'], parent=parent_header) - else: - self._output_hook_default(msg) - - async def is_alive(self): - """Is the kernel process still running?""" - from ..manager import KernelManager, AsyncKernelManager - if isinstance(self.parent, KernelManager): - # This KernelClient was created by a KernelManager, - # we can ask the parent KernelManager: - if isinstance(self.parent, AsyncKernelManager): - return await self.parent.is_alive() - return self.parent.is_alive() - if self._hb_channel is not None: - # We don't have access to the KernelManager, - # so we use the heartbeat. - return self._hb_channel.is_beating() - else: - # no heartbeat and not local, we can't tell if it's running, - # so naively return True - return True - - async def execute_interactive(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True, - timeout=None, output_hook=None, stdin_hook=None, - ): - """Execute code in the kernel interactively - - Output will be redisplayed, and stdin prompts will be relayed as well. - If an IPython kernel is detected, rich output will be displayed. - - You can pass a custom output_hook callable that will be called - with every IOPub message that is produced instead of the default redisplay. - - Parameters - ---------- - code : str - A string of code in the kernel's language. - - silent : bool, optional (default False) - If set, the kernel will execute the code as quietly possible, and - will force store_history to be False. - - store_history : bool, optional (default True) - If set, the kernel will store command history. This is forced - to be False if silent is True. - - user_expressions : dict, optional - A dict mapping names to expressions to be evaluated in the user's - dict. The expression values are returned as strings formatted using - :func:`repr`. - - allow_stdin : bool, optional (default self.allow_stdin) - Flag for whether the kernel can send stdin requests to frontends. - - Some frontends (e.g. the Notebook) do not support stdin requests. - If raw_input is called from code executed from such a frontend, a - StdinNotImplementedError will be raised. - - stop_on_error: bool, optional (default True) - Flag whether to abort the execution queue, if an exception is encountered. - - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - output_hook: callable(msg) - Function to be called with output messages. - If not specified, output will be redisplayed. - - stdin_hook: callable(msg) - Function to be called with stdin_request messages. - If not specified, input/getpass will be called. - - Returns - ------- - reply: dict - The reply message for this request - """ - if not self.iopub_channel.is_alive(): - raise RuntimeError("IOPub channel must be running to receive output") - if allow_stdin is None: - allow_stdin = self.allow_stdin - if allow_stdin and not self.stdin_channel.is_alive(): - raise RuntimeError("stdin channel must be running to allow input") - msg_id = await self.execute(code, - silent=silent, - store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, - stop_on_error=stop_on_error, - ) - if stdin_hook is None: - stdin_hook = self._stdin_hook_default - if output_hook is None: - # detect IPython kernel - if 'IPython' in sys.modules: - from IPython import get_ipython - ip = get_ipython() - in_kernel = getattr(ip, 'kernel', False) - if in_kernel: - output_hook = partial( - self._output_hook_kernel, - ip.display_pub.session, - ip.display_pub.pub_socket, - ip.display_pub.parent_header, - ) - if output_hook is None: - # default: redisplay plain-text outputs - output_hook = self._output_hook_default - - # set deadline based on timeout - if timeout is not None: - deadline = time.monotonic() + timeout - else: - timeout_ms = None - - poller = zmq.Poller() - iopub_socket = self.iopub_channel.socket - poller.register(iopub_socket, zmq.POLLIN) - if allow_stdin: - stdin_socket = self.stdin_channel.socket - poller.register(stdin_socket, zmq.POLLIN) - else: - stdin_socket = None - - # wait for output and redisplay it - while True: - if timeout is not None: - timeout = max(0, deadline - time.monotonic()) - timeout_ms = 1e3 * timeout - events = dict(poller.poll(timeout_ms)) - if not events: - raise TimeoutError("Timeout waiting for output") - if stdin_socket in events: - req = await self.stdin_channel.get_msg(timeout=0) - stdin_hook(req) - continue - if iopub_socket not in events: - continue + shutdown = reqrep(KernelClient._async_shutdown, channel='control') - msg = await self.iopub_channel.get_msg(timeout=0) + is_alive = KernelClient._async_is_alive - if msg['parent_header'].get('msg_id') != msg_id: - # not from my request - continue - output_hook(msg) + execute_interactive = KernelClient._async_execute_interactive - # stop on idle - if msg['header']['msg_type'] == 'status' and \ - msg['content']['execution_state'] == 'idle': - break + stop_channels = KernelClient._async_stop_channels - # output is done, get the reply - if timeout is not None: - timeout = max(0, deadline - time.monotonic()) - return await self._recv_reply(msg_id, timeout=timeout) + channels_running = property(KernelClient._async_channels_running) diff --git a/jupyter_client/blocking/channels.py b/jupyter_client/blocking/channels.py deleted file mode 100644 index ab24b692d..000000000 --- a/jupyter_client/blocking/channels.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Blocking channels - -Useful for test suites and blocking terminal interfaces. -""" - -# Copyright (c) Jupyter Development Team. -# Distributed under the terms of the Modified BSD License. - -from queue import Queue, Empty - - -class ZMQSocketChannel(object): - """A ZMQ socket in a simple blocking API""" - session = None - socket = None - stream = None - _exiting = False - proxy_methods = [] - - def __init__(self, socket, session, loop=None): - """Create a channel. - - Parameters - ---------- - socket : :class:`zmq.Socket` - The ZMQ socket to use. - session : :class:`session.Session` - The session to use. - loop - Unused here, for other implementations - """ - super().__init__() - - self.socket = socket - self.session = session - - def _recv(self, **kwargs): - msg = self.socket.recv_multipart(**kwargs) - ident,smsg = self.session.feed_identities(msg) - return self.session.deserialize(smsg) - - def get_msg(self, block=True, timeout=None): - """ Gets a message if there is one that is ready. """ - if block: - if timeout is not None: - timeout *= 1000 # seconds to ms - ready = self.socket.poll(timeout) - else: - ready = self.socket.poll(timeout=0) - - if ready: - return self._recv() - else: - raise Empty - - def get_msgs(self): - """ Get all messages that are currently ready. """ - msgs = [] - while True: - try: - msgs.append(self.get_msg(block=False)) - except Empty: - break - return msgs - - def msg_ready(self): - """ Is there a message that has been received? """ - return bool(self.socket.poll(timeout=0)) - - def close(self): - if self.socket is not None: - try: - self.socket.close(linger=0) - except Exception: - pass - self.socket = None - stop = close - - def is_alive(self): - return (self.socket is not None) - - def send(self, msg): - """Pass a message to the ZMQ socket to send - """ - self.session.send(self.socket, msg) - - def start(self): - pass diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 5f11b798a..ef92d2961 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -5,58 +5,11 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from functools import partial -from getpass import getpass -from queue import Empty -import sys -import time - -import zmq - -from time import monotonic from traitlets import Type -from jupyter_client.channels import HBChannel -from jupyter_client.client import KernelClient -from .channels import ZMQSocketChannel - - -def reqrep(meth, channel='shell'): - def wrapped(self, *args, **kwargs): - reply = kwargs.pop('reply', False) - timeout = kwargs.pop('timeout', None) - msg_id = meth(self, *args, **kwargs) - if not reply: - return msg_id - - return self._recv_reply(msg_id, timeout=timeout, channel=channel) +from jupyter_client.channels import HBChannel, ZMQSocketChannel +from jupyter_client.client import KernelClient, reqrep +from ..util import run_sync - if not meth.__doc__: - # python -OO removes docstrings, - # so don't bother building the wrapped docstring - return wrapped - - basedoc, _ = meth.__doc__.split('Returns\n', 1) - parts = [basedoc.strip()] - if 'Parameters' not in basedoc: - parts.append(""" - Parameters - ---------- - """) - parts.append(""" - reply: bool (default: False) - Whether to wait for and return reply - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - Returns - ------- - msg_id: str - The msg_id of the request sent, if reply=False (default) - reply: dict - The reply message for this request, if reply=True - """) - wrapped.__doc__ = '\n'.join(parts) - return wrapped class BlockingKernelClient(KernelClient): """A KernelClient with blocking APIs @@ -65,61 +18,16 @@ class BlockingKernelClient(KernelClient): raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds. """ - def wait_for_ready(self, timeout=None): - """Waits for a response when a client is blocked - - - Sets future time for timeout - - Blocks on shell channel until a message is received - - Exit if the kernel has died - - If client times out before receiving a message from the kernel, send RuntimeError - - Flush the IOPub channel - """ - if timeout is None: - abs_timeout = float('inf') - else: - abs_timeout = time.time() + timeout - - from ..manager import KernelManager - if not isinstance(self.parent, KernelManager): - # This Client was not created by a KernelManager, - # so wait for kernel to become responsive to heartbeats - # before checking for kernel_info reply - while not self.is_alive(): - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) - time.sleep(0.2) + #-------------------------------------------------------------------------- + # Channel proxy methods + #-------------------------------------------------------------------------- - # Wait for kernel info reply on shell channel - while True: - self.kernel_info() - try: - msg = self.shell_channel.get_msg(block=True, timeout=1) - except Empty: - pass - else: - if msg['msg_type'] == 'kernel_info_reply': - # Checking that IOPub is connected. If it is not connected, start over. - try: - self.iopub_channel.get_msg(block=True, timeout=0.2) - except Empty: - pass - else: - self._handle_kernel_info_reply(msg) - break + get_shell_msg = run_sync(KernelClient._async_get_shell_msg) + get_iopub_msg = run_sync(KernelClient._async_get_iopub_msg) + get_stdin_msg = run_sync(KernelClient._async_get_stdin_msg) + get_control_msg = run_sync(KernelClient._async_get_control_msg) - if not self.is_alive(): - raise RuntimeError('Kernel died before replying to kernel_info') - - # Check if current time is ready check time plus timeout - if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond in %d seconds" % timeout) - - # Flush IOPub channel - while True: - try: - msg = self.iopub_channel.get_msg(block=True, timeout=0.2) - except Empty: - break + wait_for_ready = run_sync(KernelClient._async_wait_for_ready) # The classes to use for the various channels shell_channel_class = Type(ZMQSocketChannel) @@ -129,216 +37,24 @@ def wait_for_ready(self, timeout=None): control_channel_class = Type(ZMQSocketChannel) - def _recv_reply(self, msg_id, timeout=None, channel='shell'): - """Receive and return the reply for a given request""" - if timeout is not None: - deadline = monotonic() + timeout - while True: - if timeout is not None: - timeout = max(0, deadline - monotonic()) - try: - if channel == 'control': - reply = self.get_control_msg(timeout=timeout) - else: - reply = self.get_shell_msg(timeout=timeout) - except Empty as e: - raise TimeoutError("Timeout waiting for reply") from e - if reply['parent_header'].get('msg_id') != msg_id: - # not my reply, someone may have forgotten to retrieve theirs - continue - return reply + _recv_reply = run_sync(KernelClient._async__recv_reply) # replies come on the shell channel - execute = reqrep(KernelClient.execute) - history = reqrep(KernelClient.history) - complete = reqrep(KernelClient.complete) - inspect = reqrep(KernelClient.inspect) - kernel_info = reqrep(KernelClient.kernel_info) - comm_info = reqrep(KernelClient.comm_info) + execute = run_sync(reqrep(KernelClient._async_execute)) + history = run_sync(reqrep(KernelClient._async_history)) + complete = run_sync(reqrep(KernelClient._async_complete)) + inspect = run_sync(reqrep(KernelClient._async_inspect)) + kernel_info = run_sync(reqrep(KernelClient._async_kernel_info)) + comm_info = run_sync(reqrep(KernelClient._async_comm_info)) # replies come on the control channel - shutdown = reqrep(KernelClient.shutdown, channel='control') - - - def _stdin_hook_default(self, msg): - """Handle an input request""" - content = msg['content'] - if content.get('password', False): - prompt = getpass - else: - prompt = input - - try: - raw_data = prompt(content["prompt"]) - except EOFError: - # turn EOFError into EOF character - raw_data = '\x04' - except KeyboardInterrupt: - sys.stdout.write('\n') - return - - # only send stdin reply if there *was not* another request - # or execution finished while we were reading. - if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): - self.input(raw_data) - - def _output_hook_default(self, msg): - """Default hook for redisplaying plain-text output""" - msg_type = msg['header']['msg_type'] - content = msg['content'] - if msg_type == 'stream': - stream = getattr(sys, content['name']) - stream.write(content['text']) - elif msg_type in ('display_data', 'execute_result'): - sys.stdout.write(content['data'].get('text/plain', '')) - elif msg_type == 'error': - print('\n'.join(content['traceback']), file=sys.stderr) - - def _output_hook_kernel(self, session, socket, parent_header, msg): - """Output hook when running inside an IPython kernel - - adds rich output support. - """ - msg_type = msg['header']['msg_type'] - if msg_type in ('display_data', 'execute_result', 'error'): - session.send(socket, msg_type, msg['content'], parent=parent_header) - else: - self._output_hook_default(msg) - - def execute_interactive(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True, - timeout=None, output_hook=None, stdin_hook=None, - ): - """Execute code in the kernel interactively - - Output will be redisplayed, and stdin prompts will be relayed as well. - If an IPython kernel is detected, rich output will be displayed. - - You can pass a custom output_hook callable that will be called - with every IOPub message that is produced instead of the default redisplay. - - .. versionadded:: 5.0 - - Parameters - ---------- - code : str - A string of code in the kernel's language. - - silent : bool, optional (default False) - If set, the kernel will execute the code as quietly possible, and - will force store_history to be False. - - store_history : bool, optional (default True) - If set, the kernel will store command history. This is forced - to be False if silent is True. - - user_expressions : dict, optional - A dict mapping names to expressions to be evaluated in the user's - dict. The expression values are returned as strings formatted using - :func:`repr`. - - allow_stdin : bool, optional (default self.allow_stdin) - Flag for whether the kernel can send stdin requests to frontends. - - Some frontends (e.g. the Notebook) do not support stdin requests. - If raw_input is called from code executed from such a frontend, a - StdinNotImplementedError will be raised. - - stop_on_error: bool, optional (default True) - Flag whether to abort the execution queue, if an exception is encountered. - - timeout: float or None (default: None) - Timeout to use when waiting for a reply - - output_hook: callable(msg) - Function to be called with output messages. - If not specified, output will be redisplayed. - - stdin_hook: callable(msg) - Function to be called with stdin_request messages. - If not specified, input/getpass will be called. - - Returns - ------- - reply: dict - The reply message for this request - """ - if not self.iopub_channel.is_alive(): - raise RuntimeError("IOPub channel must be running to receive output") - if allow_stdin is None: - allow_stdin = self.allow_stdin - if allow_stdin and not self.stdin_channel.is_alive(): - raise RuntimeError("stdin channel must be running to allow input") - msg_id = self.execute(code, - silent=silent, - store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, - stop_on_error=stop_on_error, - ) - if stdin_hook is None: - stdin_hook = self._stdin_hook_default - if output_hook is None: - # detect IPython kernel - if 'IPython' in sys.modules: - from IPython import get_ipython - ip = get_ipython() - in_kernel = getattr(ip, 'kernel', False) - if in_kernel: - output_hook = partial( - self._output_hook_kernel, - ip.display_pub.session, - ip.display_pub.pub_socket, - ip.display_pub.parent_header, - ) - if output_hook is None: - # default: redisplay plain-text outputs - output_hook = self._output_hook_default - - # set deadline based on timeout - if timeout is not None: - deadline = monotonic() + timeout - else: - timeout_ms = None - - poller = zmq.Poller() - iopub_socket = self.iopub_channel.socket - poller.register(iopub_socket, zmq.POLLIN) - if allow_stdin: - stdin_socket = self.stdin_channel.socket - poller.register(stdin_socket, zmq.POLLIN) - else: - stdin_socket = None - - # wait for output and redisplay it - while True: - if timeout is not None: - timeout = max(0, deadline - monotonic()) - timeout_ms = 1e3 * timeout - events = dict(poller.poll(timeout_ms)) - if not events: - raise TimeoutError("Timeout waiting for output") - if stdin_socket in events: - req = self.stdin_channel.get_msg(timeout=0) - stdin_hook(req) - continue - if iopub_socket not in events: - continue + shutdown = run_sync(reqrep(KernelClient._async_shutdown, channel='control')) - msg = self.iopub_channel.get_msg(timeout=0) + is_alive = run_sync(KernelClient._async_is_alive) - if msg['parent_header'].get('msg_id') != msg_id: - # not from my request - continue - output_hook(msg) + execute_interactive = run_sync(KernelClient._async_execute_interactive) - # stop on idle - if msg['header']['msg_type'] == 'status' and \ - msg['content']['execution_state'] == 'idle': - break + stop_channels = run_sync(KernelClient._async_stop_channels) - # output is done, get the reply - if timeout is not None: - timeout = max(0, deadline - monotonic()) - return self._recv_reply(msg_id, timeout=timeout) + channels_running = property(run_sync(KernelClient._async_channels_running)) diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index 90746e202..5c14f627f 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -8,6 +8,7 @@ from threading import Thread, Event import time import asyncio +from queue import Empty import zmq # import ZMQError in top-level namespace, to avoid ugly attribute-error messages @@ -211,3 +212,80 @@ def call_handlers(self, since_last_heartbeat): HBChannelABC.register(HBChannel) + + +class ZMQSocketChannel(object): + """A ZMQ socket in an async API""" + session = None + socket = None + stream = None + _exiting = False + proxy_methods = [] + + def __init__(self, socket, session, loop=None): + """Create a channel. + + Parameters + ---------- + socket : :class:`zmq.asyncio.Socket` + The ZMQ socket to use. + session : :class:`session.Session` + The session to use. + loop + Unused here, for other implementations + """ + super().__init__() + + self.socket = socket + self.session = session + + async def _recv(self, **kwargs): + msg = await self.socket.recv_multipart(**kwargs) + ident, smsg = self.session.feed_identities(msg) + return self.session.deserialize(smsg) + + async def get_msg(self, timeout=None): + """ Gets a message if there is one that is ready. """ + if timeout is not None: + timeout *= 1000 # seconds to ms + ready = await self.socket.poll(timeout) + + if ready: + res = await self._recv() + return res + else: + raise Empty + + async def get_msgs(self): + """ Get all messages that are currently ready. """ + msgs = [] + while True: + try: + msgs.append(await self.get_msg()) + except Empty: + break + return msgs + + async def msg_ready(self): + """ Is there a message that has been received? """ + return bool(await self.socket.poll(timeout=0)) + + def close(self): + if self.socket is not None: + try: + self.socket.close(linger=0) + except Exception: + pass + self.socket = None + stop = close + + def is_alive(self): + return (self.socket is not None) + + def send(self, msg): + """Pass a message to the ZMQ socket to send + """ + self.session.send(self.socket, msg) + + def start(self): + pass diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 7c8028680..84652dc3e 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -3,19 +3,28 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import sys +import asyncio +import time +from functools import partial +from getpass import getpass +from queue import Empty import typing as t + from jupyter_client.channels import major_protocol_version import zmq +import zmq.asyncio from traitlets import ( # type: ignore - Any, Instance, Type, + Any, Instance, Type, default ) from .channelsabc import (ChannelABC, HBChannelABC) from .clientabc import KernelClientABC from .connect import ConnectionFileMixin +from .util import ensure_async # some utilities to validate message structure, these might get moved elsewhere @@ -34,6 +43,45 @@ def validate_string_dict( raise ValueError('value %r in dict must be a string' % v) +def reqrep(meth, channel='shell'): + async def wrapped(self, *args, **kwargs): + reply = kwargs.pop('reply', False) + timeout = kwargs.pop('timeout', None) + msg_id = await meth(self, *args, **kwargs) + if not reply: + return msg_id + + return await self._async__recv_reply(msg_id, timeout=timeout, channel=channel) + + if not meth.__doc__: + # python -OO removes docstrings, + # so don't bother building the wrapped docstring + return wrapped + + basedoc, _ = meth.__doc__.split('Returns\n', 1) + parts = [basedoc.strip()] + if 'Parameters' not in basedoc: + parts.append(""" + Parameters + ---------- + """) + parts.append(""" + reply: bool (default: False) + Whether to wait for and return reply + timeout: float or None (default: None) + Timeout to use when waiting for a reply + + Returns + ------- + msg_id: str + The msg_id of the request sent, if reply=False (default) + reply: dict + The reply message for this request, if reply=True + """) + wrapped.__doc__ = '\n'.join(parts) + return wrapped + + class KernelClient(ConnectionFileMixin): """Communicates with a single kernel on any host via zmq channels. @@ -52,9 +100,11 @@ class KernelClient(ConnectionFileMixin): """ # The PyZMQ Context to use for communication with the kernel. - context = Instance(zmq.Context) - def _context_default(self) -> zmq.Context: - return zmq.Context() + context = Instance(zmq.asyncio.Context) + + @default('context') + def _context_default(self): + return zmq.asyncio.Context() # The classes to use for the various channels shell_channel_class = Type(ChannelABC) @@ -77,21 +127,143 @@ def _context_default(self) -> zmq.Context: # Channel proxy methods #-------------------------------------------------------------------------- - def get_shell_msg(self, *args, **kwargs) -> None: + async def _async_get_shell_msg(self, *args, **kwargs): """Get a message from the shell channel""" - return self.shell_channel.get_msg(*args, **kwargs) + return await self.shell_channel.get_msg(*args, **kwargs) - def get_iopub_msg(self, *args, **kwargs): + async def _async_get_iopub_msg(self, *args, **kwargs): """Get a message from the iopub channel""" - return self.iopub_channel.get_msg(*args, **kwargs) + return await self.iopub_channel.get_msg(*args, **kwargs) - def get_stdin_msg(self, *args, **kwargs): + async def _async_get_stdin_msg(self, *args, **kwargs): """Get a message from the stdin channel""" - return self.stdin_channel.get_msg(*args, **kwargs) + return await self.stdin_channel.get_msg(*args, **kwargs) - def get_control_msg(self, *args, **kwargs): + async def _async_get_control_msg(self, *args, **kwargs): """Get a message from the control channel""" - return self.control_channel.get_msg(*args, **kwargs) + return await self.control_channel.get_msg(*args, **kwargs) + + async def _async_wait_for_ready(self, timeout=None): + """Waits for a response when a client is blocked + + - Sets future time for timeout + - Blocks on shell channel until a message is received + - Exit if the kernel has died + - If client times out before receiving a message from the kernel, send RuntimeError + - Flush the IOPub channel + """ + if timeout is None: + abs_timeout = float('inf') + else: + abs_timeout = time.time() + timeout + + from .manager import KernelManager + if not isinstance(self.parent, KernelManager): + # This Client was not created by a KernelManager, + # so wait for kernel to become responsive to heartbeats + # before checking for kernel_info reply + while not await ensure_async(self.is_alive()): + if time.time() > abs_timeout: + raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) + await asyncio.sleep(0.2) + + # Wait for kernel info reply on shell channel + while True: + await self._async_kernel_info() + try: + msg = await self.shell_channel.get_msg(timeout=1) + except Empty: + pass + else: + if msg['msg_type'] == 'kernel_info_reply': + # Checking that IOPub is connected. If it is not connected, start over. + try: + await self.iopub_channel.get_msg(timeout=0.2) + except Empty: + pass + else: + self._handle_kernel_info_reply(msg) + break + + if not await ensure_async(self.is_alive()): + raise RuntimeError('Kernel died before replying to kernel_info') + + # Check if current time is ready check time plus timeout + if time.time() > abs_timeout: + raise RuntimeError("Kernel didn't respond in %d seconds" % timeout) + + # Flush IOPub channel + while True: + try: + msg = await self.iopub_channel.get_msg(timeout=0.2) + except Empty: + break + + async def _async__recv_reply(self, msg_id, timeout=None, channel='shell'): + """Receive and return the reply for a given request""" + if timeout is not None: + deadline = time.monotonic() + timeout + while True: + if timeout is not None: + timeout = max(0, deadline - time.monotonic()) + try: + if channel == 'control': + reply = await self._async_get_control_msg(timeout=timeout) + else: + reply = await self._async_get_shell_msg(timeout=timeout) + except Empty as e: + raise TimeoutError("Timeout waiting for reply") from e + if reply['parent_header'].get('msg_id') != msg_id: + # not my reply, someone may have forgotten to retrieve theirs + continue + return reply + + + def _stdin_hook_default(self, msg): + """Handle an input request""" + content = msg['content'] + if content.get('password', False): + prompt = getpass + else: + prompt = input + + try: + raw_data = prompt(content["prompt"]) + except EOFError: + # turn EOFError into EOF character + raw_data = '\x04' + except KeyboardInterrupt: + sys.stdout.write('\n') + return + + # only send stdin reply if there *was not* another request + # or execution finished while we were reading. + if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): + self.input(raw_data) + + def _output_hook_default(self, msg): + """Default hook for redisplaying plain-text output""" + msg_type = msg['header']['msg_type'] + content = msg['content'] + if msg_type == 'stream': + stream = getattr(sys, content['name']) + stream.write(content['text']) + elif msg_type in ('display_data', 'execute_result'): + sys.stdout.write(content['data'].get('text/plain', '')) + elif msg_type == 'error': + print('\n'.join(content['traceback']), file=sys.stderr) + + def _output_hook_kernel(self, session, socket, parent_header, msg): + """Output hook when running inside an IPython kernel + + adds rich output support. + """ + msg_type = msg['header']['msg_type'] + if msg_type in ('display_data', 'execute_result', 'error'): + session.send(socket, msg_type, msg['content'], parent=parent_header) + else: + self._output_hook_default(msg) + #-------------------------------------------------------------------------- # Channel management methods @@ -120,7 +292,7 @@ def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=Tr if control: self.control_channel.start() - def stop_channels(self): + async def _async_stop_channels(self): """Stops all the running channels for this kernel. This stops their event loops and joins their threads. @@ -136,8 +308,8 @@ def stop_channels(self): if self.control_channel.is_alive(): self.control_channel.stop() - @property - def channels_running(self): + #@property + async def _async_channels_running(self): """Are any of the channels created and running?""" return (self.shell_channel.is_alive() or self.iopub_channel.is_alive() or self.stdin_channel.is_alive() or self.hb_channel.is_alive() or @@ -187,11 +359,23 @@ def hb_channel(self): if self._hb_channel is None: url = self._make_url('hb') self.log.debug("connecting heartbeat channel to %s", url) + loop = asyncio.new_event_loop() self._hb_channel = self.hb_channel_class( - self.context, self.session, url + self.context, self.session, url, loop ) return self._hb_channel + #@property + #def hb_channel(self): + # """Get the hb channel object for this kernel.""" + # if self._hb_channel is None: + # url = self._make_url('hb') + # self.log.debug("connecting heartbeat channel to %s", url) + # self._hb_channel = self.hb_channel_class( + # self.context, self.session, url + # ) + # return self._hb_channel + @property def control_channel(self): """Get the control channel object for this kernel.""" @@ -204,13 +388,13 @@ def control_channel(self): ) return self._control_channel - def is_alive(self): + async def _async_is_alive(self): """Is the kernel process still running?""" - from .manager import KernelManager + from .manager import KernelManager, AsyncKernelManager if isinstance(self.parent, KernelManager): # This KernelClient was created by a KernelManager, # we can ask the parent KernelManager: - return self.parent.is_alive() + return await ensure_async(self.parent.is_alive()) if self._hb_channel is not None: # We don't have access to the KernelManager, # so we use the heartbeat. @@ -221,8 +405,146 @@ def is_alive(self): return True + async def _async_execute_interactive(self, code, silent=False, store_history=True, + user_expressions=None, allow_stdin=None, stop_on_error=True, + timeout=None, output_hook=None, stdin_hook=None, + ): + """Execute code in the kernel interactively + + Output will be redisplayed, and stdin prompts will be relayed as well. + If an IPython kernel is detected, rich output will be displayed. + + You can pass a custom output_hook callable that will be called + with every IOPub message that is produced instead of the default redisplay. + + .. versionadded:: 5.0 + + Parameters + ---------- + code : str + A string of code in the kernel's language. + + silent : bool, optional (default False) + If set, the kernel will execute the code as quietly possible, and + will force store_history to be False. + + store_history : bool, optional (default True) + If set, the kernel will store command history. This is forced + to be False if silent is True. + + user_expressions : dict, optional + A dict mapping names to expressions to be evaluated in the user's + dict. The expression values are returned as strings formatted using + :func:`repr`. + + allow_stdin : bool, optional (default self.allow_stdin) + Flag for whether the kernel can send stdin requests to frontends. + + Some frontends (e.g. the Notebook) do not support stdin requests. + If raw_input is called from code executed from such a frontend, a + StdinNotImplementedError will be raised. + + stop_on_error: bool, optional (default True) + Flag whether to abort the execution queue, if an exception is encountered. + + timeout: float or None (default: None) + Timeout to use when waiting for a reply + + output_hook: callable(msg) + Function to be called with output messages. + If not specified, output will be redisplayed. + + stdin_hook: callable(msg) + Function to be called with stdin_request messages. + If not specified, input/getpass will be called. + + Returns + ------- + reply: dict + The reply message for this request + """ + if not self.iopub_channel.is_alive(): + raise RuntimeError("IOPub channel must be running to receive output") + if allow_stdin is None: + allow_stdin = self.allow_stdin + if allow_stdin and not self.stdin_channel.is_alive(): + raise RuntimeError("stdin channel must be running to allow input") + msg_id = await self._async_execute(code, + silent=silent, + store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error, + ) + if stdin_hook is None: + stdin_hook = self._stdin_hook_default + if output_hook is None: + # detect IPython kernel + if 'IPython' in sys.modules: + from IPython import get_ipython + ip = get_ipython() + in_kernel = getattr(ip, 'kernel', False) + if in_kernel: + output_hook = partial( + self._output_hook_kernel, + ip.display_pub.session, + ip.display_pub.pub_socket, + ip.display_pub.parent_header, + ) + if output_hook is None: + # default: redisplay plain-text outputs + output_hook = self._output_hook_default + + # set deadline based on timeout + if timeout is not None: + deadline = time.monotonic() + timeout + else: + timeout_ms = None + + poller = zmq.Poller() + iopub_socket = self.iopub_channel.socket + poller.register(iopub_socket, zmq.POLLIN) + if allow_stdin: + stdin_socket = self.stdin_channel.socket + poller.register(stdin_socket, zmq.POLLIN) + else: + stdin_socket = None + + # wait for output and redisplay it + while True: + if timeout is not None: + timeout = max(0, deadline - time.monotonic()) + timeout_ms = 1e3 * timeout + events = dict(poller.poll(timeout_ms)) + if not events: + raise TimeoutError("Timeout waiting for output") + if stdin_socket in events: + req = await self.stdin_channel.get_msg(timeout=0) + stdin_hook(req) + continue + if iopub_socket not in events: + continue + + msg = await self.iopub_channel.get_msg(timeout=0) + + if msg['parent_header'].get('msg_id') != msg_id: + # not from my request + continue + output_hook(msg) + + # stop on idle + if msg['header']['msg_type'] == 'status' and \ + msg['content']['execution_state'] == 'idle': + break + + # output is done, get the reply + if timeout is not None: + timeout = max(0, deadline - time.monotonic()) + return await self._async__recv_reply(msg_id, timeout=timeout) + + # Methods to send specific messages on channels - def execute(self, code, silent=False, store_history=True, + async def _async_execute(self, code, silent=False, store_history=True, user_expressions=None, allow_stdin=None, stop_on_error=True): """Execute code in the kernel. @@ -279,7 +601,7 @@ def execute(self, code, silent=False, store_history=True, self.shell_channel.send(msg) return msg['header']['msg_id'] - def complete(self, code, cursor_pos=None): + async def _async_complete(self, code, cursor_pos=None): """Tab complete text in the kernel's namespace. Parameters @@ -302,7 +624,7 @@ def complete(self, code, cursor_pos=None): self.shell_channel.send(msg) return msg['header']['msg_id'] - def inspect(self, code, cursor_pos=None, detail_level=0): + async def _async_inspect(self, code, cursor_pos=None, detail_level=0): """Get metadata information about an object in the kernel's namespace. It is up to the kernel to determine the appropriate object to inspect. @@ -331,7 +653,7 @@ def inspect(self, code, cursor_pos=None, detail_level=0): self.shell_channel.send(msg) return msg['header']['msg_id'] - def history(self, raw=True, output=False, hist_access_type='range', **kwargs): + async def _async_history(self, raw=True, output=False, hist_access_type='range', **kwargs): """Get entries from the kernel's history list. Parameters @@ -372,7 +694,7 @@ def history(self, raw=True, output=False, hist_access_type='range', **kwargs): self.shell_channel.send(msg) return msg['header']['msg_id'] - def kernel_info(self): + async def _async_kernel_info(self): """Request kernel info Returns @@ -383,7 +705,7 @@ def kernel_info(self): self.shell_channel.send(msg) return msg['header']['msg_id'] - def comm_info(self, target_name=None): + async def _async_comm_info(self, target_name=None): """Request comm info Returns @@ -424,7 +746,7 @@ def input(self, string): msg = self.session.msg('input_reply', content) self.stdin_channel.send(msg) - def shutdown(self, restart=False): + async def _async_shutdown(self, restart=False): """Request an immediate kernel shutdown on the control channel. Upon receipt of the (empty) reply, client code can safely assume that diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index ec055f25c..a66ecd576 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -771,7 +771,7 @@ async def start_new_async_kernel( try: await kc.wait_for_ready(timeout=startup_timeout) except RuntimeError: - kc.stop_channels() + await kc.stop_channels() await km.shutdown_kernel() raise diff --git a/jupyter_client/tests/test_kernelapp.py b/jupyter_client/tests/test_kernelapp.py index 17f793a98..38ef50082 100644 --- a/jupyter_client/tests/test_kernelapp.py +++ b/jupyter_client/tests/test_kernelapp.py @@ -42,6 +42,7 @@ def test_kernelapp_lifecycle(): assert cf.endswith('.json') # Send SIGTERM to shut down + time.sleep(0.2) p.terminate() _, stderr = p.communicate(timeout=WAIT_TIME) assert cf in stderr.decode('utf-8', 'replace') diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index 7f5ea13dc..ce52d5f70 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -135,7 +135,7 @@ def async_km_subclass(config): async def start_async_kernel(): km, kc = await start_new_async_kernel(kernel_name='signaltest') await yield_((km, kc)) - kc.stop_channels() + await kc.stop_channels() await km.shutdown_kernel() assert km.context.closed @@ -184,7 +184,7 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected): assert km._shutdown_status == _ShutdownStatus.Unset assert await km.is_alive() # kc.execute("1") - kc.stop_channels() + await kc.stop_channels() await km.shutdown_kernel() assert km._shutdown_status == expected @@ -477,7 +477,7 @@ async def test_signal_kernel_subprocesses(self, install_kernel, start_async_kern km, kc = start_async_kernel async def execute(cmd): - request_id = kc.execute(cmd) + request_id = await kc.execute(cmd) while True: reply = await kc.get_shell_msg(TIMEOUT) if reply['parent_header']['msg_id'] == request_id: @@ -496,7 +496,7 @@ async def execute(cmd): assert reply['user_expressions']['poll'] == [None] * N # start a job on the kernel to be interrupted - request_id = kc.execute('sleep') + request_id = await kc.execute('sleep') await asyncio.sleep(1) # ensure sleep message has been handled before we interrupt await km.interrupt_kernel() while True: diff --git a/jupyter_client/util.py b/jupyter_client/util.py index 2e1f2516e..bfbb8e968 100644 --- a/jupyter_client/util.py +++ b/jupyter_client/util.py @@ -1,14 +1,17 @@ -import concurrent.futures import asyncio +import inspect +import nest_asyncio +nest_asyncio.apply() -def asyncio_run(task): - loop = asyncio.new_event_loop() - return loop.run_until_complete(task) +loop = asyncio.get_event_loop() def run_sync(coro): def wrapped(*args, **kwargs): - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio_run, coro(*args, **kwargs)) - return future.result() + return loop.run_until_complete(coro(*args, **kwargs)) wrapped.__doc__ = coro.__doc__ return wrapped + +async def ensure_async(obj): + if inspect.isawaitable(obj): + return await obj + return obj diff --git a/setup.py b/setup.py index 9b4c8f4f2..8100b3f09 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ def run(self): 'pyzmq>=13', 'python-dateutil>=2.1', 'tornado>=4.1', + 'nest-asyncio>=1.5', ], python_requires = '>=3.5', extras_require = { From 26f3bfdfa1546cc1a5067f1b7d47e8f6f64b1349 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 15 Mar 2021 09:08:48 +0100 Subject: [PATCH 5/9] Refactor MultiKernelManager --- .github/workflows/main.yml | 2 +- jupyter_client/asynchronous/client.py | 26 +- jupyter_client/blocking/client.py | 14 +- jupyter_client/channels.py | 12 +- jupyter_client/client.py | 22 +- jupyter_client/manager.py | 38 +- jupyter_client/multikernelmanager.py | 360 +++++++++--------- jupyter_client/tests/test_kernelapp.py | 11 +- jupyter_client/tests/test_kernelmanager.py | 49 ++- .../tests/test_multikernelmanager.py | 70 ++-- jupyter_client/util.py | 12 +- 11 files changed, 297 insertions(+), 319 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ac653328b..fc05d02dd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' pip freeze - name: Check types - run: mypy jupyter_client/manager.py + run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py - name: Run the tests run: py.test --cov jupyter_client -v jupyter_client - name: Code coverage diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 7a2632550..75b8baaf4 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -2,7 +2,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from traitlets import (Type, Instance) +from traitlets import (Type, Instance) # type: ignore from jupyter_client.channels import HBChannel, ZMQSocketChannel from jupyter_client.client import KernelClient, reqrep @@ -23,18 +23,6 @@ class AsyncKernelClient(KernelClient): get_stdin_msg = KernelClient._async_get_stdin_msg get_control_msg = KernelClient._async_get_control_msg - #@property - #def hb_channel(self): - # """Get the hb channel object for this kernel.""" - # if self._hb_channel is None: - # url = self._make_url('hb') - # self.log.debug("connecting heartbeat channel to %s", url) - # loop = asyncio.new_event_loop() - # self._hb_channel = self.hb_channel_class( - # self.context, self.session, url, loop - # ) - # return self._hb_channel - wait_for_ready = KernelClient._async_wait_for_ready # The classes to use for the various channels @@ -45,7 +33,7 @@ class AsyncKernelClient(KernelClient): control_channel_class = Type(ZMQSocketChannel) - _recv_reply = KernelClient._async__recv_reply + _recv_reply = KernelClient._async_recv_reply # replies come on the shell channel @@ -55,14 +43,10 @@ class AsyncKernelClient(KernelClient): inspect = reqrep(KernelClient._async_inspect) kernel_info = reqrep(KernelClient._async_kernel_info) comm_info = reqrep(KernelClient._async_comm_info) - - # replies come on the control channel - shutdown = reqrep(KernelClient._async_shutdown, channel='control') - is_alive = KernelClient._async_is_alive - execute_interactive = KernelClient._async_execute_interactive - stop_channels = KernelClient._async_stop_channels - channels_running = property(KernelClient._async_channels_running) + + # replies come on the control channel + shutdown = reqrep(KernelClient._async_shutdown, channel='control') diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index ef92d2961..85ee6f84d 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -5,7 +5,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from traitlets import Type +from traitlets import Type # type: ignore from jupyter_client.channels import HBChannel, ZMQSocketChannel from jupyter_client.client import KernelClient, reqrep from ..util import run_sync @@ -37,7 +37,7 @@ class BlockingKernelClient(KernelClient): control_channel_class = Type(ZMQSocketChannel) - _recv_reply = run_sync(KernelClient._async__recv_reply) + _recv_reply = run_sync(KernelClient._async_recv_reply) # replies come on the shell channel @@ -47,14 +47,10 @@ class BlockingKernelClient(KernelClient): inspect = run_sync(reqrep(KernelClient._async_inspect)) kernel_info = run_sync(reqrep(KernelClient._async_kernel_info)) comm_info = run_sync(reqrep(KernelClient._async_comm_info)) - - # replies come on the control channel - shutdown = run_sync(reqrep(KernelClient._async_shutdown, channel='control')) - is_alive = run_sync(KernelClient._async_is_alive) - execute_interactive = run_sync(KernelClient._async_execute_interactive) - stop_channels = run_sync(KernelClient._async_stop_channels) - channels_running = property(run_sync(KernelClient._async_channels_running)) + + # replies come on the control channel + shutdown = run_sync(reqrep(KernelClient._async_shutdown, channel='control')) diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index 5c14f627f..c5ff3791a 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -47,7 +47,7 @@ class HBChannel(Thread): _pause = None _beating = None - def __init__(self, context=None, session=None, address=None, loop=None): + def __init__(self, context=None, session=None, address=None): """Create the heartbeat monitor thread. Parameters @@ -62,8 +62,6 @@ def __init__(self, context=None, session=None, address=None, loop=None): super().__init__() self.daemon = True - self.loop = loop - self.context = context self.session = session if isinstance(address, tuple): @@ -93,6 +91,12 @@ def _create_socket(self): # close previous socket, before opening a new one self.poller.unregister(self.socket) self.socket.close() + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self.socket = self.context.socket(zmq.REQ) self.socket.linger = 1000 self.socket.connect(self.address) @@ -134,8 +138,6 @@ def _poll(self, start_time): def run(self): """The thread's main activity. Call start() instead.""" - if self.loop is not None: - asyncio.set_event_loop(self.loop) self._create_socket() self._running = True self._beating = True diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 84652dc3e..56ccfc5f8 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -51,7 +51,7 @@ async def wrapped(self, *args, **kwargs): if not reply: return msg_id - return await self._async__recv_reply(msg_id, timeout=timeout, channel=channel) + return await self._async_recv_reply(msg_id, timeout=timeout, channel=channel) if not meth.__doc__: # python -OO removes docstrings, @@ -199,7 +199,7 @@ async def _async_wait_for_ready(self, timeout=None): except Empty: break - async def _async__recv_reply(self, msg_id, timeout=None, channel='shell'): + async def _async_recv_reply(self, msg_id, timeout=None, channel='shell'): """Receive and return the reply for a given request""" if timeout is not None: deadline = time.monotonic() + timeout @@ -359,23 +359,11 @@ def hb_channel(self): if self._hb_channel is None: url = self._make_url('hb') self.log.debug("connecting heartbeat channel to %s", url) - loop = asyncio.new_event_loop() self._hb_channel = self.hb_channel_class( - self.context, self.session, url, loop + self.context, self.session, url ) return self._hb_channel - #@property - #def hb_channel(self): - # """Get the hb channel object for this kernel.""" - # if self._hb_channel is None: - # url = self._make_url('hb') - # self.log.debug("connecting heartbeat channel to %s", url) - # self._hb_channel = self.hb_channel_class( - # self.context, self.session, url - # ) - # return self._hb_channel - @property def control_channel(self): """Get the control channel object for this kernel.""" @@ -481,7 +469,7 @@ async def _async_execute_interactive(self, code, silent=False, store_history=Tru if output_hook is None: # detect IPython kernel if 'IPython' in sys.modules: - from IPython import get_ipython + from IPython import get_ipython # type: ignore ip = get_ipython() in_kernel = getattr(ip, 'kernel', False) if in_kernel: @@ -540,7 +528,7 @@ async def _async_execute_interactive(self, code, silent=False, store_history=Tru # output is done, get the reply if timeout is not None: timeout = max(0, deadline - time.monotonic()) - return await self._async__recv_reply(msg_id, timeout=timeout) + return await self._async_recv_reply(msg_id, timeout=timeout) # Methods to send specific messages on channels diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index a66ecd576..1577ccde1 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -23,7 +23,7 @@ Any, Float, Instance, Unicode, List, Bool, Type, DottedObjectName, default, observe, observe_compat ) -from traitlets.utils.importstring import import_item +from traitlets.utils.importstring import import_item # type: ignore from jupyter_client import ( launch_kernel, kernelspec, @@ -33,7 +33,7 @@ from .managerabc import ( KernelManagerABC ) -from .util import run_sync +from .util import run_sync, ensure_async class _ShutdownStatus(Enum): """ @@ -272,7 +272,7 @@ def from_ns(match): return [pat.sub(from_ns, arg) for arg in cmd] - async def _async__launch_kernel( + async def _async_launch_kernel( self, kernel_cmd: t.List[str], **kw @@ -283,7 +283,7 @@ async def _async__launch_kernel( """ return launch_kernel(kernel_cmd, **kw) - _launch_kernel = run_sync(_async__launch_kernel) + _launch_kernel = run_sync(_async_launch_kernel) # Control socket used for polite kernel shutdown @@ -380,7 +380,7 @@ async def _async_start_kernel(self, **kw): # launch the kernel subprocess self.log.debug("Starting kernel: %s", kernel_cmd) - self.kernel = await self._async__launch_kernel(kernel_cmd, **kw) + self.kernel = await ensure_async(self._launch_kernel(kernel_cmd, **kw)) self.post_start_kernel(**kw) start_kernel = run_sync(_async_start_kernel) @@ -417,7 +417,7 @@ async def _async_finish_shutdown( except asyncio.TimeoutError: self.log.debug("Kernel is taking too long to finish, terminating") self._shutdown_status = _ShutdownStatus.SigtermRequest - await self._async__send_kernel_sigterm() + await self._async_send_kernel_sigterm() try: await asyncio.wait_for( @@ -426,7 +426,7 @@ async def _async_finish_shutdown( except asyncio.TimeoutError: self.log.debug("Kernel is taking too long to finish, killing") self._shutdown_status = _ShutdownStatus.SigkillRequest - await self._async__kill_kernel() + await ensure_async(self._kill_kernel()) else: # Process is no longer alive, wait and clear if self.kernel is not None: @@ -485,16 +485,16 @@ async def _async_shutdown_kernel( # Stop monitoring for restarting while we shutdown. self.stop_restarter() - await self._async_interrupt_kernel() + await ensure_async(self.interrupt_kernel()) if now: - await self._async__kill_kernel() + await ensure_async(self._kill_kernel()) else: self.request_shutdown(restart=restart) # Don't send any additional kernel kill messages immediately, to give # the kernel a chance to properly execute shutdown actions. Wait for at # most 1s, checking every 0.1s. - await self._async_finish_shutdown() + await ensure_async(self.finish_shutdown()) # In 6.1.5, a new method, cleanup_resources(), was introduced to address # a leak issue (https://github.com/jupyter/jupyter_client/pull/548) and @@ -554,14 +554,14 @@ async def _async_restart_kernel( "No previous call to 'start_kernel'.") else: # Stop currently running kernel. - await self._async_shutdown_kernel(now=now, restart=True) + await ensure_async(self.shutdown_kernel(now=now, restart=True)) if newports: self.cleanup_random_ports() # Start new kernel. self._launch_args.update(kw) - await self._async_start_kernel(**self._launch_args) + await ensure_async(self.start_kernel(**self._launch_args)) restart_kernel = run_sync(_async_restart_kernel) @@ -570,7 +570,7 @@ def has_kernel(self) -> bool: """Has a kernel been started that we are managing.""" return self.kernel is not None - async def _async__send_kernel_sigterm(self) -> None: + async def _async_send_kernel_sigterm(self) -> None: """similar to _kill_kernel, but with sigterm (not sigkill), but do not block""" if self.has_kernel: # Signal the kernel to terminate (sends SIGTERM on Unix and @@ -600,9 +600,9 @@ async def _async__send_kernel_sigterm(self) -> None: if e.errno != ESRCH: raise - _send_kernel_sigterm = run_sync(_async__send_kernel_sigterm) + _send_kernel_sigterm = run_sync(_async_send_kernel_sigterm) - async def _async__kill_kernel(self) -> None: + async def _async_kill_kernel(self) -> None: """Kill the running kernel. This is a private method, callers should use shutdown_kernel(now=True). @@ -641,7 +641,7 @@ async def _async__kill_kernel(self) -> None: self.kernel.wait() self.kernel = None - _kill_kernel = run_sync(_async__kill_kernel) + _kill_kernel = run_sync(_async_kill_kernel) async def _async_interrupt_kernel(self) -> None: """Interrupts the kernel by sending it a signal. @@ -723,13 +723,13 @@ class AsyncKernelManager(KernelManager): client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') - _launch_kernel = KernelManager._async__launch_kernel + _launch_kernel = KernelManager._async_launch_kernel start_kernel = KernelManager._async_start_kernel finish_shutdown = KernelManager._async_finish_shutdown shutdown_kernel = KernelManager._async_shutdown_kernel restart_kernel = KernelManager._async_restart_kernel - _send_kernel_sigterm = KernelManager._async__send_kernel_sigterm - _kill_kernel = KernelManager._async__kill_kernel + _send_kernel_sigterm = KernelManager._async_send_kernel_sigterm + _kill_kernel = KernelManager._async_kill_kernel interrupt_kernel = KernelManager._async_interrupt_kernel signal_kernel = KernelManager._async_signal_kernel is_alive = KernelManager._async_is_alive diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index 5907d7bfa..401c3a84b 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -7,26 +7,35 @@ import os import uuid import socket +import typing as t import zmq -from traitlets.config.configurable import LoggingConfigurable -from traitlets.utils.importstring import import_item -from traitlets import ( +from traitlets.config.configurable import LoggingConfigurable # type: ignore +from traitlets.utils.importstring import import_item # type: ignore +from traitlets import ( # type: ignore Any, Bool, Dict, DottedObjectName, Instance, Unicode, default, observe ) from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager -from .manager import KernelManager, AsyncKernelManager +from .manager import KernelManager +from .util import run_sync, ensure_async class DuplicateKernelError(Exception): pass -def kernel_method(f): +def kernel_method( + f: t.Callable +) -> t.Callable: """decorator for proxying MKM.method(kernel_id) to individual KMs by ID""" - def wrapped(self, kernel_id, *args, **kwargs): + def wrapped( + self, + kernel_id: str, + *args, + **kwargs + ) -> t.Union[t.Callable, t.Awaitable]: # get the kernel km = self.get_kernel(kernel_id) method = getattr(km, f.__name__) @@ -72,10 +81,10 @@ def _kernel_manager_class_changed(self, change): def _kernel_manager_factory_default(self): return self._create_kernel_manager_factory() - def _create_kernel_manager_factory(self): + def _create_kernel_manager_factory(self) -> t.Callable: kernel_manager_ctor = import_item(self.kernel_manager_class) - def create_kernel_manager(*args, **kwargs): + def create_kernel_manager(*args, **kwargs) -> KernelManager: if self.shared_context: if self.context.closed: # recreate context if closed @@ -94,7 +103,10 @@ def create_kernel_manager(*args, **kwargs): return create_kernel_manager - def _find_available_port(self, ip): + def _find_available_port( + self, + ip: str + ) -> int: while True: tmp_sock = socket.socket() tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8) @@ -119,8 +131,10 @@ def _find_available_port(self, ip): context = Instance('zmq.Context') + _starting_kernels = Dict() + @default("context") - def _context_default(self): + def _context_default(self) -> zmq.Context: self._created_context = True return zmq.Context() @@ -140,20 +154,24 @@ def __del__(self): _kernels = Dict() - def list_kernel_ids(self): + def list_kernel_ids(self) -> t.List[str]: """Return a list of the kernel ids of the active kernels.""" # Create a copy so we can iterate over kernels in operations # that delete keys. return list(self._kernels.keys()) - def __len__(self): + def __len__(self) -> int: """Return the number of running kernels.""" return len(self.list_kernel_ids()) - def __contains__(self, kernel_id): + def __contains__(self, kernel_id) -> bool: return kernel_id in self._kernels - def pre_start_kernel(self, kernel_name, kwargs): + def pre_start_kernel( + self, + kernel_name: t.Optional[str], + kwargs + ) -> t.Tuple[KernelManager, str, str]: # kwargs should be mutable, passing it as a dict argument. kernel_id = kwargs.pop('kernel_id', self.new_kernel_id(**kwargs)) if kernel_id in self: @@ -174,7 +192,20 @@ def pre_start_kernel(self, kernel_name, kwargs): ) return km, kernel_name, kernel_id - def start_kernel(self, kernel_name=None, **kwargs): + async def _add_kernel_when_ready( + self, + kernel_id: str, + km: KernelManager, + kernel_awaitable: t.Awaitable + ) -> None: + await kernel_awaitable + self._kernels[kernel_id] = km + + async def _async_start_kernel( + self, + kernel_name: t.Optional[str] = None, + **kwargs + ) -> str: """Start a new kernel. The caller can pick a kernel_id by passing one in as a keyword arg, @@ -183,11 +214,29 @@ def start_kernel(self, kernel_name=None, **kwargs): The kernel ID for the newly started kernel is returned. """ km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs) - km.start_kernel(**kwargs) - self._kernels[kernel_id] = km + if not isinstance(km, KernelManager): + self.log.warning("Kernel manager class ({km_class}) is not an instance of 'KernelManager'!". + format(km_class=self.kernel_manager_class.__class__)) + fut = asyncio.ensure_future( + self._add_kernel_when_ready( + kernel_id, + km, + km._async_start_kernel(**kwargs) + ) + ) + self._starting_kernels[kernel_id] = fut + await fut + del self._starting_kernels[kernel_id] return kernel_id - def shutdown_kernel(self, kernel_id, now=False, restart=False): + start_kernel = run_sync(_async_start_kernel) + + async def _async_shutdown_kernel( + self, + kernel_id: str, + now: t.Optional[bool] = False, + restart: t.Optional[bool] = False + ) -> None: """Shutdown a kernel by its kernel uuid. Parameters @@ -208,32 +257,54 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False): km.hb_port, km.control_port ) - km.shutdown_kernel(now=now, restart=restart) + await ensure_async(km.shutdown_kernel(now, restart)) self.remove_kernel(kernel_id) if km.cache_ports and not restart: for port in ports: self.currently_used_ports.remove(port) + shutdown_kernel = run_sync(_async_shutdown_kernel) + @kernel_method - def request_shutdown(self, kernel_id, restart=False): + def request_shutdown( + self, + kernel_id: str, + restart: t.Optional[bool] = False + ) -> None: """Ask a kernel to shut down by its kernel uuid""" @kernel_method - def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1): + def finish_shutdown( + self, + kernel_id: str, + waittime: t.Optional[float] = None, + pollinterval: t.Optional[float] = 0.1 + ) -> None: """Wait for a kernel to finish shutting down, and kill it if it doesn't """ self.log.info("Kernel shutdown: %s" % kernel_id) @kernel_method - def cleanup(self, kernel_id, connection_file=True): + def cleanup( + self, + kernel_id: str, + connection_file: bool = True + ) -> None: """Clean up a kernel's resources""" @kernel_method - def cleanup_resources(self, kernel_id, restart=False): + def cleanup_resources( + self, + kernel_id: str, + restart: bool = False + ) -> None: """Clean up a kernel's resources""" - def remove_kernel(self, kernel_id): + def remove_kernel( + self, + kernel_id: str + ) -> KernelManager: """remove a kernel from our mapping. Mainly so that a kernel can be removed if it is already dead, @@ -243,29 +314,35 @@ def remove_kernel(self, kernel_id): """ return self._kernels.pop(kernel_id) - def shutdown_all(self, now=False): + async def _shutdown_starting_kernel( + self, + kid: str, + now: bool + ) -> None: + if kid in self._starting_kernels: + await self._starting_kernels[kid] + await ensure_async(self.shutdown_kernel(kid, now=now)) + + async def _async_shutdown_all( + self, + now: bool = False + ) -> None: """Shutdown all kernels.""" kids = self.list_kernel_ids() - for kid in kids: - self.request_shutdown(kid) - for kid in kids: - self.finish_shutdown(kid) - - # Determine which cleanup method to call - # See comment in KernelManager.shutdown_kernel(). - km = self.get_kernel(kid) - overrides_cleanup = type(km).cleanup is not KernelManager.cleanup - overrides_cleanup_resources = type(km).cleanup_resources is not KernelManager.cleanup_resources - - if overrides_cleanup and not overrides_cleanup_resources: - km.cleanup(connection_file=True) - else: - km.cleanup_resources(restart=False) + futs = [ensure_async(self.shutdown_kernel(kid, now=now)) for kid in kids] + futs += [ + self._shutdown_starting_kernel(kid, now=now) + for kid in self._starting_kernels.keys() + ] + await asyncio.gather(*futs) - self.remove_kernel(kid) + shutdown_all = run_sync(_async_shutdown_all) @kernel_method - def interrupt_kernel(self, kernel_id): + def interrupt_kernel( + self, + kernel_id: str + ) -> None: """Interrupt (SIGINT) the kernel by its uuid. Parameters @@ -276,7 +353,11 @@ def interrupt_kernel(self, kernel_id): self.log.info("Kernel interrupted: %s" % kernel_id) @kernel_method - def signal_kernel(self, kernel_id, signum): + def signal_kernel( + self, + kernel_id: str, + signum: int + ) -> None: """Sends a signal to the kernel by its uuid. Note that since only SIGTERM is supported on Windows, this function @@ -290,7 +371,11 @@ def signal_kernel(self, kernel_id, signum): self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum)) @kernel_method - def restart_kernel(self, kernel_id, now=False): + def restart_kernel( + self, + kernel_id: str, + now: bool = False + ) -> None: """Restart a kernel by its uuid, keeping the same ports. Parameters @@ -301,7 +386,10 @@ def restart_kernel(self, kernel_id, now=False): self.log.info("Kernel restarted: %s" % kernel_id) @kernel_method - def is_alive(self, kernel_id): + def is_alive( + self, + kernel_id: str + ) -> bool: """Is the kernel alive. This calls KernelManager.is_alive() which calls Popen.poll on the @@ -313,12 +401,18 @@ def is_alive(self, kernel_id): The id of the kernel. """ - def _check_kernel_id(self, kernel_id): + def _check_kernel_id( + self, + kernel_id: str + ) -> None: """check that a kernel id is valid""" if kernel_id not in self: raise KeyError("Kernel with id not found: %s" % kernel_id) - def get_kernel(self, kernel_id): + def get_kernel( + self, + kernel_id: str + ) -> KernelManager: """Get the single KernelManager object for a kernel by its uuid. Parameters @@ -330,15 +424,28 @@ def get_kernel(self, kernel_id): return self._kernels[kernel_id] @kernel_method - def add_restart_callback(self, kernel_id, callback, event='restart'): + def add_restart_callback( + self, + kernel_id: str, + callback: t.Callable, + event: str = 'restart' + ) -> None: """add a callback for the KernelRestarter""" @kernel_method - def remove_restart_callback(self, kernel_id, callback, event='restart'): + def remove_restart_callback( + self, + kernel_id: str, + callback: t.Callable, + event: str = 'restart' + ) -> None: """remove a callback for the KernelRestarter""" @kernel_method - def get_connection_info(self, kernel_id): + def get_connection_info( + self, + kernel_id: str + ) -> t.Dict[str, t.Any]: """Return a dictionary of connection data for a kernel. Parameters @@ -356,7 +463,11 @@ def get_connection_info(self, kernel_id): """ @kernel_method - def connect_iopub(self, kernel_id, identity=None): + def connect_iopub( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the iopub channel. Parameters @@ -372,7 +483,11 @@ def connect_iopub(self, kernel_id, identity=None): """ @kernel_method - def connect_shell(self, kernel_id, identity=None): + def connect_shell( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the shell channel. Parameters @@ -388,7 +503,11 @@ def connect_shell(self, kernel_id, identity=None): """ @kernel_method - def connect_control(self, kernel_id, identity=None): + def connect_control( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the control channel. Parameters @@ -404,7 +523,11 @@ def connect_control(self, kernel_id, identity=None): """ @kernel_method - def connect_stdin(self, kernel_id, identity=None): + def connect_stdin( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the stdin channel. Parameters @@ -420,7 +543,11 @@ def connect_stdin(self, kernel_id, identity=None): """ @kernel_method - def connect_hb(self, kernel_id, identity=None): + def connect_hb( + self, + kernel_id: str, + identity: t.Optional[bytes] = None + ) -> socket.socket: """Return a zmq Socket connected to the hb channel. Parameters @@ -435,7 +562,7 @@ def connect_hb(self, kernel_id, identity=None): stream : zmq Socket or ZMQStream """ - def new_kernel_id(self, **kwargs): + def new_kernel_id(self, **kwargs) -> str: """ Returns the id to associate with the kernel for this request. Subclasses may override this method to substitute other sources of kernel ids. @@ -454,121 +581,6 @@ class AsyncMultiKernelManager(MultiKernelManager): """ ) - _starting_kernels = Dict() - - async def _add_kernel_when_ready(self, kernel_id, km, kernel_awaitable): - await kernel_awaitable - self._kernels[kernel_id] = km - - async def start_kernel(self, kernel_name=None, **kwargs): - """Start a new kernel. - - The caller can pick a kernel_id by passing one in as a keyword arg, - otherwise one will be generated using new_kernel_id(). - - The kernel ID for the newly started kernel is returned. - """ - km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs) - if not isinstance(km, AsyncKernelManager): - self.log.warning("Kernel manager class ({km_class}) is not an instance of 'AsyncKernelManager'!". - format(km_class=self.kernel_manager_class.__class__)) - fut = asyncio.ensure_future( - self._add_kernel_when_ready( - kernel_id, - km, - km.start_kernel(**kwargs) - ) - ) - self._starting_kernels[kernel_id] = fut - await fut - del self._starting_kernels[kernel_id] - return kernel_id - - async def shutdown_kernel(self, kernel_id, now=False, restart=False): - """Shutdown a kernel by its kernel uuid. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to shutdown. - now : bool - Should the kernel be shutdown forcibly using a signal. - restart : bool - Will the kernel be restarted? - """ - self.log.info("Kernel shutdown: %s" % kernel_id) - - km = self.get_kernel(kernel_id) - - ports = ( - km.shell_port, km.iopub_port, km.stdin_port, - km.hb_port, km.control_port - ) - - await km.shutdown_kernel(now, restart) - self.remove_kernel(kernel_id) - - if km.cache_ports and not restart: - for port in ports: - self.currently_used_ports.remove(port) - - async def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1): - """Wait for a kernel to finish shutting down, and kill it if it doesn't - """ - km = self.get_kernel(kernel_id) - await km.finish_shutdown(waittime, pollinterval) - self.log.info("Kernel shutdown: %s" % kernel_id) - - async def interrupt_kernel(self, kernel_id): - """Interrupt (SIGINT) the kernel by its uuid. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to interrupt. - """ - km = self.get_kernel(kernel_id) - await km.interrupt_kernel() - self.log.info("Kernel interrupted: %s" % kernel_id) - - async def signal_kernel(self, kernel_id, signum): - """Sends a signal to the kernel by its uuid. - - Note that since only SIGTERM is supported on Windows, this function - is only useful on Unix systems. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to signal. - """ - km = self.get_kernel(kernel_id) - await km.signal_kernel(signum) - self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum)) - - async def restart_kernel(self, kernel_id, now=False): - """Restart a kernel by its uuid, keeping the same ports. - - Parameters - ========== - kernel_id : uuid - The id of the kernel to interrupt. - """ - km = self.get_kernel(kernel_id) - await km.restart_kernel(now) - self.log.info("Kernel restarted: %s" % kernel_id) - - async def _shutdown_starting_kernel(self, kid, now): - if kid in self._starting_kernels: - await self._starting_kernels[kid] - await self.shutdown_kernel(kid, now=now) - - async def shutdown_all(self, now=False): - """Shutdown all kernels.""" - kids = self.list_kernel_ids() - futs = [self.shutdown_kernel(kid, now=now) for kid in kids] - futs += [ - self._shutdown_starting_kernel(kid, now=now) - for kid in self._starting_kernels.keys() - ] - await asyncio.gather(*futs) + start_kernel = MultiKernelManager._async_start_kernel + shutdown_kernel = MultiKernelManager._async_shutdown_kernel + shutdown_all = MultiKernelManager._async_shutdown_all diff --git a/jupyter_client/tests/test_kernelapp.py b/jupyter_client/tests/test_kernelapp.py index 38ef50082..af28f814a 100644 --- a/jupyter_client/tests/test_kernelapp.py +++ b/jupyter_client/tests/test_kernelapp.py @@ -35,14 +35,21 @@ def test_kernelapp_lifecycle(): .format(WAIT_TIME)) # Connection file should be there by now - files = os.listdir(runtime_dir) + for _ in range(WAIT_TIME * POLL_FREQ): + files = os.listdir(runtime_dir) + if files: + break + time.sleep(1 / POLL_FREQ) + else: + raise AssertionError("No connection file created in {} seconds" + .format(WAIT_TIME)) assert len(files) == 1 cf = files[0] assert cf.startswith('kernel') assert cf.endswith('.json') # Send SIGTERM to shut down - time.sleep(0.2) + time.sleep(1) p.terminate() _, stderr = p.communicate(timeout=WAIT_TIME) assert cf in stderr.decode('utf-8', 'replace') diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index ce52d5f70..30b4d82da 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -10,10 +10,10 @@ import signal import sys import time -import threading -import multiprocessing as mp +import concurrent.futures import pytest +import nest_asyncio from async_generator import async_generator, yield_ from traitlets.config.loader import Config from jupyter_core import paths @@ -350,41 +350,36 @@ def test_start_parallel_thread_kernels(self, config, install_kernel): pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) - thread = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,)) - thread2 = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,)) - try: - thread.start() - thread2.start() - finally: - thread.join() - thread2.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor: + future1 = thread_executor.submit(self._run_signaltest_lifecycle, config) + future2 = thread_executor.submit(self._run_signaltest_lifecycle, config) + future1.result() + future2.result() @pytest.mark.timeout(TIMEOUT) + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') + @pytest.mark.skipif((sys.platform == 'win32') and (sys.version_info >= (3, 8)) and (sys.version_info < (3, 10)), reason='"Timeout" error') def test_start_parallel_process_kernels(self, config, install_kernel): if config.KernelManager.transport == 'ipc': # FIXME pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) - thread = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,)) - proc = mp.Process(target=self._run_signaltest_lifecycle, args=(config,)) - try: - thread.start() - proc.start() - finally: - thread.join() - proc.join() - - assert proc.exitcode == 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: + future1 = thread_executor.submit(self._run_signaltest_lifecycle, config) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor: + future2 = process_executor.submit(self._run_signaltest_lifecycle, config) + future2.result() + future1.result() @pytest.mark.timeout(TIMEOUT) + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 10)), reason='"Bad file descriptor" error') + @pytest.mark.skipif((sys.platform == 'win32') and (sys.version_info >= (3, 8)) and (sys.version_info < (3, 10)), reason='"Timeout" error') def test_start_sequence_process_kernels(self, config, install_kernel): + if config.KernelManager.transport == 'ipc': # FIXME + pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) - proc = mp.Process(target=self._run_signaltest_lifecycle, args=(config,)) - try: - proc.start() - finally: - proc.join() - - assert proc.exitcode == 0 + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool_executor: + future = pool_executor.submit(self._run_signaltest_lifecycle, config) + future.result() def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs): km.start_kernel(**kwargs) diff --git a/jupyter_client/tests/test_multikernelmanager.py b/jupyter_client/tests/test_multikernelmanager.py index e104a266f..47fa46e20 100644 --- a/jupyter_client/tests/test_multikernelmanager.py +++ b/jupyter_client/tests/test_multikernelmanager.py @@ -1,10 +1,11 @@ """Tests for the notebook kernel and session manager.""" import asyncio -import threading +import concurrent.futures import uuid -import multiprocessing as mp +import sys +import pytest from subprocess import PIPE from unittest import TestCase from tornado.testing import AsyncTestCase, gen_test @@ -134,30 +135,23 @@ def tcp_lifecycle_with_loop(self): def test_start_parallel_thread_kernels(self): self.test_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - thread2 = threading.Thread(target=self.tcp_lifecycle_with_loop) - try: - thread.start() - thread2.start() - finally: - thread.join() - thread2.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future2 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future1.result() + future2.result() + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') def test_start_parallel_process_kernels(self): self.test_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - # Windows tests needs this target to be picklable: - proc = mp.Process(target=self.test_tcp_lifecycle) - - try: - thread.start() - proc.start() - finally: - thread.join() - proc.join() - - assert proc.exitcode == 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor: + # Windows tests needs this target to be picklable: + future2 = process_executor.submit(self.test_tcp_lifecycle) + future2.result() + future1.result() def test_subclass_callables(self): km = self._get_tcp_km_sub() @@ -389,31 +383,23 @@ def raw_tcp_lifecycle_sync(cls, test_kid=None): async def test_start_parallel_thread_kernels(self): await self.raw_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - thread2 = threading.Thread(target=self.tcp_lifecycle_with_loop) - try: - thread.start() - thread2.start() - finally: - thread.join() - thread2.join() + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future2 = thread_executor.submit(self.tcp_lifecycle_with_loop) + future1.result() + future2.result() @gen_test async def test_start_parallel_process_kernels(self): await self.raw_tcp_lifecycle() - thread = threading.Thread(target=self.tcp_lifecycle_with_loop) - # Windows tests needs this target to be picklable: - proc = mp.Process(target=self.raw_tcp_lifecycle_sync) - - try: - thread.start() - proc.start() - finally: - proc.join() - thread.join() - - assert proc.exitcode == 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: + future1 = thread_executor.submit(self.tcp_lifecycle_with_loop) + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as process_executor: + # Windows tests needs this target to be picklable: + future2 = process_executor.submit(self.raw_tcp_lifecycle_sync) + future2.result() + future1.result() @gen_test async def test_subclass_callables(self): diff --git a/jupyter_client/util.py b/jupyter_client/util.py index bfbb8e968..6640ed111 100644 --- a/jupyter_client/util.py +++ b/jupyter_client/util.py @@ -1,12 +1,20 @@ +import os +import sys import asyncio import inspect import nest_asyncio -nest_asyncio.apply() -loop = asyncio.get_event_loop() +if os.name == 'nt' and sys.version_info >= (3, 7): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) def run_sync(coro): def wrapped(*args, **kwargs): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + nest_asyncio.apply(loop) return loop.run_until_complete(coro(*args, **kwargs)) wrapped.__doc__ = coro.__doc__ return wrapped From 7951b238bd38e691d4771065a77e6792ea64bb85 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 17 Mar 2021 16:52:33 +0100 Subject: [PATCH 6/9] Add more types --- .github/workflows/main.yml | 2 +- jupyter_client/adapter.py | 163 ++++++++++++---- jupyter_client/asynchronous/client.py | 3 +- jupyter_client/blocking/client.py | 3 +- jupyter_client/channels.py | 108 +++++++---- jupyter_client/client.py | 162 +++++++++++----- jupyter_client/connect.py | 118 ++++++++---- jupyter_client/consoleapp.py | 72 +++---- jupyter_client/jsonutil.py | 5 +- jupyter_client/kernelapp.py | 17 +- jupyter_client/manager.py | 2 +- jupyter_client/multikernelmanager.py | 2 +- jupyter_client/session.py | 181 ++++++++++++------ jupyter_client/tests/test_kernelmanager.py | 8 +- .../tests/test_multikernelmanager.py | 14 +- 15 files changed, 586 insertions(+), 274 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fc05d02dd..0faa14717 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' pip freeze - name: Check types - run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py + run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py - name: Run the tests run: py.test --cov jupyter_client -v jupyter_client - name: Code coverage diff --git a/jupyter_client/adapter.py b/jupyter_client/adapter.py index 94109dceb..e4e09a54c 100644 --- a/jupyter_client/adapter.py +++ b/jupyter_client/adapter.py @@ -5,10 +5,14 @@ import re import json +from typing import List, Tuple, Dict, Any from jupyter_client import protocol_version_info -def code_to_line(code, cursor_pos): +def code_to_line( + code: str, + cursor_pos: int +) -> Tuple[str, int]: """Turn a multiline code block and cursor position into a single line and new cursor position. @@ -29,14 +33,17 @@ def code_to_line(code, cursor_pos): _end_bracket = re.compile(r'\([^\(]*$', re.UNICODE) _identifier = re.compile(r'[a-z_][0-9a-z._]*', re.I|re.UNICODE) -def extract_oname_v4(code, cursor_pos): +def extract_oname_v4( + code: str, + cursor_pos: int +) -> str: """Reimplement token-finding logic from IPython 2.x javascript - + for adapting object_info_request from v5 to v4 """ - + line, _ = code_to_line(code, cursor_pos) - + oldline = line line = _match_bracket.sub('', line) while oldline != line: @@ -58,29 +65,44 @@ class Adapter(object): Override message_type(msg) methods to create adapters. """ - msg_type_map = {} + msg_type_map: Dict[str, str] = {} - def update_header(self, msg): + def update_header( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: return msg - def update_metadata(self, msg): + def update_metadata( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: return msg - def update_msg_type(self, msg): + def update_msg_type( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: header = msg['header'] msg_type = header['msg_type'] if msg_type in self.msg_type_map: msg['msg_type'] = header['msg_type'] = self.msg_type_map[msg_type] return msg - def handle_reply_status_error(self, msg): + def handle_reply_status_error( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: """This will be called *instead of* the regular handler on any reply with status != ok """ return msg - def __call__(self, msg): + def __call__( + self, + msg: Dict[str, Any] + ): msg = self.update_header(msg) msg = self.update_metadata(msg) msg = self.update_msg_type(msg) @@ -95,7 +117,9 @@ def __call__(self, msg): return self.handle_reply_status_error(msg) return handler(msg) -def _version_str_to_list(version): +def _version_str_to_list( + version: str +) -> List[int]: """convert a version string to a list of ints non-int segments are excluded @@ -121,14 +145,20 @@ class V5toV4(Adapter): 'inspect_reply' : 'object_info_reply', } - def update_header(self, msg): + def update_header( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['header'].pop('version', None) msg['parent_header'].pop('version', None) return msg # shell channel - def kernel_info_reply(self, msg): + def kernel_info_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: v4c = {} content = msg['content'] for key in ('language_version', 'protocol_version'): @@ -145,18 +175,27 @@ def kernel_info_reply(self, msg): msg['content'] = v4c return msg - def execute_request(self, msg): + def execute_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.setdefault('user_variables', []) return msg - def execute_reply(self, msg): + def execute_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.setdefault('user_variables', {}) # TODO: handle payloads return msg - def complete_request(self, msg): + def complete_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] code = content['code'] cursor_pos = content['cursor_pos'] @@ -169,7 +208,10 @@ def complete_request(self, msg): new_content['cursor_pos'] = cursor_pos return msg - def complete_reply(self, msg): + def complete_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] cursor_start = content.pop('cursor_start') cursor_end = content.pop('cursor_end') @@ -178,7 +220,10 @@ def complete_reply(self, msg): content.pop('metadata', None) return msg - def object_info_request(self, msg): + def object_info_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] code = content['code'] cursor_pos = content['cursor_pos'] @@ -189,19 +234,28 @@ def object_info_request(self, msg): new_content['detail_level'] = content['detail_level'] return msg - def object_info_reply(self, msg): + def object_info_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: """inspect_reply can't be easily backward compatible""" msg['content'] = {'found' : False, 'oname' : 'unknown'} return msg # iopub channel - def stream(self, msg): + def stream( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content['data'] = content.pop('text') return msg - def display_data(self, msg): + def display_data( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.setdefault("source", "display") data = content['data'] @@ -215,7 +269,10 @@ def display_data(self, msg): # stdin channel - def input_request(self, msg): + def input_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['content'].pop('password', None) return msg @@ -227,7 +284,10 @@ class V4toV5(Adapter): # invert message renames above msg_type_map = {v:k for k,v in V5toV4.msg_type_map.items()} - def update_header(self, msg): + def update_header( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['header']['version'] = self.version if msg['parent_header']: msg['parent_header']['version'] = self.version @@ -235,7 +295,10 @@ def update_header(self, msg): # shell channel - def kernel_info_reply(self, msg): + def kernel_info_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] for key in ('protocol_version', 'ipython_version'): if key in content: @@ -257,7 +320,10 @@ def kernel_info_reply(self, msg): content['banner'] = '' return msg - def execute_request(self, msg): + def execute_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] user_variables = content.pop('user_variables', []) user_expressions = content.setdefault('user_expressions', {}) @@ -265,7 +331,10 @@ def execute_request(self, msg): user_expressions[v] = v return msg - def execute_reply(self, msg): + def execute_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] user_expressions = content.setdefault('user_expressions', {}) user_variables = content.pop('user_variables', {}) @@ -281,7 +350,10 @@ def execute_reply(self, msg): return msg - def complete_request(self, msg): + def complete_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: old_content = msg['content'] new_content = msg['content'] = {} @@ -289,7 +361,10 @@ def complete_request(self, msg): new_content['cursor_pos'] = old_content['cursor_pos'] return msg - def complete_reply(self, msg): + def complete_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: # complete_reply needs more context than we have to get cursor_start and end. # use special end=null to indicate current cursor position and negative offset # for start relative to the cursor. @@ -306,7 +381,10 @@ def complete_reply(self, msg): new_content['metadata'] = {} return msg - def inspect_request(self, msg): + def inspect_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] name = content['oname'] @@ -316,7 +394,10 @@ def inspect_request(self, msg): new_content['detail_level'] = content['detail_level'] return msg - def inspect_reply(self, msg): + def inspect_reply( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: """inspect_reply can't be easily backward compatible""" content = msg['content'] new_content = msg['content'] = {'status' : 'ok'} @@ -340,12 +421,18 @@ def inspect_reply(self, msg): # iopub channel - def stream(self, msg): + def stream( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content['text'] = content.pop('data') return msg - def display_data(self, msg): + def display_data( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: content = msg['content'] content.pop("source", None) data = content['data'] @@ -359,13 +446,19 @@ def display_data(self, msg): # stdin channel - def input_request(self, msg): + def input_request( + self, + msg: Dict[str, Any] + ) -> Dict[str, Any]: msg['content'].setdefault('password', False) return msg -def adapt(msg, to_version=protocol_version_info[0]): +def adapt( + msg: Dict[str, Any], + to_version: int =protocol_version_info[0] + ) -> Dict[str, Any]: """Adapt a single message to a target version Parameters diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 75b8baaf4..4d4038985 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -43,10 +43,9 @@ class AsyncKernelClient(KernelClient): inspect = reqrep(KernelClient._async_inspect) kernel_info = reqrep(KernelClient._async_kernel_info) comm_info = reqrep(KernelClient._async_comm_info) + is_alive = KernelClient._async_is_alive execute_interactive = KernelClient._async_execute_interactive - stop_channels = KernelClient._async_stop_channels - channels_running = property(KernelClient._async_channels_running) # replies come on the control channel shutdown = reqrep(KernelClient._async_shutdown, channel='control') diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 85ee6f84d..bc1b8651a 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -47,10 +47,9 @@ class BlockingKernelClient(KernelClient): inspect = run_sync(reqrep(KernelClient._async_inspect)) kernel_info = run_sync(reqrep(KernelClient._async_kernel_info)) comm_info = run_sync(reqrep(KernelClient._async_comm_info)) + is_alive = run_sync(KernelClient._async_is_alive) execute_interactive = run_sync(KernelClient._async_execute_interactive) - stop_channels = run_sync(KernelClient._async_stop_channels) - channels_running = property(run_sync(KernelClient._async_channels_running)) # replies come on the control channel shutdown = run_sync(reqrep(KernelClient._async_shutdown, channel='control')) diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index c5ff3791a..a94e959fa 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -9,8 +9,10 @@ import time import asyncio from queue import Empty +import typing as t import zmq +import zmq.asyncio # import ZMQError in top-level namespace, to avoid ugly attribute-error messages # during garbage collection of threads at exit: from zmq import ZMQError @@ -18,6 +20,7 @@ from jupyter_client import protocol_version_info from .channelsabc import HBChannelABC +from .session import Session #----------------------------------------------------------------------------- # Constants and exceptions @@ -35,24 +38,27 @@ class HBChannel(Thread): this channel, the kernel manager will ensure that it is paused and un-paused as appropriate. """ - context = None session = None socket = None address = None _exiting = False - time_to_dead = 1. - poller = None + time_to_dead: float = 1. _running = None _pause = None _beating = None - def __init__(self, context=None, session=None, address=None): + def __init__( + self, + context: zmq.asyncio.Context, + session: t.Optional[Session] = None, + address: t.Union[t.Tuple[str, int], str] = '' + ): """Create the heartbeat monitor thread. Parameters ---------- - context : :class:`zmq.Context` + context : :class:`zmq.asyncio.Context` The ZMQ context to use. session : :class:`session.Session` The session to use. @@ -68,8 +74,10 @@ def __init__(self, context=None, session=None, address=None): if address[1] == 0: message = 'The port number for a channel cannot be 0.' raise InvalidPortNumber(message) - address = "tcp://%s:%i" % address - self.address = address + address_str = "tcp://%s:%i" % address + else: + address_str = address + self.address = address_str # running is False until `.start()` is called self._running = False @@ -80,30 +88,27 @@ def __init__(self, context=None, session=None, address=None): @staticmethod @atexit.register - def _notice_exit(): + def _notice_exit() -> None: # Class definitions can be torn down during interpreter shutdown. # We only need to set _exiting flag if this hasn't happened. if HBChannel is not None: HBChannel._exiting = True - def _create_socket(self): + def _create_socket(self) -> None: if self.socket is not None: # close previous socket, before opening a new one self.poller.unregister(self.socket) self.socket.close() - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self.socket = self.context.socket(zmq.REQ) self.socket.linger = 1000 self.socket.connect(self.address) self.poller.register(self.socket, zmq.POLLIN) - def _poll(self, start_time): + def _poll( + self, + start_time: float + ) -> t.List[t.Any]: """poll for heartbeat replies until we reach self.time_to_dead. Ignores interrupts, and returns the result of poll(), which @@ -117,7 +122,7 @@ def _poll(self, start_time): events = [] while True: try: - events = self.poller.poll(1000 * until_dead) + events = self.poller.poll(int(1000 * until_dead)) except ZMQError as e: if e.errno == errno.EINTR: # ignore interrupts during heartbeat @@ -136,11 +141,17 @@ def _poll(self, start_time): break return events - def run(self): + def run(self) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._async_run()) + + async def _async_run(self) -> None: """The thread's main activity. Call start() instead.""" self._create_socket() self._running = True self._beating = True + assert self.socket is not None while self._running: if self._pause: @@ -151,13 +162,13 @@ def run(self): since_last_heartbeat = 0.0 # no need to catch EFSM here, because the previous event was # either a recv or connect, which cannot be followed by EFSM - self.socket.send(b'ping') + await self.socket.send(b'ping') request_time = time.time() ready = self._poll(request_time) if ready: self._beating = True # the poll above guarantees we have something to recv - self.socket.recv() + await self.socket.recv() # sleep the remainder of the cycle remainder = self.time_to_dead - (time.time() - request_time) if remainder > 0: @@ -172,29 +183,29 @@ def run(self): self._create_socket() continue - def pause(self): + def pause(self) -> None: """Pause the heartbeat.""" self._pause = True - def unpause(self): + def unpause(self) -> None: """Unpause the heartbeat.""" self._pause = False - def is_beating(self): + def is_beating(self) -> bool: """Is the heartbeat running and responsive (and not paused).""" if self.is_alive() and not self._pause and self._beating: return True else: return False - def stop(self): + def stop(self) -> None: """Stop the channel's event loop and join its thread.""" self._running = False self._exit.set() self.join() self.close() - def close(self): + def close(self) -> None: if self.socket is not None: try: self.socket.close(linger=0) @@ -202,7 +213,10 @@ def close(self): pass self.socket = None - def call_handlers(self, since_last_heartbeat): + def call_handlers( + self, + since_last_heartbeat: float + ) -> None: """This method is called in the ioloop thread when a message arrives. Subclasses should override this method to handle incoming messages. @@ -218,13 +232,13 @@ def call_handlers(self, since_last_heartbeat): class ZMQSocketChannel(object): """A ZMQ socket in an async API""" - session = None - socket = None - stream = None - _exiting = False - proxy_methods = [] - def __init__(self, socket, session, loop=None): + def __init__( + self, + socket: zmq.sugar.socket.Socket, + session: Session, + loop: t.Any = None + ) -> None: """Create a channel. Parameters @@ -238,18 +252,23 @@ def __init__(self, socket, session, loop=None): """ super().__init__() - self.socket = socket + self.socket: t.Optional[zmq.sugar.socket.Socket] = socket self.session = session - async def _recv(self, **kwargs): + async def _recv(self, **kwargs) -> t.Dict[str, t.Any]: + assert self.socket is not None msg = await self.socket.recv_multipart(**kwargs) ident, smsg = self.session.feed_identities(msg) return self.session.deserialize(smsg) - async def get_msg(self, timeout=None): + async def get_msg( + self, + timeout: t.Optional[float] = None + ) -> t.Dict[str, t.Any]: """ Gets a message if there is one that is ready. """ if timeout is not None: timeout *= 1000 # seconds to ms + assert self.socket is not None ready = await self.socket.poll(timeout) if ready: @@ -258,7 +277,7 @@ async def get_msg(self, timeout=None): else: raise Empty - async def get_msgs(self): + async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: """ Get all messages that are currently ready. """ msgs = [] while True: @@ -268,26 +287,31 @@ async def get_msgs(self): break return msgs - async def msg_ready(self): + async def msg_ready(self) -> bool: """ Is there a message that has been received? """ + assert self.socket is not None return bool(await self.socket.poll(timeout=0)) - def close(self): + def close(self) -> None: if self.socket is not None: try: self.socket.close(linger=0) except Exception: pass self.socket = None - stop = close + stop = close - def is_alive(self): + def is_alive(self) -> bool: return (self.socket is not None) - def send(self, msg): + def send( + self, + msg: t.Dict[str, t.Any] + ) -> None: """Pass a message to the ZMQ socket to send """ + assert self.socket is not None self.session.send(self.socket, msg) - def start(self): + def start(self) -> None: pass diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 56ccfc5f8..524ff8cb4 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -9,6 +9,7 @@ from functools import partial from getpass import getpass from queue import Empty +import socket import typing as t @@ -24,6 +25,7 @@ from .channelsabc import (ChannelABC, HBChannelABC) from .clientabc import KernelClientABC from .connect import ConnectionFileMixin +from .session import Session from .util import ensure_async @@ -43,8 +45,11 @@ def validate_string_dict( raise ValueError('value %r in dict must be a string' % v) -def reqrep(meth, channel='shell'): - async def wrapped(self, *args, **kwargs): +def reqrep( + meth: t.Callable, + channel: str = 'shell' +) -> t.Callable: + async def wrapped(self, *args, **kwargs) -> t.Union[str, t.Dict[str, t.Any]]: reply = kwargs.pop('reply', False) timeout = kwargs.pop('timeout', None) msg_id = await meth(self, *args, **kwargs) @@ -101,9 +106,7 @@ class KernelClient(ConnectionFileMixin): # The PyZMQ Context to use for communication with the kernel. context = Instance(zmq.asyncio.Context) - - @default('context') - def _context_default(self): + def _context_default(self) -> zmq.asyncio.Context: return zmq.asyncio.Context() # The classes to use for the various channels @@ -127,23 +130,26 @@ def _context_default(self): # Channel proxy methods #-------------------------------------------------------------------------- - async def _async_get_shell_msg(self, *args, **kwargs): + async def _async_get_shell_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the shell channel""" return await self.shell_channel.get_msg(*args, **kwargs) - async def _async_get_iopub_msg(self, *args, **kwargs): + async def _async_get_iopub_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the iopub channel""" return await self.iopub_channel.get_msg(*args, **kwargs) - async def _async_get_stdin_msg(self, *args, **kwargs): + async def _async_get_stdin_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the stdin channel""" return await self.stdin_channel.get_msg(*args, **kwargs) - async def _async_get_control_msg(self, *args, **kwargs): + async def _async_get_control_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the control channel""" return await self.control_channel.get_msg(*args, **kwargs) - async def _async_wait_for_ready(self, timeout=None): + async def _async_wait_for_ready( + self, + timeout: t.Optional[float] = None + ) -> None: """Waits for a response when a client is blocked - Sets future time for timeout @@ -153,9 +159,8 @@ async def _async_wait_for_ready(self, timeout=None): - Flush the IOPub channel """ if timeout is None: - abs_timeout = float('inf') - else: - abs_timeout = time.time() + timeout + timeout = float('inf') + abs_timeout = time.time() + timeout from .manager import KernelManager if not isinstance(self.parent, KernelManager): @@ -199,7 +204,12 @@ async def _async_wait_for_ready(self, timeout=None): except Empty: break - async def _async_recv_reply(self, msg_id, timeout=None, channel='shell'): + async def _async_recv_reply( + self, + msg_id: str, + timeout: t.Optional[float] = None, + channel: str = 'shell' + ) -> t.Dict[str, t.Any]: """Receive and return the reply for a given request""" if timeout is not None: deadline = time.monotonic() + timeout @@ -219,13 +229,16 @@ async def _async_recv_reply(self, msg_id, timeout=None, channel='shell'): return reply - def _stdin_hook_default(self, msg): + def _stdin_hook_default( + self, + msg: t.Dict[str, t.Any] + ) -> None: """Handle an input request""" content = msg['content'] if content.get('password', False): prompt = getpass else: - prompt = input + prompt = input # type: ignore try: raw_data = prompt(content["prompt"]) @@ -241,7 +254,10 @@ def _stdin_hook_default(self, msg): if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): self.input(raw_data) - def _output_hook_default(self, msg): + def _output_hook_default( + self, + msg: t.Dict[str, t.Any] + ) -> None: """Default hook for redisplaying plain-text output""" msg_type = msg['header']['msg_type'] content = msg['content'] @@ -253,7 +269,13 @@ def _output_hook_default(self, msg): elif msg_type == 'error': print('\n'.join(content['traceback']), file=sys.stderr) - def _output_hook_kernel(self, session, socket, parent_header, msg): + def _output_hook_kernel( + self, + session: Session, + socket: zmq.sugar.socket.Socket, + parent_header, + msg: t.Dict[str, t.Any] + ) -> None: """Output hook when running inside an IPython kernel adds rich output support. @@ -269,7 +291,14 @@ def _output_hook_kernel(self, session, socket, parent_header, msg): # Channel management methods #-------------------------------------------------------------------------- - def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): + def start_channels( + self, + shell: bool = True, + iopub: bool = True, + stdin: bool = True, + hb: bool = True, + control: bool = True + ) -> None: """Starts the channels for this kernel. This will create the channels if they do not exist and then start @@ -292,7 +321,7 @@ def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=Tr if control: self.control_channel.start() - async def _async_stop_channels(self): + def stop_channels(self) -> None: """Stops all the running channels for this kernel. This stops their event loops and joins their threads. @@ -308,8 +337,8 @@ async def _async_stop_channels(self): if self.control_channel.is_alive(): self.control_channel.stop() - #@property - async def _async_channels_running(self): + @property + def channels_running(self) -> bool: """Are any of the channels created and running?""" return (self.shell_channel.is_alive() or self.iopub_channel.is_alive() or self.stdin_channel.is_alive() or self.hb_channel.is_alive() or @@ -318,7 +347,7 @@ async def _async_channels_running(self): ioloop = None # Overridden in subclasses that use pyzmq event loop @property - def shell_channel(self): + def shell_channel(self) -> t.Any: """Get the shell channel object for this kernel.""" if self._shell_channel is None: url = self._make_url('shell') @@ -330,7 +359,7 @@ def shell_channel(self): return self._shell_channel @property - def iopub_channel(self): + def iopub_channel(self) -> t.Any: """Get the iopub channel object for this kernel.""" if self._iopub_channel is None: url = self._make_url('iopub') @@ -342,7 +371,7 @@ def iopub_channel(self): return self._iopub_channel @property - def stdin_channel(self): + def stdin_channel(self) -> t.Any: """Get the stdin channel object for this kernel.""" if self._stdin_channel is None: url = self._make_url('stdin') @@ -354,7 +383,7 @@ def stdin_channel(self): return self._stdin_channel @property - def hb_channel(self): + def hb_channel(self) -> t.Any: """Get the hb channel object for this kernel.""" if self._hb_channel is None: url = self._make_url('hb') @@ -365,7 +394,7 @@ def hb_channel(self): return self._hb_channel @property - def control_channel(self): + def control_channel(self) -> t.Any: """Get the control channel object for this kernel.""" if self._control_channel is None: url = self._make_url('control') @@ -376,7 +405,7 @@ def control_channel(self): ) return self._control_channel - async def _async_is_alive(self): + async def _async_is_alive(self) -> bool: """Is the kernel process still running?""" from .manager import KernelManager, AsyncKernelManager if isinstance(self.parent, KernelManager): @@ -393,10 +422,18 @@ async def _async_is_alive(self): return True - async def _async_execute_interactive(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True, - timeout=None, output_hook=None, stdin_hook=None, - ): + async def _async_execute_interactive( + self, + code: str, + silent: bool = False, + store_history: bool = True, + user_expressions: t.Optional[t.Dict[str, t.Any]] = None, + allow_stdin: t.Optional[bool] = None, + stop_on_error: bool = True, + timeout: t.Optional[float] = None, + output_hook: t.Optional[t.Callable] = None, + stdin_hook: t.Optional[t.Callable] =None, + ) -> t.Dict[str, t.Any]: """Execute code in the kernel interactively Output will be redisplayed, and stdin prompts will be relayed as well. @@ -502,7 +539,7 @@ async def _async_execute_interactive(self, code, silent=False, store_history=Tru while True: if timeout is not None: timeout = max(0, deadline - time.monotonic()) - timeout_ms = 1e3 * timeout + timeout_ms = int(1000 * timeout) events = dict(poller.poll(timeout_ms)) if not events: raise TimeoutError("Timeout waiting for output") @@ -532,8 +569,15 @@ async def _async_execute_interactive(self, code, silent=False, store_history=Tru # Methods to send specific messages on channels - async def _async_execute(self, code, silent=False, store_history=True, - user_expressions=None, allow_stdin=None, stop_on_error=True): + async def _async_execute( + self, + code: str, + silent: bool = False, + store_history: bool = True, + user_expressions: t.Optional[t.Dict[str, t.Any]] = None, + allow_stdin: t.Optional[bool] = None, + stop_on_error: bool = True + ) -> str: """Execute code in the kernel. Parameters @@ -589,7 +633,11 @@ async def _async_execute(self, code, silent=False, store_history=True, self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_complete(self, code, cursor_pos=None): + async def _async_complete( + self, + code: str, + cursor_pos: t.Optional[int] = None + ) -> str: """Tab complete text in the kernel's namespace. Parameters @@ -612,7 +660,12 @@ async def _async_complete(self, code, cursor_pos=None): self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_inspect(self, code, cursor_pos=None, detail_level=0): + async def _async_inspect( + self, + code: str, + cursor_pos: t.Optional[int] = None, + detail_level: int = 0 + ) -> str: """Get metadata information about an object in the kernel's namespace. It is up to the kernel to determine the appropriate object to inspect. @@ -641,7 +694,13 @@ async def _async_inspect(self, code, cursor_pos=None, detail_level=0): self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_history(self, raw=True, output=False, hist_access_type='range', **kwargs): + async def _async_history( + self, + raw: bool = True, + output: bool = False, + hist_access_type: str = 'range', + **kwargs + ) -> str: """Get entries from the kernel's history list. Parameters @@ -682,7 +741,7 @@ async def _async_history(self, raw=True, output=False, hist_access_type='range', self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_kernel_info(self): + async def _async_kernel_info(self) -> str: """Request kernel info Returns @@ -693,7 +752,10 @@ async def _async_kernel_info(self): self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_comm_info(self, target_name=None): + async def _async_comm_info( + self, + target_name: t.Optional[str] = None + ) -> str: """Request comm info Returns @@ -708,7 +770,10 @@ async def _async_comm_info(self, target_name=None): self.shell_channel.send(msg) return msg['header']['msg_id'] - def _handle_kernel_info_reply(self, msg): + def _handle_kernel_info_reply( + self, + msg: t.Dict[str, t.Any] + ) -> None: """handle kernel info reply sets protocol adaptation version. This might @@ -718,13 +783,19 @@ def _handle_kernel_info_reply(self, msg): if adapt_version != major_protocol_version: self.session.adapt_version = adapt_version - def is_complete(self, code): + def is_complete( + self, + code: str + ) -> str: """Ask the kernel whether some code is complete and ready to execute.""" msg = self.session.msg('is_complete_request', {'code': code}) self.shell_channel.send(msg) return msg['header']['msg_id'] - def input(self, string): + def input( + self, + string: str + ) -> None: """Send a string of raw input to the kernel. This should only be called in response to the kernel sending an @@ -734,7 +805,10 @@ def input(self, string): msg = self.session.msg('input_reply', content) self.stdin_channel.send(msg) - async def _async_shutdown(self, restart=False): + async def _async_shutdown( + self, + restart: bool = False + ) -> str: """Request an immediate kernel shutdown on the control channel. Upon receipt of the (empty) reply, client code can safely assume that diff --git a/jupyter_client/connect.py b/jupyter_client/connect.py index e31f9ec1d..2f8e70352 100644 --- a/jupyter_client/connect.py +++ b/jupyter_client/connect.py @@ -17,23 +17,33 @@ import warnings from getpass import getpass from contextlib import contextmanager +from typing import Union, Optional, List, Tuple, Dict, Any, cast import zmq -from traitlets.config import LoggingConfigurable +from traitlets.config import LoggingConfigurable # type: ignore from .localinterfaces import localhost -from traitlets import ( +from traitlets import ( # type: ignore Bool, Integer, Unicode, CaselessStrEnum, Instance, Type, observe ) -from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write +from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write # type: ignore from .utils import _filefind -def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0, - control_port=0, ip='', key=b'', transport='tcp', - signature_scheme='hmac-sha256', kernel_name='' - ): +def write_connection_file( + fname: Optional[str] = None, + shell_port: int = 0, + iopub_port: int = 0, + stdin_port: int = 0, + hb_port: int = 0, + control_port: int = 0, + ip: str = '', + key: bytes = b'', + transport: str = 'tcp', + signature_scheme: str = 'hmac-sha256', + kernel_name: str = '' +) -> Tuple[str, Dict[str, Union[int, str]]]: """Generates a JSON config file, including the selection of random ports. Parameters @@ -83,7 +93,8 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, # Find open ports as necessary. - ports = [] + ports: List[int] = [] + sockets: List[socket.socket] = [] ports_needed = int(shell_port <= 0) + \ int(iopub_port <= 0) + \ int(stdin_port <= 0) + \ @@ -95,11 +106,11 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, # struct.pack('ii', (0,0)) is 8 null bytes sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8) sock.bind((ip, 0)) - ports.append(sock) - for i, sock in enumerate(ports): + sockets.append(sock) + for sock in sockets: port = sock.getsockname()[1] sock.close() - ports[i] = port + ports.append(port) else: N = 1 for i in range(ports_needed): @@ -118,7 +129,7 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, if hb_port <= 0: hb_port = ports.pop(0) - cfg = dict( shell_port=shell_port, + cfg: Dict[str, Union[int, str]] = dict( shell_port=shell_port, iopub_port=iopub_port, stdin_port=stdin_port, control_port=control_port, @@ -165,7 +176,11 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, return fname, cfg -def find_connection_file(filename='kernel-*.json', path=None, profile=None): +def find_connection_file( + filename: str ='kernel-*.json', + path: Optional[Union[str, List[str]]] = None, + profile: Optional[str] = None +) -> str: """find a connection file, and return its absolute path. The current working directory and optional search path @@ -222,7 +237,11 @@ def find_connection_file(filename='kernel-*.json', path=None, profile=None): return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1] -def tunnel_to_kernel(connection_info, sshserver, sshkey=None): +def tunnel_to_kernel( + connection_info: Union[str, Dict[str, Any]], + sshserver: str, + sshkey: Optional[str] = None +) -> Tuple[Any, ...]: """tunnel connections to a kernel via ssh This will open five SSH tunnels from localhost on this machine to the @@ -254,7 +273,7 @@ def tunnel_to_kernel(connection_info, sshserver, sshkey=None): with open(connection_info) as f: connection_info = json.loads(f.read()) - cf = connection_info + cf = cast(Dict[str, Any], connection_info) lports = tunnel.select_random_ports(5) rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port'], cf['control_port'] @@ -262,11 +281,11 @@ def tunnel_to_kernel(connection_info, sshserver, sshkey=None): remote_ip = cf['ip'] if tunnel.try_passwordless_ssh(sshserver, sshkey): - password=False + password: Union[bool, str] = False else: password = getpass("SSH Password for %s: " % sshserver) - for lp,rp in zip(lports, rports): + for lp, rp in zip(lports, rports): tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password) return tuple(lports) @@ -341,10 +360,10 @@ def _ip_changed(self, change): help="set the control (ROUTER) port [default: random]") # names of the ports with random assignment - _random_port_names = None + _random_port_names: Optional[List[str]] = None @property - def ports(self): + def ports(self) -> List[int]: return [ getattr(self, name) for name in port_names ] # The Session to use for communication with the kernel. @@ -357,7 +376,10 @@ def _session_default(self): # Connection and ipc file management #-------------------------------------------------------------------------- - def get_connection_info(self, session=False): + def get_connection_info( + self, + session: bool =False + ) -> Dict[str, Any]: """Return the connection info as a dict Parameters @@ -403,7 +425,7 @@ def blocking_client(self): bc.session.key = self.session.key return bc - def cleanup_connection_file(self): + def cleanup_connection_file(self) -> None: """Cleanup connection file *if we wrote it* Will not raise if the connection file was already removed somehow. @@ -416,7 +438,7 @@ def cleanup_connection_file(self): except (IOError, OSError, AttributeError): pass - def cleanup_ipc_files(self): + def cleanup_ipc_files(self) -> None: """Cleanup ipc files if we wrote them.""" if self.transport != 'ipc': return @@ -427,7 +449,7 @@ def cleanup_ipc_files(self): except (IOError, OSError): pass - def _record_random_port_names(self): + def _record_random_port_names(self) -> None: """Records which of the ports are randomly assigned. Records on first invocation, if the transport is tcp. @@ -443,7 +465,7 @@ def _record_random_port_names(self): if getattr(self, name) <= 0: self._random_port_names.append(name) - def cleanup_random_ports(self): + def cleanup_random_ports(self) -> None: """Forgets randomly assigned port numbers and cleans up the connection file. Does nothing if no port numbers have been randomly assigned. @@ -458,7 +480,7 @@ def cleanup_random_ports(self): self.cleanup_connection_file() - def write_connection_file(self): + def write_connection_file(self) -> None: """Write connection info to JSON dict in self.connection_file.""" if self._connection_file_written and os.path.exists(self.connection_file): return @@ -478,7 +500,10 @@ def write_connection_file(self): self._connection_file_written = True - def load_connection_file(self, connection_file=None): + def load_connection_file( + self, + connection_file: Optional[str] = None + ) -> None: """Load connection info from JSON dict in self.connection_file. Parameters @@ -494,7 +519,10 @@ def load_connection_file(self, connection_file=None): info = json.load(f) self.load_connection_info(info) - def load_connection_info(self, info): + def load_connection_info( + self, + info: Dict[str, int] + ) -> None: """Load connection info from a dict containing connection info. Typically this data comes from a connection file @@ -529,7 +557,10 @@ def load_connection_info(self, info): # Creating connected sockets #-------------------------------------------------------------------------- - def _make_url(self, channel): + def _make_url( + self, + channel: str + ) -> str: """Make a ZeroMQ URL for a given channel.""" transport = self.transport ip = self.ip @@ -540,7 +571,11 @@ def _make_url(self, channel): else: return "%s://%s-%s" % (transport, ip, port) - def _create_connected_socket(self, channel, identity=None): + def _create_connected_socket( + self, + channel: str, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """Create a zmq Socket and connect it to the kernel.""" url = self._make_url(channel) socket_type = channel_socket_types[channel] @@ -553,25 +588,40 @@ def _create_connected_socket(self, channel, identity=None): sock.connect(url) return sock - def connect_iopub(self, identity=None): + def connect_iopub( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the IOPub channel""" sock = self._create_connected_socket('iopub', identity=identity) sock.setsockopt(zmq.SUBSCRIBE, b'') return sock - def connect_shell(self, identity=None): + def connect_shell( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Shell channel""" return self._create_connected_socket('shell', identity=identity) - def connect_stdin(self, identity=None): + def connect_stdin( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the StdIn channel""" return self._create_connected_socket('stdin', identity=identity) - def connect_hb(self, identity=None): + def connect_hb( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Heartbeat channel""" return self._create_connected_socket('hb', identity=identity) - def connect_control(self, identity=None): + def connect_control( + self, + identity: Optional[bytes] = None + ) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Control channel""" return self._create_connected_socket('control', identity=identity) diff --git a/jupyter_client/consoleapp.py b/jupyter_client/consoleapp.py index 42ce2fb77..e491dcc24 100644 --- a/jupyter_client/consoleapp.py +++ b/jupyter_client/consoleapp.py @@ -13,14 +13,15 @@ import sys import uuid import warnings +from typing import cast -from traitlets.config.application import boolean_flag -from traitlets import ( +from traitlets.config.application import boolean_flag # type: ignore +from traitlets import ( # type: ignore Dict, List, Unicode, CUnicode, CBool, Any, Type ) -from jupyter_core.application import base_flags, base_aliases +from jupyter_core.application import base_flags, base_aliases # type: ignore from .blocking import BlockingKernelClient from .restarter import KernelRestarter @@ -93,19 +94,19 @@ class JupyterConsoleApp(ConnectionFileMixin): description = """ The Jupyter Console Mixin. - + This class contains the common portions of console client (QtConsole, ZMQ-based terminal console, etc). It is not a full console, in that launched terminal subprocesses will not be able to accept input. - + The Console using this mixing supports various extra features beyond the single-process Terminal IPython shell, such as connecting to existing kernel, via: - + jupyter console --existing - + as well as tunnel via SSH - + """ classes = classes @@ -121,13 +122,13 @@ class JupyterConsoleApp(ConnectionFileMixin): kernel_argv = List(Unicode()) # connection info: - + sshserver = Unicode('', config=True, help="""The SSH server to use to connect to the kernel.""") sshkey = Unicode('', config=True, help="""Path to the ssh key to use for logging in to the ssh server.""") - - def _connection_file_default(self): + + def _connection_file_default(self) -> str: return 'kernel-%i.json' % os.getpid() existing = CUnicode('', config=True, @@ -141,26 +142,26 @@ def _connection_file_default(self): Set to display confirmation dialog on exit. You can always use 'exit' or 'quit', to force a direct exit without any confirmation.""", ) - - def build_kernel_argv(self, argv=None): + + def build_kernel_argv(self, argv=None) -> None: """build argv to be passed to kernel subprocess - + Override in subclasses if any args should be passed to the kernel """ self.kernel_argv = self.extra_args - - def init_connection_file(self): + + def init_connection_file(self) -> None: """find the connection file, and load the info if found. - + The current working directory and the current profile's security directory will be searched for the file if it is not given by absolute path. - + When attempting to connect to an existing kernel and the `--existing` argument does not match an existing file, it will be interpreted as a fileglob, and the matching file in the current profile's security dir with the latest access time will be used. - + After this method is called, self.connection_file contains the *full path* to the connection file, never just its name. """ @@ -192,7 +193,7 @@ def init_connection_file(self): except IOError: self.log.debug("Connection File not found: %s", self.connection_file) return - + # should load_connection_file only be used for existing? # as it is now, this allows reusing ports if an existing # file is requested @@ -201,25 +202,25 @@ def init_connection_file(self): except Exception: self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True) self.exit(1) - - def init_ssh(self): + + def init_ssh(self) -> None: """set up ssh tunnels, if needed.""" if not self.existing or (not self.sshserver and not self.sshkey): return self.load_connection_file() - + transport = self.transport ip = self.ip - + if transport != 'tcp': self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport) sys.exit(-1) - + if self.sshkey and not self.sshserver: # specifying just the key implies that we are connecting directly self.sshserver = ip ip = localhost() - + # build connection dict for tunnels: info = dict(ip=ip, shell_port=self.shell_port, @@ -228,9 +229,9 @@ def init_ssh(self): hb_port=self.hb_port, control_port=self.control_port ) - + self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver)) - + # tunnels return a new set of ports, which will be on localhost: self.ip = localhost() try: @@ -239,17 +240,17 @@ def init_ssh(self): # even catch KeyboardInterrupt self.log.error("Could not setup tunnels", exc_info=True) self.exit(1) - + self.shell_port, self.iopub_port, self.stdin_port, self.hb_port, self.control_port = newports - + cf = self.connection_file root, ext = os.path.splitext(cf) self.connection_file = root + '-ssh' + ext self.write_connection_file() # write the new connection file self.log.info("To connect another client via this tunnel, use:") self.log.info("--existing %s" % os.path.basename(self.connection_file)) - - def _new_connection_file(self): + + def _new_connection_file(self) -> str: cf = '' while not cf: # we don't need a 128b id to distinguish kernels, use more readable @@ -262,7 +263,7 @@ def _new_connection_file(self): cf = cf if not os.path.exists(cf) else '' return cf - def init_kernel_manager(self): + def init_kernel_manager(self) -> None: # Don't let Qt or ZMQ swallow KeyboardInterupts. if self.existing: self.kernel_manager = None @@ -289,6 +290,7 @@ def init_kernel_manager(self): self.log.critical("Could not find kernel %s", self.kernel_name) self.exit(1) + self.kernel_manager = cast(KernelManager, self.kernel_manager) self.kernel_manager.client_factory = self.kernel_client_class kwargs = {} kwargs['extra_arguments'] = self.kernel_argv @@ -310,7 +312,7 @@ def init_kernel_manager(self): atexit.register(self.kernel_manager.cleanup_connection_file) - def init_kernel_client(self): + def init_kernel_client(self) -> None: if self.kernel_manager is not None: self.kernel_client = self.kernel_manager.client() else: @@ -331,7 +333,7 @@ def init_kernel_client(self): - def initialize(self, argv=None): + def initialize(self, argv=None) -> None: """ Classes which mix this class in should call: JupyterConsoleApp.initialize(self,argv) diff --git a/jupyter_client/jsonutil.py b/jupyter_client/jsonutil.py index d3a472fee..667e33f1f 100644 --- a/jupyter_client/jsonutil.py +++ b/jupyter_client/jsonutil.py @@ -6,6 +6,7 @@ from datetime import datetime import re import warnings +from typing import Optional, Union from dateutil.parser import parse as _dateutil_parse from dateutil.tz import tzlocal @@ -28,7 +29,7 @@ # Classes and functions #----------------------------------------------------------------------------- -def _ensure_tzinfo(dt): +def _ensure_tzinfo(dt: datetime) -> datetime: """Ensure a datetime object has tzinfo If no tzinfo is present, add tzlocal @@ -41,7 +42,7 @@ def _ensure_tzinfo(dt): dt = dt.replace(tzinfo=tzlocal()) return dt -def parse_date(s): +def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]: """parse an ISO8601 date string If it is None or not a valid ISO8601 timestamp, diff --git a/jupyter_client/kernelapp.py b/jupyter_client/kernelapp.py index 33607049c..b95afb0b0 100644 --- a/jupyter_client/kernelapp.py +++ b/jupyter_client/kernelapp.py @@ -2,9 +2,9 @@ import signal import uuid -from jupyter_core.application import JupyterApp, base_flags +from jupyter_core.application import JupyterApp, base_flags # type: ignore from tornado.ioloop import IOLoop -from traitlets import Unicode +from traitlets import Unicode # type: ignore from . import __version__ from .kernelspec import KernelSpecManager, NATIVE_KERNEL_NAME @@ -39,7 +39,7 @@ def initialize(self, argv=None): self.loop = IOLoop.current() self.loop.add_callback(self._record_started) - def setup_signals(self): + def setup_signals(self) -> None: """Shutdown on SIGTERM or SIGINT (Ctrl-C)""" if os.name == 'nt': return @@ -49,17 +49,20 @@ def shutdown_handler(signo, frame): for sig in [signal.SIGTERM, signal.SIGINT]: signal.signal(sig, shutdown_handler) - def shutdown(self, signo): + def shutdown( + self, + signo: int + ) -> None: self.log.info('Shutting down on signal %d' % signo) self.km.shutdown_kernel() self.loop.stop() - def log_connection_info(self): + def log_connection_info(self) -> None: cf = self.km.connection_file self.log.info('Connection file: %s', cf) self.log.info("To connect a client: --existing %s", os.path.basename(cf)) - def _record_started(self): + def _record_started(self) -> None: """For tests, create a file to indicate that we've started Do not rely on this except in our own tests! @@ -69,7 +72,7 @@ def _record_started(self): with open(fn, 'wb'): pass - def start(self): + def start(self) -> None: self.log.info('Starting kernel %r', self.kernel_name) try: self.km.start_kernel() diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 1577ccde1..da60fb1b1 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -771,7 +771,7 @@ async def start_new_async_kernel( try: await kc.wait_for_ready(timeout=startup_timeout) except RuntimeError: - await kc.stop_channels() + kc.stop_channels() await km.shutdown_kernel() raise diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index 401c3a84b..ef85028f4 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -221,7 +221,7 @@ async def _async_start_kernel( self._add_kernel_when_ready( kernel_id, km, - km._async_start_kernel(**kwargs) + ensure_async(km.start_kernel(**kwargs)) ) ) self._starting_kernels[kernel_id] = fut diff --git a/jupyter_client/session.py b/jupyter_client/session.py index 88fc4746f..fde87d165 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -21,6 +21,7 @@ import pprint import random import warnings +import typing as t from datetime import datetime from datetime import timezone @@ -42,13 +43,13 @@ from jupyter_client import protocol_version from jupyter_client.adapter import adapt -from traitlets import ( +from traitlets import ( # type: ignore CBytes, Unicode, Bool, Any, Instance, Set, DottedObjectName, CUnicode, Dict, Integer, TraitError, observe ) -from traitlets.log import get_logger -from traitlets.utils.importstring import import_item -from traitlets.config.configurable import Configurable, LoggingConfigurable +from traitlets.log import get_logger # type: ignore +from traitlets.utils.importstring import import_item # type: ignore +from traitlets.config.configurable import Configurable, LoggingConfigurable # type: ignore #----------------------------------------------------------------------------- # utility functions @@ -98,7 +99,7 @@ def squash_unicode(obj): # Mixin tools for apps that use Sessions #----------------------------------------------------------------------------- -def new_id(): +def new_id() -> str: """Generate a new random id. Avoids problematic runtime import in stdlib uuid on Python 2. @@ -113,7 +114,7 @@ def new_id(): buf[:4], buf[4:] )) -def new_id_bytes(): +def new_id_bytes() -> bytes: """Return new_id as ascii bytes""" return new_id().encode('ascii') @@ -123,7 +124,7 @@ def new_id_bytes(): keyfile = 'Session.keyfile', ) -session_flags = { +session_flags = { 'secure' : ({'Session' : { 'key' : new_id_bytes(), 'keyfile' : '' }}, """Use HMAC digests for authentication of messages. @@ -133,7 +134,7 @@ def new_id_bytes(): """Don't authenticate messages."""), } -def default_secure(cfg): +def default_secure(cfg) -> None: """Set the default behavior for a config environment to be secure. If Session.key/keyfile have not been set, set Session.key to @@ -146,7 +147,7 @@ def default_secure(cfg): # key/keyfile not specified, generate new UUID: cfg.Session.key = new_id_bytes() -def utcnow(): +def utcnow() -> datetime: """Return timezone-aware UTC timestamp""" return datetime.utcnow().replace(tzinfo=utc) @@ -162,12 +163,12 @@ class SessionFactory(LoggingConfigurable): logname = Unicode('') @observe('logname') - def _logname_changed(self, change): + def _logname_changed(self, change) -> None: self.log = logging.getLogger(change['new']) # not configurable: context = Instance('zmq.Context') - def _context_default(self): + def _context_default(self) -> zmq.Context: return zmq.Context() session = Instance('jupyter_client.session.Session', @@ -191,7 +192,10 @@ class Message(object): A Message can be created from a dict and a dict from a Message instance simply by calling dict(msg_obj).""" - def __init__(self, msg_dict): + def __init__( + self, + msg_dict: t.Dict[str, t.Any] + ) -> None: dct = self.__dict__ for k, v in dict(msg_dict).items(): if isinstance(v, dict): @@ -199,29 +203,36 @@ def __init__(self, msg_dict): dct[k] = v # Having this iterator lets dict(msg_obj) work out of the box. - def __iter__(self): + def __iter__(self) -> t.ItemsView[str, t.Any]: return self.__dict__.items() - def __repr__(self): + def __repr__(self) -> str: return repr(self.__dict__) - def __str__(self): + def __str__(self) -> str: return pprint.pformat(self.__dict__) - def __contains__(self, k): + def __contains__(self, k) -> bool: return k in self.__dict__ - def __getitem__(self, k): + def __getitem__(self, k) -> t.Any: return self.__dict__[k] -def msg_header(msg_id, msg_type, username, session): +def msg_header( + msg_id: str, + msg_type: str, + username: str, + session: 'Session' +) -> t.Dict[str, t.Any]: """Create a new message header""" date = utcnow() version = protocol_version return locals() -def extract_header(msg_or_header): +def extract_header( + msg_or_header: t.Dict[str, t.Any] +) -> t.Dict[str, t.Any]: """Given a message or header, return the header.""" if not msg_or_header: return {} @@ -328,7 +339,7 @@ def _unpacker_changed(self, change): session = CUnicode('', config=True, help="""The UUID identifying this session.""") - def _session_default(self): + def _session_default(self) -> str: u = new_id() self.bsession = u.encode('ascii') return u @@ -355,7 +366,7 @@ def _session_changed(self, change): key = CBytes(config=True, help="""execution key, for signing messages.""") - def _key_default(self): + def _key_default(self) -> bytes: return new_id_bytes() @observe('key') @@ -380,12 +391,12 @@ def _signature_scheme_changed(self, change): self._new_auth() digest_mod = Any() - def _digest_mod_default(self): + def _digest_mod_default(self) -> t.Callable: return hashlib.sha256 auth = Instance(hmac.HMAC, allow_none=True) - def _new_auth(self): + def _new_auth(self) -> None: if self.key: self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) else: @@ -491,7 +502,7 @@ def __init__(self, **kwargs): if not self.key: get_logger().warning("Message signing is disabled. This is insecure and not recommended!") - def clone(self): + def clone(self) -> 'Session': """Create a copy of this Session Useful when connecting multiple times to a given kernel. @@ -511,28 +522,28 @@ def clone(self): message_count = 0 @property - def msg_id(self): + def msg_id(self) -> str: message_number = self.message_count self.message_count += 1 return '{}_{}'.format(self.session, message_number) - def _check_packers(self): + def _check_packers(self) -> None: """check packers for datetime support.""" pack = self.pack unpack = self.unpack # check simple serialization - msg = dict(a=[1,'hi']) + msg_list = dict(a=[1,'hi']) try: - packed = pack(msg) + packed = pack(msg_list) except Exception as e: - msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" + error_msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( - msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) + error_msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) ) from e # ensure packed message is bytes @@ -542,31 +553,41 @@ def _check_packers(self): # check that unpack is pack's inverse try: unpacked = unpack(packed) - assert unpacked == msg + assert unpacked == msg_list except Exception as e: - msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" + error_msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" if self.packer == 'json': jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" raise ValueError( - msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) + error_msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) ) from e # check datetime support - msg = dict(t=utcnow()) + msg_datetime = dict(t=utcnow()) try: - unpacked = unpack(pack(msg)) + unpacked = unpack(pack(msg_datetime)) if isinstance(unpacked['t'], datetime): raise ValueError("Shouldn't deserialize to datetime") except Exception: self.pack = lambda o: pack(squash_dates(o)) self.unpack = lambda s: unpack(s) - def msg_header(self, msg_type): + def msg_header( + self, + msg_type: str + ) -> t.Dict[str, t.Any]: return msg_header(self.msg_id, msg_type, self.username, self.session) - def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): + def msg( + self, + msg_type: str, + content: t.Optional[t.Dict] = None, + parent: t.Optional[t.Dict[str, t.Any]] = None, + header: t.Optional[t.Dict[str, t.Any]] = None, + metadata: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Dict[str, t.Any]: """Return the nested message dict. This format is different from what is sent over the wire. The @@ -585,7 +606,10 @@ def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): msg['metadata'].update(metadata) return msg - def sign(self, msg_list): + def sign( + self, + msg_list: t.List + ) -> bytes: """Sign a message with HMAC digest. If no auth, return b''. Parameters @@ -600,7 +624,11 @@ def sign(self, msg_list): h.update(m) return h.hexdigest().encode() - def serialize(self, msg, ident=None): + def serialize( + self, + msg: t.Dict[str, t.Any], + ident: t.Optional[t.Union[t.List[bytes], bytes]] = None + ) -> t.List[bytes]: """Serialize the message components to bytes. This is roughly the inverse of deserialize. The serialize/deserialize @@ -659,8 +687,18 @@ def serialize(self, msg, ident=None): return to_send - def send(self, stream, msg_or_type, content=None, parent=None, ident=None, - buffers=None, track=False, header=None, metadata=None): + def send( + self, + stream: zmq.sugar.socket.Socket, + msg_or_type: t.Union[t.Dict[str, t.Any], str], + content: t.Optional[t.Dict[str, t.Any]] = None, + parent: t.Optional[t.Dict[str, t.Any]] = None, + ident: t.Optional[t.Union[bytes, t.List[bytes]]] = None, + buffers: t.Optional[t.List[bytes]] = None, + track: bool = False, + header: t.Optional[t.Dict[str, t.Any]] = None, + metadata: t.Optional[t.Dict[str, t.Any]] = None + ) -> t.Optional[t.Dict[str, t.Any]]: """Build and send a message via stream or socket. The message format used by this function internally is as follows: @@ -720,7 +758,7 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None, get_logger().warning("WARNING: attempted to send message from fork\n%s", msg ) - return + return None buffers = [] if buffers is None else buffers for idx, buf in enumerate(buffers): if isinstance(buf, memoryview): @@ -761,7 +799,14 @@ def send(self, stream, msg_or_type, content=None, parent=None, ident=None, return msg - def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): + def send_raw( + self, + stream: zmq.sugar.socket.Socket, + msg_list: t.List, + flags: int = 0, + copy: bool = True, + ident: t.Optional[t.Union[bytes, t.List[bytes]]] = None, + ) -> None: """Send a raw message via ident path. This method is used to send a already serialized message. @@ -789,7 +834,13 @@ def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): to_send.extend(msg_list) stream.send_multipart(to_send, flags, copy=copy) - def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): + def recv( + self, + socket: zmq.sugar.socket.Socket, + mode: int =zmq.NOBLOCK, + content: bool =True, + copy: bool = True + ) -> t.Tuple[t.Optional[t.List[bytes]], t.Optional[t.Dict[str, t.Any]]]: """Receive and unpack a message. Parameters @@ -811,7 +862,7 @@ def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case # recv_multipart won't return None. - return None,None + return None, None else: raise # split multipart message into identity list and message dict @@ -823,7 +874,11 @@ def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): # TODO: handle it raise e - def feed_identities(self, msg_list, copy=True): + def feed_identities( + self, + msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], + copy: bool =True + ) -> t.Tuple[t.List[bytes], t.Union[t.List[bytes], t.List[zmq.Message]]]: """Split the identities from the rest of the message. Feed until DELIM is reached, then return the prefix as idents and @@ -847,20 +902,25 @@ def feed_identities(self, msg_list, copy=True): point. """ if copy: + msg_list = t.cast(t.List[bytes], msg_list) idx = msg_list.index(DELIM) return msg_list[:idx], msg_list[idx+1:] else: + msg_list = t.cast(t.List[zmq.Message], msg_list) failed = True - for idx,m in enumerate(msg_list): + for idx, m in enumerate(msg_list): if m.bytes == DELIM: failed = False break if failed: raise ValueError("DELIM not in msg_list") idents, msg_list = msg_list[:idx], msg_list[idx+1:] - return [m.bytes for m in idents], msg_list + return [bytes(m.bytes) for m in idents], msg_list - def _add_digest(self, signature): + def _add_digest( + self, + signature: bytes + ) -> None: """add a digest to history to protect against replay attacks""" if self.digest_history_size == 0: # no history, never add digests @@ -871,7 +931,7 @@ def _add_digest(self, signature): # threshold reached, cull 10% self._cull_digest_history() - def _cull_digest_history(self): + def _cull_digest_history(self) -> None: """cull the digest history Removes a randomly selected 10% of the digest history @@ -884,7 +944,12 @@ def _cull_digest_history(self): to_cull = random.sample(tuple(sorted(self.digest_history)), n_to_cull) self.digest_history.difference_update(to_cull) - def deserialize(self, msg_list, content=True, copy=True): + def deserialize( + self, + msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], + content: bool =True, + copy: bool =True + ) -> t.Dict[str, t.Any]: """Unserialize a msg_list to a nested message dict. This is roughly the inverse of serialize. The serialize/deserialize @@ -913,10 +978,13 @@ def deserialize(self, msg_list, content=True, copy=True): message = {} if not copy: # pyzmq didn't copy the first parts of the message, so we'll do it - for i in range(minlen): - msg_list[i] = msg_list[i].bytes + msg_list = t.cast(t.List[zmq.Message], msg_list) + msg_list_beginning = [bytes(msg.bytes) for msg in msg_list[:minlen]] + msg_list = t.cast(t.List[bytes], msg_list) + msg_list = msg_list_beginning + msg_list[minlen:] + msg_list = t.cast(t.List[bytes], msg_list) if self.auth is not None: - signature = msg_list[0] + signature = t.cast(bytes, msg_list[0]) if not signature: raise ValueError("Unsigned Message") if signature in self.digest_history: @@ -942,14 +1010,15 @@ def deserialize(self, msg_list, content=True, copy=True): buffers = [memoryview(b) for b in msg_list[5:]] if buffers and buffers[0].shape is None: # force copy to workaround pyzmq #646 - buffers = [memoryview(b.bytes) for b in msg_list[5:]] + msg_list = t.cast(t.List[zmq.Message], msg_list) + buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]] message['buffers'] = buffers if self.debug: pprint.pprint(message) # adapt to the current version return adapt(message) - def unserialize(self, *args, **kwargs): + def unserialize(self, *args, **kwargs) -> t.Dict[str, t.Any]: warnings.warn( "Session.unserialize is deprecated. Use Session.deserialize.", DeprecationWarning, diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index 30b4d82da..dd5539ed8 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -135,7 +135,7 @@ def async_km_subclass(config): async def start_async_kernel(): km, kc = await start_new_async_kernel(kernel_name='signaltest') await yield_((km, kc)) - await kc.stop_channels() + kc.stop_channels() await km.shutdown_kernel() assert km.context.closed @@ -184,7 +184,7 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected): assert km._shutdown_status == _ShutdownStatus.Unset assert await km.is_alive() # kc.execute("1") - await kc.stop_channels() + kc.stop_channels() await km.shutdown_kernel() assert km._shutdown_status == expected @@ -358,7 +358,6 @@ def test_start_parallel_thread_kernels(self, config, install_kernel): @pytest.mark.timeout(TIMEOUT) @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') - @pytest.mark.skipif((sys.platform == 'win32') and (sys.version_info >= (3, 8)) and (sys.version_info < (3, 10)), reason='"Timeout" error') def test_start_parallel_process_kernels(self, config, install_kernel): if config.KernelManager.transport == 'ipc': # FIXME pytest.skip("IPC transport is currently not working for this test!") @@ -371,8 +370,7 @@ def test_start_parallel_process_kernels(self, config, install_kernel): future1.result() @pytest.mark.timeout(TIMEOUT) - @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 10)), reason='"Bad file descriptor" error') - @pytest.mark.skipif((sys.platform == 'win32') and (sys.version_info >= (3, 8)) and (sys.version_info < (3, 10)), reason='"Timeout" error') + @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') def test_start_sequence_process_kernels(self, config, install_kernel): if config.KernelManager.transport == 'ipc': # FIXME pytest.skip("IPC transport is currently not working for this test!") diff --git a/jupyter_client/tests/test_multikernelmanager.py b/jupyter_client/tests/test_multikernelmanager.py index 47fa46e20..6dfa1469e 100644 --- a/jupyter_client/tests/test_multikernelmanager.py +++ b/jupyter_client/tests/test_multikernelmanager.py @@ -10,7 +10,7 @@ from unittest import TestCase from tornado.testing import AsyncTestCase, gen_test from traitlets.config.loader import Config -from jupyter_client import KernelManager +from jupyter_client import KernelManager, AsyncKernelManager from jupyter_client.multikernelmanager import MultiKernelManager, AsyncMultiKernelManager from .utils import skip_win32, SyncMKMSubclass, AsyncMKMSubclass, SyncKMSubclass, AsyncKMSubclass from ..localinterfaces import localhost @@ -200,10 +200,10 @@ def test_subclass_callables(self): km.get_kernel(kid).reset_counts() km.reset_counts() km.shutdown_all(now=True) - assert km.call_count('shutdown_kernel') == 0 + #assert km.call_count('shutdown_kernel') == 0 assert km.call_count('remove_kernel') == 1 - assert km.call_count('request_shutdown') == 1 - assert km.call_count('finish_shutdown') == 1 + #assert km.call_count('request_shutdown') == 1 + #assert km.call_count('finish_shutdown') == 1 assert km.call_count('cleanup_resources') == 0 assert kid not in km, f'{kid} not in {km}' @@ -449,10 +449,10 @@ async def test_subclass_callables(self): mkm.get_kernel(kid).reset_counts() mkm.reset_counts() await mkm.shutdown_all(now=True) - assert mkm.call_count('shutdown_kernel') == 1 + #assert mkm.call_count('shutdown_kernel') == 1 assert mkm.call_count('remove_kernel') == 1 - assert mkm.call_count('request_shutdown') == 0 - assert mkm.call_count('finish_shutdown') == 0 + #assert mkm.call_count('request_shutdown') == 0 + #assert mkm.call_count('finish_shutdown') == 0 assert mkm.call_count('cleanup_resources') == 0 assert kid not in mkm, f'{kid} not in {mkm}' From 1c822ff58dfa7cdeb08e9722abf2fb6fb8249a8b Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 22 Mar 2021 12:00:00 +0100 Subject: [PATCH 7/9] Add types --- .github/workflows/main.yml | 2 +- jupyter_client/launcher.py | 29 ++++++++++++------- jupyter_client/manager.py | 7 +++-- jupyter_client/tests/test_kernelmanager.py | 4 +-- .../tests/test_multikernelmanager.py | 12 ++++---- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0faa14717..7335cc831 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' pip freeze - name: Check types - run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py + run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py jupyter_client/launcher.py - name: Run the tests run: py.test --cov jupyter_client -v jupyter_client - name: Code coverage diff --git a/jupyter_client/launcher.py b/jupyter_client/launcher.py index 0646a434a..930ee74b0 100644 --- a/jupyter_client/launcher.py +++ b/jupyter_client/launcher.py @@ -6,12 +6,21 @@ import os import sys from subprocess import Popen, PIPE +from typing import List, Dict, Optional -from traitlets.log import get_logger +from traitlets.log import get_logger # type: ignore -def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, - independent=False, cwd=None, **kw): +def launch_kernel( + cmd: List[str], + stdin: Optional[int] = None, + stdout: Optional[int] = None, + stderr: Optional[int] = None, + env: Optional[Dict[str, str]] = None, + independent: bool = False, + cwd: Optional[str] = None, + **kw +) -> Popen: """ Launches a localhost kernel, binding to the specified ports. Parameters @@ -90,11 +99,11 @@ def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, env["IPY_INTERRUPT_EVENT"] = env["JPY_INTERRUPT_EVENT"] try: - from _winapi import DuplicateHandle, GetCurrentProcess, \ - DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP + from _winapi import (DuplicateHandle, GetCurrentProcess, + DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP) except: - from _subprocess import DuplicateHandle, GetCurrentProcess, \ - DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP + from _subprocess import (DuplicateHandle, GetCurrentProcess, # type: ignore + DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP) # type: ignore # create a handle on the parent to be inherited if independent: @@ -127,8 +136,7 @@ def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, try: # Allow to use ~/ in the command or its arguments - cmd = list(map(os.path.expanduser, cmd)) - + cmd = [os.path.expanduser(s) for s in cmd] proc = Popen(cmd, **kwargs) except Exception as exc: msg = ( @@ -145,11 +153,12 @@ def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None, if sys.platform == 'win32': # Attach the interrupt event to the Popen objet so it can be used later. - proc.win32_interrupt_event = interrupt_event + proc.win32_interrupt_event = interrupt_event # type: ignore # Clean up pipes created to work around Popen bug. if redirect_in: if stdin is None: + assert proc.stdin is not None proc.stdin.close() return proc diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index da60fb1b1..735348cb7 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -233,7 +233,6 @@ def format_kernel_cmd( ) -> t.List[str]: """replace templated args (e.g. {connection_file})""" extra_arguments = extra_arguments or [] - self.log.info(str(self.kernel_spec)) if self.kernel_cmd: cmd = self.kernel_cmd + extra_arguments else: @@ -430,7 +429,8 @@ async def _async_finish_shutdown( else: # Process is no longer alive, wait and clear if self.kernel is not None: - self.kernel.wait() + while self.kernel.poll() is None: + await asyncio.sleep(pollinterval) self.kernel = None finish_shutdown = run_sync(_async_finish_shutdown) @@ -638,7 +638,8 @@ async def _async_kill_kernel(self) -> None: else: # Process is no longer alive, wait and clear if self.kernel is not None: - self.kernel.wait() + while self.kernel.poll() is None: + await asyncio.sleep(0.1) self.kernel = None _kill_kernel = run_sync(_async_kill_kernel) diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index dd5539ed8..1a35e3ad7 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -242,8 +242,8 @@ def execute(cmd): content = reply['content'] assert content['status'] == 'ok' assert content['user_expressions']['interrupted'] - # wait up to 5s for subprocesses to handle signal - for i in range(50): + # wait up to 10s for subprocesses to handle signal + for i in range(100): reply = execute('check') if reply['user_expressions']['poll'] != [-signal.SIGINT] * N: time.sleep(0.1) diff --git a/jupyter_client/tests/test_multikernelmanager.py b/jupyter_client/tests/test_multikernelmanager.py index 6dfa1469e..fba0eff04 100644 --- a/jupyter_client/tests/test_multikernelmanager.py +++ b/jupyter_client/tests/test_multikernelmanager.py @@ -200,10 +200,10 @@ def test_subclass_callables(self): km.get_kernel(kid).reset_counts() km.reset_counts() km.shutdown_all(now=True) - #assert km.call_count('shutdown_kernel') == 0 + assert km.call_count('shutdown_kernel') == 1 assert km.call_count('remove_kernel') == 1 - #assert km.call_count('request_shutdown') == 1 - #assert km.call_count('finish_shutdown') == 1 + assert km.call_count('request_shutdown') == 0 + assert km.call_count('finish_shutdown') == 0 assert km.call_count('cleanup_resources') == 0 assert kid not in km, f'{kid} not in {km}' @@ -449,10 +449,10 @@ async def test_subclass_callables(self): mkm.get_kernel(kid).reset_counts() mkm.reset_counts() await mkm.shutdown_all(now=True) - #assert mkm.call_count('shutdown_kernel') == 1 + assert mkm.call_count('shutdown_kernel') == 1 assert mkm.call_count('remove_kernel') == 1 - #assert mkm.call_count('request_shutdown') == 0 - #assert mkm.call_count('finish_shutdown') == 0 + assert mkm.call_count('request_shutdown') == 0 + assert mkm.call_count('finish_shutdown') == 0 assert mkm.call_count('cleanup_resources') == 0 assert kid not in mkm, f'{kid} not in {mkm}' From 5774be3d2557dc0caa8a7afcc8a3ad9b60fc3e34 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 22 Mar 2021 15:23:59 +0100 Subject: [PATCH 8/9] Move util.py into utils.py --- jupyter_client/blocking/client.py | 2 +- jupyter_client/client.py | 3 +-- jupyter_client/manager.py | 2 +- jupyter_client/multikernelmanager.py | 2 +- jupyter_client/util.py | 25 ---------------------- jupyter_client/utils.py | 31 +++++++++++++++++++++++++++- 6 files changed, 34 insertions(+), 31 deletions(-) delete mode 100644 jupyter_client/util.py diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index bc1b8651a..9233b4eeb 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -8,7 +8,7 @@ from traitlets import Type # type: ignore from jupyter_client.channels import HBChannel, ZMQSocketChannel from jupyter_client.client import KernelClient, reqrep -from ..util import run_sync +from ..utils import run_sync class BlockingKernelClient(KernelClient): diff --git a/jupyter_client/client.py b/jupyter_client/client.py index 524ff8cb4..a65060e49 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -12,7 +12,6 @@ import socket import typing as t - from jupyter_client.channels import major_protocol_version import zmq @@ -26,7 +25,7 @@ from .clientabc import KernelClientABC from .connect import ConnectionFileMixin from .session import Session -from .util import ensure_async +from .utils import ensure_async # some utilities to validate message structure, these might get moved elsewhere diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index 735348cb7..a2c2ad60c 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -33,7 +33,7 @@ from .managerabc import ( KernelManagerABC ) -from .util import run_sync, ensure_async +from .utils import run_sync, ensure_async class _ShutdownStatus(Enum): """ diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index ef85028f4..8d8cb5d9c 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -19,7 +19,7 @@ from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager from .manager import KernelManager -from .util import run_sync, ensure_async +from .utils import run_sync, ensure_async class DuplicateKernelError(Exception): diff --git a/jupyter_client/util.py b/jupyter_client/util.py deleted file mode 100644 index 6640ed111..000000000 --- a/jupyter_client/util.py +++ /dev/null @@ -1,25 +0,0 @@ -import os -import sys -import asyncio -import inspect -import nest_asyncio - -if os.name == 'nt' and sys.version_info >= (3, 7): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - -def run_sync(coro): - def wrapped(*args, **kwargs): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - nest_asyncio.apply(loop) - return loop.run_until_complete(coro(*args, **kwargs)) - wrapped.__doc__ = coro.__doc__ - return wrapped - -async def ensure_async(obj): - if inspect.isawaitable(obj): - return await obj - return obj diff --git a/jupyter_client/utils.py b/jupyter_client/utils.py index 2f49d2103..942932b12 100644 --- a/jupyter_client/utils.py +++ b/jupyter_client/utils.py @@ -1,8 +1,37 @@ """ -Utils vendored from ipython_genutils that should be retired at some point. +utils: +- provides utility wrapeprs to run asynchronous functions in a blocking environment. +- vendor functions from ipython_genutils that should be retired at some point. """ import os +import sys +import asyncio +import inspect +import nest_asyncio + + +if os.name == 'nt' and sys.version_info >= (3, 7): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + +def run_sync(coro): + def wrapped(*args, **kwargs): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + nest_asyncio.apply(loop) + return loop.run_until_complete(coro(*args, **kwargs)) + wrapped.__doc__ = coro.__doc__ + return wrapped + + +async def ensure_async(obj): + if inspect.isawaitable(obj): + return await obj + return obj def _filefind(filename, path_dirs=None): From 1ea1afca1de8ef9c2ec40ae58c262db6b59aa452 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 24 Mar 2021 17:44:24 +0100 Subject: [PATCH 9/9] Make requests sync or async depending on reply parameter --- jupyter_client/asynchronous/client.py | 25 ++++++++++---- jupyter_client/blocking/client.py | 25 ++++++++++---- jupyter_client/client.py | 39 +++++++++------------- jupyter_client/tests/test_kernelmanager.py | 4 +-- 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 4d4038985..86fb8737e 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -7,6 +7,17 @@ from jupyter_client.client import KernelClient, reqrep +def wrapped(meth, channel): + def _(self, *args, **kwargs): + reply = kwargs.pop('reply', False) + timeout = kwargs.pop('timeout', None) + msg_id = meth(self, *args, **kwargs) + if not reply: + return msg_id + return self._async_recv_reply(msg_id, timeout=timeout, channel=channel) + return _ + + class AsyncKernelClient(KernelClient): """A KernelClient with async APIs @@ -37,15 +48,15 @@ class AsyncKernelClient(KernelClient): # replies come on the shell channel - execute = reqrep(KernelClient._async_execute) - history = reqrep(KernelClient._async_history) - complete = reqrep(KernelClient._async_complete) - inspect = reqrep(KernelClient._async_inspect) - kernel_info = reqrep(KernelClient._async_kernel_info) - comm_info = reqrep(KernelClient._async_comm_info) + execute = reqrep(wrapped, KernelClient._execute) + history = reqrep(wrapped, KernelClient._history) + complete = reqrep(wrapped, KernelClient._complete) + inspect = reqrep(wrapped, KernelClient._inspect) + kernel_info = reqrep(wrapped, KernelClient._kernel_info) + comm_info = reqrep(wrapped, KernelClient._comm_info) is_alive = KernelClient._async_is_alive execute_interactive = KernelClient._async_execute_interactive # replies come on the control channel - shutdown = reqrep(KernelClient._async_shutdown, channel='control') + shutdown = reqrep(wrapped, KernelClient._shutdown, channel='control') diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 9233b4eeb..34dafdf43 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -11,6 +11,17 @@ from ..utils import run_sync +def wrapped(meth, channel): + def _(self, *args, **kwargs): + reply = kwargs.pop('reply', False) + timeout = kwargs.pop('timeout', None) + msg_id = meth(self, *args, **kwargs) + if not reply: + return msg_id + return run_sync(self._async_recv_reply)(msg_id, timeout=timeout, channel=channel) + return _ + + class BlockingKernelClient(KernelClient): """A KernelClient with blocking APIs @@ -41,15 +52,15 @@ class BlockingKernelClient(KernelClient): # replies come on the shell channel - execute = run_sync(reqrep(KernelClient._async_execute)) - history = run_sync(reqrep(KernelClient._async_history)) - complete = run_sync(reqrep(KernelClient._async_complete)) - inspect = run_sync(reqrep(KernelClient._async_inspect)) - kernel_info = run_sync(reqrep(KernelClient._async_kernel_info)) - comm_info = run_sync(reqrep(KernelClient._async_comm_info)) + execute = reqrep(wrapped, KernelClient._execute) + history = reqrep(wrapped, KernelClient._history) + complete = reqrep(wrapped, KernelClient._complete) + inspect = reqrep(wrapped, KernelClient._inspect) + kernel_info = reqrep(wrapped, KernelClient._kernel_info) + comm_info = reqrep(wrapped, KernelClient._comm_info) is_alive = run_sync(KernelClient._async_is_alive) execute_interactive = run_sync(KernelClient._async_execute_interactive) # replies come on the control channel - shutdown = run_sync(reqrep(KernelClient._async_shutdown, channel='control')) + shutdown = reqrep(wrapped, KernelClient._shutdown, channel='control') diff --git a/jupyter_client/client.py b/jupyter_client/client.py index a65060e49..b4ba3f004 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -45,18 +45,11 @@ def validate_string_dict( def reqrep( + wrapped: t.Callable, meth: t.Callable, channel: str = 'shell' ) -> t.Callable: - async def wrapped(self, *args, **kwargs) -> t.Union[str, t.Dict[str, t.Any]]: - reply = kwargs.pop('reply', False) - timeout = kwargs.pop('timeout', None) - msg_id = await meth(self, *args, **kwargs) - if not reply: - return msg_id - - return await self._async_recv_reply(msg_id, timeout=timeout, channel=channel) - + wrapped = wrapped(meth, channel) if not meth.__doc__: # python -OO removes docstrings, # so don't bother building the wrapped docstring @@ -173,7 +166,7 @@ async def _async_wait_for_ready( # Wait for kernel info reply on shell channel while True: - await self._async_kernel_info() + self._kernel_info() try: msg = await self.shell_channel.get_msg(timeout=1) except Empty: @@ -493,12 +486,12 @@ async def _async_execute_interactive( allow_stdin = self.allow_stdin if allow_stdin and not self.stdin_channel.is_alive(): raise RuntimeError("stdin channel must be running to allow input") - msg_id = await self._async_execute(code, - silent=silent, - store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, - stop_on_error=stop_on_error, + msg_id = self._execute(code, + silent=silent, + store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error, ) if stdin_hook is None: stdin_hook = self._stdin_hook_default @@ -568,7 +561,7 @@ async def _async_execute_interactive( # Methods to send specific messages on channels - async def _async_execute( + def _execute( self, code: str, silent: bool = False, @@ -632,7 +625,7 @@ async def _async_execute( self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_complete( + def _complete( self, code: str, cursor_pos: t.Optional[int] = None @@ -659,7 +652,7 @@ async def _async_complete( self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_inspect( + def _inspect( self, code: str, cursor_pos: t.Optional[int] = None, @@ -693,7 +686,7 @@ async def _async_inspect( self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_history( + def _history( self, raw: bool = True, output: bool = False, @@ -740,7 +733,7 @@ async def _async_history( self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_kernel_info(self) -> str: + def _kernel_info(self) -> str: """Request kernel info Returns @@ -751,7 +744,7 @@ async def _async_kernel_info(self) -> str: self.shell_channel.send(msg) return msg['header']['msg_id'] - async def _async_comm_info( + def _comm_info( self, target_name: t.Optional[str] = None ) -> str: @@ -804,7 +797,7 @@ def input( msg = self.session.msg('input_reply', content) self.stdin_channel.send(msg) - async def _async_shutdown( + def _shutdown( self, restart: bool = False ) -> str: diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index 1a35e3ad7..5380bdca0 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -470,7 +470,7 @@ async def test_signal_kernel_subprocesses(self, install_kernel, start_async_kern km, kc = start_async_kernel async def execute(cmd): - request_id = await kc.execute(cmd) + request_id = kc.execute(cmd) while True: reply = await kc.get_shell_msg(TIMEOUT) if reply['parent_header']['msg_id'] == request_id: @@ -489,7 +489,7 @@ async def execute(cmd): assert reply['user_expressions']['poll'] == [None] * N # start a job on the kernel to be interrupted - request_id = await kc.execute('sleep') + request_id = kc.execute('sleep') await asyncio.sleep(1) # ensure sleep message has been handled before we interrupt await km.interrupt_kernel() while True: