From 3d9e9e89bf192e786398d540013581631968bc7f Mon Sep 17 00:00:00 2001 From: Trevor Manz Date: Wed, 24 Jul 2024 14:38:13 -0400 Subject: [PATCH] feat: Hoist modules --- anywidget/_static_asset.py | 9 ++--- anywidget/widget.py | 59 ++++++++++++++++++++++++-------- packages/anywidget/src/widget.js | 22 ++++++++++++ 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/anywidget/_static_asset.py b/anywidget/_static_asset.py index e806ef88..500b01be 100644 --- a/anywidget/_static_asset.py +++ b/anywidget/_static_asset.py @@ -3,8 +3,6 @@ import pathlib import typing -import traitlets as t - from anywidget._file_contents import VirtualFileContents from ._descriptor import open_comm @@ -56,8 +54,5 @@ def __del__(self) -> None: """Close the comm when the asset is deleted.""" self._comm.close() - def as_traittype(self) -> t.TraitType: - """Return a traitlet that represents the asset.""" - return t.Instance(StaticAsset, default_value=self).tag( - sync=True, to_json=lambda *_: "anywidget-static-asset:" + self._comm.comm_id - ) + def serialize(self) -> str: + return f"anywidget-static-asset:{self._comm.comm_id}" diff --git a/anywidget/widget.py b/anywidget/widget.py index c52b38d4..3cd298c5 100644 --- a/anywidget/widget.py +++ b/anywidget/widget.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing +from contextlib import contextmanager import ipywidgets import traitlets.traitlets as t @@ -36,30 +37,29 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: if in_colab(): enable_custom_widget_manager_once() - anywidget_traits = {} + self._anywidget_internal_state = {} + for key in (_ESM_KEY, _CSS_KEY): + if hasattr(self, key) and not self.has_trait(key): + self._anywidget_internal_state[key] = getattr(self, key) + if not hasattr(self, _ESM_KEY): - anywidget_traits[_ESM_KEY] = StaticAsset(_DEFAULT_ESM).as_traittype() + self._anywidget_internal_state[_ESM_KEY] = _DEFAULT_ESM + + self._anywidget_internal_state[_ANYWIDGET_ID_KEY] = _id_for(self) - # TODO: a better way to uniquely identify this subclasses? - # We use the fully-qualified name to get an id which we - # can use to update CSS if necessary. - anywidget_traits[_ANYWIDGET_ID_KEY] = t.Unicode( - f"{self.__class__.__module__}.{self.__class__.__name__}" - ).tag(sync=True) + with _patch_get_state(self, self._anywidget_internal_state): + super().__init__(*args, **kwargs) - self.add_traits(**anywidget_traits) - super().__init__(*args, **kwargs) _register_anywidget_commands(self) def __init_subclass__(cls, **kwargs: dict) -> None: """Coerces _esm and _css to FileContents if they are files.""" super().__init_subclass__(**kwargs) for key in (_ESM_KEY, _CSS_KEY) & cls.__dict__.keys(): + # TODO: Upgrate to := when we drop Python 3.7 value = getattr(cls, key) - if isinstance(value, t.TraitType): - # we don't know how to handle this - continue - setattr(cls, key, StaticAsset(value).as_traittype()) + if not isinstance(value, StaticAsset): + setattr(cls, key, StaticAsset(value)) _collect_anywidget_commands(cls) def _repr_mimebundle_(self, **kwargs: dict) -> tuple[dict, dict] | None: @@ -69,3 +69,34 @@ def _repr_mimebundle_(self, **kwargs: dict) -> tuple[dict, dict] | None: if self._view_name is None: return None # type: ignore[unreachable] return repr_mimebundle(model_id=self.model_id, repr_text=plaintext) + + +def _id_for(obj: typing.Any) -> str: + """Return a unique identifier for an object.""" + # TODO: a better way to uniquely identify this subclasses? + # We use the fully-qualified name to get an id which we + # can use to update CSS if necessary. + return f"{obj.__class__.__module__}.{obj.__class__.__name__}" + + +@contextmanager +def _patch_get_state( + widget: AnyWidget, extra_state: dict[str, str | StaticAsset] +) -> typing.Generator[None, None, None]: + """Patch get_state to include anywidget-specific data.""" + original_get_state = widget.get_state + + def temp_get_state(): + return { + **original_get_state(), + **{ + k: v.serialize() if isinstance(v, StaticAsset) else v + for k, v in extra_state.items() + }, + } + + widget.get_state = temp_get_state + try: + yield + finally: + widget.get_state = original_get_state diff --git a/packages/anywidget/src/widget.js b/packages/anywidget/src/widget.js index ba6e6d20..57f3b623 100644 --- a/packages/anywidget/src/widget.js +++ b/packages/anywidget/src/widget.js @@ -467,6 +467,22 @@ class Runtime { } } +let anywidget_static_asset = { + /** @param {{ model_id: string }} model */ + serialize(model) { + return `anywidget-static-asset:${model.model_id}`; + }, + /** + * @param {string} value + * @param {import("@jupyter-widgets/base").DOMWidgetModel["widget_manager"]} widget_manager + */ + async deserialize(value, widget_manager) { + let model_id = value.slice("anywidget-static-asset:".length); + let model = await widget_manager.get_model(model_id); + return model; + }, +}; + // @ts-expect-error - injected by bundler let version = globalThis.VERSION; @@ -498,6 +514,12 @@ export default function ({ DOMWidgetModel, DOMWidgetView }) { RUNTIMES.set(this, runtime); } + static serializers = { + ...DOMWidgetModel.serializers, + _esm: anywidget_static_asset, + _css: anywidget_static_asset, + }; + /** * @param {Record} state *