diff --git a/examples/hitl/rearrange_v2/rearrange_v2.py b/examples/hitl/rearrange_v2/rearrange_v2.py index 259d080573..f6d50696bc 100644 --- a/examples/hitl/rearrange_v2/rearrange_v2.py +++ b/examples/hitl/rearrange_v2/rearrange_v2.py @@ -5,17 +5,18 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Set +from typing import Any, List, Tuple import hydra import magnum as mn import numpy as np +from ui import UI -import habitat_sim from habitat.sims.habitat_simulator import sim_utilities from habitat_hitl._internal.networking.average_rate_tracker import ( AverageRateTracker, ) +from habitat_hitl.app_states.app_service import AppService from habitat_hitl.app_states.app_state_abc import AppState from habitat_hitl.core.client_helper import ClientHelper from habitat_hitl.core.gui_input import GuiInput @@ -28,15 +29,9 @@ GuiHumanoidController, GuiRobotController, ) -from habitat_hitl.environment.gui_pick_helper import GuiPickHelper -from habitat_hitl.environment.gui_placement_helper import GuiPlacementHelper from habitat_hitl.environment.hablab_utils import get_agent_art_obj_transform from habitat_sim.utils.common import quat_from_magnum, quat_to_coeffs -ENABLE_ARTICULATED_OPEN_CLOSE = False -# Visually snap picked objects into the humanoid's hand. May be useful in third-person mode. Beware that this conflicts with GuiPlacementHelper. -DO_HUMANOID_GRASP_OBJECTS = False - class DataLogger: def __init__(self, app_service): @@ -100,7 +95,7 @@ class AppStateRearrangeV2(AppState): Todo """ - def __init__(self, app_service): + def __init__(self, app_service: AppService): self._app_service = app_service self._gui_agent_controllers = self._app_service.gui_agent_controllers self._num_users = len(self._gui_agent_controllers) @@ -109,37 +104,34 @@ def __init__(self, app_service): ) self._sim = app_service.sim - self._ao_root_bbs: Dict = None - self._opened_ao_set: Set = set() - self._cam_transform = None self._camera_user_index = 0 - self._held_obj_id = None - self._recent_reach_pos = None self._paused = False - self._hide_gui_text = False - self._can_place_object = False + self._show_gui_text = True self._camera_helper = CameraHelper( self._app_service.hitl_config, self._app_service.gui_input, ) - - self._pick_helper = GuiPickHelper(self._app_service, user_index=0) - self._placement_helper = GuiPlacementHelper( - self._app_service, user_index=0 - ) self._client_helper = None if self._app_service.hitl_config.networking.enable: self._client_helper = ClientHelper(self._app_service) - self._has_grasp_preview = False - self._frame_counter = 0 self._sps_tracker = AverageRateTracker(2.0) self._task_instruction = "" self._data_logger = DataLogger(app_service=self._app_service) + self._ui = UI( + hitl_config=app_service.hitl_config, + user_index=0, + gui_controller=self._gui_agent_controllers[0], + sim=self._sim, + gui_input=app_service.gui_input, + gui_drawer=app_service.gui_drawer, + camera_helper=self._camera_helper, + ) + if self._app_service.hitl_config.networking.enable: self._app_service.remote_client_state.on_client_connected.registerCallback( self._on_client_connected @@ -160,93 +152,9 @@ def _on_client_disconnected(self, disconnection: DisconnectionRecord): def get_sim_utilities() -> Any: return sim_utilities - def _remap_key(self, user_index, key): - key_remap = { - GuiInput.KeyNS.SPACE: GuiInput.KeyNS.N, - GuiInput.KeyNS.Z: GuiInput.KeyNS.X, - } - if user_index == 1: - assert key in key_remap - key = key_remap[key] - return key - - def _get_user_key_down(self, user_index, key): - return self._app_service.gui_input.get_key_down( - self._remap_key(user_index, key) - ) - - def _open_close_ao(self, ao_handle: str): - if not ENABLE_ARTICULATED_OPEN_CLOSE: - return - - ao = self.get_sim_utilities().get_obj_from_handle(self._sim, ao_handle) - - # Check whether the ao is opened - is_opened = ao_handle in self._opened_ao_set - - # Set ao joint positions - joint_limits = ao.joint_position_limits - joint_limits = joint_limits[0] if is_opened else joint_limits[1] - ao.joint_positions = joint_limits - ao.clamp_joint_limits() - - # Remove ao from opened set - if is_opened: - self._opened_ao_set.remove(ao_handle) - else: - self._opened_ao_set.add(ao_handle) - - def _find_reachable_ao(self, player_pos) -> str: - """Returns the handle of the nearest reachable articulated object. Returns None if none is found.""" - if not ENABLE_ARTICULATED_OPEN_CLOSE: - return None - - max_distance = 2.0 # TODO: Const - player_pos_xz = mn.Vector3(player_pos.x, 0.0, player_pos.z) - min_dist: float = max_distance - output: str = None - - # TODO: Caching - # TODO: Improve heuristic using bounding box sizes and view angle - for handle, _ in self._ao_root_bbs.items(): - ao = self.get_sim_utilities().get_obj_from_handle( - self._sim, handle - ) - ao_pos = ao.translation - ao_pos_xz = mn.Vector3(ao_pos.x, 0.0, ao_pos.z) - dist_xz = (ao_pos_xz - player_pos_xz).length() - if dist_xz < max_distance and dist_xz < min_dist: - min_dist = dist_xz - output = handle - - return output - - def _highlight_ao(self, handle: str): - assert ENABLE_ARTICULATED_OPEN_CLOSE - bb = self._ao_root_bbs[handle] - ao = self.get_sim_utilities().get_obj_from_handle(self._sim, handle) - ao_pos = ao.translation - ao_pos.y = 0.0 # project to ground - radius = max(bb.size_x(), bb.size_y(), bb.size_z()) / 2.0 - # sloppy: use private GuiPickHelper._add_highlight_ring - self._pick_helper._add_highlight_ring( - ao_pos, mn.Color3(0, 1, 0), radius, do_pulse=False, billboard=False - ) - def on_environment_reset(self, episode_recorder_dict): - if ENABLE_ARTICULATED_OPEN_CLOSE: - self._ao_root_bbs = self.get_sim_utilities().get_ao_root_bbs( - self._sim - ) - # HACK: Remove humans and spot from the AO collections - handle_filter = ["male", "female", "hab_spot_arm"] - for key in list(self._ao_root_bbs.keys()): - if any(handle in key for handle in handle_filter): - del self._ao_root_bbs[key] - - self._held_obj_id = None - - self._pick_helper.on_environment_reset() + object_receptacle_pairs = self._create_goal_object_receptacle_pairs() + self._ui.reset(object_receptacle_pairs) self._camera_helper.update(self._get_camera_lookat_pos(), dt=0) @@ -260,132 +168,83 @@ def on_environment_reset(self, episode_recorder_dict): client_message_manager = self._app_service.client_message_manager if client_message_manager: client_message_manager.signal_scene_change() - # Not currently needed since the browser client doesn't have a notion of a humanoid. Here for reference. - # human_pos = ( - # self.get_sim() - # .get_agent_data(self.get_gui_controlled_agent_index()) - # .articulated_agent.base_pos - # ) - # client_message_manager.change_humanoid_position(human_pos) - # client_message_manager.update_navmesh_triangles( - # self._get_navmesh_triangle_vertices() - # ) def get_sim(self): return self._app_service.sim - def _get_gui_agent_translation(self, user_index): - return get_agent_art_obj_transform( - self.get_sim(), self.get_gui_controlled_agent_index(user_index) - ).translation + def _create_goal_object_receptacle_pairs( + self, + ) -> List[Tuple[List[int], List[int]]]: + """Parse the current episode and returns the goal object-receptacle pairs.""" + sim = self.get_sim() + paired_goal_ids: List[Tuple[List[int], List[int]]] = [] + current_episode = self._app_service.env.current_episode + if current_episode.info.get("extra_info") is not None: + extra_info = current_episode.info["extra_info"] + self._task_instruction = extra_info["instruction"] + for proposition in extra_info["evaluation_propositions"]: + object_ids: List[int] = [] + object_handles = proposition["args"]["object_handles"] + for object_handle in object_handles: + obj = sim_utilities.get_obj_from_handle(sim, object_handle) + object_id = obj.object_id + object_ids.append(object_id) + receptacle_ids: List[int] = [] + receptacle_handles = proposition["args"]["receptacle_handles"] + for receptacle_handle in receptacle_handles: + obj = sim_utilities.get_obj_from_handle( + sim, receptacle_handle + ) + object_id = obj.object_id + # TODO: Support for finding links by handle. + receptacle_ids.append(object_id) + paired_goal_ids.append((object_ids, receptacle_ids)) + return paired_goal_ids def _update_grasping_and_set_act_hints(self, user_index): - drop_pos = None - grasp_object_id = None - throw_vel = None - reach_pos = None - - self._has_grasp_preview = False - - # todo: implement grasping properly for each user. _held_obj_id, _has_grasp_preview, etc. must be tracked per user. - if self._held_obj_id is not None: - if ( - self._get_user_key_down(user_index, GuiInput.KeyNS.SPACE) - and self._can_place_object - ): - if DO_HUMANOID_GRASP_OBJECTS: - # todo: better drop pos - drop_pos = self._get_gui_agent_translation( - user_index - ) # self._gui_agent_controllers.get_base_translation() - else: - # GuiPlacementHelper has already placed this object. - pass - self._held_obj_id = None - else: - query_pos = self._get_gui_agent_translation(user_index) - obj_id = self._pick_helper.get_pick_object_near_query_position( - query_pos - ) - if obj_id: - if self._get_user_key_down(user_index, GuiInput.KeyNS.SPACE): - if DO_HUMANOID_GRASP_OBJECTS: - grasp_object_id = obj_id - self._held_obj_id = obj_id - else: - self._has_grasp_preview = True - - walk_dir = None - distance_multiplier = 1.0 - - # reference code for click-to-walk - # if self._app_service.gui_input.get_mouse_button( - # GuiInput.MouseNS.RIGHT - # ): - # ( - # candidate_walk_dir, - # candidate_distance_multiplier, - # ) = self._nav_helper.get_humanoid_walk_hints_from_ray_cast( - # visualize_path=True - # ) - # walk_dir = candidate_walk_dir - # distance_multiplier = candidate_distance_multiplier - gui_agent_controller = self._gui_agent_controllers[user_index] assert isinstance( gui_agent_controller, (GuiHumanoidController, GuiRobotController) ) gui_agent_controller.set_act_hints( - walk_dir, - distance_multiplier, - grasp_object_id, - drop_pos, - self._camera_helper.lookat_offset_yaw, - throw_vel=throw_vel, - reach_pos=reach_pos, + walk_dir=None, + distance_multiplier=1.0, + grasp_obj_idx=None, + do_drop=None, + cam_yaw=self._camera_helper.lookat_offset_yaw, + throw_vel=None, + reach_pos=None, ) - return drop_pos - def get_gui_controlled_agent_index(self, user_index): return self._gui_agent_controllers[user_index]._agent_idx def _get_controls_text(self): - def get_grasp_release_controls_text(): - if self._held_obj_id is not None: - return "Space/N: put down\n" - elif self._has_grasp_preview: - return "Space/N: pick up\n" - else: - return "" + if self._paused: + return "Session ended." - controls_str: str = "" - if not self._hide_gui_text: - if self._sps_tracker.get_smoothed_rate() is not None: - controls_str += f"server SPS: {self._sps_tracker.get_smoothed_rate():.1f}\n" - if self._client_helper and self._client_helper.display_latency_ms: - controls_str += f"latency: {self._client_helper.display_latency_ms:.0f}ms\n" - controls_str += "H: show/hide help text\n" - controls_str += "P: pause\n" - controls_str += "I, K: look up, down\n" - controls_str += "A, D: turn\n" - controls_str += "W/F, S/V: walk\n" - controls_str += "N: next episode\n" - if ENABLE_ARTICULATED_OPEN_CLOSE: - controls_str += "Z/X: open/close receptacle\n" - controls_str += get_grasp_release_controls_text() - if self._num_users > 1 and self._held_obj_id is None: - controls_str += "T: toggle camera user\n" + if not self._show_gui_text: + return "" + controls_str: str = "" + controls_str += "H: Toggle help\n" + controls_str += "Look: Middle click (drag), I, K\n" + controls_str += "Walk: W, S\n" + controls_str += "Turn: A, D\n" + controls_str += "Finish episode: Zero (0)\n" + controls_str += "Open/close: Double-click\n" + controls_str += "Pick object: Double-click\n" + controls_str += "Place object: Right click (hold)\n" return controls_str def _get_status_text(self): + if self._paused: + return "" + status_str = "" if len(self._task_instruction) > 0: - status_str += "\nInstruction: " + self._task_instruction + "\n" - if self._paused: - status_str += "\n\npaused\n" + status_str += "Instruction: " + self._task_instruction + "\n" if ( self._client_helper and self._client_helper.do_show_idle_kick_warning @@ -397,12 +256,6 @@ def _get_status_text(self): return status_str def _update_help_text(self): - controls_str = self._get_controls_text() - if len(controls_str) > 0: - self._app_service.text_drawer.add_text( - controls_str, TextOnScreenAlignment.TOP_LEFT - ) - status_str = self._get_status_text() if len(status_str) > 0: self._app_service.text_drawer.add_text( @@ -412,6 +265,12 @@ def _update_help_text(self): text_delta_y=-50, ) + controls_str = self._get_controls_text() + if len(controls_str) > 0: + self._app_service.text_drawer.add_text( + controls_str, TextOnScreenAlignment.TOP_LEFT + ) + def _get_camera_lookat_pos(self): agent_root = get_agent_art_obj_transform( self.get_sim(), @@ -426,31 +285,13 @@ def is_user_idle_this_frame(self) -> bool: def _check_change_episode(self): if self._paused or not self._app_service.gui_input.get_key_down( - GuiInput.KeyNS.N + GuiInput.KeyNS.ZERO ): return if self._app_service.episode_helper.next_episode_exists(): self._app_service.end_episode(do_reset=True) - def _update_held_object_placement(self): - if not self._held_obj_id: - return - - ray = habitat_sim.geo.Ray() - ray.origin = self._camera_helper.get_eye_pos() - ray.direction = ( - self._camera_helper.get_lookat_pos() - - self._camera_helper.get_eye_pos() - ).normalized() - - if self._placement_helper.update(ray, self._held_obj_id): - # sloppy: save another keyframe here since we just moved the held object - self.get_sim().gfx_replay_manager.save_keyframe() - self._can_place_object = True - else: - self._can_place_object = False - def sim_update(self, dt, post_sim_update_dict): if ( not self._app_service.hitl_config.networking.enable @@ -468,57 +309,31 @@ def sim_update(self, dt, post_sim_update_dict): self._sps_tracker.get_smoothed_rate(), ) - if self._app_service.gui_input.get_key_down(GuiInput.KeyNS.P): - self._paused = not self._paused - if self._app_service.gui_input.get_key_down(GuiInput.KeyNS.H): - self._hide_gui_text = not self._hide_gui_text + self._show_gui_text = not self._show_gui_text self._check_change_episode() - for user_index in range(self._num_users): - reachable_ao_handle = self._find_reachable_ao( - self._get_gui_agent_translation(user_index) - ) - if reachable_ao_handle is not None: - self._highlight_ao(reachable_ao_handle) - if self._get_user_key_down(user_index, GuiInput.KeyNS.Z): - self._open_close_ao(reachable_ao_handle) - if not self._paused: for user_index in range(self._num_users): + self._ui.update() # TODO: One UI per user. self._update_grasping_and_set_act_hints(user_index) self._app_service.compute_action_and_step_env() else: # temp hack: manually add a keyframe while paused self.get_sim().gfx_replay_manager.save_keyframe() - # todo: visualize objects properly for each user (this requires a separate debug_line_render per user!), or find a reasonable debug line visualization that can be shared between both users every frame. - if self._held_obj_id is None: - self._pick_helper.viz_objects() - - if ( - self._num_users > 1 - and self._held_obj_id is None - and self._app_service.gui_input.get_key_down(GuiInput.KeyNS.T) - ): - self._camera_user_index = ( - self._camera_user_index + 1 - ) % self._num_users - self._camera_helper.update(self._get_camera_lookat_pos(), dt) - # after camera update - self._update_held_object_placement() - self._cam_transform = self._camera_helper.get_cam_transform() post_sim_update_dict["cam_transform"] = self._cam_transform + self._ui.draw_ui() # TODO: One UI per user. self._update_help_text() def record_state(self): task_completed = self._app_service.gui_input.get_key_down( - GuiInput.KeyNS.N + GuiInput.KeyNS.ZERO ) self._data_logger.record_state(task_completed=task_completed) diff --git a/examples/hitl/rearrange_v2/ui.py b/examples/hitl/rearrange_v2/ui.py new file mode 100644 index 0000000000..8345887206 --- /dev/null +++ b/examples/hitl/rearrange_v2/ui.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Set, Tuple + +import magnum as mn + +from habitat.sims.habitat_simulator import sim_utilities +from habitat.tasks.rearrange.articulated_agent_manager import ( + ArticulatedAgentManager, +) +from habitat.tasks.rearrange.rearrange_sim import RearrangeSim +from habitat_hitl.core.gui_drawer import GuiDrawer +from habitat_hitl.core.gui_input import GuiInput +from habitat_hitl.core.key_mapping import MouseButton +from habitat_hitl.core.selection import Selection +from habitat_hitl.core.user_mask import Mask +from habitat_hitl.environment.camera_helper import CameraHelper +from habitat_hitl.environment.controllers.controller_abc import GuiController +from habitat_hitl.environment.hablab_utils import get_agent_art_obj_transform +from habitat_sim.physics import ManagedArticulatedObject + +# Verticality threshold for successful placement. +MINIMUM_DROP_VERTICALITY: float = 0.9 + +# Maximum delay between two clicks to be registered as a double-click. +DOUBLE_CLICK_DELAY: float = 0.33 + +_HI = 0.8 +_LO = 0.4 +# Color for a valid action. +COLOR_VALID = mn.Color4(0.0, _HI, 0.0, 1.0) # Green +# Color for an invalid action. +COLOR_INVALID = mn.Color4(_HI, 0.0, 0.0, 1.0) # Red +# Color for goal object-receptacle pairs. +COLOR_GOALS: List[mn.Color4] = [ + mn.Color4(0.0, _HI, _HI, 1.0), # Cyan + mn.Color4(_HI, 0.0, _HI, 1.0), # Magenta + mn.Color4(_HI, _HI, 0.0, 1.0), # Yellow + mn.Color4(_HI, 0.0, _LO, 1.0), # Purple + mn.Color4(_LO, _HI, 0.0, 1.0), # Orange +] + + +class UI: + """ + User interface for the rearrange_v2 app. + Each user has their own UI class. + """ + + def __init__( + self, + hitl_config, + user_index: int, + gui_controller: GuiController, + sim: RearrangeSim, + gui_input: GuiInput, + gui_drawer: GuiDrawer, + camera_helper: CameraHelper, + ): + self._user_index = user_index + self._dest_mask = Mask.from_index(self._user_index) + self._gui_controller = gui_controller + self._sim = sim + self._gui_input = gui_input + self._gui_drawer = gui_drawer + self._camera_helper = camera_helper + + self._can_grasp_place_threshold = hitl_config.can_grasp_place_threshold + + # ID of the object being held. None if no object is held. + self._held_object_id: Optional[int] = None + # Cache of all link IDs and their parent articulated objects. + self._link_id_to_ao_map: Dict[int, int] = {} + # Cache of all opened articulated object links. + self._opened_link_set: Set = set() + # Cache of pickable objects IDs. + self._pickable_object_ids: Set[int] = set() + # Cache of interactable objects IDs. + self._interactable_object_ids: Set[int] = set() + # Last time a click was done. Used to track double-clicking. + self._last_click_time: datetime = datetime.now() + # Cache of goal object-receptacle pairs. + self._object_receptacle_pairs: List[Tuple[List[int], List[int]]] = [] + + # Selection trackers. + self._selections: List[Selection] = [] + # Track hovered object. + self._hover_selection = Selection( + self._sim, self._gui_input, Selection.hover_fn + ) + self._selections.append(self._hover_selection) + # Track left-clicked object. + self._click_selection = Selection( + self._sim, + self._gui_input, + Selection.left_click_fn, + ) + self._selections.append(self._click_selection) + + # Track drop placement. + def place_selection_fn(gui_input: GuiInput) -> bool: + return gui_input.get_mouse_button( + MouseButton.RIGHT + ) or gui_input.get_mouse_button_up(MouseButton.RIGHT) + + self._place_selection = Selection( + self._sim, + self._gui_input, + place_selection_fn, + ) + self._selections.append(self._place_selection) + + def reset( + self, object_receptacle_pairs: List[Tuple[List[int], List[int]]] + ) -> None: + """ + Reset the UI. Call on simulator reset. + """ + sim = self._sim + + self._held_object_id = None + self._link_id_to_ao_map = sim_utilities.get_ao_link_id_map(sim) + self._opened_link_set = set() + self._object_receptacle_pairs = object_receptacle_pairs + self._last_click_time = datetime.now() + for selection in self._selections: + selection.deselect() + + self._pickable_object_ids = set(sim._scene_obj_ids) + for pickable_obj_id in self._pickable_object_ids: + rigid_obj = self._get_rigid_object(pickable_obj_id) + # Ensure that rigid objects are collidable. + rigid_obj.collidable = True + + # Get set of interactable articulated object links. + # Exclude all agents. + agent_ao_object_ids: Set[int] = set() + agent_manager: ArticulatedAgentManager = sim.agents_mgr + for agent_index in range(len(agent_manager)): + agent = agent_manager[agent_index] + agent_ao = agent.articulated_agent.sim_obj + agent_ao_object_ids.add(agent_ao.object_id) + self._interactable_object_ids = set() + aom = sim.get_articulated_object_manager() + all_ao: List[ + ManagedArticulatedObject + ] = aom.get_objects_by_handle_substring().values() + # All add non-root links that are not agents. + for ao in all_ao: + if ao.object_id not in agent_ao_object_ids: + for link_object_id in ao.link_object_ids: + if link_object_id != ao.object_id: + self._interactable_object_ids.add(link_object_id) + + def update(self) -> None: + """ + Handle user actions and update the UI. + """ + + def _handle_double_click() -> bool: + time_since_last_click = datetime.now() - self._last_click_time + double_clicking = time_since_last_click < timedelta( + seconds=DOUBLE_CLICK_DELAY + ) + if not double_clicking: + self._last_click_time = datetime.now() + return double_clicking + + for selection in self._selections: + selection.update() + + if self._gui_input.get_mouse_button_down(MouseButton.LEFT): + clicked_object_id = self._click_selection.object_id + if _handle_double_click(): + # Double-click to select pickable. + if self._is_object_pickable(clicked_object_id): + self._pick_object(clicked_object_id) + # Double-click to interact. + elif self._is_object_interactable(clicked_object_id): + self._interact_with_object(clicked_object_id) + + # Drop when releasing right click. + if self._gui_input.get_mouse_button_up(MouseButton.RIGHT): + self._place_object() + self._place_selection.deselect() + + def draw_ui(self) -> None: + """ + Draw the UI. + """ + self._update_held_object_placement() + self._draw_place_selection() + self._draw_hovered_interactable() + self._draw_hovered_pickable() + self._draw_goals() + + def _pick_object(self, object_id: int) -> None: + """Pick the specified object_id. The object must be pickable.""" + if not self._is_holding_object() and self._is_object_pickable( + object_id + ): + rigid_object = self._get_rigid_object(object_id) + if rigid_object is not None: + rigid_pos = rigid_object.translation + if self._is_within_reach(rigid_pos): + # Pick the object. + self._held_object_id = object_id + self._place_selection.deselect() + + def _update_held_object_placement(self) -> None: + """Update the location of the held object.""" + object_id = self._held_object_id + if not object_id: + return + + eye_position = self._camera_helper.get_eye_pos() + forward_vector = ( + self._camera_helper.get_lookat_pos() + - self._camera_helper.get_eye_pos() + ).normalized() + + rigid_object = self._sim.get_rigid_object_manager().get_object_by_id( + object_id + ) + rigid_object.translation = eye_position + forward_vector + + def _place_object(self) -> None: + """Place the currently held object.""" + if not self._place_selection.selected: + return + + object_id = self._held_object_id + point = self._place_selection.point + normal = self._place_selection.normal + if ( + object_id is not None + and object_id != self._place_selection.object_id + and self._is_location_suitable_for_placement(point, normal) + ): + # Drop the object. + rigid_object = self._get_rigid_object(object_id) + rigid_object.translation = point + mn.Vector3( + 0.0, rigid_object.collision_shape_aabb.size_y() / 2, 0.0 + ) + self._held_object_id = None + self._place_selection.deselect() + + def _interact_with_object(self, object_id: int) -> None: + """Open/close the selected object. Must be interactable.""" + if self._is_object_interactable(object_id): + link_id = object_id + link_index = self._get_link_index(link_id) + if link_index: + ao_id = self._link_id_to_ao_map[link_id] + ao = self._get_articulated_object(ao_id) + link_node = ao.get_link_scene_node(link_index) + link_pos = link_node.translation + if self._is_within_reach(link_pos): + # Open/close object. + if link_id in self._opened_link_set: + sim_utilities.close_link(ao, link_index) + self._opened_link_set.remove(link_id) + else: + sim_utilities.open_link(ao, link_index) + self._opened_link_set.add(link_id) + + def _user_pos(self) -> mn.Vector3: + """Get the translation of the agent controlled by the user.""" + return get_agent_art_obj_transform( + self._sim, self._gui_controller._agent_idx + ).translation + + def _get_rigid_object(self, object_id: int) -> Optional[Any]: + """Get the rigid object with the specified ID. Returns None if unsuccessful.""" + rom = self._sim.get_rigid_object_manager() + return rom.get_object_by_id(object_id) + + def _get_articulated_object(self, object_id: int) -> Optional[Any]: + """Get the articulated object with the specified ID. Returns None if unsuccessful.""" + aom = self._sim.get_articulated_object_manager() + return aom.get_object_by_id(object_id) + + def _get_link_index(self, object_id: int) -> int: + """Get the index of a link. Returns None if unsuccessful.""" + link_id = object_id + if link_id in self._link_id_to_ao_map: + ao_id = self._link_id_to_ao_map[link_id] + ao = self._get_articulated_object(ao_id) + link_id_to_index: Dict[int, int] = ao.link_object_ids + if link_id in link_id_to_index: + return link_id_to_index[link_id] + return None + + def _horizontal_distance(self, a: mn.Vector3, b: mn.Vector3) -> float: + """Compute the distance between two points on the horizontal plane.""" + displacement = a - b + displacement.y = 0.0 + return mn.Vector3(displacement.x, 0.0, displacement.z).length() + + def _is_object_pickable(self, object_id: int) -> bool: + """Returns true if the object can be picked.""" + return object_id is not None and object_id in self._pickable_object_ids + + def _is_object_interactable(self, object_id: int) -> bool: + """Returns true if the object can be opened or closed.""" + return ( + object_id is not None + and object_id in self._interactable_object_ids + and object_id in self._link_id_to_ao_map + ) + + def _is_holding_object(self) -> bool: + """Returns true if the user is holding an object.""" + return self._held_object_id is not None + + def _is_within_reach(self, target_pos: mn.Vector3) -> bool: + """Returns true if the target can be reached by the user.""" + return ( + self._horizontal_distance(self._user_pos(), target_pos) + < self._can_grasp_place_threshold + ) + + def _is_location_suitable_for_placement( + self, point: mn.Vector3, normal: mn.Vector3 + ) -> bool: + """Returns true if the target location is suitable for placement.""" + placement_verticality = mn.math.dot(normal, mn.Vector3(0, 1, 0)) + placement_valid = placement_verticality > MINIMUM_DROP_VERTICALITY + return placement_valid and self._is_within_reach(point) + + def _draw_aabb( + self, aabb: mn.Range3D, transform: mn.Matrix4, color: mn.Color3 + ) -> None: + """Draw an AABB.""" + self._gui_drawer.push_transform( + transform, destination_mask=self._dest_mask + ) + self._gui_drawer.draw_box( + min_extent=aabb.back_bottom_left, + max_extent=aabb.front_top_right, + color=color, + destination_mask=self._dest_mask, + ) + self._gui_drawer.pop_transform(destination_mask=self._dest_mask) + + def _draw_place_selection(self) -> None: + """Draw the object placement selection.""" + if not self._place_selection.selected or self._held_object_id is None: + return + + point = self._place_selection.point + normal = self._place_selection.normal + placement_valid = self._is_location_suitable_for_placement( + point, normal + ) + color = COLOR_VALID if placement_valid else COLOR_INVALID + radius = 0.15 if placement_valid else 0.05 + self._gui_drawer.draw_circle( + translation=point, + radius=radius, + color=color, + normal=normal, + billboard=False, + destination_mask=self._dest_mask, + ) + + def _draw_hovered_interactable(self) -> None: + """Highlight the hovered interactable object.""" + if not self._hover_selection.selected: + return + + object_id = self._hover_selection.object_id + if not self._is_object_interactable(object_id): + return + + link_index = self._get_link_index(object_id) + if link_index: + ao = sim_utilities.get_obj_from_id( + self._sim, object_id, self._link_id_to_ao_map + ) + link_node = ao.get_link_scene_node(link_index) + aabb = link_node.cumulative_bb + reachable = self._is_within_reach(link_node.translation) + color = COLOR_VALID if reachable else COLOR_INVALID + self._draw_aabb(aabb, link_node.transformation, color) + + def _draw_hovered_pickable(self) -> None: + """Highlight the hovered pickable object.""" + if not self._hover_selection.selected or self._is_holding_object(): + return + + object_id = self._hover_selection.object_id + if not self._is_object_pickable(object_id): + return + + managed_object = sim_utilities.get_obj_from_id( + self._sim, object_id, self._link_id_to_ao_map + ) + translation = managed_object.translation + reachable = self._is_within_reach(translation) + color = COLOR_VALID if reachable else COLOR_INVALID + aabb = managed_object.collision_shape_aabb + self._draw_aabb(aabb, managed_object.transformation, color) + + def _draw_goals(self) -> None: + """Draw goal object-receptacle pairs.""" + # TODO: Cache + sim = self._sim + obj_receptacle_pairs = self._object_receptacle_pairs + link_id_to_ao_map = self._link_id_to_ao_map + dest_mask = self._dest_mask + get_obj_from_id = sim_utilities.get_obj_from_id + draw_gui_circle = self._gui_drawer.draw_circle + draw_gui_aabb = self._draw_aabb + + for i in range(len(obj_receptacle_pairs)): + rigid_ids = obj_receptacle_pairs[i][0] + receptacle_ids = obj_receptacle_pairs[i][1] + goal_pair_color = COLOR_GOALS[i % len(COLOR_GOALS)] + for rigid_id in rigid_ids: + managed_object = get_obj_from_id( + sim, rigid_id, link_id_to_ao_map + ) + translation = managed_object.translation + draw_gui_circle( + translation=translation, + radius=0.25, + color=goal_pair_color, + billboard=True, + destination_mask=dest_mask, + ) + for receptacle_id in receptacle_ids: + managed_object = get_obj_from_id( + sim, receptacle_id, link_id_to_ao_map + ) + aabb, matrix = sim_utilities.get_bb_for_object_id( + sim, receptacle_id, link_id_to_ao_map + ) + if aabb is not None: + draw_gui_aabb(aabb, matrix, goal_pair_color) diff --git a/habitat-hitl/habitat_hitl/core/selection.py b/habitat-hitl/habitat_hitl/core/selection.py new file mode 100644 index 0000000000..8c8bcea58a --- /dev/null +++ b/habitat-hitl/habitat_hitl/core/selection.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import magnum as mn + +from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim +from habitat_hitl.core.gui_input import GuiInput +from habitat_hitl.core.key_mapping import MouseButton +from habitat_sim.geo import Ray +from habitat_sim.physics import RayHitInfo + + +class Selection: + """ + Class that handles selection by tracking a given GuiInput. + """ + + def hover_fn(_gui_input: GuiInput) -> bool: # type: ignore + """Select the object under the cursor every frame.""" + return True + + def left_click_fn(_gui_input: GuiInput) -> bool: # type: ignore + """Select the object under the cursor when left clicking.""" + return _gui_input.get_mouse_button_down(MouseButton.LEFT) + + def right_click_fn(_gui_input: GuiInput) -> bool: # type: ignore + """Select the object under the cursor when right clicking.""" + return _gui_input.get_mouse_button_down(MouseButton.RIGHT) + + def default_discriminator(_object_id: int) -> bool: # type: ignore + """Pick any object ID.""" + return True + + def __init__( + self, + simulator: HabitatSim, + gui_input: GuiInput, + selection_fn: Callable[[GuiInput], bool], + object_id_discriminator: Callable[[int], bool] = default_discriminator, + ): + """ + :param simulator: Simulator that is raycast upon. + :param gui_input: GuiInput to track. + :param selection_fn: Function that returns true if gui_input is attempting selection. + :param object_id_discriminator: Function that determines whether an object ID is selectable. + By default, all objects are selectable. + """ + self._sim = simulator + self._gui_input = gui_input + self._discriminator = object_id_discriminator + self._selection_fn = selection_fn + + self._selected = False + self._object_id: Optional[int] = None + self._point: Optional[mn.Vector3] = None + self._normal: Optional[mn.Vector3] = None + + @property + def selected(self) -> bool: + """Returns true if something is selected.""" + return self._selected + + @property + def object_id(self) -> Optional[int]: + """Currently selected object ID.""" + return self._object_id + + @property + def point(self) -> Optional[mn.Vector3]: + """Point of the currently selected location.""" + return self._point + + @property + def normal(self) -> Optional[mn.Vector3]: + """Normal at the currently selected location.""" + return self._normal + + def deselect(self) -> None: + """Clear selection.""" + self._selected = False + self._object_id = None + self._point = None + self._normal = None + + def update(self) -> None: + """Update selection.""" + if self._selection_fn(self._gui_input): + ray = self._gui_input.mouse_ray + if ray is not None: + hit_info = self._raycast(ray) + if hit_info is None: + self.deselect() + return + + object_id: int = hit_info.object_id + + if self._discriminator(object_id): + self._selected = True + self._object_id = object_id + self._point = hit_info.point + self._normal = hit_info.normal + else: + self.deselect() + + def _raycast(self, ray: Ray) -> Optional[RayHitInfo]: + raycast_results = self._sim.cast_ray(ray=ray) + if not raycast_results.has_hits(): + return None + # Results are sorted by distance. [0] is the nearest one. + hit_info = raycast_results.hits[0] + return hit_info