Skip to content

Commit

Permalink
Transition from Models Hub to Datasets Hub for expert trajectories (#723
Browse files Browse the repository at this point in the history
)

* Remove load_rolluts_from_huggingface and replace it with code in demonstrations.py that loads demonstrations from huggingface datasets instead of huggingface models.

* Allow specifying the repo_id directly in the loader_kwargs of the demonstrations ingredient and pass remaining loader_kwargs to datasets.load_dataset.

* Simplify demonstrations ingredient configuration and make it more flexible at the same time.

* Remove now obsolete test.

* Rename rollout_type to type and rollout_path to path and make default type "generated" to match previous behavior.

* Fix documentation of the Raises: section of get_exper_trajectories() and improve wording of ValueError when n_expert_demos is missing while generating trajectories.

* Simplify unnecessarily complex regexes to match raised exceptions during testing.

* Fix regex for ValueError to reflect updated ValueError string.

* Add an edge case to accommodate the fact that the HuggingFace Hub only has an expert for seals/Cartpole while the testdata folder only has an expert for normal Cartpole.

* Make it explicit that in some tests the rollout should be loaded locally from disk.

* Fix formatting issues in test_scripts.py

* Rename demonstrations.type to demonstrations.source to overcome name-clash with build-in keyword of python.

* Make sure to load local demonstrations in quickstart.sh

* Fix formatting issue.

* Ensure the readme contains the same snippet as the examples.
  • Loading branch information
ernestum committed Jul 7, 2023
1 parent 688e163 commit b452384
Show file tree
Hide file tree
Showing 27 changed files with 188 additions and 187 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ From [examples/quickstart.sh:](examples/quickstart.sh)
python -m imitation.scripts.train_rl with pendulum environment.fast policy_evaluation.fast rl.fast fast logging.log_dir=quickstart/rl/

# Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory).
python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz
python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local

# Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory).
python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz
python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local
```

Tips:
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_airl_seals_ant_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"reward": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"reward": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_airl_seals_hopper_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"reward": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_airl_seals_swimmer_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"expert": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_airl_seals_walker_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"expert": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_bc_seals_ant_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_bc_seals_hopper_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_bc_seals_swimmer_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_bc_seals_walker_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_dagger_seals_ant_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_dagger_seals_hopper_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_dagger_seals_swimmer_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_dagger_seals_walker_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"use_offline_rollouts": false
},
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"policy": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_gail_seals_ant_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"reward": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"reward": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_gail_seals_hopper_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"reward": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_gail_seals_swimmer_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"expert": {
Expand Down
3 changes: 2 additions & 1 deletion benchmarking/example_gail_seals_walker_best_hp_eval.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
},
"checkpoint_interval": 0,
"demonstrations": {
"rollout_type": "ppo-huggingface",
"source": "huggingface",
"algo_name": "ppo",
"n_expert_demos": null
},
"expert": {
Expand Down
2 changes: 1 addition & 1 deletion benchmarking/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None:
# remove key 'agent_path'
config.pop("agent_path")
config.pop("seed")
config.get("demonstrations", {}).pop("rollout_path")
config.get("demonstrations", {}).pop("path")
config.get("expert", {}).get("loader_kwargs", {}).pop("path", None)
env_name = config.pop("environment").pop("gym_id")
config["environment"] = {"gym_id": env_name}
Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
python -m imitation.scripts.train_rl with pendulum environment.fast policy_evaluation.fast rl.fast fast logging.log_dir=quickstart/rl/

# Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory).
python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz
python -m imitation.scripts.train_adversarial gail with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local

# Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory).
python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.rollout_path=quickstart/rl/rollouts/final.npz
python -m imitation.scripts.train_adversarial airl with pendulum environment.fast demonstrations.fast policy_evaluation.fast rl.fast fast demonstrations.path=quickstart/rl/rollouts/final.npz demonstrations.source=local
12 changes: 0 additions & 12 deletions src/imitation/data/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Mapping, Sequence, cast

import datasets
import huggingface_sb3 as hfsb3
import numpy as np

from imitation.data import huggingface_utils
Expand Down Expand Up @@ -87,14 +86,3 @@ def load_with_rewards(path: AnyPath) -> Sequence[TrajectoryWithRew]:
)

return cast(Sequence[TrajectoryWithRew], data)


def load_rollouts_from_huggingface(
algo_name: str,
env_name: str,
organization: str = "HumanCompatibleAI",
) -> str:
model_name = hfsb3.ModelName(algo_name, hfsb3.EnvironmentName(env_name))
repo_id = hfsb3.ModelRepoId(organization, model_name)
filename = hfsb3.load_from_hub(repo_id, "rollouts.npz")
return filename
Loading

0 comments on commit b452384

Please sign in to comment.