From c68b7001432f5f02ee15c19f1613c08943afbbd3 Mon Sep 17 00:00:00 2001 From: Golf Player <> Date: Fri, 12 Jun 2020 22:33:29 -0500 Subject: [PATCH 1/2] Add basic hooks during execution This will enable tracking of execution process without subclassing the way papermill does. --- nbclient/client.py | 58 ++++++++++++++++++++++++++++++++++++++++++---- nbclient/util.py | 15 +++++++++++- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/nbclient/client.py b/nbclient/client.py index 43ff16e0..cdbac0e9 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -26,7 +26,7 @@ DeadKernelError, ) from .output_widget import OutputWidget -from .util import ensure_async, run_sync +from .util import ensure_async, run_sync, run_hook def timestamp(msg: Optional[Dict] = None) -> str: @@ -261,6 +261,45 @@ class NotebookClient(LoggingConfigurable): kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.') + on_execution_start: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + Called after the kernel manager and kernel client are setup, and cells + are about to execute. + Called with kwargs `kernel_id`. + """), + ).tag(config=True) + + on_cell_start: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + A callable which executes before a cell is executed. + Called with kwargs `cell`, and `cell_index`. + """), + ).tag(config=True) + + on_cell_complete: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + A callable which executes after a cell execution is complete. It is + called even when a cell results in a failure. + Called with kwargs `cell`, and `cell_index`. + """), + ).tag(config=True) + + on_cell_error: t.Optional[t.Callable] = Any( + default_value=None, + allow_none=True, + help=dedent(""" + A callable which executes when a cell execution results in an error. + This is executed even if errors are suppressed with `cell_allows_errors`. + Called with kwargs `cell`, and `cell_index`. + """), + ).tag(config=True) + @default('kernel_manager_class') def _kernel_manager_class_default(self) -> KernelManager: """Use a dynamic default to avoid importing jupyter_client at startup""" @@ -442,6 +481,7 @@ async def async_start_new_kernel_client(self) -> KernelClient: await self._async_cleanup_kernel() raise self.kc.allow_stdin = False + run_hook(sself.on_execution_start) return self.kc start_new_kernel_client = run_sync(async_start_new_kernel_client) @@ -745,7 +785,11 @@ def _passed_deadline(self, deadline: int) -> bool: return True return False - def _check_raise_for_error(self, cell: NotebookNode, exec_reply: t.Optional[t.Dict]) -> None: + def _check_raise_for_error( + self, + cell: NotebookNode, + cell_index: int, + exec_reply: t.Optional[t.Dict]) -> None: if exec_reply is None: return None @@ -760,8 +804,10 @@ def _check_raise_for_error(self, cell: NotebookNode, exec_reply: t.Optional[t.Di or "raises-exception" in cell.metadata.get("tags", []) ) - if not cell_allows_errors: - raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content) + if (exec_reply is not None) and exec_reply['content']['status'] == 'error': + run_hook(self.on_cell_error, cell=cell, cell_index=cell_index) + if self.force_raise_errors or not cell_allows_errors: + raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) async def async_execute_cell( self, @@ -821,11 +867,13 @@ async def async_execute_cell( self.allow_errors or "raises-exception" in cell.metadata.get("tags", []) ) + run_hook(self.on_cell_start, cell=cell, cell_index=cell_index) parent_msg_id = await ensure_async( self.kc.execute( cell.source, store_history=store_history, stop_on_error=not cell_allows_errors ) ) + run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index) # We launched a code cell to execute self.code_cells_executed += 1 exec_timeout = self._get_timeout(cell) @@ -859,7 +907,7 @@ async def async_execute_cell( if execution_count: cell['execution_count'] = execution_count - self._check_raise_for_error(cell, exec_reply) + self._check_raise_for_error(cell, cell_index, exec_reply) self.nb['cells'][cell_index] = cell return cell diff --git a/nbclient/util.py b/nbclient/util.py index 1bc83c2e..6221bc2c 100644 --- a/nbclient/util.py +++ b/nbclient/util.py @@ -6,7 +6,8 @@ import asyncio import inspect import sys -from typing import Any, Awaitable, Callable, Union +from typing import Any, Awaitable, Callable, Optional, Union +from functools import partial def check_ipython() -> None: @@ -102,3 +103,15 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any: return result # obj doesn't need to be awaited return obj + + +def run_hook(hook: Optional[Callable], **kwargs) -> None: + if hook is None: + return + if inspect.iscoroutinefunction(hook): + future = hook(**kwargs) + else: + loop = asyncio.get_event_loop() + hook_with_kwargs = partial(hook, **kwargs) + future = loop.run_in_executor(None, hook_with_kwargs) + asyncio.ensure_future(future) From 51c3eceea8ed06df0f4aeff98a664a438cde3a38 Mon Sep 17 00:00:00 2001 From: Devin Tang Date: Thu, 23 Dec 2021 14:14:57 -0800 Subject: [PATCH 2/2] Rebased with master and added tests Run_hook is now async and renamed util to test_util so it gets picked up by pytest. Also added new hooks: on_notebook_error, on_cell_execution Updated docs --- docs/client.rst | 30 ++++ nbclient/client.py | 135 +++++++++++------ nbclient/tests/test_client.py | 175 +++++++++++++++++++++++ nbclient/tests/{util.py => test_util.py} | 18 ++- nbclient/util.py | 13 +- requirements-dev.txt | 1 + 6 files changed, 322 insertions(+), 50 deletions(-) rename nbclient/tests/{util.py => test_util.py} (79%) diff --git a/docs/client.rst b/docs/client.rst index cb6eb156..086bf029 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -96,6 +96,36 @@ on both versions. Here the traitlet ``kernel_name`` helps simplify and maintain consistency: we can just run a notebook twice, specifying first "python2" and then "python3" as the kernel name. +Hooks before and after notebook or cell execution +------------------------------------------------- +There are several configurable hooks that allow the user to execute code before and +after a notebook or a cell is executed. Each one is configured with a function that will be called in its +respective place in the execution pipeline. +Each is described below: + +**Notebook-level hooks**: These hooks are called with a single extra parameter: + +- ``notebook=NotebookNode``: the current notebook being executed. + +Here is the available hooks: + +- ``on_notebook_start`` will run when the notebook client is initialized, before any execution has happened. +- ``on_notebook_complete`` will run when the notebook client has finished executing, after kernel cleanup. +- ``on_notebook_error`` will run when the notebook client has encountered an exception before kernel cleanup. + +**Cell-level hooks**: These hooks are called with two parameters: + +- ``cell=NotebookNode``: a reference to the current cell. +- ``cell_index=int``: the index of the cell in the current notebook's list of cells. + +Here are the available hooks: + +- ``on_cell_start`` will run for all cell types before the cell is executed. +- ``on_cell_execute`` will run right before the code cell is executed. +- ``on_cell_complete`` will run after execution, if the cell is executed with no errors. +- ``on_cell_error`` will run if there is an error during cell execution. + + Handling errors and exceptions ------------------------------ diff --git a/nbclient/client.py b/nbclient/client.py index cdbac0e9..ab7615bd 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -15,7 +15,18 @@ from jupyter_client.client import KernelClient from nbformat import NotebookNode from nbformat.v4 import output_from_msg -from traitlets import Any, Bool, Dict, Enum, Integer, List, Type, Unicode, default +from traitlets import ( + Any, + Bool, + Callable, + Dict, + Enum, + Integer, + List, + Type, + Unicode, + default, +) from traitlets.config.configurable import LoggingConfigurable from .exceptions import ( @@ -26,7 +37,7 @@ DeadKernelError, ) from .output_widget import OutputWidget -from .util import ensure_async, run_sync, run_hook +from .util import ensure_async, run_hook, run_sync def timestamp(msg: Optional[Dict] = None) -> str: @@ -261,43 +272,85 @@ class NotebookClient(LoggingConfigurable): kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.') - on_execution_start: t.Optional[t.Callable] = Any( + on_notebook_start: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - Called after the kernel manager and kernel client are setup, and cells - are about to execute. - Called with kwargs `kernel_id`. - """), + help=dedent( + """ + A callable which executes after the kernel manager and kernel client are setup, and + cells are about to execute. + Called with kwargs `notebook`. + """ + ), ).tag(config=True) - on_cell_start: t.Optional[t.Callable] = Any( + on_notebook_complete: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - A callable which executes before a cell is executed. - Called with kwargs `cell`, and `cell_index`. - """), + help=dedent( + """ + A callable which executes after the kernel is cleaned up. + Called with kwargs `notebook`. + """ + ), ).tag(config=True) - on_cell_complete: t.Optional[t.Callable] = Any( + on_notebook_error: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - A callable which executes after a cell execution is complete. It is - called even when a cell results in a failure. - Called with kwargs `cell`, and `cell_index`. - """), + help=dedent( + """ + A callable which executes when the notebook encounters an error. + Called with kwargs `notebook`. + """ + ), ).tag(config=True) - on_cell_error: t.Optional[t.Callable] = Any( + on_cell_start: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - A callable which executes when a cell execution results in an error. - This is executed even if errors are suppressed with `cell_allows_errors`. - Called with kwargs `cell`, and `cell_index`. - """), + help=dedent( + """ + A callable which executes before a cell is executed and before non-executing cells + are skipped. + Called with kwargs `cell` and `cell_index`. + """ + ), + ).tag(config=True) + + on_cell_execute: t.Optional[t.Callable] = Callable( + default_value=None, + allow_none=True, + help=dedent( + """ + A callable which executes just before a code cell is executed. + Called with kwargs `cell` and `cell_index`. + """ + ), + ).tag(config=True) + + on_cell_complete: t.Optional[t.Callable] = Callable( + default_value=None, + allow_none=True, + help=dedent( + """ + A callable which executes after a cell execution is complete. It is + called even when a cell results in a failure. + Called with kwargs `cell` and `cell_index`. + """ + ), + ).tag(config=True) + + on_cell_error: t.Optional[t.Callable] = Callable( + default_value=None, + allow_none=True, + help=dedent( + """ + A callable which executes when a cell execution results in an error. + This is executed even if errors are suppressed with `cell_allows_errors`. + Called with kwargs `cell` and `cell_index`. + """ + ), ).tag(config=True) @default('kernel_manager_class') @@ -481,7 +534,7 @@ async def async_start_new_kernel_client(self) -> KernelClient: await self._async_cleanup_kernel() raise self.kc.allow_stdin = False - run_hook(sself.on_execution_start) + await run_hook(self.on_notebook_start, notebook=self.nb) return self.kc start_new_kernel_client = run_sync(async_start_new_kernel_client) @@ -553,10 +606,13 @@ def on_signal(): await self.async_start_new_kernel_client() try: yield + except RuntimeError as e: + await run_hook(self.on_notebook_error, notebook=self.nb) + raise e finally: if cleanup_kc: await self._async_cleanup_kernel() - + await run_hook(self.on_notebook_complete, notebook=self.nb) atexit.unregister(self._cleanup_kernel) try: loop.remove_signal_handler(signal.SIGINT) @@ -785,11 +841,9 @@ def _passed_deadline(self, deadline: int) -> bool: return True return False - def _check_raise_for_error( - self, - cell: NotebookNode, - cell_index: int, - exec_reply: t.Optional[t.Dict]) -> None: + async def _check_raise_for_error( + self, cell: NotebookNode, cell_index: int, exec_reply: t.Optional[t.Dict] + ) -> None: if exec_reply is None: return None @@ -803,11 +857,9 @@ def _check_raise_for_error( or exec_reply_content.get('ename') in self.allow_error_names or "raises-exception" in cell.metadata.get("tags", []) ) - - if (exec_reply is not None) and exec_reply['content']['status'] == 'error': - run_hook(self.on_cell_error, cell=cell, cell_index=cell_index) - if self.force_raise_errors or not cell_allows_errors: - raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) + await run_hook(self.on_cell_error, cell=cell, cell_index=cell_index) + if not cell_allows_errors: + raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content) async def async_execute_cell( self, @@ -850,6 +902,9 @@ async def async_execute_cell( The cell which was just processed. """ assert self.kc is not None + + await run_hook(self.on_cell_start, cell=cell, cell_index=cell_index) + if cell.cell_type != 'code' or not cell.source.strip(): self.log.debug("Skipping non-executing cell %s", cell_index) return cell @@ -867,13 +922,13 @@ async def async_execute_cell( self.allow_errors or "raises-exception" in cell.metadata.get("tags", []) ) - run_hook(self.on_cell_start, cell=cell, cell_index=cell_index) + await run_hook(self.on_cell_execute, cell=cell, cell_index=cell_index) parent_msg_id = await ensure_async( self.kc.execute( cell.source, store_history=store_history, stop_on_error=not cell_allows_errors ) ) - run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index) + await run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index) # We launched a code cell to execute self.code_cells_executed += 1 exec_timeout = self._get_timeout(cell) @@ -907,7 +962,7 @@ async def async_execute_cell( if execution_count: cell['execution_count'] = execution_count - self._check_raise_for_error(cell, cell_index, exec_reply) + await self._check_raise_for_error(cell, cell_index, exec_reply) self.nb['cells'][cell_index] = cell return cell diff --git a/nbclient/tests/test_client.py b/nbclient/tests/test_client.py index 1fde7807..bf777802 100644 --- a/nbclient/tests/test_client.py +++ b/nbclient/tests/test_client.py @@ -33,6 +33,16 @@ # see: https://github.com/ipython/ipython/blob/master/docs/source/whatsnew/version8.rst#traceback-improvements # noqa ipython8_input_pat = re.compile(r'Input In \[\d+\],') +hook_methods = [ + "on_cell_start", + "on_cell_execute", + "on_cell_complete", + "on_cell_error", + "on_notebook_start", + "on_notebook_complete", + "on_notebook_error", +] + class AsyncMock(Mock): pass @@ -741,6 +751,82 @@ def test_widgets(self): assert 'version_major' in wdata assert 'version_minor' in wdata + def test_execution_hook(self): + filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hooks = [MagicMock() for i in range(7)] + executor = NotebookClient(input_nb) + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + executor.execute() + for hook in hooks[:3]: + hook.assert_called_once() + hooks[3].assert_not_called() + for hook in hooks[4:6]: + hook.assert_called_once() + hooks[6].assert_not_called() + + def test_error_execution_hook_error(self): + filename = os.path.join(current_dir, 'files', 'Error.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hooks = [MagicMock() for i in range(7)] + executor = NotebookClient(input_nb) + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + with pytest.raises(CellExecutionError): + executor.execute() + for hook in hooks[:5]: + hook.assert_called_once() + hooks[6].assert_not_called() + + def test_error_notebook_hook(self): + filename = os.path.join(current_dir, 'files', 'Autokill.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hooks = [MagicMock() for i in range(7)] + executor = NotebookClient(input_nb) + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + with pytest.raises(RuntimeError): + executor.execute() + for hook in hooks[:3]: + hook.assert_called_once() + hooks[3].assert_not_called() + for hook in hooks[4:]: + hook.assert_called_once() + + def test_async_execution_hook(self): + filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hooks = [AsyncMock() for i in range(7)] + executor = NotebookClient(input_nb) + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + executor.execute() + for hook in hooks[:3]: + hook.assert_called_once() + hooks[3].assert_not_called() + for hook in hooks[4:6]: + hook.assert_called_once() + hooks[6].assert_not_called() + + def test_error_async_execution_hook(self): + filename = os.path.join(current_dir, 'files', 'Error.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hooks = [AsyncMock() for i in range(7)] + executor = NotebookClient(input_nb) + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + with pytest.raises(CellExecutionError): + executor.execute().execute() + for hook in hooks[:5]: + hook.assert_called_once() + hooks[6].assert_not_called() + class TestRunCell(NBClientTestsBase): """Contains test functions for NotebookClient.execute_cell""" @@ -1524,3 +1610,92 @@ def test_no_source(self, executor, cell_mock, message_mock): assert message_mock.call_count == 0 # Should also consume the message stream assert cell_mock.outputs == [] + + @prepare_cell_mocks() + def test_cell_hooks(self, executor, cell_mock, message_mock): + hooks = [MagicMock() for i in range(7)] + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + executor.execute_cell(cell_mock, 0) + for hook in hooks[:3]: + hook.assert_called_once_with(cell=cell_mock, cell_index=0) + for hook in hooks[4:]: + hook.assert_not_called() + + @prepare_cell_mocks( + { + 'msg_type': 'error', + 'header': {'msg_type': 'error'}, + 'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']}, + }, + reply_msg={ + 'msg_type': 'execute_reply', + 'header': {'msg_type': 'execute_reply'}, + # ERROR + 'content': {'status': 'error'}, + }, + ) + def test_error_cell_hooks(self, executor, cell_mock, message_mock): + hooks = [MagicMock() for i in range(7)] + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + with self.assertRaises(CellExecutionError): + executor.execute_cell(cell_mock, 0) + for hook in hooks[:4]: + hook.assert_called_once_with(cell=cell_mock, cell_index=0) + for hook in hooks[5:]: + hook.assert_not_called() + + @prepare_cell_mocks( + reply_msg={ + 'msg_type': 'execute_reply', + 'header': {'msg_type': 'execute_reply'}, + # ERROR + 'content': {'status': 'error'}, + } + ) + def test_non_code_cell_hooks(self, executor, cell_mock, message_mock): + cell_mock = NotebookNode(source='"foo" = "bar"', metadata={}, cell_type='raw', outputs=[]) + hooks = [MagicMock() for i in range(7)] + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + executor.execute_cell(cell_mock, 0) + for hook in hooks[:1]: + hook.assert_called_once_with(cell=cell_mock, cell_index=0) + for hook in hooks[1:]: + hook.assert_not_called() + + @prepare_cell_mocks() + def test_async_cell_hooks(self, executor, cell_mock, message_mock): + hooks = [AsyncMock() for i in range(7)] + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + executor.execute_cell(cell_mock, 0) + for hook in hooks[:3]: + hook.assert_called_once_with(cell=cell_mock, cell_index=0) + for hook in hooks[4:]: + hook.assert_not_called() + + @prepare_cell_mocks( + { + 'msg_type': 'error', + 'header': {'msg_type': 'error'}, + 'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']}, + }, + reply_msg={ + 'msg_type': 'execute_reply', + 'header': {'msg_type': 'execute_reply'}, + # ERROR + 'content': {'status': 'error'}, + }, + ) + def test_error_async_cell_hooks(self, executor, cell_mock, message_mock): + hooks = [AsyncMock() for i in range(7)] + for executor_hook, hook in zip(hook_methods, hooks): + setattr(executor, executor_hook, hook) + with self.assertRaises(CellExecutionError): + executor.execute_cell(cell_mock, 0) + for hook in hooks[:4]: + hook.assert_called_once_with(cell=cell_mock, cell_index=0) + for hook in hooks[4:]: + hook.assert_not_called() diff --git a/nbclient/tests/util.py b/nbclient/tests/test_util.py similarity index 79% rename from nbclient/tests/util.py rename to nbclient/tests/test_util.py index 2e864b85..a40e48d2 100644 --- a/nbclient/tests/util.py +++ b/nbclient/tests/test_util.py @@ -1,8 +1,10 @@ import asyncio +from unittest.mock import MagicMock +import pytest import tornado -from nbclient.util import run_sync +from nbclient.util import run_hook, run_sync @run_sync @@ -55,3 +57,17 @@ async def run(): assert some_sync_function() == 42 ioloop.run_sync(run) + + +@pytest.mark.asyncio +async def test_run_hook_sync(): + some_sync_function = MagicMock() + await run_hook(some_sync_function) + assert some_sync_function.call_count == 1 + + +@pytest.mark.asyncio +async def test_run_hook_async(): + hook = MagicMock(return_value=some_async_function()) + await run_hook(hook) + assert hook.call_count == 1 diff --git a/nbclient/util.py b/nbclient/util.py index 6221bc2c..45901a6a 100644 --- a/nbclient/util.py +++ b/nbclient/util.py @@ -7,7 +7,6 @@ import inspect import sys from typing import Any, Awaitable, Callable, Optional, Union -from functools import partial def check_ipython() -> None: @@ -105,13 +104,9 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any: return obj -def run_hook(hook: Optional[Callable], **kwargs) -> None: +async def run_hook(hook: Optional[Callable], **kwargs) -> None: if hook is None: return - if inspect.iscoroutinefunction(hook): - future = hook(**kwargs) - else: - loop = asyncio.get_event_loop() - hook_with_kwargs = partial(hook, **kwargs) - future = loop.run_in_executor(None, hook_with_kwargs) - asyncio.ensure_future(future) + res = hook(**kwargs) + if inspect.isawaitable(res): + await res diff --git a/requirements-dev.txt b/requirements-dev.txt index 941626b5..556e4ae7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ ipython ipykernel ipywidgets<8.0.0 pytest>=4.1 +pytest-asyncio pytest-cov>=2.6.1 check-manifest flake8