Skip to content

Commit

Permalink
Modal cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Dec 11, 2024
1 parent 2d4a1ae commit e942b04
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 18 deletions.
7 changes: 5 additions & 2 deletions src/viser/_gui_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
self._container_handle_from_uuid: dict[str, GuiContainerProtocol] = {
"root": _RootGuiContainer({})
}
self._modal_handle_from_uuid: dict[str, GuiModalHandle] = {}
self._current_file_upload_states: dict[str, _FileUploadState] = {}

# Set to True when plotly.min.js has been sent to client.
Expand Down Expand Up @@ -372,7 +373,7 @@ def _get_container_uuid(self) -> str:
"""Get container ID associated with the current thread."""
return self._target_container_from_thread_id.get(threading.get_ident(), "root")

def _set_container_uid(self, container_uuid: str) -> None:
def _set_container_uuid(self, container_uuid: str) -> None:
"""Set container ID associated with the current thread."""
self._target_container_from_thread_id[threading.get_ident()] = container_uuid

Expand All @@ -381,6 +382,8 @@ def reset(self) -> None:
root_container = self._container_handle_from_uuid["root"]
while len(root_container._children) > 0:
next(iter(root_container._children.values())).remove()
while len(self._modal_handle_from_uuid) > 0:
next(iter(self._modal_handle_from_uuid.values())).close()

def set_panel_label(self, label: str | None) -> None:
"""Set the main label that appears in the GUI panel.
Expand Down Expand Up @@ -535,7 +538,7 @@ def add_modal(
)
return GuiModalHandle(
_gui_api=self,
_uid=modal_container_id,
_uuid=modal_container_id,
)

def add_tab_group(
Expand Down
30 changes: 16 additions & 14 deletions src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,13 +598,13 @@ class GuiTabHandle:

def __enter__(self) -> GuiTabHandle:
self._container_id_restore = self._parent._impl.gui_api._get_container_uuid()
self._parent._impl.gui_api._set_container_uid(self._id)
self._parent._impl.gui_api._set_container_uuid(self._id)
return self

def __exit__(self, *args) -> None:
del args
assert self._container_id_restore is not None
self._parent._impl.gui_api._set_container_uid(self._container_id_restore)
self._parent._impl.gui_api._set_container_uuid(self._container_id_restore)
self._container_id_restore = None

def __post_init__(self) -> None:
Expand Down Expand Up @@ -662,13 +662,13 @@ def __init__(self, _impl: _GuiHandleState[None]) -> None:

def __enter__(self) -> GuiFolderHandle:
self._container_id_restore = self._impl.gui_api._get_container_uuid()
self._impl.gui_api._set_container_uid(self._impl.uuid)
self._impl.gui_api._set_container_uuid(self._impl.uuid)
return self

def __exit__(self, *args) -> None:
del args
assert self._container_id_restore is not None
self._impl.gui_api._set_container_uid(self._container_id_restore)
self._impl.gui_api._set_container_uuid(self._container_id_restore)
self._container_id_restore = None

def remove(self) -> None:
Expand Down Expand Up @@ -703,34 +703,36 @@ class GuiModalHandle:
"""Use as a context to place GUI elements into a modal."""

_gui_api: GuiApi
_uid: str # Used as container ID of children.
_container_uid_restore: str | None = None
_uuid: str # Used as container ID of children.
_container_uuid_restore: str | None = None
_children: dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)

def __enter__(self) -> GuiModalHandle:
self._container_uid_restore = self._gui_api._get_container_uuid()
self._gui_api._set_container_uid(self._uid)
self._container_uuid_restore = self._gui_api._get_container_uuid()
self._gui_api._set_container_uuid(self._uuid)
return self

def __exit__(self, *args) -> None:
del args
assert self._container_uid_restore is not None
self._gui_api._set_container_uid(self._container_uid_restore)
self._container_uid_restore = None
assert self._container_uuid_restore is not None
self._gui_api._set_container_uuid(self._container_uuid_restore)
self._container_uuid_restore = None

def __post_init__(self) -> None:
self._gui_api._container_handle_from_uuid[self._uid] = self
self._gui_api._container_handle_from_uuid[self._uuid] = self
self._gui_api._modal_handle_from_uuid[self._uuid] = self

def close(self) -> None:
"""Close this modal and permananently remove all contained GUI elements."""
self._gui_api._websock_interface.queue_message(
GuiCloseModalMessage(self._uid),
GuiCloseModalMessage(self._uuid),
)
for child in tuple(self._children.values()):
child.remove()
self._gui_api._container_handle_from_uuid.pop(self._uid)
self._gui_api._container_handle_from_uuid.pop(self._uuid)
self._gui_api._modal_handle_from_uuid.pop(self._uuid)


def _get_data_url(url: str, image_root: Path | None) -> str:
Expand Down
8 changes: 8 additions & 0 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,11 +933,19 @@ class GuiModalMessage(Message):
uuid: str
title: str

@override
def redundancy_key(self) -> str:
return f"modal-{self.uuid}"


@dataclasses.dataclass
class GuiCloseModalMessage(Message):
uuid: str

@override
def redundancy_key(self) -> str:
return f"modal-{self.uuid}"


@dataclasses.dataclass
class GuiButtonProps(GuiBaseProps):
Expand Down
4 changes: 2 additions & 2 deletions src/viser/_scene_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,13 +697,13 @@ def __init__(self, impl: _SceneNodeHandleState, gui_api: GuiApi, container_id: s

def __enter__(self) -> Gui3dContainerHandle:
self._container_id_restore = self._gui_api._get_container_uuid()
self._gui_api._set_container_uid(self._container_id)
self._gui_api._set_container_uuid(self._container_id)
return self

def __exit__(self, *args) -> None:
del args
assert self._container_id_restore is not None
self._gui_api._set_container_uid(self._container_id_restore)
self._gui_api._set_container_uuid(self._container_id_restore)
self._container_id_restore = None

def remove(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/viser/_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ def _run_garbage_collector(self, force: bool = False) -> None:
elif isinstance(message, _messages.GuiRemoveMessage):
remove_message_ids.append(id)
remove_gui_uuids.add(message.uuid)
elif isinstance(message, _messages.GuiCloseModalMessage):
remove_message_ids.append(id)

# For removed elements, no need to send any update messages.
if (
Expand Down
24 changes: 24 additions & 0 deletions tests/test_garbage_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,27 @@ def test_remove_gui_element() -> None:
assert len(internal_message_dict) > orig_len
server._run_garbage_collector(force=True)
assert len(internal_message_dict) == orig_len


def test_remove_gui_in_modal() -> None:
"""Test that viser's internal message buffer is cleaned up properly when we
remove GUI elements."""

# def test_server_port_is_freed():
# Mock the client autobuild to avoid building the client.
viser._client_autobuild.ensure_client_is_built = lambda: None

server = viser.ViserServer()

internal_message_dict = server._websock_server._broadcast_buffer.message_from_id
orig_len = len(internal_message_dict)

with server.gui.add_modal("Buttons in folder") as modal:
for i in range(50):
server.gui.add_button(f"Button {i}")

assert len(internal_message_dict) > orig_len
modal.close()
assert len(internal_message_dict) > orig_len
server._run_garbage_collector(force=True)
assert len(internal_message_dict) == orig_len

0 comments on commit e942b04

Please sign in to comment.