Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Above" set query util #1817

Merged
merged 7 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,88 @@ def get_obj_from_handle(
return aom.get_object_by_handle(obj_handle)

return None


def get_rigid_object_global_keypoints(
objectA: habitat_sim.physics.ManagedRigidObject,
) -> List[mn.Vector3]:
"""
Get a list of rigid object keypoints in global space.
0th point is the bounding box center, others are bounding box corners.

:param objectA: The ManagedRigidObject from which to extract keypoints.

:return: A set of global 3D keypoints for the object.
"""

bb = objectA.root_scene_node.cumulative_bb
local_keypoints = [bb.center()]
local_keypoints.extend(get_bb_corners(bb))
global_keypoints = [
objectA.transformation.transform_point(key_point)
for key_point in local_keypoints
]
return global_keypoints


def object_keypoint_cast(
sim: habitat_sim.Simulator,
objectA: habitat_sim.physics.ManagedRigidObject,
direction: mn.Vector3 = None,
) -> List[habitat_sim.physics.RaycastResults]:
"""
Computes object global keypoints, casts rays from each in the specified direction and returns the resulting RaycastResults.

:param sim: The Simulator instance.
:param objectA: The ManagedRigidObject from which to extract keypoints and raycast.
:param direction: Optionally provide a unit length global direction vector for the raycast. If None, default to -Y.

:return: A list of RaycastResults, one from each object keypoint.
"""

if direction is None:
# default to downward raycast
direction = mn.Vector3(0, -1, 0)

global_keypoints = get_rigid_object_global_keypoints(objectA)
return [
sim.cast_ray(habitat_sim.geo.Ray(keypoint, direction))
for keypoint in global_keypoints
]


# ============================================================
# Utilities for Querying Object Relationships
# ============================================================


def above(
sim: habitat_sim.Simulator,
objectA: Union[
habitat_sim.physics.ManagedRigidObject,
habitat_sim.physics.ManagedArticulatedObject,
],
) -> List[int]:
"""
Get a list of all objects that a particular objectA is 'above'.
Concretely, 'above' is defined as: a downward raycast of any object keypoint hits the object below.

:param sim: The Simulator instance.
:param objectA: The ManagedRigidObject for which to query the 'above' set.

:return: a list of object ids.
"""

# get object ids of all objects below this one
above_object_ids = [
hit.object_id
for keypoint_raycast_result in object_keypoint_cast(sim, objectA)
for hit in keypoint_raycast_result.hits
]
above_object_ids = list(set(above_object_ids))

# remove self from the list if present
if objectA.object_id in above_object_ids:
above_object_ids.remove(objectA.object_id)

return above_object_ids
66 changes: 66 additions & 0 deletions test/test_sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import pytest

from habitat.sims.habitat_simulator.sim_utilities import (
above,
bb_ray_prescreen,
get_all_object_ids,
get_all_objects,
get_ao_link_id_map,
get_obj_from_handle,
get_obj_from_id,
object_keypoint_cast,
snap_down,
)
from habitat_sim import Simulator, built_with_bullet
Expand Down Expand Up @@ -222,3 +224,67 @@ def test_object_getters():
sim, link_object_id, ao_link_map
)
assert obj_from_id_getter.object_id == ao.object_id


@pytest.mark.skipif(
not built_with_bullet,
reason="Raycasting API requires Bullet physics.",
)
@pytest.mark.skipif(
not osp.exists("data/replica_cad/"),
reason="Requires ReplicaCAD dataset.",
)
def test_keypoint_cast_prepositions():
sim_settings = default_sim_settings.copy()
sim_settings[
"scene_dataset_config_file"
] = "data/replica_cad/replicaCAD.scene_dataset_config.json"
sim_settings["scene"] = "apt_0"
hab_cfg = make_cfg(sim_settings)
with Simulator(hab_cfg) as sim:
all_objects = get_all_object_ids(sim)

mixer_object = get_obj_from_handle(
sim, "frl_apartment_small_appliance_01_:0000"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these sim entities stable enough to be included in tests, or should we be placing new entities?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming these are stable because they are ReplicaCAD scenes which we have no intention of modifying. This scene "apt_0" is a reconstruction of the first static "FRL apartment" arrangement from Replica.

If anything were to change in the future it would most likely be dynamic settling which should preserve the tested relationships here.

)
mixer_above = above(sim, mixer_object)
mixer_above_strings = [
all_objects[obj_id] for obj_id in mixer_above if obj_id >= 0
]
expected_mixer_above_strings = [
"kitchen_counter_:0000",
"kitchen_counter_:0000 -- drawer2_bottom",
"kitchen_counter_:0000 -- drawer2_middle",
"kitchen_counter_:0000 -- drawer2_top",
]
for expected in expected_mixer_above_strings:
assert expected in mixer_above_strings
assert len(mixer_above_strings) == len(expected_mixer_above_strings)

tv_object = get_obj_from_handle(sim, "frl_apartment_tv_screen_:0000")
tv_above = above(sim, tv_object)
tv_above_strings = [
all_objects[obj_id] for obj_id in tv_above if obj_id >= 0
]
expected_tv_above_strings = [
"frl_apartment_tvstand_:0000",
"frl_apartment_chair_01_:0000",
]

for expected in expected_tv_above_strings:
assert expected in tv_above_strings
assert len(tv_above_strings) == len(expected_tv_above_strings)

# now define a custom keypoint cast from the mixer constructed to include tv in the set
mixer_to_tv = (
tv_object.translation - mixer_object.translation
).normalized()
mixer_to_tv_object_ids = [
hit.object_id
for keypoint_raycast_result in object_keypoint_cast(
sim, mixer_object, direction=mixer_to_tv
)
for hit in keypoint_raycast_result.hits
]
mixer_to_tv_object_ids = list(set(mixer_to_tv_object_ids))
assert tv_object.object_id in mixer_to_tv_object_ids