Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for special VecEnv (brax, IsaacSim, ...) #484

Merged
merged 7 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
env:
TERM: xterm-256color
FORCE_COLOR: 1

HF_TOKEN: ${{ secrets.HF_TOKEN }}
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
Expand Down
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
## Release 2.6.0a2 (WIP)

### Breaking Changes
- Upgraded to SB3 >= 2.6.0

### New Features
- Save the exact command line used to launch a training
- Added support for special vectorized env (e.g. Brax, IsaacSim) by allowing to override the `VecEnv` class use to instantiate the env in the `ExperimentManager`
- Allow to disable auto-logging by passing `--log-interval -2` (useful when logging things manually)
- Added Gymnasium v1.1 support

### Bug fixes
- Fixed use of old HF api in `get_hf_trained_models()`

### Documentation

### Other

## Release 2.5.0 (2025-01-27)

### Breaking Changes
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
gym==0.26.2
stable-baselines3[extra,tests,docs]>=2.5.0,<3.0
stable-baselines3[extra,tests,docs]>=2.6.0a2,<3.0
box2d-py==2.3.8
pybullet_envs_gymnasium>=0.5.0
pybullet_envs_gymnasium>=0.6.0
# minigrid
cloudpickle>=2.2.1
# optuna plots:
Expand Down
1 change: 1 addition & 0 deletions rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def enjoy() -> None: # noqa: C901
should_render=not args.no_render,
hyperparams=hyperparams,
env_kwargs=env_kwargs,
vec_env_cls=ExperimentManager.default_vec_env_cls,
)

kwargs = dict(seed=args.seed)
Expand Down
20 changes: 20 additions & 0 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib
import os
import pickle as pkl
import sys
import time
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -60,6 +61,9 @@ class ExperimentManager:
Please take a look at `train.py` to have the details for each argument.
"""

# For special VecEnv like Brax, IsaacLab, ...
default_vec_env_cls: Optional[type[VecEnv]] = None

def __init__(
self,
args: argparse.Namespace,
Expand Down Expand Up @@ -122,6 +126,10 @@ def __init__(
self.optimization_log_path = optimization_log_path

self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
# Override
if self.default_vec_env_cls is not None:
self.vec_env_class = self.default_vec_env_cls

self.vec_env_wrapper: Optional[Callable] = None

self.vec_env_kwargs: dict[str, Any] = {}
Expand Down Expand Up @@ -224,8 +232,13 @@ def learn(self, model: BaseAlgorithm) -> None:
:param model: an initialized RL model
"""
kwargs: dict[str, Any] = {}
# log_interval == -1 -> default
# < -2 -> no auto-logging
if self.log_interval > -1:
kwargs = {"log_interval": self.log_interval}
elif self.log_interval < -1:
# Deactivate auto-logging, helpful when using callback like LogEveryNTimesteps
kwargs = {"log_interval": None}

if len(self.callbacks) > 0:
kwargs["callback"] = self.callbacks
Expand Down Expand Up @@ -288,6 +301,13 @@ def _save_config(self, saved_hyperparams: dict[str, Any]) -> None:
ordered_args = OrderedDict([(key, vars(self.args)[key]) for key in sorted(vars(self.args).keys())])
yaml.dump(ordered_args, f)

# Save command used to train
command = "python3 " + " ".join(sys.argv)
# Python 3.10+
if hasattr(sys, "orig_argv"):
command = " ".join(sys.orig_argv)
(Path(self.params_path) / "command.txt").write_text(command)

print(f"Log path: {self.save_path}")

def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
7 changes: 6 additions & 1 deletion rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def train() -> None:
)
parser.add_argument("-n", "--n-timesteps", help="Overwrite the number of timesteps", default=-1, type=int)
parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int)
parser.add_argument("--log-interval", help="Override log interval (default: -1, no change)", default=-1, type=int)
parser.add_argument(
"--log-interval",
help="Override log interval (default: -1, no change, -2: no logging useful when using custom logging freq)",
default=-1,
type=int,
)
parser.add_argument(
"--eval-freq",
help="Evaluate the agent every n steps (if negative, no evaluation). "
Expand Down
25 changes: 17 additions & 8 deletions rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def create_test_env(
should_render: bool = True,
hyperparams: Optional[dict[str, Any]] = None,
env_kwargs: Optional[dict[str, Any]] = None,
vec_env_cls: Optional[type[VecEnv]] = None,
vec_env_kwargs: Optional[dict[str, Any]] = None,
) -> VecEnv:
"""
Create environment for testing a trained agent
Expand All @@ -220,6 +222,8 @@ def create_test_env(
:param should_render: For Pybullet env, display the GUI
:param hyperparams: Additional hyperparams (ex: n_stack)
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: ``VecEnv`` class constructor.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:return:
"""
# Create the environment and wrap it if necessary
Expand All @@ -231,9 +235,9 @@ def create_test_env(
if "env_wrapper" in hyperparams.keys():
del hyperparams["env_wrapper"]

vec_env_kwargs: dict[str, Any] = {}
# Avoid potential shared memory issue
vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv
if vec_env_cls is None:
vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv

# Fix for gym 0.26, to keep old behavior
env_kwargs = env_kwargs or {}
Expand Down Expand Up @@ -349,21 +353,24 @@ def get_hf_trained_models(organization: str = "sb3", check_filename: bool = Fals
for model in models:
# Try to extract algorithm and environment id from model card
try:
env_id = model.cardData["model-index"][0]["results"][0]["dataset"]["name"]
algo = model.cardData["model-index"][0]["name"].lower()
assert model.card_data is not None
env_id = model.card_data["model-index"][0]["results"][0]["dataset"]["name"]
algo = model.card_data["model-index"][0]["name"].lower()
# RecurrentPPO alias is "ppo_lstm" in the rl zoo
if algo == "recurrentppo":
algo = "ppo_lstm"
except (KeyError, IndexError):
print(f"Skipping {model.modelId}")
except (KeyError, IndexError, AssertionError):
print(f"Skipping {model.id}")
continue # skip model if name env id or algo name could not be found

env_name = EnvironmentName(env_id)
model_name = ModelName(algo, env_name)

# check if there is a model file in the repo
if check_filename and not any(f.rfilename == model_name.filename for f in api.model_info(model.modelId).siblings):
continue # skip model if the repo contains no properly named model file
if check_filename:
maybe_siblings = api.model_info(model.id).siblings
if maybe_siblings and not any(f.rfilename == model_name.filename for f in maybe_siblings):
continue # skip model if the repo contains no properly named model file

trained_models[model_name] = (algo, env_id)

Expand Down Expand Up @@ -422,6 +429,8 @@ def get_saved_hyperparams(
normalize_kwargs = eval(hyperparams["normalize"])
if test_mode:
normalize_kwargs["norm_reward"] = norm_reward
elif isinstance(hyperparams["normalize"], dict):
normalize_kwargs = hyperparams["normalize"]
else:
normalize_kwargs = {"norm_obs": hyperparams["normalize"], "norm_reward": norm_reward}
hyperparams["normalize_kwargs"] = normalize_kwargs
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0
2.6.0a2
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.5.0,<3.0",
"gymnasium>=0.29.1,<1.1.0",
"sb3_contrib>=2.6.0a2,<3.0",
"gymnasium>=0.29.1,<1.2.0",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
"rich",
Expand Down