-
Notifications
You must be signed in to change notification settings - Fork 247
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Pin SB3 version to 1.7.0 (#738) * Update conftest.py (#742) * Custom environment tutorial (#746) * Custom environment tutorial draft * Update the docs website * Clean notebook * Text clarification and new environment * Decrease training duration to hopefully make CI happy * Clarify that BC itself does not learn rewards --------- Co-authored-by: Ariel Kwiatkowski <ariel.j.kwiatkowski@gmail.com> * Tutorial on comparing algorithm performance (#747) * Add a new tutorial * Update index.rst * Improvements to the tutorial * Some more caution words * Fix typos --------- Co-authored-by: Ariel Kwiatkowski <ariel.j.kwiatkowski@gmail.com> --------- Co-authored-by: Adam Gleave <adam@gleave.me>
- Loading branch information
1 parent
68f693b
commit c4b0521
Showing
4 changed files
with
850 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,366 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_custom_env.ipynb)\n", | ||
"# Train Behavior Cloning in a Custom Environment\n", | ||
"\n", | ||
"You can use `imitation` to train a policy (and, for many imitation learning algorithm, learn rewards) in a custom environment.\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Step 1: Define the environment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We will use a simple ObservationMatching environment as an example. The premise is simple -- the agent receives a vector of observations, and must output a vector of actions that matches the observations as closely as possible.\n", | ||
"\n", | ||
"If you have your own environment that you'd like to use, you can replace the code below with your own environment. Make sure it complies with the standard Gym API, and that the observation and action spaces are specified correctly." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import gym\n", | ||
"\n", | ||
"from gym.spaces import Box\n", | ||
"from gym.utils import seeding\n", | ||
"\n", | ||
"\n", | ||
"class ObservationMatchingEnv(gym.Env):\n", | ||
" def __init__(self, num_options: int = 2):\n", | ||
" self.num_options = num_options\n", | ||
" self.observation_space = Box(0, 1, shape=(num_options,), dtype=np.float32)\n", | ||
" self.action_space = Box(0, 1, shape=(num_options,), dtype=np.float32)\n", | ||
" self.seed()\n", | ||
"\n", | ||
" def seed(self, seed=None):\n", | ||
" self.np_random, seed = seeding.np_random(seed)\n", | ||
" return [seed]\n", | ||
"\n", | ||
" def reset(self):\n", | ||
" self.state = self.np_random.uniform(size=self.num_options)\n", | ||
" return self.state\n", | ||
"\n", | ||
" def step(self, action):\n", | ||
" reward = -np.abs(self.state - action).mean()\n", | ||
" self.state = self.np_random.uniform(size=self.num_options)\n", | ||
" return self.state, reward, False, {}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Step 2: create the environment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"From here, we have two options:\n", | ||
"- Add the environment to the gym registry, and use it with existing utilities (e.g. `make`)\n", | ||
"- Use the environment directly\n", | ||
"\n", | ||
"You only need to execute the cells in step 2a, or step 2b to proceed.\n", | ||
"\n", | ||
"At the end of these steps, we want to have:\n", | ||
"- `env`: a single environment that we can use for training an expert with SB3\n", | ||
"- `venv`: a vectorized environment where each individual environment is wrapped in `RolloutInfoWrapper`, that we can use for collecting rollouts with `imitation`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Step 2a (recommended): add the environment to the gym registry" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The standard approach is adding the environment to the gym registry." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"gym.register(\n", | ||
" id=\"custom/ObservationMatching-v0\",\n", | ||
" entry_point=ObservationMatchingEnv, # This can also be the path to the class, e.g. `observation_matching:ObservationMatchingEnv`\n", | ||
" max_episode_steps=500,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After registering, you can create an environment is `gym.make(env_id)` which automatically handles the `TimeLimit` wrapper.\n", | ||
"\n", | ||
"To create a vectorized env, you can use the `make_vec_env` helper function (Option A), or create it directly (Options B1 and B2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from gym.wrappers import TimeLimit\n", | ||
"from imitation.data import rollout\n", | ||
"from imitation.data.wrappers import RolloutInfoWrapper\n", | ||
"from imitation.util.util import make_vec_env\n", | ||
"from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n", | ||
"\n", | ||
"# Create a single environment for training an expert with SB3\n", | ||
"env = gym.make(\"custom/ObservationMatching-v0\")\n", | ||
"\n", | ||
"\n", | ||
"# Create a vectorized environment for training with `imitation`\n", | ||
"\n", | ||
"# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers=[lambda env, _: RolloutInfoWrapper(env)]`\n", | ||
"venv = make_vec_env(\n", | ||
" \"custom/ObservationMatching-v0\",\n", | ||
" rng=np.random.default_rng(),\n", | ||
" n_envs=4,\n", | ||
" post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"# Option B1: use a custom env creator, and create VecEnv directly\n", | ||
"# def _make_env():\n", | ||
"# \"\"\"Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper.\"\"\"\n", | ||
"# _env = gym.make(\"custom/ObservationMatching-v0\")\n", | ||
"# _env = RolloutInfoWrapper(_env)\n", | ||
"# return _env\n", | ||
"#\n", | ||
"# venv = DummyVecEnv([_make_env for _ in range(4)])\n", | ||
"#\n", | ||
"# # Option B2: we can also use a parallel VecEnv implementation\n", | ||
"# venv = SubprocVecEnv([_make_env for _ in range(4)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"## Step 2b: directly use the environment\n", | ||
"\n", | ||
"Alternatively, we can directly initialize the environment by instantiating the class we created earlier, and handle all the additional logic ourselves." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from gym.wrappers import TimeLimit\n", | ||
"from imitation.data import rollout\n", | ||
"from imitation.data.wrappers import RolloutInfoWrapper\n", | ||
"from stable_baselines3.common.vec_env import DummyVecEnv\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"# Create a single environment for training with SB3\n", | ||
"env = ObservationMatchingEnv()\n", | ||
"env = TimeLimit(env, max_episode_steps=500)\n", | ||
"\n", | ||
"# Create a vectorized environment for training with `imitation`\n", | ||
"\n", | ||
"\n", | ||
"# Option A: use a helper function to create multiple environments\n", | ||
"def _make_env():\n", | ||
" \"\"\"Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper.\"\"\"\n", | ||
" _env = ObservationMatchingEnv()\n", | ||
" _env = TimeLimit(_env, max_episode_steps=500)\n", | ||
" _env = RolloutInfoWrapper(_env)\n", | ||
" return _env\n", | ||
"\n", | ||
"\n", | ||
"venv = DummyVecEnv([_make_env for _ in range(4)])\n", | ||
"\n", | ||
"\n", | ||
"# Option B: use a single environment\n", | ||
"# env = FixedHorizonCartPoleEnv()\n", | ||
"# venv = DummyVecEnv([lambda: RolloutInfoWrapper(env)]) # Wrap a single environment -- only useful for simple testing like this\n", | ||
"\n", | ||
"# Option C: use multiple environments\n", | ||
"# venv = DummyVecEnv([lambda: RolloutInfoWrapper(ObservationMatchingEnv()) for _ in range(4)]) # Wrap multiple environments" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Step 3: Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"And now we're just about done! Whether you used step 2a or 2b, your environment should now be ready to use with SB3 and `imitation`.\n", | ||
"\n", | ||
"For the sake of completeness, we'll train a BC model, the same way as in the first tutorial, but with our custom environment.\n", | ||
"\n", | ||
"Keep in mind that while we're using BC in this tutorial, you can just as easily use any of the other algorithms with the environment prepared in this way." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from stable_baselines3 import PPO\n", | ||
"from stable_baselines3.ppo import MlpPolicy\n", | ||
"from stable_baselines3.common.evaluation import evaluate_policy\n", | ||
"from gym.wrappers import TimeLimit\n", | ||
"\n", | ||
"expert = PPO(\n", | ||
" policy=MlpPolicy,\n", | ||
" env=env,\n", | ||
" seed=0,\n", | ||
" batch_size=64,\n", | ||
" ent_coef=0.0,\n", | ||
" learning_rate=0.0003,\n", | ||
" n_epochs=10,\n", | ||
" n_steps=64,\n", | ||
")\n", | ||
"\n", | ||
"reward, _ = evaluate_policy(expert, env, 10)\n", | ||
"print(f\"Reward before training: {reward}\")\n", | ||
"\n", | ||
"\n", | ||
"# Note: if you followed step 2a, i.e. registered the environment, you can use the environment name directly\n", | ||
"\n", | ||
"# expert = PPO(\n", | ||
"# policy=MlpPolicy,\n", | ||
"# env=\"custom/ObservationMatching-v0\",\n", | ||
"# seed=0,\n", | ||
"# batch_size=64,\n", | ||
"# ent_coef=0.0,\n", | ||
"# learning_rate=0.0003,\n", | ||
"# n_epochs=10,\n", | ||
"# n_steps=64,\n", | ||
"# )\n", | ||
"expert.learn(10_000) # Note: set to 100000 to train a proficient expert\n", | ||
"\n", | ||
"reward, _ = evaluate_policy(expert, env, 10)\n", | ||
"print(f\"Expert reward: {reward}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rng = np.random.default_rng()\n", | ||
"rollouts = rollout.rollout(\n", | ||
" expert,\n", | ||
" venv,\n", | ||
" rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n", | ||
" rng=rng,\n", | ||
")\n", | ||
"transitions = rollout.flatten_trajectories(rollouts)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from imitation.algorithms import bc\n", | ||
"\n", | ||
"bc_trainer = bc.BC(\n", | ||
" observation_space=env.observation_space,\n", | ||
" action_space=env.action_space,\n", | ||
" demonstrations=transitions,\n", | ||
" rng=rng,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"As before, the untrained policy only gets poor rewards:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n", | ||
"print(f\"Reward before training: {reward_before_training}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After training, we can get much closer to the expert's performance:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"bc_trainer.train(n_epochs=1)\n", | ||
"reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n", | ||
"print(f\"Reward after training: {reward_after_training}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.