Skip to content

Commit

Permalink
Synthetic data smoke test. (#75)
Browse files Browse the repository at this point in the history
* Model runs + draws in notebook, no data output

* Configs etc - model not yet working.

* Disable energy

* Offset shape bugfix (#73)

* fix offset shape when saving

* update demo config

* linter

* fix demo.yaml

* update configs

* update config test

---------

Co-authored-by: Charles Zhang <charleszhang@boslogin02.rc.fas.harvard.edu>

* Smoke test (#74)

* run stac experiment

* Fixed yaml

* Test fail with bad input.

* Should fail.

* Corrected input - test should pass.

* Update demo.yaml - enable ik_only()

* Revert update demo.yaml

* IT'S WORKING

* Offset shape bugfix (#73)

* fix offset shape when saving

* update demo config

* linter

* fix demo.yaml

* update configs

* update config test

---------

Co-authored-by: Charles Zhang <charleszhang@boslogin02.rc.fas.harvard.edu>

* Configs etc - model not yet working.

* Offset shape bugfix (#73)

* fix offset shape when saving

* update demo config

* linter

* fix demo.yaml

* update configs

* update config test

---------

Co-authored-by: Charles Zhang <charleszhang@boslogin02.rc.fas.harvard.edu>

* Configs etc - model not yet working.

* Fix weird merge.

* Clean up synth_model config file.

* Remove TIME_BINS (which was a merge accident.)

* Fix smoke test.

* Fix smoke test.

* Clean up.

* Fixed root optimization, but still some debug code.

* Add root_kp_index

* Forgot model yaml.

* Reset rodent configs, enable synth config.

* Add synth_data smoke test.

* Missed data file.

* Clean up.

* Add root opt keypoint to model configs + clean up.

* Clean up.

* CR feedback.

* Add synth data generation program.

* Add comments.

---------

Co-authored-by: Charles Zhang <33401293+charles-zhng@users.noreply.github.com>
Co-authored-by: Charles Zhang <charleszhang@boslogin02.rc.fas.harvard.edu>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent 11819ee commit f3980e4
Show file tree
Hide file tree
Showing 19 changed files with 949 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
verbose: false
token: ${{ secrets.CODECOV_TOKEN }}

# Test probably delete this
# Smoke test. Shows end to end run with out crashing.
- name: Smoke Test
shell: bash -l {0}
run: python run_stac.py
run: python run_stac.py stac=stac_synth_data model=synth_data
594 changes: 594 additions & 0 deletions Mat-to-Nwb-Synth-Data.ipynb

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions configs/model/fly_tethered.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,3 @@ N_SAMPLE_FRAMES: 100
# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1

TIME_BINS: 0.02
4 changes: 2 additions & 2 deletions configs/model/fly_treadmill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ KEYPOINT_INITIAL_OFFSETS:
l3: 0. 0. 0.
r3: 0. 0. 0.

ROOT_OPTIMIZATION_KEYPOINT: head

TRUNK_OPTIMIZATION_KEYPOINTS:
- 'head'
- 'thorax'
Expand Down Expand Up @@ -101,5 +103,3 @@ N_SAMPLE_FRAMES: 100
# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1

TIME_BINS: 0.02
2 changes: 2 additions & 0 deletions configs/model/mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ KEYPOINT_INITIAL_OFFSETS:
Lisfranc_L: 0.0 0.0 0.0
MTP_R: 0.0 0.0 0.0

ROOT_OPTIMIZATION_KEYPOINT: Trunk

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Trunk"
- "HipL"
Expand Down
2 changes: 2 additions & 0 deletions configs/model/rodent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "SpineF"
- "SpineL"
Expand Down
58 changes: 58 additions & 0 deletions configs/model/synth_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

MJCF_PATH: 'models/synth_model.xml'

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 1

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 1

KP_NAMES:
- part_0

ROOT_OPTIMIZATION_KEYPOINT: part_0

# The model sites used to register the keypoints.
KEYPOINT_MODEL_PAIRS:
part_0: base

# The initial offsets for each keypoint in meters.
KEYPOINT_INITIAL_OFFSETS:
part_0: 0 0 0.01

TRUNK_OPTIMIZATION_KEYPOINTS:
- part_0

INDIVIDUAL_PART_OPTIMIZATION:
model_base: [base]

# Color to use for each keypoint when visualizing the results
KEYPOINT_COLOR_PAIRS:
part_0: 0 .5 1 1

# What is the size of the animal you'd like to register, relative to the model?
SCALE_FACTOR: 1

# Multiplier to put the mocap data into the same scale as the data. Eg, if the
# mocap data is known to be in millimeters and the model is in meters, this is
# .001
MOCAP_SCALE_FACTOR: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using this with M_REG_COEF.
SITES_TO_REGULARIZE:
- part_0

RENDER_FPS: 200

N_SAMPLE_FRAMES: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
2 changes: 1 addition & 1 deletion configs/stac/stac_fly_treadmill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ik_only_path: "transform_treadmill.p"
# File is too large to commit
# DL from: https://datadryad.org/stash/dataset/doi:10.5061/dryad.mpg4f4r73
# Actual file: https://datadryad.org/stash/downloads/file_stream/3361804
data_path: "/tests/data/wt_berlin_linear_treadmill_dataset.csv"
data_path: "../tests/data/wt_berlin_linear_treadmill_dataset.csv"

n_fit_frames: 1800
skip_fit: False
Expand Down
12 changes: 12 additions & 0 deletions configs/stac/stac_synth_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fit_offsets_path: "synth_fit.p"
ik_only_path: "synth_ik_only.p"
data_path: "tests/data/test_synth_1_frames.nwb"

n_fit_frames: 1
skip_fit_offsets: False
skip_ik_only: False

mujoco:
solver: newton
iterations: 1
ls_iterations: 4
224 changes: 224 additions & 0 deletions demos/create_synth_data.ipynb

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions models/synth_model.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<mujoco>
<option timestep=".001">
</option>

<default>
<joint type="hinge" axis="0 -1 0"/>
<geom type="capsule" size=".02"/>
</default>

<worldbody>
<light pos="0 -.4 1"/>
<camera name="fixed" pos="0 -1 0" xyaxes="1 0 0 0 0 1"/>
<body name="base" pos="0 0 .2">
<joint type="free" name="root"/>
<geom fromto="0 0 0 0 0 -.25" rgba="1 1 0 1"/>
</body>
</worldbody>
</mujoco>
3 changes: 2 additions & 1 deletion stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def root_optimization(
mjx_model,
mjx_data,
kp_data: jp.ndarray,
root_kp_idx: int,
lb: jp.ndarray,
ub: jp.ndarray,
site_idxs: jp.ndarray,
Expand Down Expand Up @@ -50,7 +51,7 @@ def root_optimization(
# necessarily exactly so. The value of 3*18 is chosen for the
# rodent.xml, corresponding to the index of 'SpineL' keypoint.
# For the mouse model this should be 3*5, corresponding 'Trunk'
root_kp_idx = 3 * 18
# root_kp_idx = 3 * 18
# FLY_MODEL:
# root_kp_idx = 0
q0.at[:3].set(kp_data[frame, :][root_kp_idx : root_kp_idx + 3])
Expand Down
27 changes: 22 additions & 5 deletions stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,22 @@ def __init__(self, xml_path: str, cfg: DictConfig, kp_names: List[str]):
self._mj_model.body(i).name for i in range(self._mj_model.nbody)
]

joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]
if "ROOT_OPTIMIZATION_KEYPOINT" in self.cfg.model:
self._root_kp_idx = self._kp_names.index(
self.cfg.model.ROOT_OPTIMIZATION_KEYPOINT
)
else:
self._root_kp_idx = -1

# Set up bounds and part_names based on joint ranges, taking into account the dimensionality of parameters
joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]
self._lb, self._ub, self._part_names = _align_joint_dims(
self._mj_model.jnt_type, self._mj_model.jnt_range, joint_names
)

self._indiv_parts = self.part_opt_setup()

# Generate boolean flags for keypoints included in trunk optimization.
self._trunk_kps = jp.array(
[n in self.cfg.model.TRUNK_OPTIMIZATION_KEYPOINTS for n in kp_names],
)
Expand All @@ -113,7 +120,7 @@ def get_part_ids(parts: List) -> jp.ndarray:
[any(part in name for part in parts) for name in self._part_names]
)

if self.cfg.model.INDIVIDUAL_PART_OPTIMIZATION is None:
if "INDIVIDUAL_PART_OPTIMIZATION" not in self.cfg.model:
indiv_parts = []
else:
indiv_parts = jp.array(
Expand Down Expand Up @@ -224,11 +231,16 @@ def fit_offsets(self, kp_data):

# Begin optimization steps
# Skip root optimization if model is fixed (no free joint at root)
if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
if self._root_kp_idx == -1:
print(
"ROOT_OPTIMIZATION_KEYPOINT not specified, skipping Root Optimization."
)
elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
mjx_data = compute_stac.root_optimization(
mjx_model,
mjx_data,
kp_data,
self._root_kp_idx,
self._lb,
self._ub,
self._body_site_idxs,
Expand Down Expand Up @@ -339,15 +351,20 @@ def mjx_setup(kp_data, mj_model):
)

# q_phase - root
if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
if self._root_kp_idx == -1:
print(
"Missing or invalid ROOT_OPTIMIZATION_KEYPOINT, skipping root_optimization()"
)
elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
vmap_root_opt = jax.vmap(
compute_stac.root_optimization,
in_axes=(0, 0, 0, None, None, None, None),
in_axes=(0, 0, 0, None, None, None, None, None),
)
mjx_data = vmap_root_opt(
mjx_model,
mjx_data,
batched_kp_data,
self._root_kp_idx,
self._lb,
self._ub,
self._body_site_idxs,
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ KEYPOINT_INITIAL_OFFSETS:
Lisfranc_L: 0.0 0.0 0.0
MTP_R: 0.0 0.0 0.0

ROOT_OPTIMIZATION_KEYPOINT: Trunk

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Trunk"
- "HipL"
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_rodent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
3 changes: 2 additions & 1 deletion tests/configs/model/test_rodent_label3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ N_FRAMES_PER_CLIP: 360
# presumed to be derived from label3d:
KP_NAMES_LABEL3D_PATH: "tests/data/rat23.mat"


# The model sites used to register the keypoints.
KEYPOINT_MODEL_PAIRS:
AnkleL: lower_leg_L
Expand Down Expand Up @@ -61,6 +60,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_rodent_less_kp_names.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_rodent_no_kp_names.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
Binary file added tests/data/test_synth_1_frames.nwb
Binary file not shown.

0 comments on commit f3980e4

Please sign in to comment.