Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-zhng committed Sep 25, 2024
1 parent 3dd72ab commit 21d7e3c
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions stac_mjx/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,11 @@ def _package_data(self, mjx_model, q, x, walker_body_sites, kp_data, batched=Fal
if batched:
# prepare batched data to be packaged
get_batch_offsets = jax.vmap(op.get_site_pos, in_axes=(0, None))
offsets = get_batch_offsets(mjx_model, self._body_site_idxs).copy()[0]
offsets = get_batch_offsets(mjx_model, self._body_site_idxs)[0]

Check warning on line 389 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L389

Added line #L389 was not covered by tests
x = x.reshape(-1, x.shape[-1])
q = q.reshape(-1, q.shape[-1])
else:
offsets = (
self._offsets
) # op.get_site_pos(mjx_model, self._body_site_idxs).copy()
offsets = self._offsets

Check warning on line 393 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L393

Added line #L393 was not covered by tests

kp_data = kp_data.reshape(-1, kp_data.shape[-1])

Expand Down Expand Up @@ -511,7 +509,7 @@ def render(
self._create_keypoint_sites()
)

# Add body sites for new offset positions
# Add body sites for new offsets
for (key, v), pos in zip(

Check warning on line 513 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L513

Added line #L513 was not covered by tests
self.cfg.model.KEYPOINT_MODEL_PAIRS.items(), offsets.reshape((-1, 3))
):
Expand All @@ -526,11 +524,9 @@ def render(
group=2,
)

# tendons from new marker sites to kp
# Tendons from new marker sites to kp
if show_marker_error:
for key, v in self.cfg.model.KEYPOINT_MODEL_PAIRS.items():
# pos = utils.params["KEYPOINT_INITIAL_OFFSETS"][key]
rgba = self.cfg.model.KEYPOINT_COLOR_PAIRS[key]
tendon = self._root.tendon.add(

Check warning on line 530 in stac_mjx/controller.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/controller.py#L528-L530

Added lines #L528 - L530 were not covered by tests
"spatial",
name=key + "-" + v,
Expand Down Expand Up @@ -575,7 +571,7 @@ def render(
# render while stepping using mujoco
with imageio.get_writer(save_path, fps=self.cfg.model.RENDER_FPS) as video:
for qpos, kps in tqdm(zip(qposes, kp_data)):
# Set keypoints--they're in cartesian space, but since they're attached to the worldbody it's the same
# Set keypoints--they're in cartesian space, but since they're attached to the worldbody they're the same as offsets
render_mj_model.site_pos[keypoint_site_idxs] = np.reshape(kps, (-1, 3))
mj_data.qpos = qpos

Expand Down

0 comments on commit 21d7e3c

Please sign in to comment.