Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace imitation environments with seals #541

Merged
merged 18 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/algorithms/mce_irl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce`

from functools import partial

from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
import numpy as np

from stable_baselines3.common.vec_env import DummyVecEnv
Expand All @@ -23,24 +25,24 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce`
mce_partition_fh,
)
from imitation.data import rollout
from imitation.envs import resettable_env
from imitation.envs.examples.model_envs import CliffWorld
from imitation.rewards import reward_nets

rng = np.random.default_rng(0)

env_creator = partial(CliffWorld, height=4, horizon=8, width=7, use_xy_obs=True)
env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)
env_single = env_creator()

state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())

# This is just a vectorized environment because `generate_trajectories` expects one
state_venv = resettable_env.DictExtractWrapper(DummyVecEnv([env_creator] * 4), "state")
state_venv = DummyVecEnv([state_env_creator] * 4)

_, _, pi = mce_partition_fh(env_single)

_, om = mce_occupancy_measures(env_single, pi=pi)

reward_net = reward_nets.BasicRewardNet(
env_single.pomdp_observation_space,
env_single.observation_space,
env_single.action_space,
hid_sizes=[256],
use_action=False,
Expand Down
17 changes: 10 additions & 7 deletions docs/tutorials/6_train_mce.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,28 @@
"source": [
"from functools import partial\n",
"\n",
"import numpy as np\n",
"from seals import base_envs\n",
"from seals.diagnostics.cliff_world import CliffWorldEnv\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
"\n",
"import numpy as np\n",
"\n",
"from imitation.algorithms.mce_irl import (\n",
" MCEIRL,\n",
" mce_occupancy_measures,\n",
" mce_partition_fh,\n",
" TabularPolicy,\n",
")\n",
"from imitation.data import rollout\n",
"from imitation.envs import resettable_env\n",
"from imitation.envs.examples.model_envs import CliffWorld\n",
"from imitation.rewards import reward_nets\n",
"\n",
"env_creator = partial(CliffWorld, height=4, horizon=8, width=7, use_xy_obs=True)\n",
"env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)\n",
"env_single = env_creator()\n",
"\n",
"state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())\n",
"\n",
"# This is just a vectorized environment because `generate_trajectories` expects one\n",
"state_venv = resettable_env.DictExtractWrapper(DummyVecEnv([env_creator] * 4), \"state\")"
"state_venv = DummyVecEnv([state_env_creator] * 4)"
]
},
{
Expand All @@ -65,7 +68,7 @@
"\n",
"rng = np.random.default_rng()\n",
"expert = TabularPolicy(\n",
" state_space=env_single.pomdp_state_space,\n",
" state_space=env_single.state_space,\n",
" action_space=env_single.action_space,\n",
" pi=pi,\n",
" rng=rng,\n",
Expand Down Expand Up @@ -102,7 +105,7 @@
"\n",
"def train_mce_irl(demos, hidden_sizes, lr=0.01, **kwargs):\n",
" reward_net = reward_nets.BasicRewardNet(\n",
" env_single.pomdp_observation_space,\n",
" env_single.observation_space,\n",
" env_single.action_space,\n",
" hid_sizes=hidden_sizes,\n",
" use_action=False,\n",
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# working versions to make our CI/CD pipeline as stable as possible.
TESTS_REQUIRE = (
[
"seals==0.1.2",
"black[jupyter]~=22.6.0",
"coverage~=6.4.2",
"codecov~=2.1.12",
Expand Down Expand Up @@ -74,7 +73,6 @@
"sphinx-github-changelog~=1.2.0",
"myst-nb==0.16.0",
"ipykernel~=6.15.2",
"seals==0.1.2",
] + ATARI_REQUIRE


Expand Down Expand Up @@ -200,6 +198,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"torch>=1.4.0",
"tqdm",
"scikit-learn>=0.21.2",
"seals==0.1.4",
STABLE_BASELINES3,
# TODO(adam) switch to upstream release if they make it
# See https://github.com/IDSIA/sacred/issues/879
Expand Down
26 changes: 13 additions & 13 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
import numpy as np
import scipy.special
import torch as th
from seals import base_envs
from stable_baselines3.common import policies

from imitation.algorithms import base
from imitation.data import rollout, types
from imitation.envs import resettable_env
from imitation.rewards import reward_nets
from imitation.util import logger as imit_logger
from imitation.util import networks, util


def mce_partition_fh(
env: resettable_env.TabularModelEnv,
env: base_envs.TabularModelPOMDP,
*,
reward: Optional[np.ndarray] = None,
discount: float = 1.0,
Expand All @@ -46,8 +46,8 @@ def mce_partition_fh(
"""
# shorthand
horizon = env.horizon
n_states = env.n_states
n_actions = env.n_actions
n_states = env.state_dim
n_actions = env.action_dim
T = env.transition_matrix
if reward is None:
reward = env.reward_matrix
Expand Down Expand Up @@ -77,7 +77,7 @@ def mce_partition_fh(


def mce_occupancy_measures(
env: resettable_env.TabularModelEnv,
env: base_envs.TabularModelPOMDP,
*,
reward: Optional[np.ndarray] = None,
pi: Optional[np.ndarray] = None,
Expand All @@ -102,8 +102,8 @@ def mce_occupancy_measures(
"""
# shorthand
horizon = env.horizon
n_states = env.n_states
n_actions = env.n_actions
n_states = env.state_dim
n_actions = env.action_dim
T = env.transition_matrix
if reward is None:
reward = env.reward_matrix
Expand Down Expand Up @@ -257,7 +257,7 @@ class MCEIRL(base.DemonstrationAlgorithm[types.TransitionsMinimal]):
def __init__(
self,
demonstrations: Optional[MCEDemonstrations],
env: resettable_env.TabularModelEnv,
env: base_envs.TabularModelPOMDP,
reward_net: reward_nets.RewardNet,
rng: np.random.Generator,
optimizer_cls: Type[th.optim.Optimizer] = th.optim.Adam,
Expand Down Expand Up @@ -318,17 +318,17 @@ def __init__(
# Initialize policy to be uniform random. We don't use this for MCE IRL
# training, but it gives us something to return at all times with `policy`
# property, similar to other algorithms.
ones = np.ones((self.env.horizon, self.env.n_states, self.env.n_actions))
uniform_pi = ones / self.env.n_actions
ones = np.ones((self.env.horizon, self.env.state_dim, self.env.action_dim))
uniform_pi = ones / self.env.action_dim
self._policy = TabularPolicy(
state_space=self.env.pomdp_state_space,
state_space=self.env.state_space,
action_space=self.env.action_space,
pi=uniform_pi,
rng=self.rng,
)

def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None:
self.demo_state_om = np.zeros((self.env.n_states,))
self.demo_state_om = np.zeros((self.env.state_dim,))
num_demos = 0
for traj in trajs:
cum_discount = 1.0
Expand All @@ -344,7 +344,7 @@ def _set_demo_from_obs(
dones: Optional[np.ndarray],
next_obses: Optional[np.ndarray],
) -> None:
self.demo_state_om = np.zeros((self.env.n_states,))
self.demo_state_om = np.zeros((self.env.state_dim,))

for obs in obses:
if isinstance(obs, th.Tensor):
Expand Down
1 change: 0 additions & 1 deletion src/imitation/envs/__init__.py

This file was deleted.

8 changes: 0 additions & 8 deletions src/imitation/envs/examples/__init__.py

This file was deleted.

Loading