Skip to content

Commit

Permalink
Replace imitation environments with seals (#541)
Browse files Browse the repository at this point in the history
* Replace imitation with seals

* Fix bug in test

* Manually force CI to fetch latest seals changes

* Update code for new seals changes.

* Fix notebook example

* Update seals version

* Replace old seals naming convention

* Fix docs examples

* Rename environment identifier

* Rename environment identifier

* Rename wrong attribute

* Remove empty files from merge conflict

* Remove unused env testing files

* Remove import rename for seals package
  • Loading branch information
Rocamonde authored Oct 12, 2022
1 parent 41c41b1 commit 288c25a
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 910 deletions.
12 changes: 7 additions & 5 deletions docs/algorithms/mce_irl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce`

from functools import partial

from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
import numpy as np

from stable_baselines3.common.vec_env import DummyVecEnv
Expand All @@ -23,24 +25,24 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce`
mce_partition_fh,
)
from imitation.data import rollout
from imitation.envs import resettable_env
from imitation.envs.examples.model_envs import CliffWorld
from imitation.rewards import reward_nets

rng = np.random.default_rng(0)

env_creator = partial(CliffWorld, height=4, horizon=8, width=7, use_xy_obs=True)
env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)
env_single = env_creator()

state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())

# This is just a vectorized environment because `generate_trajectories` expects one
state_venv = resettable_env.DictExtractWrapper(DummyVecEnv([env_creator] * 4), "state")
state_venv = DummyVecEnv([state_env_creator] * 4)

_, _, pi = mce_partition_fh(env_single)

_, om = mce_occupancy_measures(env_single, pi=pi)

