Skip to content

Commit

Permalink
Fixed a number of async issues and enhanced process cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
MSeal committed Mar 27, 2020
1 parent 9c941f9 commit b3a9169
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 39 deletions.
130 changes: 96 additions & 34 deletions nbclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# For python 3.5 compatibility we import asynccontextmanager from async_generator instead of
# contextlib, and we `await yield_()` instead of just `yield`
from async_generator import asynccontextmanager, async_generator, yield_
from contextlib import contextmanager

from time import monotonic
from queue import Empty
Expand All @@ -15,8 +16,14 @@

from nbformat.v4 import output_from_msg

from .exceptions import CellTimeoutError, DeadKernelError, CellExecutionComplete, CellExecutionError
from .util import run_sync
from .exceptions import (
CellControlSignal,
CellTimeoutError,
DeadKernelError,
CellExecutionComplete,
CellExecutionError
)
from .util import run_sync, await_or_block


def timestamp():
Expand Down Expand Up @@ -324,7 +331,28 @@ def start_kernel_manager(self):
self.km.client_class = 'jupyter_client.asynchronous.AsyncKernelClient'
return self.km

async def start_new_kernel_client(self, **kwargs):
async def _async_cleanup_kernel(self):
try:
# Send a polite shutdown request
await await_or_block(self.kc.shutdown)
try:
# Queue the manager to kill the process, sometimes the built-in and above
# shutdowns have not been successful or called yet, so give a direct kill
# call here and recover gracefully if it's already dead.
await await_or_block(self.km.shutdown_kernel, now=True)
except RuntimeError as e:
# The error isn't specialized, so we have to check the message
if 'No kernel is running!' not in str(e):
raise
finally:
# Remove any state left over even if we failed to stop the kernel
await await_or_block(self.km.cleanup)
await await_or_block(self.kc.stop_channels)
self.kc = None

_cleanup_kernel = run_sync(_async_cleanup_kernel)

async def async_start_new_kernel_client(self, **kwargs):
"""Creates a new kernel client.
Parameters
Expand All @@ -346,22 +374,44 @@ async def start_new_kernel_client(self, **kwargs):
if self.km.ipykernel and self.ipython_hist_file:
self.extra_arguments += ['--HistoryManager.hist_file={}'.format(self.ipython_hist_file)]

await self.km.start_kernel(extra_arguments=self.extra_arguments, **kwargs)
await await_or_block(self.km.start_kernel, extra_arguments=self.extra_arguments, **kwargs)

self.kc = self.km.client()
self.kc.start_channels()
await await_or_block(self.kc.start_channels)
try:
await self.kc.wait_for_ready(timeout=self.startup_timeout)
await await_or_block(self.kc.wait_for_ready, timeout=self.startup_timeout)
except RuntimeError:
self.kc.stop_channels()
await self.km.shutdown_kernel()
await self._async_cleanup_kernel()
raise
self.kc.allow_stdin = False
return self.kc

start_new_kernel_client = run_sync(async_start_new_kernel_client)

@contextmanager
def setup_kernel(self, **kwargs):
"""
Context manager for setting up the kernel to execute a notebook.
The assigns the Kernel Manager (`self.km`) if missing and Kernel Client(`self.kc`).
When control returns from the yield it stops the client's zmq channels, and shuts
down the kernel.
"""
# Can't use run_until_complete on an asynccontextmanager function :(
if self.km is None:
self.start_kernel_manager()

if not self.km.has_kernel:
self.start_new_kernel_client(**kwargs)
try:
yield
finally:
self._cleanup_kernel()

@asynccontextmanager
@async_generator # needed for python 3.5 compatibility
async def setup_kernel(self, **kwargs):
async def async_setup_kernel(self, **kwargs):
"""
Context manager for setting up the kernel to execute a notebook.
Expand All @@ -374,12 +424,11 @@ async def setup_kernel(self, **kwargs):
self.start_kernel_manager()

if not self.km.has_kernel:
await self.start_new_kernel_client(**kwargs)
await self.async_start_new_kernel_client(**kwargs)
try:
await yield_(None) # would just yield in python >3.5
finally:
self.kc.stop_channels()
self.kc = None
await self._async_cleanup_kernel()

async def async_execute(self, **kwargs):
"""
Expand All @@ -392,15 +441,16 @@ async def async_execute(self, **kwargs):
"""
self.reset_execution_trackers()

