Skip to content

Commit

Permalink
Add support for special VecEnv (brax, IsaacSim, ...) (#484)
Browse files Browse the repository at this point in the history
* Allow to change default VecEnv

* Add default vec env cls argument

* Fix normalization loading for objects

* Save exact command used when training, update changelog

* Update HF api usage

* Fix log-interval default behavior and upgrade to gym v1.1

* Add HF token to CI
  • Loading branch information
araffin authored Mar 5, 2025
1 parent 2e99bec commit 06ab062
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1

HF_TOKEN: ${{ secrets.HF_TOKEN }}
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
Expand Down
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
## Release 2.6.0a2 (WIP)

### Breaking Changes
- Upgraded to SB3 >= 2.6.0

### New Features
- Save the exact command line used to launch a training
- Added support for special vectorized env (e.g. Brax, IsaacSim) by allowing to override the `VecEnv` class use to instantiate the env in the `ExperimentManager`
- Allow to disable auto-logging by passing `--log-interval -2` (useful when logging things manually)
- Added Gymnasium v1.1 support

### Bug fixes
- Fixed use of old HF api in `get_hf_trained_models()`

### Documentation

### Other

## Release 2.5.0 (2025-01-27)

### Breaking Changes
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
gym==0.26.2
stable-baselines3[extra,tests,docs]>=2.5.0,<3.0
stable-baselines3[extra,tests,docs]>=2.6.0a2,<3.0
box2d-py==2.3.8
pybullet_envs_gymnasium>=0.5.0
pybullet_envs_gymnasium>=0.6.0
# minigrid
cloudpickle>=2.2.1
# optuna plots:
Expand Down
1 change: 1 addition & 0 deletions rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def enjoy() -> None: # noqa: C901
should_render=not args.no_render,
hyperparams=hyperparams,
env_kwargs=env_kwargs,
vec_env_cls=ExperimentManager.default_vec_env_cls,
)

kwargs = dict(seed=args.seed)
Expand Down
20 changes: 20 additions & 0 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib
import os
import pickle as pkl
import sys
import time
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -60,6 +61,9 @@ class ExperimentManager:
Please take a look at `train.py` to have the details for each argument.
"""

# For special VecEnv like Brax, IsaacLab, ...
default_vec_env_cls: Optional[type[VecEnv]] = None

def __init__(
self,
args: argparse.Namespace,
Expand Down Expand Up @@ -122,6 +126,10 @@ def __init__(
self.optimization_log_path = optimization_log_path

self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
# Override
if self.default_vec_env_cls is not None:
self.vec_env_class = self.default_vec_env_cls

self.vec_env_wrapper: Optional[Callable] = None

self.vec_env_kwargs: dict[str, Any] = {}
Expand Down Expand Up @@ -224,8 +232,13 @@ def learn(self, model: BaseAlgorithm) -> None:
:param model: an initialized RL model
"""
kwargs: dict[str, Any] = {}
# log_interval == -1 -> default
# < -2 -> no auto-logging
if self.log_interval > -1:
kwargs = {"log_interval": self.log_interval}
elif self.log_interval < -1:
# Deactivate auto-logging, helpful when using callback like LogEveryNTimesteps
kwargs = {"log_interval": None}

if len(self.callbacks) > 0:
kwargs["callback"] = self.callbacks
Expand Down Expand Up @@ -288,6 +301,13 @@ def _save_config(self, saved_hyperparams: dict[str, Any]) -> None:
ordered_args = OrderedDict([(key, vars(self.args)[key]) for key in sorted(vars(self.args).keys())])
yaml.dump(ordered_args, f)

# Save command used to train
command = "python3 " + " ".join(sys.argv)
# Python 3.10+
if hasattr(sys, "orig_argv"):
command = " ".join(sys.orig_argv)
(Path(self.params_path) / "command.txt").write_text(command)

print(f"Log path: {self.save_path}")

def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
7 changes: 6 additions & 1 deletion rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def train() -> None:
)
parser.add_argument("-n", "--n-timesteps", help="Overwrite the number of timesteps", default=-1, type=int)
parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int)
parser.add_argument("--log-interval", help="Override log interval (default: -1, no change)", default=-1, type=int)
parser.add_argument(
"--log-interval",
help="Override log interval (default: -1, no change, -2: no logging useful when using custom logging freq)",
default=-1,
type=int,
)
parser.add_argument(
"--eval-freq",
help="Evaluate the agent every n steps (if negative, no evaluation). "
Expand Down
25 changes: 17 additions & 8 deletions rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def create_test_env(
should_render: bool = True,
hyperparams: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
vec_env_cls: Optional[type[VecEnv]] = None,
vec_env_kwargs: Optional[dict[str, Any]] = None,
) -> VecEnv:
"""
Create environment for testing a trained agent
Expand All @@ -220,6 +222,8 @@ def create_test_env(
:param should_render: For Pybullet env, display the GUI
:param hyperparams: Additional hyperparams (ex: n_stack)
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: ``VecEnv`` class constructor.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:return:
"""
# Create the environment and wrap it if necessary
Expand All @@ -231,9 +235,9 @@ def create_test_env(
if "env_wrapper" in hyperparams.keys():
del hyperparams["env_wrapper"]

vec_env_kwargs: dict[str, Any] = {}
# Avoid potential shared memory issue
vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv
if vec_env_cls is None:
vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv

# Fix for gym 0.26, to keep old behavior
env_kwargs = env_kwargs or {}
Expand Down Expand Up @@ -349,21 +353,24 @@ def get_hf_trained_models(organization: str = "sb3", check_filename: bool = Fals
for model in models:
# Try to extract algorithm and environment id from model card
try:
env_id = model.cardData["model-index"][0]["results"][0]["dataset"]["name"]
algo = model.cardData["model-index"][0]["name"].lower()
assert model.card_data is not None
env_id = model.card_data["model-index"][0]["results"][0]["dataset"]["name"]
algo = model.card_data["model-index"][0]["name"].lower()
# RecurrentPPO alias is "ppo_lstm" in the rl zoo
if algo == "recurrentppo":
algo = "ppo_lstm"
except (KeyError, IndexError):
print(f"Skipping {model.modelId}")
except (KeyError, IndexError, AssertionError):
print(f"Skipping {model.id}")
continue # skip model if name env id or algo name could not be found

env_name = EnvironmentName(env_id)
model_name = ModelName(algo, env_name)

# check if there is a model file in the repo
if check_filename and not any(f.rfilename == model_name.filename for f in api.model_info(model.modelId).siblings):
continue # skip model if the repo contains no properly named model file
if check_filename:
maybe_siblings = api.model_info(model.id).siblings
if maybe_siblings and not any(f.rfilename == model_name.filename for f in maybe_siblings):
continue # skip model if the repo contains no properly named model file

trained_models[model_name] = (algo, env_id)

Expand Down Expand Up @@ -422,6 +429,8 @@ def get_saved_hyperparams(
normalize_kwargs = eval(hyperparams["normalize"])
if test_mode:
normalize_kwargs["norm_reward"] = norm_reward
elif isinstance(hyperparams["normalize"], dict):
normalize_kwargs = hyperparams["normalize"]
else:
normalize_kwargs = {"norm_obs": hyperparams["normalize"], "norm_reward": norm_reward}
hyperparams["normalize_kwargs"] = normalize_kwargs
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0
2.6.0a2
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.5.0,<3.0",
"gymnasium>=0.29.1,<1.1.0",
"sb3_contrib>=2.6.0a2,<3.0",
"gymnasium>=0.29.1,<1.2.0",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
"rich",
Expand Down

0 comments on commit 06ab062

Please sign in to comment.