Skip to content

Commit

Permalink
Refactor configs (#54)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Charles Zhang <charleszhang@holylogin04.rc.fas.harvard.edu>
  • Loading branch information
charles-zhng and Charles Zhang authored Sep 18, 2024
1 parent 20d4617 commit ea9c1b6
Show file tree
Hide file tree
Showing 30 changed files with 371 additions and 429 deletions.
47 changes: 18 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,27 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.
## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`. A CLI script using the rodent model is also provided at `run_rodent.py`
2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`. A CLI script is also provided at `run_stac.py`. Refer to [hydra documention](https://hydra.cc/docs/advanced/override_grammar/basic/) for formatting args to override configs.

```python
from stac_mjx import main
from stac_mjx import utils
import stac_mjx
from pathlib import Path
import os
# XLA flags for Nvidia GPU
if xla_bridge.get_backend().platform == "gpu":
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=True "
)
# Set base path to the parent directory of your config files
base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"
# Enable XLA flags if on GPU
stac_mjx.enable_xla_flags()
# Choose parent directory as base path for data files
base_path = Path.cwd().parent
# Load configs
stac_cfg, model_cfg = main.load_configs(stac_config_path, model_config_path)
cfg = stac_mjx.load_configs(base_path / "configs")
# Load data
data_path = base_path / cfg.paths.data_path
kp_data, sorted_kp_names = utils.load_data(data_path, model_cfg)
kp_data, sorted_kp_names = stac_mjx.load_data(cfg, base_path)
# Run stac
fit_path, transform_path = main.run_stac(
stac_cfg,
model_cfg,
fit_path, transform_path = stac_mjx.run_stac(
cfg,
kp_data,
sorted_kp_names,
base_path=base_path
Expand All @@ -73,16 +64,14 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.

3. Render the resulting data using `mujoco_viz()` (example notebook found in `demos/viz_usage.ipynb`):
```python
import os
import mediapy as media
import stac_mjx

from stac_mjx.viz import viz_stac
from stac_mjx import main
import mediapy as media
from pathlib import Path
import os

base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"
cfg = stac_mjx.load_configs(base_path / "configs")

stac_cfg, model_cfg = main.load_configs(stac_config_path, model_config_path)

Expand All @@ -91,10 +80,10 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.
save_path = base_path / "videos/direct_render.mp4"

# Call mujoco_viz
frames = viz_stac(data_path, stac_cfg, model_cfg, n_frames, save_path, start_frame=0, camera="close_profile", base_path=Path.cwd().parent)
frames = viz_stac(data_path, cfg, n_frames, save_path, start_frame=0, camera="close_profile", base_path=Path.cwd().parent)

# Show the video in the notebook (it is also saved to the save_path)
media.show_video(frames, fps=model_cfg["RENDER_FPS"])
media.show_video(frames, fps=cfg.model.RENDER_FPS)
```
4. If the rendering is poor, it's likely that some hyperparameter tuning is necessary. (details WIP)
4 changes: 4 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- stac: demo
- model: rodent
- _self_
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion configs/stac_mouse.yaml → configs/stac/stac_mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ fit_path: "fit_mouse.p"
transform_path: "transform_mouse.p"
data_path: "tests/data/test_mouse_mocap_3600_frames.h5"

n_fit_frames: 3600
n_fit_frames: 250
skip_fit: False
skip_transform: True

Expand Down
199 changes: 98 additions & 101 deletions demos/api_usage.ipynb

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions demos/viz_usage.ipynb

Large diffs are not rendered by default.

49 changes: 0 additions & 49 deletions run_mouse.py

This file was deleted.

49 changes: 0 additions & 49 deletions run_rodent.py

This file was deleted.

30 changes: 30 additions & 0 deletions run_stac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""CLI script for running rodent skeletal registration"""

import logging
import hydra
from omegaconf import DictConfig, OmegaConf

import stac_mjx


def load_and_run_stac(cfg):
kp_data, sorted_kp_names = stac_mjx.load_data(cfg)

fit_path, transform_path = stac_mjx.run_stac(cfg, kp_data, sorted_kp_names)

logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)


@hydra.main(config_path="./configs", config_name="config", version_base=None)
def hydra_entry(cfg: DictConfig):
logging.info(f"cfg: {OmegaConf.to_yaml(cfg)}")

stac_mjx.enable_xla_flags()

load_and_run_stac(cfg)


if __name__ == "__main__":
hydra_entry()
Loading

0 comments on commit ea9c1b6

Please sign in to comment.