async with self.setup_kernel(**kwargs):
async with self.async_setup_kernel(**kwargs):
self.log.info("Executing notebook with kernel: %s" % self.kernel_name)
for index, cell in enumerate(self.nb.cells):
# Ignore `'execution_count' in content` as it's always 1
# when store_history is False
await self.async_execute_cell(
cell, index, execution_count=self.code_cells_executed + 1
)
info_msg = await self._wait_for_reply(self.kc.kernel_info())
msg_id = await await_or_block(self.kc.kernel_info)
info_msg = await self.async_wait_for_reply(msg_id)
self.nb.metadata['language_info'] = info_msg['content']['language_info']
self.set_widgets_metadata()

Expand Down Expand Up @@ -450,12 +500,12 @@ def _update_display_id(self, display_id, msg):
outputs[output_idx]['data'] = out['data']
outputs[output_idx]['metadata'] = out['metadata']

async def _poll_for_reply(self, msg_id, cell, timeout, task_poll_output_msg):
async def _async_poll_for_reply(self, msg_id, cell, timeout, task_poll_output_msg):
if timeout is not None:
deadline = monotonic() + timeout
while True:
try:
msg = await self.kc.shell_channel.get_msg(timeout=timeout)
msg = await await_or_block(self.kc.shell_channel.get_msg, timeout=timeout)
if msg['parent_header'].get('msg_id') == msg_id:
if self.record_timing:
cell['metadata']['execution']['shell.execute_reply'] = timestamp()
Expand All @@ -474,12 +524,12 @@ async def _poll_for_reply(self, msg_id, cell, timeout, task_poll_output_msg):
timeout = max(0, deadline - monotonic())
except Empty:
# received no message, check if kernel is still alive
await self._check_alive()
await self._handle_timeout(timeout, cell)
await self._async_check_alive()
await self._async_handle_timeout(timeout, cell)

async def _poll_output_msg(self, parent_msg_id, cell, cell_index):
async def _async_poll_output_msg(self, parent_msg_id, cell, cell_index):
while True:
msg = await self.kc.iopub_channel.get_msg(timeout=None)
msg = await await_or_block(self.kc.iopub_channel.get_msg, timeout=None)
if msg['parent_header'].get('msg_id') == parent_msg_id:
try:
# Will raise CellExecutionComplete when completed
Expand All @@ -498,39 +548,42 @@ def _get_timeout(self, cell):

return timeout

async def _handle_timeout(self, timeout, cell=None):
async def _async_handle_timeout(self, timeout, cell=None):
self.log.error("Timeout waiting for execute reply (%is)." % timeout)
if self.interrupt_on_timeout:
self.log.error("Interrupting kernel")
await self.km.interrupt_kernel()
await await_or_block(self.km.interrupt_kernel)
else:
raise CellTimeoutError.error_from_timeout_and_cell(
"Cell execution timed out", timeout, cell
)

async def _check_alive(self):
if not await self.kc.is_alive():
async def _async_check_alive(self):
if not await await_or_block(self.kc.is_alive):
self.log.error("Kernel died while waiting for execute reply.")
raise DeadKernelError("Kernel died")

async def _wait_for_reply(self, msg_id, cell=None):
async def async_wait_for_reply(self, msg_id, cell=None):
# wait for finish, with timeout
timeout = self._get_timeout(cell)
cummulative_time = 0
self.shell_timeout_interval = 5
while True:
try:
msg = await self.kc.shell_channel.get_msg(timeout=self.shell_timeout_interval)
msg = await await_or_block(self.kc.shell_channel.get_msg, timeout=self.shell_timeout_interval)
except Empty:
await self._check_alive()
await self._async_check_alive()
cummulative_time += self.shell_timeout_interval
if timeout and cummulative_time > timeout:
await self._handle_timeout(timeout, cell)
await self._async_async_handle_timeout(timeout, cell)
break
else:
if msg['parent_header'].get('msg_id') == msg_id:
return msg

wait_for_reply = run_sync(async_wait_for_reply)
# Backwards compatability naming for papermill
_wait_for_reply = wait_for_reply

