Skip to content

Commit

Permalink
feat: Hoist modules
Browse files Browse the repository at this point in the history
  • Loading branch information
manzt committed Jul 24, 2024
1 parent e1d230b commit 3d9e9e8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 21 deletions.
9 changes: 2 additions & 7 deletions anywidget/_static_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import pathlib
import typing

import traitlets as t

from anywidget._file_contents import VirtualFileContents

from ._descriptor import open_comm
Expand Down Expand Up @@ -56,8 +54,5 @@ def __del__(self) -> None:
"""Close the comm when the asset is deleted."""
self._comm.close()

def as_traittype(self) -> t.TraitType:
"""Return a traitlet that represents the asset."""
return t.Instance(StaticAsset, default_value=self).tag(
sync=True, to_json=lambda *_: "anywidget-static-asset:" + self._comm.comm_id
)
def serialize(self) -> str:
return f"anywidget-static-asset:{self._comm.comm_id}"
59 changes: 45 additions & 14 deletions anywidget/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import typing
from contextlib import contextmanager

import ipywidgets
import traitlets.traitlets as t
Expand Down Expand Up @@ -36,30 +37,29 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
if in_colab():
enable_custom_widget_manager_once()

anywidget_traits = {}
self._anywidget_internal_state = {}
for key in (_ESM_KEY, _CSS_KEY):
if hasattr(self, key) and not self.has_trait(key):
self._anywidget_internal_state[key] = getattr(self, key)

if not hasattr(self, _ESM_KEY):
anywidget_traits[_ESM_KEY] = StaticAsset(_DEFAULT_ESM).as_traittype()
self._anywidget_internal_state[_ESM_KEY] = _DEFAULT_ESM

self._anywidget_internal_state[_ANYWIDGET_ID_KEY] = _id_for(self)

# TODO: a better way to uniquely identify this subclasses?
# We use the fully-qualified name to get an id which we
# can use to update CSS if necessary.
anywidget_traits[_ANYWIDGET_ID_KEY] = t.Unicode(
f"{self.__class__.__module__}.{self.__class__.__name__}"
).tag(sync=True)
with _patch_get_state(self, self._anywidget_internal_state):
super().__init__(*args, **kwargs)

self.add_traits(**anywidget_traits)
super().__init__(*args, **kwargs)
_register_anywidget_commands(self)

def __init_subclass__(cls, **kwargs: dict) -> None:
"""Coerces _esm and _css to FileContents if they are files."""
super().__init_subclass__(**kwargs)
for key in (_ESM_KEY, _CSS_KEY) & cls.__dict__.keys():
# TODO: Upgrate to := when we drop Python 3.7
value = getattr(cls, key)
if isinstance(value, t.TraitType):
# we don't know how to handle this
continue
setattr(cls, key, StaticAsset(value).as_traittype())
if not isinstance(value, StaticAsset):
setattr(cls, key, StaticAsset(value))
_collect_anywidget_commands(cls)

def _repr_mimebundle_(self, **kwargs: dict) -> tuple[dict, dict] | None:
Expand All @@ -69,3 +69,34 @@ def _repr_mimebundle_(self, **kwargs: dict) -> tuple[dict, dict] | None:
if self._view_name is None:
return None # type: ignore[unreachable]
return repr_mimebundle(model_id=self.model_id, repr_text=plaintext)


def _id_for(obj: typing.Any) -> str:
"""Return a unique identifier for an object."""
# TODO: a better way to uniquely identify this subclasses?
# We use the fully-qualified name to get an id which we
# can use to update CSS if necessary.
return f"{obj.__class__.__module__}.{obj.__class__.__name__}"


@contextmanager
def _patch_get_state(
widget: AnyWidget, extra_state: dict[str, str | StaticAsset]
) -> typing.Generator[None, None, None]:
"""Patch get_state to include anywidget-specific data."""
original_get_state = widget.get_state

def temp_get_state():
return {
**original_get_state(),
**{
k: v.serialize() if isinstance(v, StaticAsset) else v
for k, v in extra_state.items()
},
}

widget.get_state = temp_get_state
try:
yield
finally:
widget.get_state = original_get_state
22 changes: 22 additions & 0 deletions packages/anywidget/src/widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,22 @@ class Runtime {
}
}

let anywidget_static_asset = {
/** @param {{ model_id: string }} model */
serialize(model) {
return `anywidget-static-asset:${model.model_id}`;
},
/**
* @param {string} value
* @param {import("@jupyter-widgets/base").DOMWidgetModel["widget_manager"]} widget_manager
*/
async deserialize(value, widget_manager) {
let model_id = value.slice("anywidget-static-asset:".length);
let model = await widget_manager.get_model(model_id);
return model;
},
};

// @ts-expect-error - injected by bundler
let version = globalThis.VERSION;

Expand Down Expand Up @@ -498,6 +514,12 @@ export default function ({ DOMWidgetModel, DOMWidgetView }) {
RUNTIMES.set(this, runtime);
}

static serializers = {
...DOMWidgetModel.serializers,
_esm: anywidget_static_asset,
_css: anywidget_static_asset,
};

/**
* @param {Record<string, any>} state
*
Expand Down

0 comments on commit 3d9e9e8

Please sign in to comment.