diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index e7232213..a29aeb5f 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -206,6 +206,7 @@ def __init__( self._container_handle_from_uuid: dict[str, GuiContainerProtocol] = { "root": _RootGuiContainer({}) } + self._modal_handle_from_uuid: dict[str, GuiModalHandle] = {} self._current_file_upload_states: dict[str, _FileUploadState] = {} # Set to True when plotly.min.js has been sent to client. @@ -372,7 +373,7 @@ def _get_container_uuid(self) -> str: """Get container ID associated with the current thread.""" return self._target_container_from_thread_id.get(threading.get_ident(), "root") - def _set_container_uid(self, container_uuid: str) -> None: + def _set_container_uuid(self, container_uuid: str) -> None: """Set container ID associated with the current thread.""" self._target_container_from_thread_id[threading.get_ident()] = container_uuid @@ -381,6 +382,8 @@ def reset(self) -> None: root_container = self._container_handle_from_uuid["root"] while len(root_container._children) > 0: next(iter(root_container._children.values())).remove() + while len(self._modal_handle_from_uuid) > 0: + next(iter(self._modal_handle_from_uuid.values())).close() def set_panel_label(self, label: str | None) -> None: """Set the main label that appears in the GUI panel. @@ -535,7 +538,7 @@ def add_modal( ) return GuiModalHandle( _gui_api=self, - _uid=modal_container_id, + _uuid=modal_container_id, ) def add_tab_group( diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 56c55887..b245ebdc 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -598,13 +598,13 @@ class GuiTabHandle: def __enter__(self) -> GuiTabHandle: self._container_id_restore = self._parent._impl.gui_api._get_container_uuid() - self._parent._impl.gui_api._set_container_uid(self._id) + self._parent._impl.gui_api._set_container_uuid(self._id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None - self._parent._impl.gui_api._set_container_uid(self._container_id_restore) + self._parent._impl.gui_api._set_container_uuid(self._container_id_restore) self._container_id_restore = None def __post_init__(self) -> None: @@ -662,13 +662,13 @@ def __init__(self, _impl: _GuiHandleState[None]) -> None: def __enter__(self) -> GuiFolderHandle: self._container_id_restore = self._impl.gui_api._get_container_uuid() - self._impl.gui_api._set_container_uid(self._impl.uuid) + self._impl.gui_api._set_container_uuid(self._impl.uuid) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None - self._impl.gui_api._set_container_uid(self._container_id_restore) + self._impl.gui_api._set_container_uuid(self._container_id_restore) self._container_id_restore = None def remove(self) -> None: @@ -703,34 +703,36 @@ class GuiModalHandle: """Use as a context to place GUI elements into a modal.""" _gui_api: GuiApi - _uid: str # Used as container ID of children. - _container_uid_restore: str | None = None + _uuid: str # Used as container ID of children. + _container_uuid_restore: str | None = None _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) def __enter__(self) -> GuiModalHandle: - self._container_uid_restore = self._gui_api._get_container_uuid() - self._gui_api._set_container_uid(self._uid) + self._container_uuid_restore = self._gui_api._get_container_uuid() + self._gui_api._set_container_uuid(self._uuid) return self def __exit__(self, *args) -> None: del args - assert self._container_uid_restore is not None - self._gui_api._set_container_uid(self._container_uid_restore) - self._container_uid_restore = None + assert self._container_uuid_restore is not None + self._gui_api._set_container_uuid(self._container_uuid_restore) + self._container_uuid_restore = None def __post_init__(self) -> None: - self._gui_api._container_handle_from_uuid[self._uid] = self + self._gui_api._container_handle_from_uuid[self._uuid] = self + self._gui_api._modal_handle_from_uuid[self._uuid] = self def close(self) -> None: """Close this modal and permananently remove all contained GUI elements.""" self._gui_api._websock_interface.queue_message( - GuiCloseModalMessage(self._uid), + GuiCloseModalMessage(self._uuid), ) for child in tuple(self._children.values()): child.remove() - self._gui_api._container_handle_from_uuid.pop(self._uid) + self._gui_api._container_handle_from_uuid.pop(self._uuid) + self._gui_api._modal_handle_from_uuid.pop(self._uuid) def _get_data_url(url: str, image_root: Path | None) -> str: diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 12eb7aea..6f543e85 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -933,11 +933,19 @@ class GuiModalMessage(Message): uuid: str title: str + @override + def redundancy_key(self) -> str: + return f"modal-{self.uuid}" + @dataclasses.dataclass class GuiCloseModalMessage(Message): uuid: str + @override + def redundancy_key(self) -> str: + return f"modal-{self.uuid}" + @dataclasses.dataclass class GuiButtonProps(GuiBaseProps): diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index b8b1c82e..26566d15 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -697,13 +697,13 @@ def __init__(self, impl: _SceneNodeHandleState, gui_api: GuiApi, container_id: s def __enter__(self) -> Gui3dContainerHandle: self._container_id_restore = self._gui_api._get_container_uuid() - self._gui_api._set_container_uid(self._container_id) + self._gui_api._set_container_uuid(self._container_id) return self def __exit__(self, *args) -> None: del args assert self._container_id_restore is not None - self._gui_api._set_container_uid(self._container_id_restore) + self._gui_api._set_container_uuid(self._container_id_restore) self._container_id_restore = None def remove(self) -> None: diff --git a/src/viser/_viser.py b/src/viser/_viser.py index d0fceca6..e20d95d3 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -737,6 +737,8 @@ def _run_garbage_collector(self, force: bool = False) -> None: elif isinstance(message, _messages.GuiRemoveMessage): remove_message_ids.append(id) remove_gui_uuids.add(message.uuid) + elif isinstance(message, _messages.GuiCloseModalMessage): + remove_message_ids.append(id) # For removed elements, no need to send any update messages. if ( diff --git a/tests/test_garbage_collection.py b/tests/test_garbage_collection.py index a9707132..b6724e53 100644 --- a/tests/test_garbage_collection.py +++ b/tests/test_garbage_collection.py @@ -50,3 +50,27 @@ def test_remove_gui_element() -> None: assert len(internal_message_dict) > orig_len server._run_garbage_collector(force=True) assert len(internal_message_dict) == orig_len + + +def test_remove_gui_in_modal() -> None: + """Test that viser's internal message buffer is cleaned up properly when we + remove GUI elements.""" + + # def test_server_port_is_freed(): + # Mock the client autobuild to avoid building the client. + viser._client_autobuild.ensure_client_is_built = lambda: None + + server = viser.ViserServer() + + internal_message_dict = server._websock_server._broadcast_buffer.message_from_id + orig_len = len(internal_message_dict) + + with server.gui.add_modal("Buttons in folder") as modal: + for i in range(50): + server.gui.add_button(f"Button {i}") + + assert len(internal_message_dict) > orig_len + modal.close() + assert len(internal_message_dict) > orig_len + server._run_garbage_collector(force=True) + assert len(internal_message_dict) == orig_len