From b4523848c9dc1f787b97a02c5ee47f09549cc990 Mon Sep 17 00:00:00 2001 From: "M. Ernestus" Date: Fri, 7 Jul 2023 18:50:52 +0200 Subject: [PATCH] Transition from Models Hub to Datasets Hub for expert trajectories (#723) * Remove load_rolluts_from_huggingface and replace it with code in demonstrations.py that loads demonstrations from huggingface datasets instead of huggingface models. * Allow specifying the repo_id directly in the loader_kwargs of the demonstrations ingredient and pass remaining loader_kwargs to datasets.load_dataset. * Simplify demonstrations ingredient configuration and make it more flexible at the same time. * Remove now obsolete test. * Rename rollout_type to type and rollout_path to path and make default type "generated" to match previous behavior. * Fix documentation of the Raises: section of get_exper_trajectories() and improve wording of ValueError when n_expert_demos is missing while generating trajectories. * Simplify unnecessarily complex regexes to match raised exceptions during testing. * Fix regex for ValueError to reflect updated ValueError string. * Add an edge case to accommodate the fact that the HuggingFace Hub only has an expert for seals/Cartpole while the testdata folder only has an expert for normal Cartpole. * Make it explicit that in some tests the rollout should be loaded locally from disk. * Fix formatting issues in test_scripts.py * Rename demonstrations.type to demonstrations.source to overcome name-clash with build-in keyword of python. * Make sure to load local demonstrations in quickstart.sh * Fix formatting issue. * Ensure the readme contains the same snippet as the examples. --- README.md | 4 +- .../example_airl_seals_ant_best_hp_eval.json | 3 +- ..._airl_seals_half_cheetah_best_hp_eval.json | 3 +- ...xample_airl_seals_hopper_best_hp_eval.json | 3 +- ...ample_airl_seals_swimmer_best_hp_eval.json | 3 +- ...xample_airl_seals_walker_best_hp_eval.json | 3 +- .../example_bc_seals_ant_best_hp_eval.json | 3 +- ...le_bc_seals_half_cheetah_best_hp_eval.json | 3 +- .../example_bc_seals_hopper_best_hp_eval.json | 3 +- ...example_bc_seals_swimmer_best_hp_eval.json | 3 +- .../example_bc_seals_walker_best_hp_eval.json | 3 +- ...example_dagger_seals_ant_best_hp_eval.json | 3 +- ...agger_seals_half_cheetah_best_hp_eval.json | 3 +- ...mple_dagger_seals_hopper_best_hp_eval.json | 3 +- ...ple_dagger_seals_swimmer_best_hp_eval.json | 3 +- ...mple_dagger_seals_walker_best_hp_eval.json | 3 +- .../example_gail_seals_ant_best_hp_eval.json | 3 +- ..._gail_seals_half_cheetah_best_hp_eval.json | 3 +- ...xample_gail_seals_hopper_best_hp_eval.json | 3 +- ...ample_gail_seals_swimmer_best_hp_eval.json | 3 +- ...xample_gail_seals_walker_best_hp_eval.json | 3 +- benchmarking/util.py | 2 +- examples/quickstart.sh | 4 +- src/imitation/data/serialize.py | 12 -- .../scripts/ingredients/demonstrations.py | 161 ++++++++++-------- .../scripts/ingredients/policy_evaluation.py | 2 +- tests/scripts/test_scripts.py | 130 ++++++-------- 27 files changed, 188 insertions(+), 187 deletions(-) diff --git a/README.md b/README.md index 6915814b1..50fdf5faa 100644 --- a/README.md +++ b/README.md @@ -74,10 +74,10 @@ From [examples/quickstart.sh:](examples/quickstart.sh) python -m imitation.scripts.train_rl with pendulum environment.fast policy_evaluation.fast rl.fast fast logging.log_dir=quickstart/rl/ # Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory). -python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz +python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local # Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory). -python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz +python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local ``` Tips: diff --git a/benchmarking/example_airl_seals_ant_best_hp_eval.json b/benchmarking/example_airl_seals_ant_best_hp_eval.json index bd8f14e96..17f969ff0 100644 --- a/benchmarking/example_airl_seals_ant_best_hp_eval.json +++ b/benchmarking/example_airl_seals_ant_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "reward": { diff --git a/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json b/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json index a6eb8ec87..754ba6736 100644 --- a/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/example_airl_seals_half_cheetah_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "reward": { diff --git a/benchmarking/example_airl_seals_hopper_best_hp_eval.json b/benchmarking/example_airl_seals_hopper_best_hp_eval.json index 995be5fcd..91080d7ce 100644 --- a/benchmarking/example_airl_seals_hopper_best_hp_eval.json +++ b/benchmarking/example_airl_seals_hopper_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "reward": { diff --git a/benchmarking/example_airl_seals_swimmer_best_hp_eval.json b/benchmarking/example_airl_seals_swimmer_best_hp_eval.json index edc904378..fcca8e6b3 100644 --- a/benchmarking/example_airl_seals_swimmer_best_hp_eval.json +++ b/benchmarking/example_airl_seals_swimmer_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "expert": { diff --git a/benchmarking/example_airl_seals_walker_best_hp_eval.json b/benchmarking/example_airl_seals_walker_best_hp_eval.json index eb81d9fea..c63070751 100644 --- a/benchmarking/example_airl_seals_walker_best_hp_eval.json +++ b/benchmarking/example_airl_seals_walker_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "expert": { diff --git a/benchmarking/example_bc_seals_ant_best_hp_eval.json b/benchmarking/example_bc_seals_ant_best_hp_eval.json index 56f15e063..108a93ce7 100644 --- a/benchmarking/example_bc_seals_ant_best_hp_eval.json +++ b/benchmarking/example_bc_seals_ant_best_hp_eval.json @@ -20,7 +20,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json b/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json index fcffdf182..ecaff2eb0 100644 --- a/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json @@ -20,7 +20,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_bc_seals_hopper_best_hp_eval.json b/benchmarking/example_bc_seals_hopper_best_hp_eval.json index ea75003d3..e8c821841 100644 --- a/benchmarking/example_bc_seals_hopper_best_hp_eval.json +++ b/benchmarking/example_bc_seals_hopper_best_hp_eval.json @@ -20,7 +20,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_bc_seals_swimmer_best_hp_eval.json b/benchmarking/example_bc_seals_swimmer_best_hp_eval.json index 9d3d2caf9..30884c9c4 100644 --- a/benchmarking/example_bc_seals_swimmer_best_hp_eval.json +++ b/benchmarking/example_bc_seals_swimmer_best_hp_eval.json @@ -20,7 +20,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_bc_seals_walker_best_hp_eval.json b/benchmarking/example_bc_seals_walker_best_hp_eval.json index 4066c5b8b..0ca30120e 100644 --- a/benchmarking/example_bc_seals_walker_best_hp_eval.json +++ b/benchmarking/example_bc_seals_walker_best_hp_eval.json @@ -20,7 +20,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_dagger_seals_ant_best_hp_eval.json b/benchmarking/example_dagger_seals_ant_best_hp_eval.json index f44dd3b67..de75b80f1 100644 --- a/benchmarking/example_dagger_seals_ant_best_hp_eval.json +++ b/benchmarking/example_dagger_seals_ant_best_hp_eval.json @@ -24,7 +24,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json b/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json index f59ed616e..7f42bfdf9 100644 --- a/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/example_dagger_seals_half_cheetah_best_hp_eval.json @@ -24,7 +24,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_dagger_seals_hopper_best_hp_eval.json b/benchmarking/example_dagger_seals_hopper_best_hp_eval.json index 9e0f2e4dc..1cf29a1a4 100644 --- a/benchmarking/example_dagger_seals_hopper_best_hp_eval.json +++ b/benchmarking/example_dagger_seals_hopper_best_hp_eval.json @@ -24,7 +24,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json b/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json index 6a345e93a..c112db680 100644 --- a/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json +++ b/benchmarking/example_dagger_seals_swimmer_best_hp_eval.json @@ -24,7 +24,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_dagger_seals_walker_best_hp_eval.json b/benchmarking/example_dagger_seals_walker_best_hp_eval.json index 7395260fe..e59bef464 100644 --- a/benchmarking/example_dagger_seals_walker_best_hp_eval.json +++ b/benchmarking/example_dagger_seals_walker_best_hp_eval.json @@ -24,7 +24,8 @@ "use_offline_rollouts": false }, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "policy": { diff --git a/benchmarking/example_gail_seals_ant_best_hp_eval.json b/benchmarking/example_gail_seals_ant_best_hp_eval.json index d3d5c7be4..81399b00c 100644 --- a/benchmarking/example_gail_seals_ant_best_hp_eval.json +++ b/benchmarking/example_gail_seals_ant_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "reward": { diff --git a/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json b/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json index 39fee5a6d..1d2f26648 100644 --- a/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json +++ b/benchmarking/example_gail_seals_half_cheetah_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "reward": { diff --git a/benchmarking/example_gail_seals_hopper_best_hp_eval.json b/benchmarking/example_gail_seals_hopper_best_hp_eval.json index 19c9e909a..70787ff7e 100644 --- a/benchmarking/example_gail_seals_hopper_best_hp_eval.json +++ b/benchmarking/example_gail_seals_hopper_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "reward": { diff --git a/benchmarking/example_gail_seals_swimmer_best_hp_eval.json b/benchmarking/example_gail_seals_swimmer_best_hp_eval.json index 1279fc319..650c5f46a 100644 --- a/benchmarking/example_gail_seals_swimmer_best_hp_eval.json +++ b/benchmarking/example_gail_seals_swimmer_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "expert": { diff --git a/benchmarking/example_gail_seals_walker_best_hp_eval.json b/benchmarking/example_gail_seals_walker_best_hp_eval.json index d7de6c5a5..d85eb46d5 100644 --- a/benchmarking/example_gail_seals_walker_best_hp_eval.json +++ b/benchmarking/example_gail_seals_walker_best_hp_eval.json @@ -6,7 +6,8 @@ }, "checkpoint_interval": 0, "demonstrations": { - "rollout_type": "ppo-huggingface", + "source": "huggingface", + "algo_name": "ppo", "n_expert_demos": null }, "expert": { diff --git a/benchmarking/util.py b/benchmarking/util.py index c7843d15f..408f0d812 100644 --- a/benchmarking/util.py +++ b/benchmarking/util.py @@ -71,7 +71,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None: # remove key 'agent_path' config.pop("agent_path") config.pop("seed") - config.get("demonstrations", {}).pop("rollout_path") + config.get("demonstrations", {}).pop("path") config.get("expert", {}).get("loader_kwargs", {}).pop("path", None) env_name = config.pop("environment").pop("gym_id") config["environment"] = {"gym_id": env_name} diff --git a/examples/quickstart.sh b/examples/quickstart.sh index 52283cd92..0c36898ef 100755 --- a/examples/quickstart.sh +++ b/examples/quickstart.sh @@ -4,7 +4,7 @@ python -m imitation.scripts.train_rl with pendulum environment.fast policy_evaluation.fast rl.fast fast logging.log_dir=quickstart/rl/ # Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory). -python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz +python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local # Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory). -python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz +python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local diff --git a/src/imitation/data/serialize.py b/src/imitation/data/serialize.py index ddc3bd9b2..ee9d0d3f8 100644 --- a/src/imitation/data/serialize.py +++ b/src/imitation/data/serialize.py @@ -5,7 +5,6 @@ from typing import Mapping, Sequence, cast import datasets -import huggingface_sb3 as hfsb3 import numpy as np from imitation.data import huggingface_utils @@ -87,14 +86,3 @@ def load_with_rewards(path: AnyPath) -> Sequence[TrajectoryWithRew]: ) return cast(Sequence[TrajectoryWithRew], data) - - -def load_rollouts_from_huggingface( - algo_name: str, - env_name: str, - organization: str = "HumanCompatibleAI", -) -> str: - model_name = hfsb3.ModelName(algo_name, hfsb3.EnvironmentName(env_name)) - repo_id = hfsb3.ModelRepoId(organization, model_name) - filename = hfsb3.load_from_hub(repo_id, "rollouts.npz") - return filename diff --git a/src/imitation/scripts/ingredients/demonstrations.py b/src/imitation/scripts/ingredients/demonstrations.py index a8d8e9ced..f6b1fb2ac 100644 --- a/src/imitation/scripts/ingredients/demonstrations.py +++ b/src/imitation/scripts/ingredients/demonstrations.py @@ -1,14 +1,14 @@ """Ingredient for scripts learning from demonstrations.""" import logging -import pathlib -import warnings -from typing import Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence +import datasets +import huggingface_sb3 as hfsb3 import numpy as np import sacred -from imitation.data import rollout, serialize, types +from imitation.data import huggingface_utils, rollout, serialize, types from imitation.scripts.ingredients import environment, expert from imitation.scripts.ingredients import logging as logging_ingredient @@ -25,11 +25,25 @@ @demonstrations_ingredient.config def config(): - rollout_type = "local" - # path to file containing rollouts. If rollout_path is None - # and rollout_type is local, they are sampled from the expert. - rollout_path = None - n_expert_demos = None # Num demos used or sampled. None loads every demo possible. + # Either "local" or "huggingface" or "generated". + source = "generated" + + # local path or huggingface repo id to load rollouts from. + path = None + + # passed to `datasets.load_dataset` if `source` is "huggingface" + loader_kwargs: Dict[str, Any] = dict( + split="train", + ) + + # Used to deduce HuggingFace repo id if `path` is None + organization = "HumanCompatibleAI" + + # Used to deduce HuggingFace repo id if `path` is None + algo_name = "ppo" + + # Num demos used or sampled. None loads every demo possible. + n_expert_demos = None locals() # quieten flake8 @@ -42,45 +56,69 @@ def fast(): @demonstrations_ingredient.capture def get_expert_trajectories( - rollout_type: str, - rollout_path: str, + source: str, + path: str, ) -> Sequence[types.Trajectory]: """Loads expert demonstrations. Args: - rollout_type: Can be either `local` to load rollouts from the disk or to - generate them locally or of the format `{algo}-huggingface` to load - from the huggingface hub of expert trained using `{algo}`. - rollout_path: A path containing a pickled sequence of `types.Trajectory`. + source: Can be either `local` to load rollouts from the disk, + `huggingface` to load from the HuggingFace hub or + `generated` to generate the expert trajectories. + path: A path containing a pickled sequence of `sources.Trajectory`. Returns: The expert trajectories. Raises: - ValueError: if `rollout_type` is not "local" or of the form {algo}-huggingface. + ValueError: if `source` is not in ["local", "huggingface", "generated"]. """ - if rollout_type.endswith("-huggingface"): - if rollout_path is not None: - warnings.warn( - "Ignoring `rollout_path` since `rollout_type` is set to download the " - "rollouts from the huggingface-hub. If you want to load the rollouts " - 'from disk, set `rollout_type`="local" and the path in `rollout_path`.', - RuntimeWarning, + if source == "local": + if path is None: + raise ValueError( + "When source is 'local', path must be set.", ) - rollout_path = _download_expert_rollouts(rollout_type) - elif rollout_type != "local": - raise ValueError( - "`rollout_type` can either be `local` or of the form `{algo}-huggingface`.", - ) + return _constrain_number_of_demos(serialize.load(path)) + + if source == "huggingface": + return _constrain_number_of_demos(_download_expert_rollouts()) + + if source == "generated": + if path is not None: + logger.warning("Ignoring path when source is 'generated'") + return _generate_expert_trajs() - if rollout_path is not None: - return load_local_expert_trajs(rollout_path) + raise ValueError( + "`source` can either be `local` or `huggingface` or `generated`.", + ) + + +@demonstrations_ingredient.capture +def _constrain_number_of_demos( + demos: Sequence[types.Trajectory], + n_expert_demos: Optional[int], +) -> Sequence[types.Trajectory]: + """Constrains the number of demonstrations to n_expert_demos if it is not None.""" + if n_expert_demos is None: + return demos else: - return generate_expert_trajs() + if len(demos) < n_expert_demos: + raise ValueError( + f"Want to use n_expert_demos={n_expert_demos} trajectories, but only " + f"{len(demos)} are available.", + ) + if len(demos) > n_expert_demos: + logger.warning( + f"Using only the first {n_expert_demos} trajectories out of " + f"{len(demos)} available.", + ) + return demos[:n_expert_demos] + else: + return demos @demonstrations_ingredient.capture -def generate_expert_trajs( +def _generate_expert_trajs( n_expert_demos: Optional[int], _rnd: np.random.Generator, ) -> Optional[Sequence[types.Trajectory]]: @@ -98,8 +136,9 @@ def generate_expert_trajs( ValueError: If n_expert_demos is None. """ if n_expert_demos is None: - raise ValueError("n_expert_demos must be specified when rollout_path is None") + raise ValueError("n_expert_demos must be specified when generating demos.") + logger.info(f"Generating {n_expert_demos} expert trajectories") with environment.make_rollout_venv() as rollout_env: return rollout.rollout( expert.get_expert_policy(rollout_env), @@ -110,42 +149,22 @@ def generate_expert_trajs( @demonstrations_ingredient.capture -def load_local_expert_trajs( - rollout_path: Union[str, pathlib.Path], - n_expert_demos: Optional[int], -) -> Sequence[types.Trajectory]: - """Loads expert demonstrations from a local path. - - Args: - rollout_path: A path containing a pickled sequence of `types.Trajectory`. - n_expert_demos: The number of trajectories to load. - Dataset is truncated to this length if specified. - - Returns: - The expert trajectories. +def _download_expert_rollouts( + environment: Dict[str, Any], + path: Optional[str], + organization: Optional[str], + algo_name: Optional[str], + loader_kwargs: Dict[str, Any], +): + if path is not None: + repo_id = path + else: + model_name = hfsb3.ModelName( + algo_name, + hfsb3.EnvironmentName(environment["gym_id"]), + ) + repo_id = hfsb3.ModelRepoId(organization, model_name) - Raises: - ValueError: There are fewer trajectories than `n_expert_demos`. - """ - expert_trajs = serialize.load(rollout_path) - logger.info(f"Loaded {len(expert_trajs)} expert trajectories from '{rollout_path}'") - if n_expert_demos is not None: - if len(expert_trajs) < n_expert_demos: - raise ValueError( - f"Want to use n_expert_demos={n_expert_demos} trajectories, but only " - f"{len(expert_trajs)} are available via {rollout_path}.", - ) - expert_trajs = expert_trajs[:n_expert_demos] - logger.info(f"Truncated to {n_expert_demos} expert trajectories") - return expert_trajs - - -@demonstrations_ingredient.capture(prefix="expert") -def _download_expert_rollouts(rollout_type, loader_kwargs): - assert rollout_type.endswith("-huggingface") - algo_name = rollout_type.split("-")[0] - return serialize.load_rollouts_from_huggingface( - algo_name, - env_name=loader_kwargs["env_name"], - organization=loader_kwargs["organization"], - ) + logger.info(f"Loading expert trajectories from {repo_id}") + dataset = datasets.load_dataset(repo_id, **loader_kwargs) + return huggingface_utils.TrajectoryDatasetSequence(dataset) diff --git a/src/imitation/scripts/ingredients/policy_evaluation.py b/src/imitation/scripts/ingredients/policy_evaluation.py index 80083f361..7f3348f5d 100644 --- a/src/imitation/scripts/ingredients/policy_evaluation.py +++ b/src/imitation/scripts/ingredients/policy_evaluation.py @@ -46,7 +46,7 @@ def eval_policy( `rollout_stats()` on rollouts test-reward-wrapped environment, using the final policy (remember that the ground-truth reward can be recovered from the "monitor_return" key). "expert_stats" gives the return value of - `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. + `rollout_stats()` on the expert demonstrations loaded from `path`. """ sample_until_eval = rollout.make_min_episodes(n_episodes_eval) if isinstance(rl_algo, base_class.BaseAlgorithm): diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index a815c8a01..704828ba4 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -199,7 +199,7 @@ def test_train_preference_comparisons_sac(tmpdir): # Make sure rl.sac named_config is called after rl.fast to overwrite # rl_kwargs.batch_size to None - with pytest.raises(Exception, match=".*set 'batch_size' at top-level.*"): + with pytest.raises(Exception, match="set 'batch_size' at top-level"): train_preference_comparisons.train_preference_comparisons_ex.run( named_configs=["pendulum"] + RL_SAC_NAMED_CONFIGS @@ -233,9 +233,9 @@ def _run_reward_relabel_sac_preference_comparisons(buffer_cls): assert run.status == "COMPLETED" del run - with pytest.raises(AssertionError, match=".*only ReplayBuffer is supported.*"): + with pytest.raises(AssertionError, match="only ReplayBuffer is supported"): _run_reward_relabel_sac_preference_comparisons(buffers.DictReplayBuffer) - with pytest.raises(AssertionError, match=".*only ReplayBuffer is supported.*"): + with pytest.raises(AssertionError, match="only ReplayBuffer is supported"): _run_reward_relabel_sac_preference_comparisons(HerReplayBuffer) @@ -272,7 +272,7 @@ def test_train_dagger_main(tmpdir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), ), ) for warning in record: @@ -292,7 +292,7 @@ def test_train_dagger_warmstart(tmpdir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), ), ) assert run.status == "COMPLETED" @@ -304,7 +304,7 @@ def test_train_dagger_warmstart(tmpdir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), bc=dict(agent_path=policy_path), ), ) @@ -313,7 +313,7 @@ def test_train_dagger_warmstart(tmpdir): def test_train_bc_main_with_none_demonstrations_raises_value_error(tmpdir): - with pytest.raises(ValueError, match=".*n_expert_demos.*rollout_path.*"): + with pytest.raises(ValueError, match="n_expert_demos must be specified"): train_imitation.train_imitation_ex.run( command_name="bc", named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], @@ -330,48 +330,14 @@ def test_train_bc_main_with_demonstrations_from_huggingface(tmpdir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_type="ppo-huggingface"), + demonstrations=dict( + source="huggingface", + algo_name="ppo", + ), ), ) -def test_train_bc_main_with_demonstrations_raises_error_on_wrong_huggingface_format( - tmpdir, -): - with pytest.raises( - ValueError, - match="`rollout_type` can either be `local` or of the form .*-huggingface.S*", - ): - train_imitation.train_imitation_ex.run( - command_name="bc", - named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], - config_updates=dict( - logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_type="huggingface-ppo"), - ), - ) - - -def test_train_bc_main_with_demonstrations_warns_setting_rollout_type( - tmpdir, -): - with pytest.warns( - RuntimeWarning, - match="Ignoring `rollout_path` .*", - ): - train_imitation.train_imitation_ex.run( - command_name="bc", - named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], - config_updates=dict( - logging=dict(log_root=tmpdir), - demonstrations=dict( - rollout_type="ppo-huggingface", - rollout_path="path", - ), - ), - ) - - @pytest.fixture( params=[ "expert_from_path", @@ -381,22 +347,30 @@ def test_train_bc_main_with_demonstrations_warns_setting_rollout_type( ], ) def bc_config(tmpdir, request): - expert_config = dict( - expert_from_path=dict( + environment_named_config = "seals_cartpole" + + if request.param == "expert_from_path": + expert_config = dict( policy_type="ppo", loader_kwargs=dict(path=CARTPOLE_TEST_POLICY_PATH / "model.zip"), - ), - expert_from_huggingface=dict(policy_type="ppo-huggingface"), - random_expert=dict(policy_type="random"), - zero_expert=dict(policy_type="zero"), - )[request.param] + ) + # Note: we don't have a seals_cartpole expert in our testdata folder, + # so we use the cartpole environment in this case. + environment_named_config = "cartpole" + elif request.param == "expert_from_huggingface": + expert_config = dict(policy_type="ppo-huggingface") + elif request.param == "random_expert": + expert_config = dict(policy_type="random") + elif request.param == "zero_expert": + expert_config = dict(policy_type="zero") + return dict( command_name="bc", - named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], + named_configs=[environment_named_config] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), expert=expert_config, - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), ), ) @@ -413,7 +387,7 @@ def test_train_bc_warmstart(tmpdir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), expert=dict(policy_type="ppo-huggingface"), ), ) @@ -426,7 +400,7 @@ def test_train_bc_warmstart(tmpdir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), bc=dict(agent_path=policy_path), ), ) @@ -461,7 +435,7 @@ def test_train_rl_main(tmpdir, rl_train_ppo_config): def test_train_rl_wb_logging(tmpdir): """Smoke test for imitation.scripts.ingredients.logging.wandb_logging.""" - with pytest.raises(Exception, match=".*api_key not configured.*"): + with pytest.raises(Exception, match="api_key not configured"): train_rl.train_rl_ex.run( named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"] @@ -557,7 +531,7 @@ def test_train_adversarial(tmpdir, named_configs, command): ) config_updates = { "logging": dict(log_root=tmpdir), - "demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + "demonstrations": dict(path=CARTPOLE_TEST_ROLLOUT_PATH), # TensorBoard logs to get extra coverage "algorithm_kwargs": dict(init_tensorboard=True), } @@ -575,7 +549,7 @@ def test_train_adversarial_warmstart(tmpdir, command): named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["adversarial"] config_updates = { "logging": dict(log_root=tmpdir), - "demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + "demonstrations": dict(path=CARTPOLE_TEST_ROLLOUT_PATH, source="local"), } run = train_adversarial.train_adversarial_ex.run( command_name=command, @@ -609,7 +583,7 @@ def test_train_adversarial_sac(tmpdir, command): ) config_updates = { "logging": dict(log_root=tmpdir), - "demonstrations": dict(rollout_path=PENDULUM_TEST_ROLLOUT_PATH), + "demonstrations": dict(path=PENDULUM_TEST_ROLLOUT_PATH), } run = train_adversarial.train_adversarial_ex.run( command_name=command, @@ -627,11 +601,11 @@ def test_train_adversarial_algorithm_value_error(tmpdir): base_config_updates = collections.ChainMap( { "logging": dict(log_root=tmpdir), - "demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + "demonstrations": dict(path=CARTPOLE_TEST_ROLLOUT_PATH, source="local"), }, ) - with pytest.raises(TypeError, match=".*BAD_VALUE.*"): + with pytest.raises(TypeError, match="BAD_VALUE"): train_adversarial.train_adversarial_ex.run( command_name="gail", named_configs=base_named_configs, @@ -640,17 +614,17 @@ def test_train_adversarial_algorithm_value_error(tmpdir): ), ) - with pytest.raises(FileNotFoundError, match=".*BAD_VALUE.*"): + with pytest.raises(FileNotFoundError, match="BAD_VALUE"): train_adversarial.train_adversarial_ex.run( command_name="gail", named_configs=base_named_configs, config_updates=base_config_updates.new_child( - {"demonstrations.rollout_path": "path/BAD_VALUE"}, + {"demonstrations.path": "path/BAD_VALUE"}, ), ) n_traj = 1234567 - with pytest.raises(ValueError, match=f".*{n_traj}.*"): + with pytest.raises(ValueError, match=f"{n_traj}"): train_adversarial.train_adversarial_ex.run( command_name="gail", named_configs=base_named_configs, @@ -675,7 +649,7 @@ def test_transfer_learning(tmpdir: str) -> None: named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["adversarial"], config_updates=dict( logging=dict(log_dir=log_dir_train), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), ), ) assert run.status == "COMPLETED" @@ -772,7 +746,7 @@ def test_train_rl_double_normalization(tmpdir: str, rng): log_dir_data = os.path.join(tmpdir, "train_rl") with pytest.warns( RuntimeWarning, - match=r"Applying normalization to already normalized reward function.*", + match=r"Applying normalization to already normalized reward function", ): train_rl.train_rl_ex.run( named_configs=["cartpole"] + ALGO_FAST_CONFIGS["rl"], @@ -826,7 +800,7 @@ def test_train_rl_cnn_policy(tmpdir: str, rng): base_named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["adversarial"], base_config_updates={ # Need absolute path because raylet runs in different working directory. - "demonstrations.rollout_path": CARTPOLE_TEST_ROLLOUT_PATH.absolute(), + "demonstrations.path": CARTPOLE_TEST_ROLLOUT_PATH.absolute(), }, search_space={ "command_name": tune.grid_search(["gail", "airl"]), @@ -859,24 +833,24 @@ def test_parallel_arg_errors(tmpdir): config_updates.setdefault("base_config_updates", {})["logging.log_root"] = tmpdir config_updates = collections.ChainMap(config_updates) - with pytest.raises(TypeError, match=".*Sequence.*"): + with pytest.raises(TypeError, match="Sequence"): parallel.parallel_ex.run( config_updates=config_updates.new_child(dict(base_named_configs={})), ) - with pytest.raises(TypeError, match=".*Mapping.*"): + with pytest.raises(TypeError, match="Mapping"): parallel.parallel_ex.run( config_updates=config_updates.new_child(dict(base_config_updates=())), ) - with pytest.raises(TypeError, match=".*Sequence.*"): + with pytest.raises(TypeError, match="Sequence"): parallel.parallel_ex.run( config_updates=config_updates.new_child( dict(search_space={"named_configs": {}}), ), ) - with pytest.raises(TypeError, match=".*Mapping.*"): + with pytest.raises(TypeError, match="Mapping"): parallel.parallel_ex.run( config_updates=config_updates.new_child( dict(search_space={"config_updates": ()}), @@ -892,8 +866,8 @@ def _generate_test_rollouts(tmpdir: str, env_named_config: str) -> pathlib.Path: logging=dict(log_dir=tmpdir), ), ) - rollout_path = tmpdir_path / "rollouts/final.npz" - return rollout_path.absolute() + path = tmpdir_path / "rollouts/final.npz" + return path.absolute() def test_parallel_train_adversarial_custom_env(tmpdir): @@ -904,7 +878,7 @@ def test_parallel_train_adversarial_custom_env(tmpdir): ) env_named_config = "pendulum" - rollout_path = _generate_test_rollouts(tmpdir, env_named_config) + path = _generate_test_rollouts(tmpdir, env_named_config) config_updates = dict( sacred_ex_name="train_adversarial", @@ -912,7 +886,7 @@ def test_parallel_train_adversarial_custom_env(tmpdir): base_named_configs=[env_named_config] + ALGO_FAST_CONFIGS["adversarial"], base_config_updates=dict( logging=dict(log_root=tmpdir), - demonstrations=dict(rollout_path=rollout_path), + demonstrations=dict(path=path), ), search_space=dict(command_name="gail"), ) @@ -927,7 +901,7 @@ def _run_train_adv_for_test_analyze_imit(run_name, sacred_logs_dir, log_dir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["adversarial"], config_updates=dict( logging=dict(log_root=log_dir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), checkpoint_interval=-1, ), options={"--name": run_name, "--file_storage": sacred_logs_dir}, @@ -941,7 +915,7 @@ def _run_train_bc_for_test_analyze_imit(run_name, sacred_logs_dir, log_dir): named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"], config_updates=dict( logging=dict(log_dir=log_dir), - demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH), + demonstrations=dict(path=CARTPOLE_TEST_ROLLOUT_PATH), ), options={"--name": run_name, "--file_storage": sacred_logs_dir}, )