reward_net = reward_nets.BasicRewardNet(
env_single.pomdp_observation_space,
env_single.observation_space,
env_single.action_space,
hid_sizes=[256],
use_action=False,
Expand Down
17 changes: 10 additions & 7 deletions docs/tutorials/6_train_mce.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,28 @@
"source": [
"from functools import partial\n",
"\n",
"import numpy as np\n",
"from seals import base_envs\n",
"from seals.diagnostics.cliff_world import CliffWorldEnv\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
"\n",
"import numpy as np\n",
"\n",
"from imitation.algorithms.mce_irl import (\n",
" MCEIRL,\n",
" mce_occupancy_measures,\n",
" mce_partition_fh,\n",
" TabularPolicy,\n",
")\n",
"from imitation.data import rollout\n",
"from imitation.envs import resettable_env\n",
"from imitation.envs.examples.model_envs import CliffWorld\n",
"from imitation.rewards import reward_nets\n",
"\n",
"env_creator = partial(CliffWorld, height=4, horizon=8, width=7, use_xy_obs=True)\n",
"env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)\n",
"env_single = env_creator()\n",
"\n",
"state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())\n",
"\n",
"# This is just a vectorized environment because `generate_trajectories` expects one\n",
"state_venv = resettable_env.DictExtractWrapper(DummyVecEnv([env_creator] * 4), \"state\")"
"state_venv = DummyVecEnv([state_env_creator] * 4)"
]
},
{
Expand All @@ -65,7 +68,7 @@
"\n",
"rng = np.random.default_rng()\n",
"expert = TabularPolicy(\n",
" state_space=env_single.pomdp_state_space,\n",
" state_space=env_single.state_space,\n",
" action_space=env_single.action_space,\n",
" pi=pi,\n",
" rng=rng,\n",
Expand Down Expand Up @@ -102,7 +105,7 @@
"\n",
"def train_mce_irl(demos, hidden_sizes, lr=0.01, **kwargs):\n",
" reward_net = reward_nets.BasicRewardNet(\n",
" env_single.pomdp_observation_space,\n",
" env_single.observation_space,\n",
" env_single.action_space,\n",
" hid_sizes=hidden_sizes,\n",
" use_action=False,\n",
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# working versions to make our CI/CD pipeline as stable as possible.
TESTS_REQUIRE = (
[
"seals==0.1.2",
"black[jupyter]~=22.6.0",
"coverage~=6.4.2",
"codecov~=2.1.12",
Expand Down Expand Up @@ -74,7 +73,6 @@
"sphinx-github-changelog~=1.2.0",
"myst-nb==0.16.0",
"ipykernel~=6.15.2",
"seals==0.1.2",
] + ATARI_REQUIRE


Expand Down Expand Up @@ -200,6 +198,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"torch>=1.4.0",
"tqdm",
"scikit-learn>=0.21.2",
"seals==0.1.4",
STABLE_BASELINES3,
# TODO(adam) switch to upstream release if they make it
# See https://github.com/IDSIA/sacred/issues/879
Expand Down
26 changes: 13 additions & 13 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
import numpy as np
import scipy.special
import torch as th
from seals import base_envs
from stable_baselines3.common import policies

from imitation.algorithms import base
from imitation.data import rollout, types
from imitation.envs import resettable_env
from imitation.rewards import reward_nets
from imitation.util import logger as imit_logger
from imitation.util import networks, util


def mce_partition_fh(
env: resettable_env.TabularModelEnv,
env: base_envs.TabularModelPOMDP,
*,
reward: Optional[np.ndarray] = None,
discount: float = 1.0,
Expand All @@ -46,8 +46,8 @@ def mce_partition_fh(
"""
# shorthand
horizon = env.horizon
n_states = env.n_states
n_actions = env.n_actions
n_states = env.state_dim
n_actions = env.action_dim
T = env.transition_matrix
if reward is None:
reward = env.reward_matrix
Expand Down Expand Up @@ -77,7 +77,7 @@ def mce_partition_fh(


def mce_occupancy_measures(
env: resettable_env.TabularModelEnv,
env: base_envs.TabularModelPOMDP,
*,
reward: Optional[np.ndarray] = None,
pi: Optional[np.ndarray] = None,
Expand All @@ -102,8 +102,8 @@ def mce_occupancy_measures(
"""
# shorthand
horizon = env.horizon
n_states = env.n_states
n_actions = env.n_actions
n_states = env.state_dim
n_actions = env.action_dim
T = env.transition_matrix
if reward is None:
reward = env.reward_matrix
Expand Down Expand Up @@ -257,7 +257,7 @@ class MCEIRL(base.DemonstrationAlgorithm[types.TransitionsMinimal]):
def __init__(
self,
demonstrations: Optional[MCEDemonstrations],
env: resettable_env.TabularModelEnv,
env: base_envs.TabularModelPOMDP,
reward_net: reward_nets.RewardNet,
rng: np.random.Generator,
optimizer_cls: Type[th.optim.Optimizer] = th.optim.Adam,
Expand Down Expand Up @@ -318,17 +318,17 @@ def __init__(
# Initialize policy to be uniform random. We don't use this for MCE IRL
# training, but it gives us something to return at all times with `policy`
# property, similar to other algorithms.
ones = np.ones((self.env.horizon, self.env.n_states, self.env.n_actions))
uniform_pi = ones / self.env.n_actions
ones = np.ones((self.env.horizon, self.env.state_dim, self.env.action_dim))
uniform_pi = ones / self.env.action_dim
self._policy = TabularPolicy(
state_space=self.env.pomdp_state_space,
state_space=self.env.state_space,
action_space=self.env.action_space,
pi=uniform_pi,
rng=self.rng,
)

def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None:
self.demo_state_om = np.zeros((self.env.n_states,))
self.demo_state_om = np.zeros((self.env.state_dim,))
num_demos = 0
for traj in trajs:
cum_discount = 1.0
Expand All @@ -344,7 +344,7 @@ def _set_demo_from_obs(
dones: Optional[np.ndarray],
next_obses: Optional[np.ndarray],
) -> None:
self.demo_state_om = np.zeros((self.env.n_states,))
self.demo_state_om = np.zeros((self.env.state_dim,))

for obs in obses:
if isinstance(obs, th.Tensor):
Expand Down
1 change: 0 additions & 1 deletion src/imitation/envs/__init__.py

This file was deleted.

8 changes: 0 additions & 8 deletions src/imitation/envs/examples/__init__.py

This file was deleted.

Loading

0 comments on commit 288c25a

Please sign in to comment.