Skip to content

Commit

Permalink
Merge pull request #900 from maartenbreddels/widget_state
Browse files Browse the repository at this point in the history
Embed widget state in notebook on execute
  • Loading branch information
MSeal authored Mar 27, 2019
2 parents 914be01 + 9877c9e commit 1fc3968
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 3 deletions.
61 changes: 59 additions & 2 deletions nbconvert/preprocessors/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.

import base64
from textwrap import dedent
from contextlib import contextmanager

Expand Down Expand Up @@ -172,6 +172,15 @@ class ExecutePreprocessor(Preprocessor):
)
).tag(config=True)

store_widget_state = Bool(True,
help=dedent(
"""
If `True` (default), then the state of the Jupyter widgets created
at the kernel will be stored in the metadata of the notebook.
"""
)
).tag(config=True)

iopub_timeout = Integer(4, allow_none=False,
help=dedent(
"""
Expand Down Expand Up @@ -292,6 +301,8 @@ def setup_preprocessor(self, nb, resources, km=None):
self.nb = nb
# clear display_id map
self._display_id_map = {}
self.widget_state = {}
self.widget_buffers = {}

if km is None:
self.km, self.kc = self.start_new_kernel(cwd=path)
Expand Down Expand Up @@ -354,9 +365,27 @@ def preprocess(self, nb, resources, km=None):
nb, resources = super(ExecutePreprocessor, self).preprocess(nb, resources)
info_msg = self._wait_for_reply(self.kc.kernel_info())
nb.metadata['language_info'] = info_msg['content']['language_info']
self.set_widgets_metadata()

return nb, resources

def set_widgets_metadata(self):
if self.widget_state:
self.nb.metadata.widgets = {
'application/vnd.jupyter.widget-state+json': {
'state': {
model_id: _serialize_widget_state(state)
for model_id, state in self.widget_state.items() if '_model_name' in state
},
'version_major': 2,
'version_minor': 0,
}
}
for key, widget in self.nb.metadata.widgets['application/vnd.jupyter.widget-state+json']['state'].items():
buffers = self.widget_buffers.get(key)
if buffers:
widget['buffers'] = buffers

def preprocess_cell(self, cell, resources, cell_index):
"""
Executes a single code cell. See base.py for details.
Expand Down Expand Up @@ -550,7 +579,12 @@ def clear_display_id_mapping(self, cell_index):
cell_map[cell_index] = []

def handle_comm_msg(self, outs, msg, cell_index):
pass
content = msg['content']
data = content['data']
if self.store_widget_state and 'state' in data: # ignore custom msg'es
self.widget_state.setdefault(content['comm_id'], {}).update(data['state'])
if 'buffer_paths' in data and data['buffer_paths']:
self.widget_buffers[content['comm_id']] = _get_buffer_data(msg)

def executenb(nb, cwd=None, km=None, **kwargs):
"""Execute a notebook's code, updating outputs within the notebook object.
Expand All @@ -574,3 +608,26 @@ def executenb(nb, cwd=None, km=None, **kwargs):
resources['metadata'] = {'path': cwd}
ep = ExecutePreprocessor(**kwargs)
return ep.preprocess(nb, resources, km=km)[0]


def _serialize_widget_state(state):
"""Serialize a widget state, following format in @jupyter-widgets/schema."""
return {
'model_name': state.get('_model_name'),
'model_module': state.get('_model_module'),
'model_module_version': state.get('_model_module_version'),
'state': state,
}


def _get_buffer_data(msg):
encoded_buffers = []
paths = msg['content']['data']['buffer_paths']
buffers = msg['buffers']
for path, buffer in zip(paths, buffers):
encoded_buffers.append({
'data': base64.b64encode(buffer).decode('utf-8'),
'encoding': 'base64',
'path': path
})
return encoded_buffers
94 changes: 94 additions & 0 deletions nbconvert/preprocessors/tests/files/JupyterWidgets.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f46f26da84b54255bccc3a69d7eb08de",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Label(value='Hello World')"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import ipywidgets\n",
"label = ipywidgets.Label('Hello World')\n",
"label"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# it should also handle custom msg'es\n",
"label.send({'msg': 'Hello'})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {
"8273e8fe9d9941a4a63c062158e0a630": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.4.0",
"model_name": "DescriptionStyleModel",
"state": {
"description_width": ""
}
},
"a72770a4f541425f8fe85833a3dc2a8e": {
"model_module": "@jupyter-widgets/controls",
"model_module_version": "1.4.0",
"model_name": "LabelModel",
"state": {
"context_menu": null,
"layout": "IPY_MODEL_dec20f599109458ca607b1df5959469b",
"style": "IPY_MODEL_8273e8fe9d9941a4a63c062158e0a630",
"value": "Hello World"
}
},
"dec20f599109458ca607b1df5959469b": {
"model_module": "@jupyter-widgets/base",
"model_module_version": "1.1.0",
"model_name": "LayoutModel",
"state": {}
}
},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
35 changes: 35 additions & 0 deletions nbconvert/preprocessors/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ def normalize_output(output):
if 'text/plain' in output.get('data', {}):
output['data']['text/plain'] = \
re.sub(addr_pat, '<HEXADDR>', output['data']['text/plain'])
if 'application/vnd.jupyter.widget-view+json' in output.get('data', {}):
output['data']['application/vnd.jupyter.widget-view+json'] \
['model_id'] = '<MODEL_ID>'
for key, value in output.get('data', {}).items():
if isinstance(value, string_types):
if sys.version_info.major == 2:
value = value.replace('u\'', '\'')
output['data'][key] = _normalize_base64(value)
if 'traceback' in output:
tb = [
Expand Down Expand Up @@ -305,3 +310,33 @@ def test_execute_function(self):
original = copy.deepcopy(input_nb)
executed = executenb(original, os.path.dirname(filename))
self.assert_notebooks_equal(original, executed)

def test_widgets(self):
"""Runs a test notebook with widgets and checks the widget state is saved."""
input_file = os.path.join(current_dir, 'files', 'JupyterWidgets.ipynb')
opts = dict(kernel_name="python")
res = self.build_resources()
res['metadata']['path'] = os.path.dirname(input_file)
input_nb, output_nb = self.run_notebook(input_file, opts, res)

output_data = [
output.get('data', {})
for cell in output_nb['cells']
for output in cell['outputs']
]

model_ids = [
data['application/vnd.jupyter.widget-view+json']['model_id']
for data in output_data
if 'application/vnd.jupyter.widget-view+json' in data
]

wdata = output_nb['metadata']['widgets'] \
['application/vnd.jupyter.widget-state+json']
for k in model_ids:
d = wdata['state'][k]
assert 'model_name' in d
assert 'model_module' in d
assert 'state' in d
assert 'version_major' in wdata
assert 'version_minor' in wdata
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def run(self):
jupyter_client_req = 'jupyter_client>=4.2'

extra_requirements = {
'test': ['pytest', 'pytest-cov', 'ipykernel', jupyter_client_req],
'test': ['pytest', 'pytest-cov', 'ipykernel', jupyter_client_req, 'ipywidgets>=7'],
'serve': ['tornado>=4.0'],
'execute': [jupyter_client_req],
'docs': ['sphinx>=1.5.1',
Expand Down

0 comments on commit 1fc3968

Please sign in to comment.