Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

render new marker offsets, add option to visualize marker error #63

Merged
merged 6 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified demos/demo_viz.p
Binary file not shown.
22 changes: 11 additions & 11 deletions demos/viz_usage.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@

print(f"offset optimization finished in {time.time()-s}")

return mjx_model, mjx_data
return mjx_model, mjx_data, offset_opt_param

Check warning on line 179 in stac_mjx/compute_stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/compute_stac.py#L179

Added line #L179 was not covered by tests


def pose_optimization(
Expand Down
64 changes: 49 additions & 15 deletions stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
name=key,
type="sphere",
size=[0.005],
rgba="0 0 0 1",
rgba="0 0 0 0.8",
pos=pos,
group=3,
)
Expand Down Expand Up @@ -258,7 +258,7 @@
print(f"Standard deviation: {std}")

print("starting offset optimization")
mjx_model, mjx_data = compute_stac.offset_optimization(
mjx_model, mjx_data, self._offsets = compute_stac.offset_optimization(

Check warning on line 261 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L261

Added line #L261 was not covered by tests
mjx_model,
mjx_data,
kp_data,
Expand Down Expand Up @@ -386,11 +386,11 @@
if batched:
# prepare batched data to be packaged
get_batch_offsets = jax.vmap(op.get_site_pos, in_axes=(0, None))
offsets = get_batch_offsets(mjx_model, self._body_site_idxs).copy()[0]
offsets = get_batch_offsets(mjx_model, self._body_site_idxs)[0]

Check warning on line 389 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L389

Added line #L389 was not covered by tests
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])
else:
offsets = op.get_site_pos(mjx_model, self._body_site_idxs).copy()
offsets = self._offsets

Check warning on line 393 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L393

Added line #L393 was not covered by tests

kp_data = kp_data.reshape(-1, kp_data.shape[-1])

Expand Down Expand Up @@ -468,6 +468,7 @@
camera: Union[int, str] = 0,
height: int = 1200,
width: int = 1920,
show_marker_error: bool = False,
):
"""Creates rendering using the instantiated model, given the user's qposes and kp_data.

Expand All @@ -481,6 +482,7 @@
camera (Union[int, str], optional): Mujoco camera name. Defaults to 0.
height (int, optional): Height in pixels. Defaults to 1200.
width (int, optional): Width in pixels. Defaults to 1920.
show_marker_error (bool, optional): Show distance between marker and keypoint. Defaults to False.

Raises:
ValueError: qposes and kp_data must have same length (shape[0])
Expand All @@ -506,28 +508,59 @@
render_mj_model, body_site_idxs, keypoint_site_idxs = (
self._create_keypoint_sites()
)
render_mj_model.site_pos[body_site_idxs] = offsets

# Add body sites for new offsets
for (key, v), pos in zip(

Check warning on line 513 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L513

Added line #L513 was not covered by tests
self.cfg.model.KEYPOINT_MODEL_PAIRS.items(), offsets.reshape((-1, 3))
):
parent = self._root.find("body", v)
parent.add(

Check warning on line 517 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L516-L517

Added lines #L516 - L517 were not covered by tests
"site",
name=key + "_new",
type="sphere",
size=[0.005],
rgba="0 0 0 1",
pos=pos,
group=2,
)

# Tendons from new marker sites to kp
if show_marker_error:
for key, v in self.cfg.model.KEYPOINT_MODEL_PAIRS.items():
tendon = self._root.tendon.add(

Check warning on line 530 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L528-L530

Added lines #L528 - L530 were not covered by tests
"spatial",
name=key + "-" + v,
width="0.001",
rgba="255 0 0 1", # Red
limited=False,
)
tendon.add("site", site=key + "_kp")
tendon.add("site", site=key + "_new")

Check warning on line 538 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L537-L538

Added lines #L537 - L538 were not covered by tests

physics = mjcf.Physics.from_mjcf_model(self._root)
render_mj_model = deepcopy(physics.model.ptr)

Check warning on line 541 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L540-L541

Added lines #L540 - L541 were not covered by tests

scene_option = mujoco.MjvOption()
scene_option.geomgroup[1] = 0

Check warning on line 544 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L544

Added line #L544 was not covered by tests
scene_option.geomgroup[2] = 1

scene_option.sitegroup[2] = 1

scene_option.sitegroup[3] = 1
scene_option.sitegroup[3] = 0

Check warning on line 549 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L549

Added line #L549 was not covered by tests
scene_option.flags[enums.mjtVisFlag.mjVIS_TRANSPARENT] = True
scene_option.flags[enums.mjtVisFlag.mjVIS_LIGHT] = False
scene_option.flags[enums.mjtVisFlag.mjVIS_LIGHT] = True

Check warning on line 551 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L551

Added line #L551 was not covered by tests
scene_option.flags[enums.mjtVisFlag.mjVIS_CONVEXHULL] = True
scene_option.flags[enums.mjtRndFlag.mjRND_SHADOW] = False
scene_option.flags[enums.mjtRndFlag.mjRND_REFLECTION] = False
scene_option.flags[enums.mjtRndFlag.mjRND_SKYBOX] = False
scene_option.flags[enums.mjtRndFlag.mjRND_FOG] = False

scene_option.flags[enums.mjtRndFlag.mjRND_SHADOW] = True
scene_option.flags[enums.mjtRndFlag.mjRND_REFLECTION] = True
scene_option.flags[enums.mjtRndFlag.mjRND_SKYBOX] = True
scene_option.flags[enums.mjtRndFlag.mjRND_FOG] = True

Check warning on line 556 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L553-L556

Added lines #L553 - L556 were not covered by tests
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
mj_data = mujoco.MjData(render_mj_model)

mujoco.mj_kinematics(render_mj_model, mj_data)

renderer = mujoco.Renderer(render_mj_model, height=height, width=width)

# slice kp_data to match qposes length
# Slice kp_data to match qposes length
kp_data = kp_data[: qposes.shape[0]]

# Slice arrays to be the range that is being rendered
Expand All @@ -538,10 +571,11 @@
# render while stepping using mujoco
with imageio.get_writer(save_path, fps=self.cfg.model.RENDER_FPS) as video:
for qpos, kps in tqdm(zip(qposes, kp_data)):
# Set keypoints
# Set keypoints--they're in cartesian space, but since they're attached to the worldbody they're the same as offsets
render_mj_model.site_pos[keypoint_site_idxs] = np.reshape(kps, (-1, 3))
mj_data.qpos = qpos
mujoco.mj_forward(render_mj_model, mj_data)

mujoco.mj_fwdPosition(render_mj_model, mj_data)

Check warning on line 578 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L578

Added line #L578 was not covered by tests

renderer.update_scene(mj_data, camera=camera, scene_option=scene_option)
pixels = renderer.render()
Expand Down
2 changes: 2 additions & 0 deletions stac_mjx/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def viz_stac(
height: int = 1200,
width: int = 1920,
base_path=None,
show_marker_error=False,
):
"""Render forward kinematics from keypoint positions.

Expand Down Expand Up @@ -61,4 +62,5 @@ def viz_stac(
camera,
height,
width,
show_marker_error,
)
Loading