diff --git a/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py b/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py index d58505a8d5..695ccff8d0 100644 --- a/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py +++ b/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py @@ -434,3 +434,88 @@ def get_obj_from_handle( return aom.get_object_by_handle(obj_handle) return None + + +def get_rigid_object_global_keypoints( + objectA: habitat_sim.physics.ManagedRigidObject, +) -> List[mn.Vector3]: + """ + Get a list of rigid object keypoints in global space. + 0th point is the bounding box center, others are bounding box corners. + + :param objectA: The ManagedRigidObject from which to extract keypoints. + + :return: A set of global 3D keypoints for the object. + """ + + bb = objectA.root_scene_node.cumulative_bb + local_keypoints = [bb.center()] + local_keypoints.extend(get_bb_corners(bb)) + global_keypoints = [ + objectA.transformation.transform_point(key_point) + for key_point in local_keypoints + ] + return global_keypoints + + +def object_keypoint_cast( + sim: habitat_sim.Simulator, + objectA: habitat_sim.physics.ManagedRigidObject, + direction: mn.Vector3 = None, +) -> List[habitat_sim.physics.RaycastResults]: + """ + Computes object global keypoints, casts rays from each in the specified direction and returns the resulting RaycastResults. + + :param sim: The Simulator instance. + :param objectA: The ManagedRigidObject from which to extract keypoints and raycast. + :param direction: Optionally provide a unit length global direction vector for the raycast. If None, default to -Y. + + :return: A list of RaycastResults, one from each object keypoint. + """ + + if direction is None: + # default to downward raycast + direction = mn.Vector3(0, -1, 0) + + global_keypoints = get_rigid_object_global_keypoints(objectA) + return [ + sim.cast_ray(habitat_sim.geo.Ray(keypoint, direction)) + for keypoint in global_keypoints + ] + + +# ============================================================ +# Utilities for Querying Object Relationships +# ============================================================ + + +def above( + sim: habitat_sim.Simulator, + objectA: Union[ + habitat_sim.physics.ManagedRigidObject, + habitat_sim.physics.ManagedArticulatedObject, + ], +) -> List[int]: + """ + Get a list of all objects that a particular objectA is 'above'. + Concretely, 'above' is defined as: a downward raycast of any object keypoint hits the object below. + + :param sim: The Simulator instance. + :param objectA: The ManagedRigidObject for which to query the 'above' set. + + :return: a list of object ids. + """ + + # get object ids of all objects below this one + above_object_ids = [ + hit.object_id + for keypoint_raycast_result in object_keypoint_cast(sim, objectA) + for hit in keypoint_raycast_result.hits + ] + above_object_ids = list(set(above_object_ids)) + + # remove self from the list if present + if objectA.object_id in above_object_ids: + above_object_ids.remove(objectA.object_id) + + return above_object_ids diff --git a/test/test_sim_utils.py b/test/test_sim_utils.py index 2f4e8eaea5..fb7c26e1ae 100644 --- a/test/test_sim_utils.py +++ b/test/test_sim_utils.py @@ -10,12 +10,14 @@ import pytest from habitat.sims.habitat_simulator.sim_utilities import ( + above, bb_ray_prescreen, get_all_object_ids, get_all_objects, get_ao_link_id_map, get_obj_from_handle, get_obj_from_id, + object_keypoint_cast, snap_down, ) from habitat_sim import Simulator, built_with_bullet @@ -222,3 +224,67 @@ def test_object_getters(): sim, link_object_id, ao_link_map ) assert obj_from_id_getter.object_id == ao.object_id + + +@pytest.mark.skipif( + not built_with_bullet, + reason="Raycasting API requires Bullet physics.", +) +@pytest.mark.skipif( + not osp.exists("data/replica_cad/"), + reason="Requires ReplicaCAD dataset.", +) +def test_keypoint_cast_prepositions(): + sim_settings = default_sim_settings.copy() + sim_settings[ + "scene_dataset_config_file" + ] = "data/replica_cad/replicaCAD.scene_dataset_config.json" + sim_settings["scene"] = "apt_0" + hab_cfg = make_cfg(sim_settings) + with Simulator(hab_cfg) as sim: + all_objects = get_all_object_ids(sim) + + mixer_object = get_obj_from_handle( + sim, "frl_apartment_small_appliance_01_:0000" + ) + mixer_above = above(sim, mixer_object) + mixer_above_strings = [ + all_objects[obj_id] for obj_id in mixer_above if obj_id >= 0 + ] + expected_mixer_above_strings = [ + "kitchen_counter_:0000", + "kitchen_counter_:0000 -- drawer2_bottom", + "kitchen_counter_:0000 -- drawer2_middle", + "kitchen_counter_:0000 -- drawer2_top", + ] + for expected in expected_mixer_above_strings: + assert expected in mixer_above_strings + assert len(mixer_above_strings) == len(expected_mixer_above_strings) + + tv_object = get_obj_from_handle(sim, "frl_apartment_tv_screen_:0000") + tv_above = above(sim, tv_object) + tv_above_strings = [ + all_objects[obj_id] for obj_id in tv_above if obj_id >= 0 + ] + expected_tv_above_strings = [ + "frl_apartment_tvstand_:0000", + "frl_apartment_chair_01_:0000", + ] + + for expected in expected_tv_above_strings: + assert expected in tv_above_strings + assert len(tv_above_strings) == len(expected_tv_above_strings) + + # now define a custom keypoint cast from the mixer constructed to include tv in the set + mixer_to_tv = ( + tv_object.translation - mixer_object.translation + ).normalized() + mixer_to_tv_object_ids = [ + hit.object_id + for keypoint_raycast_result in object_keypoint_cast( + sim, mixer_object, direction=mixer_to_tv + ) + for hit in keypoint_raycast_result.hits + ] + mixer_to_tv_object_ids = list(set(mixer_to_tv_object_ids)) + assert tv_object.object_id in mixer_to_tv_object_ids