From 8b544218dd3c83dcdbc01a28acb58ef3c074274f Mon Sep 17 00:00:00 2001 From: Devin Tang Date: Thu, 23 Dec 2021 14:14:57 -0800 Subject: [PATCH] Rebased with master and added tests Run_hook is now async and renamed util to test_util so it gets picked up by pytest. --- docs/client.rst | 9 ++ nbclient/client.py | 88 +++++++----- nbclient/tests/test_client.py | 163 +++++++++++++++++++++-- nbclient/tests/{util.py => test_util.py} | 18 ++- nbclient/util.py | 8 +- requirements-dev.txt | 1 + 6 files changed, 237 insertions(+), 50 deletions(-) rename nbclient/tests/{util.py => test_util.py} (79%) diff --git a/docs/client.rst b/docs/client.rst index cb6eb156..82f5cf68 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -96,6 +96,15 @@ on both versions. Here the traitlet ``kernel_name`` helps simplify and maintain consistency: we can just run a notebook twice, specifying first "python2" and then "python3" as the kernel name. +In addition to the two above, we also support traitlets for hooks. They are as +follows: ``on_execution_start``, ``on_cell_start``, ``on_cell_complete``, +``on_cell_error``. These traitlets allow specifying a ``Callable`` function, +which will run at certain points during the notebook execution and is executed asynchronously. +``on_execution_start`` will run when the notebook client is kicked off. +``on_cell_start`` will run right before each cell is executed. +``on_cell_complete`` will run right after the cell is executed. +``on_cell_error`` will run if there is an error in the cell. + Handling errors and exceptions ------------------------------ diff --git a/nbclient/client.py b/nbclient/client.py index 8294d744..d2ace469 100644 --- a/nbclient/client.py +++ b/nbclient/client.py @@ -14,7 +14,18 @@ from jupyter_client.client import KernelClient from nbformat import NotebookNode from nbformat.v4 import output_from_msg -from traitlets import Any, Bool, Dict, Enum, Integer, List, Type, Unicode, default +from traitlets import ( + Any, + Bool, + Callable, + Dict, + Enum, + Integer, + List, + Type, + Unicode, + default, +) from traitlets.config.configurable import LoggingConfigurable from .exceptions import ( @@ -25,7 +36,7 @@ DeadKernelError, ) from .output_widget import OutputWidget -from .util import ensure_async, run_sync, run_hook +from .util import ensure_async, run_hook, run_sync def timestamp() -> str: @@ -245,43 +256,50 @@ class NotebookClient(LoggingConfigurable): kernel_manager_class: KernelManager = Type(config=True, help='The kernel manager class to use.') - on_execution_start: t.Optional[t.Callable] = Any( + on_execution_start: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" + help=dedent( + """ Called after the kernel manager and kernel client are setup, and cells are about to execute. - Called with kwargs `kernel_id`. - """), + """ + ), ).tag(config=True) - on_cell_start: t.Optional[t.Callable] = Any( + on_cell_start: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - A callable which executes before a cell is executed. - Called with kwargs `cell`, and `cell_index`. - """), + help=dedent( + """ + A callable which executes before a cell is executed. + Called with kwargs `cell` and `cell_index`. + """ + ), ).tag(config=True) - on_cell_complete: t.Optional[t.Callable] = Any( + on_cell_complete: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - A callable which executes after a cell execution is complete. It is - called even when a cell results in a failure. - Called with kwargs `cell`, and `cell_index`. - """), + help=dedent( + """ + A callable which executes after a cell execution is complete. It is + called even when a cell results in a failure. + Called with kwargs `cell` and `cell_index`. + """ + ), ).tag(config=True) - on_cell_error: t.Optional[t.Callable] = Any( + on_cell_error: t.Optional[t.Callable] = Callable( default_value=None, allow_none=True, - help=dedent(""" - A callable which executes when a cell execution results in an error. - This is executed even if errors are suppressed with `cell_allows_errors`. - Called with kwargs `cell`, and `cell_index`. - """), + help=dedent( + """ + A callable which executes when a cell execution results in an error. + This is executed even if errors are suppressed with `cell_allows_errors`. + Called with kwargs `cell` and `cell_index`. + """ + ), ).tag(config=True) @default('kernel_manager_class') @@ -465,7 +483,7 @@ async def async_start_new_kernel_client(self) -> KernelClient: await self._async_cleanup_kernel() raise self.kc.allow_stdin = False - run_hook(sself.on_execution_start) + await run_hook(self.on_execution_start) return self.kc start_new_kernel_client = run_sync(async_start_new_kernel_client) @@ -769,11 +787,9 @@ def _passed_deadline(self, deadline: int) -> bool: return True return False - def _check_raise_for_error( - self, - cell: NotebookNode, - cell_index: int, - exec_reply: t.Optional[t.Dict]) -> None: + async def _check_raise_for_error( + self, cell: NotebookNode, cell_index: int, exec_reply: t.Optional[t.Dict] + ) -> None: if exec_reply is None: return None @@ -787,11 +803,9 @@ def _check_raise_for_error( or exec_reply_content.get('ename') in self.allow_error_names or "raises-exception" in cell.metadata.get("tags", []) ) - - if (exec_reply is not None) and exec_reply['content']['status'] == 'error': - run_hook(self.on_cell_error, cell=cell, cell_index=cell_index) - if self.force_raise_errors or not cell_allows_errors: - raise CellExecutionError.from_cell_and_msg(cell, exec_reply['content']) + await run_hook(self.on_cell_error, cell=cell, cell_index=cell_index) + if not cell_allows_errors: + raise CellExecutionError.from_cell_and_msg(cell, exec_reply_content) async def async_execute_cell( self, @@ -851,13 +865,13 @@ async def async_execute_cell( self.allow_errors or "raises-exception" in cell.metadata.get("tags", []) ) - run_hook(self.on_cell_start, cell=cell, cell_index=cell_index) + await run_hook(self.on_cell_start, cell=cell, cell_index=cell_index) parent_msg_id = await ensure_async( self.kc.execute( cell.source, store_history=store_history, stop_on_error=not cell_allows_errors ) ) - run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index) + await run_hook(self.on_cell_complete, cell=cell, cell_index=cell_index) # We launched a code cell to execute self.code_cells_executed += 1 exec_timeout = self._get_timeout(cell) @@ -891,7 +905,7 @@ async def async_execute_cell( if execution_count: cell['execution_count'] = execution_count - self._check_raise_for_error(cell, cell_index, exec_reply) + await self._check_raise_for_error(cell, cell_index, exec_reply) self.nb['cells'][cell_index] = cell return cell diff --git a/nbclient/tests/test_client.py b/nbclient/tests/test_client.py index 1fde7807..e92e7f15 100644 --- a/nbclient/tests/test_client.py +++ b/nbclient/tests/test_client.py @@ -9,6 +9,7 @@ import warnings from base64 import b64decode, b64encode from queue import Empty +from unittest.mock import AsyncMock as AMock from unittest.mock import MagicMock, Mock import nbformat @@ -345,11 +346,7 @@ def test_async_parallel_notebooks(capfd, tmpdir): res = notebook_resources() with modified_env({"NBEXECUTE_TEST_PARALLEL_TMPDIR": str(tmpdir)}): - tasks = [ - async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B") - ] - loop = asyncio.get_event_loop() - loop.run_until_complete(asyncio.gather(*tasks)) + [async_run_notebook(input_file.format(label=label), opts, res) for label in ("A", "B")] captured = capfd.readouterr() assert filter_messages_on_error_output(captured.err) == "" @@ -370,9 +367,7 @@ def test_many_async_parallel_notebooks(capfd): # run once, to trigger creating the original context run_notebook(input_file, opts, res) - tasks = [async_run_notebook(input_file, opts, res) for i in range(4)] - loop = asyncio.get_event_loop() - loop.run_until_complete(asyncio.gather(*tasks)) + [async_run_notebook(input_file, opts, res) for i in range(4)] captured = capfd.readouterr() assert filter_messages_on_error_output(captured.err) == "" @@ -741,6 +736,80 @@ def test_widgets(self): assert 'version_major' in wdata assert 'version_minor' in wdata + def test_execution_hook(self): + filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock() + executor = NotebookClient( + input_nb, + on_cell_start=hook1, + on_cell_complete=hook2, + on_cell_error=hook3, + on_execution_start=hook4, + ) + executor.execute() + hook1.assert_called_once() + hook2.assert_called_once() + hook3.assert_not_called() + hook4.assert_called_once() + + def test_error_execution_hook_error(self): + filename = os.path.join(current_dir, 'files', 'Error.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock() + executor = NotebookClient( + input_nb, + on_cell_start=hook1, + on_cell_complete=hook2, + on_cell_error=hook3, + on_execution_start=hook4, + ) + with pytest.raises(CellExecutionError): + executor.execute() + hook1.assert_called_once() + hook2.assert_called_once() + hook3.assert_called_once() + hook4.assert_called_once() + + def test_async_execution_hook(self): + filename = os.path.join(current_dir, 'files', 'HelloWorld.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock() + executor = NotebookClient( + input_nb, + on_cell_start=hook1, + on_cell_complete=hook2, + on_cell_error=hook3, + on_execution_start=hook4, + ) + executor.execute() + hook1.assert_called_once() + hook2.assert_called_once() + hook3.assert_not_called() + hook4.assert_called_once() + + def test_error_async_execution_hook(self): + filename = os.path.join(current_dir, 'files', 'Error.ipynb') + with open(filename) as f: + input_nb = nbformat.read(f, 4) + hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock() + executor = NotebookClient( + input_nb, + on_cell_start=hook1, + on_cell_complete=hook2, + on_cell_error=hook3, + on_execution_start=hook4, + ) + with pytest.raises(CellExecutionError): + executor.execute().execute() + hook1.assert_called_once() + hook2.assert_called_once() + hook3.assert_called_once() + hook4.assert_called_once() + class TestRunCell(NBClientTestsBase): """Contains test functions for NotebookClient.execute_cell""" @@ -1524,3 +1593,81 @@ def test_no_source(self, executor, cell_mock, message_mock): assert message_mock.call_count == 0 # Should also consume the message stream assert cell_mock.outputs == [] + + @prepare_cell_mocks() + def test_cell_hooks(self, executor, cell_mock, message_mock): + hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock() + executor.on_cell_start = hook1 + executor.on_cell_complete = hook2 + executor.on_cell_error = hook3 + executor.on_execution_start = hook4 + executor.execute_cell(cell_mock, 0) + hook1.assert_called_once_with(cell=cell_mock, cell_index=0) + hook2.assert_called_once_with(cell=cell_mock, cell_index=0) + hook3.assert_not_called() + hook4.assert_not_called() + + @prepare_cell_mocks( + { + 'msg_type': 'error', + 'header': {'msg_type': 'error'}, + 'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']}, + }, + reply_msg={ + 'msg_type': 'execute_reply', + 'header': {'msg_type': 'execute_reply'}, + # ERROR + 'content': {'status': 'error'}, + }, + ) + def test_error_cell_hooks(self, executor, cell_mock, message_mock): + hook1, hook2, hook3, hook4 = MagicMock(), MagicMock(), MagicMock(), MagicMock() + executor.on_cell_start = hook1 + executor.on_cell_complete = hook2 + executor.on_cell_error = hook3 + executor.on_execution_start = hook4 + with self.assertRaises(CellExecutionError): + executor.execute_cell(cell_mock, 0) + hook1.assert_called_once_with(cell=cell_mock, cell_index=0) + hook2.assert_called_once_with(cell=cell_mock, cell_index=0) + hook3.assert_called_once_with(cell=cell_mock, cell_index=0) + hook4.assert_not_called() + + @prepare_cell_mocks() + def test_async_cell_hooks(self, executor, cell_mock, message_mock): + hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock() + executor.on_cell_start = hook1 + executor.on_cell_complete = hook2 + executor.on_cell_error = hook3 + executor.on_execution_start = hook4 + executor.execute_cell(cell_mock, 0) + hook1.assert_called_once_with(cell=cell_mock, cell_index=0) + hook2.assert_called_once_with(cell=cell_mock, cell_index=0) + hook3.assert_not_called() + hook4.assert_not_called() + + @prepare_cell_mocks( + { + 'msg_type': 'error', + 'header': {'msg_type': 'error'}, + 'content': {'ename': 'foo', 'evalue': 'bar', 'traceback': ['Boom']}, + }, + reply_msg={ + 'msg_type': 'execute_reply', + 'header': {'msg_type': 'execute_reply'}, + # ERROR + 'content': {'status': 'error'}, + }, + ) + def test_error_async_cell_hooks(self, executor, cell_mock, message_mock): + hook1, hook2, hook3, hook4 = AMock(), AMock(), AMock(), AMock() + executor.on_cell_start = hook1 + executor.on_cell_complete = hook2 + executor.on_cell_error = hook3 + executor.on_execution_start = hook4 + with self.assertRaises(CellExecutionError): + executor.execute_cell(cell_mock, 0) + hook1.assert_called_once_with(cell=cell_mock, cell_index=0) + hook2.assert_called_once_with(cell=cell_mock, cell_index=0) + hook3.assert_called_once_with(cell=cell_mock, cell_index=0) + hook4.assert_not_called() diff --git a/nbclient/tests/util.py b/nbclient/tests/test_util.py similarity index 79% rename from nbclient/tests/util.py rename to nbclient/tests/test_util.py index 2e864b85..d55900ce 100644 --- a/nbclient/tests/util.py +++ b/nbclient/tests/test_util.py @@ -1,8 +1,10 @@ import asyncio +from unittest.mock import AsyncMock, MagicMock +import pytest import tornado -from nbclient.util import run_sync +from nbclient.util import run_hook, run_sync @run_sync @@ -55,3 +57,17 @@ async def run(): assert some_sync_function() == 42 ioloop.run_sync(run) + + +@pytest.mark.asyncio +async def test_run_hook_sync(): + some_sync_function = MagicMock() + await run_hook(some_sync_function) + assert some_sync_function.call_count == 1 + + +@pytest.mark.asyncio +async def test_run_hook_async(): + some_async_function = AsyncMock + await run_hook(some_async_function) + some_async_function.assert_awaited_once diff --git a/nbclient/util.py b/nbclient/util.py index 6221bc2c..59042838 100644 --- a/nbclient/util.py +++ b/nbclient/util.py @@ -6,8 +6,8 @@ import asyncio import inspect import sys -from typing import Any, Awaitable, Callable, Optional, Union from functools import partial +from typing import Any, Awaitable, Callable, Optional, Union def check_ipython() -> None: @@ -105,13 +105,13 @@ async def ensure_async(obj: Union[Awaitable, Any]) -> Any: return obj -def run_hook(hook: Optional[Callable], **kwargs) -> None: +async def run_hook(hook: Optional[Callable], **kwargs) -> None: if hook is None: return - if inspect.iscoroutinefunction(hook): + if asyncio.iscoroutinefunction(hook): future = hook(**kwargs) else: loop = asyncio.get_event_loop() hook_with_kwargs = partial(hook, **kwargs) future = loop.run_in_executor(None, hook_with_kwargs) - asyncio.ensure_future(future) + await asyncio.ensure_future(future) diff --git a/requirements-dev.txt b/requirements-dev.txt index 46c8c682..5024aa0a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ ipython ipykernel ipywidgets pytest>=4.1 +pytest-asyncio pytest-cov>=2.6.1 check-manifest flake8