def _timeout_with_deadline(self, timeout, deadline):
if deadline is not None and deadline - monotonic() < timeout:
timeout = deadline - monotonic()
Expand Down Expand Up @@ -596,7 +649,7 @@ async def async_execute_cell(self, cell, cell_index, execution_count=None, store
cell['metadata']['execution'] = {}

self.log.debug("Executing cell:\n%s", cell.source)
parent_msg_id = self.kc.execute(
parent_msg_id = await await_or_block(self.kc.execute,
cell.source, store_history=store_history, stop_on_error=not self.allow_errors
)
# We launched a code cell to execute
Expand All @@ -607,11 +660,20 @@ async def async_execute_cell(self, cell, cell_index, execution_count=None, store
self.clear_before_next_output = False

task_poll_output_msg = asyncio.ensure_future(
self._poll_output_msg(parent_msg_id, cell, cell_index)
)
exec_reply = await self._poll_for_reply(
parent_msg_id, cell, exec_timeout, task_poll_output_msg
self._async_poll_output_msg(parent_msg_id, cell, cell_index)
)
try:
exec_reply = await self._async_poll_for_reply(
parent_msg_id, cell, exec_timeout, task_poll_output_msg
)
except Exception as e:
# Best effort to cancel request if it hasn't been resolved
try:
# Check if the task_poll_output is doing the raising for us
if not isinstance(e, CellControlSignal):
task_poll_output_msg.cancel()
finally:
raise

if execution_count:
cell['execution_count'] = execution_count
Expand Down
15 changes: 12 additions & 3 deletions nbclient/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
class CellTimeoutError(TimeoutError):
class CellControlSignal(Exception):
"""
A custom exception used to indicate that the exception is used for cell
control actions (not the best model, but it's needed to cover existing
behavior without major refactors).
"""
pass


class CellTimeoutError(TimeoutError, CellControlSignal):
"""
A custom exception to capture when a cell has timed out during execution.
"""
Expand All @@ -21,7 +30,7 @@ class DeadKernelError(RuntimeError):
pass


class CellExecutionComplete(Exception):
class CellExecutionComplete(CellControlSignal):
"""
Used as a control signal for cell execution across execute_cell and
process_message function calls. Raised when all execution requests
Expand All @@ -32,7 +41,7 @@ class CellExecutionComplete(Exception):
pass


class CellExecutionError(Exception):
class CellExecutionError(CellControlSignal):
"""
Custom exception to propagate exceptions that are raised during
notebook execution to the caller. This is mostly useful when
Expand Down
18 changes: 16 additions & 2 deletions nbclient/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
IPY_MAJOR = IPython.version_info[0]


class AsyncMock(Mock):
pass


def make_async(mock_value):
async def _():
return mock_value
Expand Down Expand Up @@ -116,7 +120,7 @@ def prepare_cell_mocks(*messages, reply_msg=None):
def shell_channel_message_mock():
# Return the message generator for
# self.kc.shell_channel.get_msg => {'parent_header': {'msg_id': parent_id}}
return MagicMock(
return AsyncMock(
return_value=make_async(NBClientTestsBase.merge_dicts(
{
'parent_header': {'msg_id': parent_id},
Expand All @@ -129,7 +133,7 @@ def shell_channel_message_mock():
def iopub_messages_mock():
# Return the message generator for
# self.kc.iopub_channel.get_msg => messages[i]
return Mock(
return AsyncMock(
side_effect=[
# Default the parent_header so mocks don't need to include this
make_async(
Expand Down Expand Up @@ -386,6 +390,16 @@ def get_time_from_str(s):
assert status_idle - cell_end < delta


def test_synchronous_setup_kernel():
nb = nbformat.v4.new_notebook()
executor = NotebookClient(nb)
with executor.setup_kernel():
# Prove it initalized client
assert executor.kc is not None
# Prove it removed the client (and hopefully cleaned up)
assert executor.kc is None


class TestExecute(NBClientTestsBase):
"""Contains test functions for execute.py"""

Expand Down
16 changes: 16 additions & 0 deletions nbclient/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import asyncio

from typing import Coroutine


def run_sync(coro):
"""Runs a coroutine and blocks until it has executed.
Expand Down Expand Up @@ -45,3 +47,17 @@ def wrapped(self, *args, **kwargs):
return result
wrapped.__doc__ = coro.__doc__
return wrapped


async def await_or_block(func, *args, **kwargs):
"""Awaits the function if it's an asynchronous function. Otherwise block
on execution.
"""
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
# Mocks mask that the function is a coroutine :/
if isinstance(result, Coroutine):
return await result
return result

0 comments on commit b3a9169

Please sign in to comment.