diff --git a/nbclient/client.py b/nbclient/client.py index 87489cd5..2dd977af 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -5,6 +5,7 @@ # For python 3.5 compatibility we import asynccontextmanager from async_generator instead of # contextlib, and we `await yield_()` instead of just `yield` from async_generator import asynccontextmanager, async_generator, yield_ +from contextlib import contextmanager from time import monotonic from queue import Empty @@ -15,8 +16,14 @@ from nbformat.v4 import output_from_msg -from .exceptions import CellTimeoutError, DeadKernelError, CellExecutionComplete, CellExecutionError -from .util import run_sync +from .exceptions import ( + CellControlSignal, + CellTimeoutError, + DeadKernelError, + CellExecutionComplete, + CellExecutionError +) +from .util import run_sync, await_or_block def timestamp(): @@ -324,7 +331,28 @@ def start_kernel_manager(self): self.km.client_class = 'jupyter_client.asynchronous.AsyncKernelClient' return self.km - async def start_new_kernel_client(self, **kwargs): + async def _async_cleanup_kernel(self): + try: + # Send a polite shutdown request + await await_or_block(self.kc.shutdown) + try: + # Queue the manager to kill the process, sometimes the built-in and above + # shutdowns have not been successful or called yet, so give a direct kill + # call here and recover gracefully if it's already dead. + await await_or_block(self.km.shutdown_kernel, now=True) + except RuntimeError as e: + # The error isn't specialized, so we have to check the message + if 'No kernel is running!' not in str(e): + raise + finally: + # Remove any state left over even if we failed to stop the kernel + await await_or_block(self.km.cleanup) + await await_or_block(self.kc.stop_channels) + self.kc = None + + _cleanup_kernel = run_sync(_async_cleanup_kernel) + + async def async_start_new_kernel_client(self, **kwargs): """Creates a new kernel client. Parameters @@ -346,22 +374,44 @@ async def start_new_kernel_client(self, **kwargs): if self.km.ipykernel and self.ipython_hist_file: self.extra_arguments += ['--HistoryManager.hist_file={}'.format(self.ipython_hist_file)] - await self.km.start_kernel(extra_arguments=self.extra_arguments, **kwargs) + await await_or_block(self.km.start_kernel, extra_arguments=self.extra_arguments, **kwargs) self.kc = self.km.client() - self.kc.start_channels() + await await_or_block(self.kc.start_channels) try: - await self.kc.wait_for_ready(timeout=self.startup_timeout) + await await_or_block(self.kc.wait_for_ready, timeout=self.startup_timeout) except RuntimeError: - self.kc.stop_channels() - await self.km.shutdown_kernel() + await self._async_cleanup_kernel() raise self.kc.allow_stdin = False return self.kc + start_new_kernel_client = run_sync(async_start_new_kernel_client) + + @contextmanager + def setup_kernel(self, **kwargs): + """ + Context manager for setting up the kernel to execute a notebook. + + The assigns the Kernel Manager (`self.km`) if missing and Kernel Client(`self.kc`). + + When control returns from the yield it stops the client's zmq channels, and shuts + down the kernel. + """ + # Can't use run_until_complete on an asynccontextmanager function :( + if self.km is None: + self.start_kernel_manager() + + if not self.km.has_kernel: + self.start_new_kernel_client(**kwargs) + try: + yield + finally: + self._cleanup_kernel() + @asynccontextmanager @async_generator # needed for python 3.5 compatibility - async def setup_kernel(self, **kwargs): + async def async_setup_kernel(self, **kwargs): """ Context manager for setting up the kernel to execute a notebook. @@ -374,12 +424,11 @@ async def setup_kernel(self, **kwargs): self.start_kernel_manager() if not self.km.has_kernel: - await self.start_new_kernel_client(**kwargs) + await self.async_start_new_kernel_client(**kwargs) try: await yield_(None) # would just yield in python >3.5 finally: - self.kc.stop_channels() - self.kc = None + await self._async_cleanup_kernel() async def async_execute(self, **kwargs): """ @@ -392,7 +441,7 @@ async def async_execute(self, **kwargs): """ self.reset_execution_trackers() - async with self.setup_kernel(**kwargs): + async with self.async_setup_kernel(**kwargs): self.log.info("Executing notebook with kernel: %s" % self.kernel_name) for index, cell in enumerate(self.nb.cells): # Ignore `'execution_count' in content` as it's always 1 @@ -400,7 +449,8 @@ async def async_execute(self, **kwargs): await self.async_execute_cell( cell, index, execution_count=self.code_cells_executed + 1 ) - info_msg = await self._wait_for_reply(self.kc.kernel_info()) + msg_id = await await_or_block(self.kc.kernel_info) + info_msg = await self.async_wait_for_reply(msg_id) self.nb.metadata['language_info'] = info_msg['content']['language_info'] self.set_widgets_metadata() @@ -450,12 +500,12 @@ def _update_display_id(self, display_id, msg): outputs[output_idx]['data'] = out['data'] outputs[output_idx]['metadata'] = out['metadata'] - async def _poll_for_reply(self, msg_id, cell, timeout, task_poll_output_msg): + async def _async_poll_for_reply(self, msg_id, cell, timeout, task_poll_output_msg): if timeout is not None: deadline = monotonic() + timeout while True: try: - msg = await self.kc.shell_channel.get_msg(timeout=timeout) + msg = await await_or_block(self.kc.shell_channel.get_msg, timeout=timeout) if msg['parent_header'].get('msg_id') == msg_id: if self.record_timing: cell['metadata']['execution']['shell.execute_reply'] = timestamp() @@ -474,12 +524,12 @@ async def _poll_for_reply(self, msg_id, cell, timeout, task_poll_output_msg): timeout = max(0, deadline - monotonic()) except Empty: # received no message, check if kernel is still alive - await self._check_alive() - await self._handle_timeout(timeout, cell) + await self._async_check_alive() + await self._async_handle_timeout(timeout, cell) - async def _poll_output_msg(self, parent_msg_id, cell, cell_index): + async def _async_poll_output_msg(self, parent_msg_id, cell, cell_index): while True: - msg = await self.kc.iopub_channel.get_msg(timeout=None) + msg = await await_or_block(self.kc.iopub_channel.get_msg, timeout=None) if msg['parent_header'].get('msg_id') == parent_msg_id: try: # Will raise CellExecutionComplete when completed @@ -498,39 +548,42 @@ def _get_timeout(self, cell): return timeout - async def _handle_timeout(self, timeout, cell=None): + async def _async_handle_timeout(self, timeout, cell=None): self.log.error("Timeout waiting for execute reply (%is)." % timeout) if self.interrupt_on_timeout: self.log.error("Interrupting kernel") - await self.km.interrupt_kernel() + await await_or_block(self.km.interrupt_kernel) else: raise CellTimeoutError.error_from_timeout_and_cell( "Cell execution timed out", timeout, cell ) - async def _check_alive(self): - if not await self.kc.is_alive(): + async def _async_check_alive(self): + if not await await_or_block(self.kc.is_alive): self.log.error("Kernel died while waiting for execute reply.") raise DeadKernelError("Kernel died") - async def _wait_for_reply(self, msg_id, cell=None): + async def async_wait_for_reply(self, msg_id, cell=None): # wait for finish, with timeout timeout = self._get_timeout(cell) cummulative_time = 0 - self.shell_timeout_interval = 5 while True: try: - msg = await self.kc.shell_channel.get_msg(timeout=self.shell_timeout_interval) + msg = await await_or_block(self.kc.shell_channel.get_msg, timeout=self.shell_timeout_interval) except Empty: - await self._check_alive() + await self._async_check_alive() cummulative_time += self.shell_timeout_interval if timeout and cummulative_time > timeout: - await self._handle_timeout(timeout, cell) + await self._async_async_handle_timeout(timeout, cell) break else: if msg['parent_header'].get('msg_id') == msg_id: return msg + wait_for_reply = run_sync(async_wait_for_reply) + # Backwards compatability naming for papermill + _wait_for_reply = wait_for_reply + def _timeout_with_deadline(self, timeout, deadline): if deadline is not None and deadline - monotonic() < timeout: timeout = deadline - monotonic() @@ -596,7 +649,7 @@ async def async_execute_cell(self, cell, cell_index, execution_count=None, store cell['metadata']['execution'] = {} self.log.debug("Executing cell:\n%s", cell.source) - parent_msg_id = self.kc.execute( + parent_msg_id = await await_or_block(self.kc.execute, cell.source, store_history=store_history, stop_on_error=not self.allow_errors ) # We launched a code cell to execute @@ -607,11 +660,20 @@ async def async_execute_cell(self, cell, cell_index, execution_count=None, store self.clear_before_next_output = False task_poll_output_msg = asyncio.ensure_future( - self._poll_output_msg(parent_msg_id, cell, cell_index) - ) - exec_reply = await self._poll_for_reply( - parent_msg_id, cell, exec_timeout, task_poll_output_msg + self._async_poll_output_msg(parent_msg_id, cell, cell_index) ) + try: + exec_reply = await self._async_poll_for_reply( + parent_msg_id, cell, exec_timeout, task_poll_output_msg + ) + except Exception as e: + # Best effort to cancel request if it hasn't been resolved + try: + # Check if the task_poll_output is doing the raising for us + if not isinstance(e, CellControlSignal): + task_poll_output_msg.cancel() + finally: + raise if execution_count: cell['execution_count'] = execution_count diff --git a/nbclient/exceptions.py b/nbclient/exceptions.py index e0485d77..c6aa52ac 100644 --- a/nbclient/exceptions.py +++ b/nbclient/exceptions.py @@ -1,4 +1,13 @@ -class CellTimeoutError(TimeoutError): +class CellControlSignal(Exception): + """ + A custom exception used to indicate that the exception is used for cell + control actions (not the best model, but it's needed to cover existing + behavior without major refactors). + """ + pass + + +class CellTimeoutError(TimeoutError, CellControlSignal): """ A custom exception to capture when a cell has timed out during execution. """ @@ -21,7 +30,7 @@ class DeadKernelError(RuntimeError): pass -class CellExecutionComplete(Exception): +class CellExecutionComplete(CellControlSignal): """ Used as a control signal for cell execution across execute_cell and process_message function calls. Raised when all execution requests @@ -32,7 +41,7 @@ class CellExecutionComplete(Exception): pass -class CellExecutionError(Exception): +class CellExecutionError(CellControlSignal): """ Custom exception to propagate exceptions that are raised during notebook execution to the caller. This is mostly useful when diff --git a/nbclient/tests/test_client.py b/nbclient/tests/test_client.py index 7cb1a7f6..537c6b6e 100644 --- a/nbclient/tests/test_client.py +++ b/nbclient/tests/test_client.py @@ -36,6 +36,10 @@ IPY_MAJOR = IPython.version_info[0] +class AsyncMock(Mock): + pass + + def make_async(mock_value): async def _(): return mock_value @@ -116,7 +120,7 @@ def prepare_cell_mocks(*messages, reply_msg=None): def shell_channel_message_mock(): # Return the message generator for # self.kc.shell_channel.get_msg => {'parent_header': {'msg_id': parent_id}} - return MagicMock( + return AsyncMock( return_value=make_async(NBClientTestsBase.merge_dicts( { 'parent_header': {'msg_id': parent_id}, @@ -129,7 +133,7 @@ def shell_channel_message_mock(): def iopub_messages_mock(): # Return the message generator for # self.kc.iopub_channel.get_msg => messages[i] - return Mock( + return AsyncMock( side_effect=[ # Default the parent_header so mocks don't need to include this make_async( @@ -386,6 +390,16 @@ def get_time_from_str(s): assert status_idle - cell_end < delta +def test_synchronous_setup_kernel(): + nb = nbformat.v4.new_notebook() + executor = NotebookClient(nb) + with executor.setup_kernel(): + # Prove it initalized client + assert executor.kc is not None + # Prove it removed the client (and hopefully cleaned up) + assert executor.kc is None + + class TestExecute(NBClientTestsBase): """Contains test functions for execute.py""" diff --git a/nbclient/util.py b/nbclient/util.py index 1a274792..c6590e9e 100644 --- a/nbclient/util.py +++ b/nbclient/util.py @@ -5,6 +5,8 @@ import asyncio +from typing import Coroutine + def run_sync(coro): """Runs a coroutine and blocks until it has executed. @@ -45,3 +47,17 @@ def wrapped(self, *args, **kwargs): return result wrapped.__doc__ = coro.__doc__ return wrapped + + +async def await_or_block(func, *args, **kwargs): + """Awaits the function if it's an asynchronous function. Otherwise block + on execution. + """ + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + # Mocks mask that the function is a coroutine :/ + if isinstance(result, Coroutine): + return await result + return result