Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor configs #54

Merged
merged 11 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,27 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.
## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`. A CLI script using the rodent model is also provided at `run_rodent.py`
2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`. A CLI script is also provided at `run_stac.py`. Refer to [hydra documention](https://hydra.cc/docs/advanced/override_grammar/basic/) for formatting args to override configs.

```python
from stac_mjx import main
from stac_mjx import utils
import stac_mjx
from pathlib import Path
import os
# XLA flags for Nvidia GPU
if xla_bridge.get_backend().platform == "gpu":
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=True "
)

# Set base path to the parent directory of your config files
base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"

# Enable XLA flags if on GPU
stac_mjx.enable_xla_flags()

# Choose parent directory as base path for data files
base_path = Path.cwd().parent

# Load configs
stac_cfg, model_cfg = main.load_configs(stac_config_path, model_config_path)
cfg = stac_mjx.load_configs(base_path / "configs")

# Load data
data_path = base_path / cfg.paths.data_path
kp_data, sorted_kp_names = utils.load_data(data_path, model_cfg)
kp_data, sorted_kp_names = stac_mjx.load_data(cfg, base_path)

# Run stac
fit_path, transform_path = main.run_stac(
stac_cfg,
model_cfg,
fit_path, transform_path = stac_mjx.run_stac(
cfg,
kp_data,
sorted_kp_names,
base_path=base_path
Expand All @@ -73,16 +64,14 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.

3. Render the resulting data using `mujoco_viz()` (example notebook found in `demos/viz_usage.ipynb`):
```python
import os
import mediapy as media
import stac_mjx

from stac_mjx.viz import viz_stac
from stac_mjx import main
import mediapy as media
from pathlib import Path
import os

base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"
cfg = stac_mjx.load_configs(base_path / "configs")

stac_cfg, model_cfg = main.load_configs(stac_config_path, model_config_path)

Expand All @@ -91,10 +80,10 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.
save_path = base_path / "videos/direct_render.mp4"

# Call mujoco_viz
frames = viz_stac(data_path, stac_cfg, model_cfg, n_frames, save_path, start_frame=0, camera="close_profile", base_path=Path.cwd().parent)
frames = viz_stac(data_path, cfg, n_frames, save_path, start_frame=0, camera="close_profile", base_path=Path.cwd().parent)

# Show the video in the notebook (it is also saved to the save_path)
media.show_video(frames, fps=model_cfg["RENDER_FPS"])
media.show_video(frames, fps=cfg.model.RENDER_FPS)
```

4. If the rendering is poor, it's likely that some hyperparameter tuning is necessary. (details WIP)
4 changes: 4 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- stac: demo
- model: rodent
- _self_
Comment on lines +1 to +4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configuration looks shipshape, but don't forget the newline at the end, partner.

The configuration using Hydra's features is spot on. However, it's best practice to include a newline at the end of the file to avoid any issues with version control systems or text editors that might get finicky about such things.

Add a newline at the end of the file to fix the issue flagged by yamllint:

  - _self_
+ 
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
defaults:
- stac: demo
- model: rodent
- _self_
defaults:
- stac: demo
- model: rodent
- _self_
Tools
yamllint

[error] 4-4: no new line character at the end of file

(new-line-at-end-of-file)

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion configs/stac_mouse.yaml → configs/stac/stac_mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ fit_path: "fit_mouse.p"
transform_path: "transform_mouse.p"
data_path: "tests/data/test_mouse_mocap_3600_frames.h5"

n_fit_frames: 3600
n_fit_frames: 250
skip_fit: False
skip_transform: True

Expand Down
199 changes: 98 additions & 101 deletions demos/api_usage.ipynb

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions demos/viz_usage.ipynb

Large diffs are not rendered by default.

49 changes: 0 additions & 49 deletions run_mouse.py

This file was deleted.

49 changes: 0 additions & 49 deletions run_rodent.py

This file was deleted.

30 changes: 30 additions & 0 deletions run_stac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""CLI script for running rodent skeletal registration"""

import logging
import hydra
from omegaconf import DictConfig, OmegaConf

import stac_mjx


def load_and_run_stac(cfg):
kp_data, sorted_kp_names = stac_mjx.load_data(cfg)

fit_path, transform_path = stac_mjx.run_stac(cfg, kp_data, sorted_kp_names)

logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)
Comment on lines +10 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the logging format, partner.

Y'all got a nice logging statement here, but consider using structured logging instead of f-string for better scalability and filtering in production environments. Here's a tweak for ya:

- logging.info(
-     f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
- )
+ logging.info("Run complete.", extra={'fit_path': fit_path, 'transform_path': transform_path})
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def load_and_run_stac(cfg):
kp_data, sorted_kp_names = stac_mjx.load_data(cfg)
fit_path, transform_path = stac_mjx.run_stac(cfg, kp_data, sorted_kp_names)
logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)
def load_and_run_stac(cfg):
kp_data, sorted_kp_names = stac_mjx.load_data(cfg)
fit_path, transform_path = stac_mjx.run_stac(cfg, kp_data, sorted_kp_names)
logging.info("Run complete.", extra={'fit_path': fit_path, 'transform_path': transform_path})



@hydra.main(config_path="./configs", config_name="config", version_base=None)
def hydra_entry(cfg: DictConfig):
logging.info(f"cfg: {OmegaConf.to_yaml(cfg)}")

stac_mjx.enable_xla_flags()

load_and_run_stac(cfg)


if __name__ == "__main__":
hydra_entry()
Loading
Loading