Skip to content

Commit

Permalink
basic python api (#31)
Browse files Browse the repository at this point in the history
Implement basic python api functions, load_configs and run_stac, along with unit tests and various small fixes. Includes refactoring, example notebooks and an updated README. 

---------

Co-authored-by: Charles Zhang <charleszhang216@mgmail.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Talmo Pereira <talmo@salk.edu>
  • Loading branch information
4 people authored Aug 15, 2024
1 parent 0d5e07b commit 04568a1
Show file tree
Hide file tree
Showing 30 changed files with 1,263 additions and 1,071 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
lcov.info

.DS_Store
snippets*
# error data files
Expand All @@ -7,7 +9,6 @@ snippets*
*.p
videos/*
!videos/.dummy
*.mat
outputs/
*.mp4

Expand Down
80 changes: 62 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,85 @@ Implementation of [STAC](https://ieeexplore.ieee.org/document/7030016) using [MJ

## Installation
stac-mjx relies on many prerequisites, therefore we suggest installing in a new conda environment, using the provided `environment.yaml`:

Create and activate the `stac-mjx-env` environment:
[Local installation before package is officially published]
1. Clone the repository `git clone https://github.com/talmolab/stac-mjx.git` and `cd` into it
2. Create and activate the `stac-mjx-env` environment:

```
conda env create -f environment.yaml
conda activate stac-mjx-env
```

## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. For new data, first run stac on just a small subset of the data with

`python stac_mjx/main.py test.skip_transform=True`

Note: this currently will fail w/o supplying a data file.

3. Render the resulting data using `mujoco_viz()` from within `viz_usage.ipynb`. Currently, this uses headless rendering on CPU via `osmesa`, which requires its own setup. To set up (currently on supported on Linux), execute the following commands sequentially:
```
Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`. We show `osmesa` setup as it supports headless rendering, which is common in remote/cluster setups. To set up (currently on supported on Linux), execute the following commands sequentially:
```bash
sudo apt-get install libglfw3 libglew2.0 libgl1-mesa-glx libosmesa6
conda install -c conda-forge glew
conda install -c conda-forge mesalib
conda install -c anaconda mesa-libgl-cos6-x86_64
conda install -c menpo glfw3
```
Finally, set the following environment variables, and reactivate the conda environment:
```
```bash
conda env config vars set MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa
conda deactivate && conda activate base
```
To ensure all of the above changes are encapsulated in your Jupyter kernel, a create a new kernel with:
```
To ensure all of the above changes are encapsulated in your Jupyter kernel, create a new kernel with:
```bash
conda install ipykernel
python -m ipykernel install --user --name stac-mjx-env --display-name "Python (stac-mjx-env)"
```

4. After tuning parameters and confirming the small clip is processed well, run through the whole thing with
`python stac-mjx/main.py`

## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`.

```python
from stac_mjx import main
from stac_mjx import utils
from pathlib import Path
# Set base path to the parent directory of your config files
base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"
# Load configs
cfg = main.load_configs(stac_config_path, model_config_path)
# Load data
data_path = base_path / cfg.paths.data_path
kp_data = utils.load_data(data_path, utils.params)
# Run stac
fit_path, transform_path = main.run_stac(cfg, kp_data, base_path)
```

3. Render the resulting data using `mujoco_viz()` (example notebook found in `demos/viz_usage.ipynb`):
```python
import os
import mediapy as media

from stac_mjx.viz import mujoco_viz
from stac_mjx import main
from stac_mjx import utils

stac_config_path = "../configs/stac.yaml"
model_config_path = "../configs/rodent.yaml"

cfg = main.load_configs(stac_config_path, model_config_path)

xml_path = "../models/rodent.xml"
data_path = "../output.p"
n_frames=250
save_path="../videos/direct_render.mp4"

# Call mujoco_viz
frames = mujoco_viz(data_path, xml_path, n_frames, save_path, start_frame=0)

# Show the video in the notebook (it is also saved to the save_path)
media.show_video(frames, fps=utils.params["RENDER_FPS"])
```
4. If the rendering is poor, it's likely that some hyperparameter tuning is necessary. (details WIP)
19 changes: 10 additions & 9 deletions configs/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Frames per clip for transform.
N_FRAMES_PER_CLIP: 360
N_FRAMES_PER_CLIP: 250

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

KP_NAMES:
- 'Snout'
Expand Down Expand Up @@ -176,14 +185,6 @@ RENDER_FPS: 50

N_SAMPLE_FRAMES: 100

# Tolerance for the optimizations of the full model, limb, and root.
FTOL: 1.0e-02
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
Expand Down
20 changes: 7 additions & 13 deletions configs/stac.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
paths:
model_config: "rodent"
xml: "././models/rodent.xml"
fit_path: "fit_sq.p"
transform_path: "transform_sq.p"
data_path: "tests/data/test_mocap_1000_frames.nwb"

n_fit_frames: 1000

sampler: "first" # first, every, or random
first_start: 0 # starting frame for "first" sampler
xml: "models/rodent.xml"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "tests/data/test_rodent_mocap_1000_frames.nwb"

# Should this be included?
test:
skip_fit: False
skip_transform: False
n_fit_frames: 1000
skip_fit: False
skip_transform: True

mujoco:
solver: "newton"
Expand Down
Empty file added conftest.py
Empty file.
Loading

0 comments on commit 04568a1

Please sign in to comment.