From e105a0f507c5258733fd017d654527f023f97281 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 25 Sep 2024 03:13:14 +0000 Subject: [PATCH] Callback, scene node removal API improvements --- src/viser/_gui_handles.py | 61 ++++++++++++++++++- src/viser/_messages.py | 2 - src/viser/_scene_api.py | 12 +++- src/viser/_scene_handles.py | 54 +++++++++++++--- .../client/src/ControlPanel/ControlPanel.tsx | 8 +-- src/viser/client/src/SceneTree.tsx | 24 +++++--- src/viser/client/src/ThreeAssets.tsx | 40 +++--------- src/viser/client/src/WebsocketMessages.ts | 2 - 8 files changed, 143 insertions(+), 60 deletions(-) diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 93aabf70..a8d15d08 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -15,6 +15,7 @@ Dict, Generic, Iterable, + Literal, Tuple, TypeVar, cast, @@ -106,6 +107,8 @@ class _GuiHandleState(Generic[T]): sync_cb: Callable[[ClientId, dict[str, Any]], None] | None = None """Callback for synchronizing inputs across clients.""" + removed: bool = False + class _OverridableGuiPropApi: """Mixin that allows reading/assigning properties defined in each scene node message.""" @@ -157,6 +160,17 @@ def __init__(self, _impl: _GuiHandleState[T]) -> None: def remove(self) -> None: """Permanently remove this GUI element from the visualizer.""" + + # Warn if already removed. + if self._impl.removed: + warnings.warn( + f"Attempted to remove an already removed {self.__class__.__name__}.", + stacklevel=2, + ) + return + self._impl.removed = True + + # Send remove to client(s) + update internal state. self._impl.gui_api._websock_interface.queue_message( GuiRemoveMessage(self._impl.id) ) @@ -241,10 +255,25 @@ class GuiInputHandle(_GuiInputHandle[T], Generic[T]): def on_update( self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], Any] ) -> Callable[[GuiEvent[TGuiHandle]], None]: - """Attach a function to call when a GUI input is updated. Happens in a thread.""" + """Attach a function to call when a GUI input is updated. Callbacks stack (need + to be manually removed via :meth:`remove_update_callback()`) and will be called + from a thread.""" self._impl.update_cb.append(func) return func + def remove_update_callback( + self, callback: Literal["all"] | Callable = "all" + ) -> None: + """Remove update callbacks from the GUI input. + + Args: + callback: Either "all" to remove all callbacks, or a specific callback function to remove. + """ + if callback == "all": + self._impl.update_cb.clear() + else: + self._impl.update_cb = [cb for cb in self._impl.update_cb if cb != callback] + class GuiCheckboxHandle(GuiInputHandle[bool], GuiCheckboxProps): """Handle for checkbox inputs. @@ -506,6 +535,16 @@ def __post_init__(self) -> None: def remove(self) -> None: """Remove this tab group and all contained GUI elements.""" + # Warn if already removed. + if self._impl.removed: + warnings.warn( + f"Attempted to remove an already removed {self.__class__.__name__}.", + stacklevel=2, + ) + return + self._impl.removed = True + + # Remove tabs, then self. for tab in tuple(self._tab_handles): tab.remove() gui_api = self._impl.gui_api @@ -524,6 +563,7 @@ class GuiTabHandle: _children: dict[str, SupportsRemoveProtocol] = dataclasses.field( default_factory=dict ) + _removed: bool = False def __enter__(self) -> GuiTabHandle: self._container_id_restore = self._parent._impl.gui_api._get_container_id() @@ -542,6 +582,15 @@ def __post_init__(self) -> None: def remove(self) -> None: """Permanently remove this tab and all contained GUI elements from the visualizer.""" + # Warn if already removed. + if self._removed: + warnings.warn( + f"Attempted to remove an already removed {self.__class__.__name__}.", + stacklevel=2, + ) + return + self._removed = True + # We may want to make this thread-safe in the future. found_index = -1 for i, tab in enumerate(self._parent._tab_handles): @@ -594,6 +643,16 @@ def __exit__(self, *args) -> None: def remove(self) -> None: """Permanently remove this folder and all contained GUI elements from the visualizer.""" + # Warn if already removed. + if self._impl.removed: + warnings.warn( + f"Attempted to remove an already removed {self.__class__.__name__}.", + stacklevel=2, + ) + return + self._impl.removed = True + + # Remove children, then self. self._impl.gui_api._websock_interface.queue_message( GuiRemoveMessage(self._impl.id) ) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 77748a44..555d745a 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -520,8 +520,6 @@ class MeshProps: """A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F, 3). Synchronized automatically when assigned.""" color: Optional[Tuple[int, int, int]] """Color of the mesh as RGB integers. Synchronized automatically when assigned.""" - vertex_colors: Optional[npt.NDArray[np.uint8]] - """Optional array of vertex colors. Synchronized automatically when assigned.""" wireframe: bool """Boolean indicating if the mesh should be rendered as a wireframe. Synchronized automatically when assigned.""" opacity: Optional[float] diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index f2acef20..507d0678 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1037,6 +1037,7 @@ def add_mesh_skinned( stacklevel=2, ) + assert len(bone_wxyzs) == len(bone_positions) num_bones = len(bone_wxyzs) assert skin_weights.shape == (vertices.shape[0], num_bones) @@ -1059,7 +1060,6 @@ def add_mesh_skinned( vertices=vertices.astype(np.float32), faces=faces.astype(np.uint32), color=_encode_rgb(color), - vertex_colors=None, wireframe=wireframe, opacity=opacity, flat_shading=flat_shading, @@ -1153,7 +1153,6 @@ def add_mesh_simple( vertices=vertices.astype(np.float32), faces=faces.astype(np.uint32), color=_encode_rgb(color), - vertex_colors=None, wireframe=wireframe, opacity=opacity, flat_shading=flat_shading, @@ -1757,3 +1756,12 @@ def add_3d_gui_container( self, message, name, wxyz, position, visible=visible ) return Gui3dContainerHandle(node_handle._impl, gui_api, container_id) + + def remove_by_name(self, name: str) -> None: + """Helper to call `.remove()` on the scene node handles of the `name` + element or any of its children.""" + handle_from_node_name = self._handle_from_node_name.copy() + name = name.rstrip("/") # '/parent/' => '/parent' + for node_name, handle in handle_from_node_name.items(): + if node_name == name or node_name.startswith(name + "/"): + handle.remove() diff --git a/src/viser/_scene_handles.py b/src/viser/_scene_handles.py index ec62e562..429b3492 100644 --- a/src/viser/_scene_handles.py +++ b/src/viser/_scene_handles.py @@ -17,7 +17,7 @@ import numpy as np import numpy.typing as onpt -from typing_extensions import get_type_hints +from typing_extensions import Self, get_type_hints from . import _messages from .infra._infra import WebsockClientConnection, WebsockServer @@ -123,10 +123,10 @@ class _SceneNodeHandleState: default_factory=lambda: np.array([0.0, 0.0, 0.0]) ) visible: bool = True - # TODO: we should remove SceneNodeHandle as an argument here. - click_cb: list[Callable[[SceneNodePointerEvent[SceneNodeHandle]], None]] | None = ( - None - ) + click_cb: list[ + Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None] + ] = dataclasses.field(default_factory=list) + removed: bool = False class _SceneNodeMessage(Protocol): @@ -223,6 +223,10 @@ def visible(self, visible: bool) -> None: def remove(self) -> None: """Remove the node from the scene.""" + # No-op if already removed. + if self._impl.removed: + return + self._impl.removed = True self._impl.api._websock_interface.queue_message( _messages.RemoveSceneNodeMessage(self._impl.name) ) @@ -253,18 +257,35 @@ class SceneNodePointerEvent(Generic[TSceneNodeHandle]): class _ClickableSceneNodeHandle(SceneNodeHandle): def on_click( - self: TSceneNodeHandle, - func: Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None], - ) -> Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None]: + self: Self, + func: Callable[[SceneNodePointerEvent[Self]], None], + ) -> Callable[[SceneNodePointerEvent[Self]], None]: """Attach a callback for when a scene node is clicked.""" self._impl.api._websock_interface.queue_message( _messages.SetSceneNodeClickableMessage(self._impl.name, True) ) if self._impl.click_cb is None: self._impl.click_cb = [] - self._impl.click_cb.append(func) # type: ignore + self._impl.click_cb.append( + cast( + Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None], func + ) + ) return func + def remove_click_callback( + self, callback: Literal["all"] | Callable = "all" + ) -> None: + """Remove click callbacks from scene node. + + Args: + callback: Either "all" to remove all callbacks, or a specific callback function to remove. + """ + if callback == "all": + self._impl.click_cb.clear() + else: + self._impl.click_cb = [cb for cb in self._impl.click_cb if cb != callback] + class CameraFrustumHandle( _ClickableSceneNodeHandle, @@ -510,6 +531,21 @@ def on_update( self._impl_aux.update_cb.append(func) return func + def remove_update_callback( + self, callback: Literal["all"] | Callable = "all" + ) -> None: + """Remove update callbacks from the transform controls. + + Args: + callback: Either "all" to remove all callbacks, or a specific callback function to remove. + """ + if callback == "all": + self._impl_aux.update_cb.clear() + else: + self._impl_aux.update_cb = [ + cb for cb in self._impl_aux.update_cb if cb != callback + ] + class Gui3dContainerHandle( SceneNodeHandle, diff --git a/src/viser/client/src/ControlPanel/ControlPanel.tsx b/src/viser/client/src/ControlPanel/ControlPanel.tsx index 59c515aa..80909be8 100644 --- a/src/viser/client/src/ControlPanel/ControlPanel.tsx +++ b/src/viser/client/src/ControlPanel/ControlPanel.tsx @@ -62,10 +62,10 @@ export default function ControlPanel(props: { controlWidthString == "small" ? "16em" : controlWidthString == "medium" - ? "20em" - : controlWidthString == "large" - ? "24em" - : null + ? "20em" + : controlWidthString == "large" + ? "24em" + : null )!; const generatedServerToggleButton = ( diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index d82fd97d..288a3e8f 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -222,16 +222,20 @@ function useObjectFactory(message: SceneNodeMessage | undefined): { message.props.plane == "xz" ? new THREE.Euler(0.0, 0.0, 0.0) : message.props.plane == "xy" - ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) - : message.props.plane == "yx" - ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) - : message.props.plane == "yz" - ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) - : message.props.plane == "zx" - ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) - : message.props.plane == "zy" - ? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0) - : undefined + ? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0) + : message.props.plane == "yx" + ? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0) + : message.props.plane == "yz" + ? new THREE.Euler(0.0, 0.0, Math.PI / 2.0) + : message.props.plane == "zx" + ? new THREE.Euler(0.0, Math.PI / 2.0, 0.0) + : message.props.plane == "zy" + ? new THREE.Euler( + -Math.PI / 2.0, + 0.0, + -Math.PI / 2.0, + ) + : undefined } /> diff --git a/src/viser/client/src/ThreeAssets.tsx b/src/viser/client/src/ThreeAssets.tsx index 2d45dae1..0a0ad98c 100644 --- a/src/viser/client/src/ThreeAssets.tsx +++ b/src/viser/client/src/ThreeAssets.tsx @@ -408,19 +408,6 @@ export const InstancedAxes = React.forwardRef< }); /** Convert raw RGB color buffers to linear color buffers. **/ -function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) { - return new THREE.Float32BufferAttribute( - new Float32Array(new Uint8Array(colors)).map((value) => { - value = value / 255.0; - if (value <= 0.04045) { - return value / 12.92; - } else { - return Math.pow((value + 0.055) / 1.055, 2.4); - } - }), - 3, - ); -} export const ViserMesh = React.forwardRef< THREE.Mesh | THREE.SkinnedMesh, MeshMessage | SkinnedMeshMessage @@ -448,7 +435,6 @@ export const ViserMesh = React.forwardRef< const standardArgs = { color: message.props.color === null ? undefined : rgbToInt(message.props.color), - vertexColors: message.props.vertex_colors !== null, wireframe: message.props.wireframe, transparent: message.props.opacity !== null, opacity: message.props.opacity ?? 1.0, @@ -474,16 +460,16 @@ export const ViserMesh = React.forwardRef< message.props.material == "standard" || message.props.wireframe ? new THREE.MeshStandardMaterial(standardArgs) : message.props.material == "toon3" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(3), - ...standardArgs, - }) - : message.props.material == "toon5" - ? new THREE.MeshToonMaterial({ - gradientMap: generateGradientMap(5), - ...standardArgs, - }) - : assertUnreachable(message.props.material); + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(3), + ...standardArgs, + }) + : message.props.material == "toon5" + ? new THREE.MeshToonMaterial({ + gradientMap: generateGradientMap(5), + ...standardArgs, + }) + : assertUnreachable(message.props.material); const geometry = new THREE.BufferGeometry(); geometry.setAttribute( "position", @@ -498,12 +484,6 @@ export const ViserMesh = React.forwardRef< 3, ), ); - if (message.props.vertex_colors !== null) { - geometry.setAttribute( - "color", - threeColorBufferFromUint8Buffer(message.props.vertex_colors), - ); - } geometry.setIndex( new THREE.Uint32BufferAttribute( diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 1518a40c..c49e0804 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -324,7 +324,6 @@ export interface MeshMessage { vertices: Uint8Array; faces: Uint8Array; color: [number, number, number] | null; - vertex_colors: Uint8Array | null; wireframe: boolean; opacity: number | null; flat_shading: boolean; @@ -343,7 +342,6 @@ export interface SkinnedMeshMessage { vertices: Uint8Array; faces: Uint8Array; color: [number, number, number] | null; - vertex_colors: Uint8Array | null; wireframe: boolean; opacity: number | null; flat_shading: boolean;