Skip to content

Commit

Permalink
Feature/actions as obs (#291)
Browse files Browse the repository at this point in the history
* feat: added actions as obs wrapper

* fix: actions shape

* fix: action_stack key

* feat: added controls

* fix: multi-discrete action stack

* test: update

* feat: added mlp_keys to prepare obs of sac, droq and a2c

* feat: added from __future__ import annotations

* feat: added noop + test

* feat: update tests + added controls in wrapper + update docs

* fix: typo
  • Loading branch information
michele-milesi authored May 27, 2024
1 parent 28d14ba commit 878620a
Show file tree
Hide file tree
Showing 12 changed files with 244 additions and 13 deletions.
27 changes: 27 additions & 0 deletions howto/actions_as_observation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Actions as Observations Wrapper
In this how-to, some indications are given on how to use the Actions as Observations Wrapper.

When you want to add the last `n` actions to the observations, you must specify three parameters in the [`./configs/env/default.yaml`](../sheeprl/configs/env/default.yaml) file:
- `actions_as_observation.num_stack` (integer greater than 0): The number of actions to add to the observations.
- `actions_as_observation.dilation` (integer greater than 0): The dilation (number of steps) between one action and the next one.
- `actions_as_observation.noop` (integer or float or list of integer): The noop action to use when resetting the environment, the buffer is filled with this action. Every environment has its own NOOP action, it is strongly recommended to use that action for the correct learning of the algorithm.

## NOOP Parameter
The NOOP parameter must be:
- An integer for discrete action spaces
- A float for continuous action spaces
- A list of integers for multi-discrete action spaces: the length of the list must be equal to the number of actions in the environment.

Each environment has its own NOOP action, usually it is specified in the documentation. Below we reported the list of noop actions of the environments supported in SheepRL:
- MuJoCo (both gymnasium and DMC) environments: `0.0`.
- Atari environments: `0`.
- Crafter: `0`.
- MineRL: `0`.
- MineDojo: `[0, 0, 0]`.
- Super Mario Bros: `0`.
- Diambra:
- Discrete: `0`.
- Multi-discrete: `[0, 0]`.
- Box2D (gymnasium):
- Discrete: `0`.
- Continuous: `0.0`.
9 changes: 9 additions & 0 deletions howto/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,19 @@ grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
actions_as_observation:
num_stack: -1
noop: "You MUST define the NOOP"
dilation: 1
max_episode_steps: null
reward_as_observation: False
wrapper: ???
```

> [!NOTE]
>
> The actions as observations wrapper is used for adding the last `n` actions to the observations. For more information, check the corresponding [howto file](./actions_as_observation.md).

Every custom environment must then "inherit" from this default config, override the particular parameters, and define the `wrapper` field, which is the one that will be directly instantiated at runtime. The `wrapper` field must define all the specific parameters to be passed to the `_target_` function when the wrapper will be instantiated. Take for example the `atari.yaml` config:

```yaml
Expand Down
6 changes: 4 additions & 2 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(
fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs
)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.stack(actions, -1).cpu().numpy()
Expand Down Expand Up @@ -304,7 +306,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand Down
10 changes: 6 additions & 4 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict
from typing import Any, Dict, Sequence

import numpy as np
import torch
Expand All @@ -13,8 +13,10 @@
AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()}
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(obs[k].copy()).to(fabric.device).float().reshape(num_envs, -1) for k in mlp_keys}
return torch_obs


Expand All @@ -28,7 +30,7 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):

while not done:
# Convert observations to tensors
torch_obs = prepare_obs(fabric, obs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(torch_obs, greedy=True)
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
with torch.inference_mode():
# Sample an action given the observation received by the environment
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
# Sample an action given the observation received by the environment
with torch.inference_mode():
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def player(
actions = envs.action_space.sample()
else:
# Sample an action given the observation received by the environment
torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs)
actions = actor(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Tensor:
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Tensor:
with fabric.device:
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1)
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in mlp_keys], dim=-1)
return torch_obs.reshape(num_envs, -1)


Expand All @@ -43,7 +45,7 @@ def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
obs = env.reset(seed=cfg.seed)[0]
while not done:
# Act greedly through the environment
torch_obs = prepare_obs(fabric, obs)
torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder)
action = actor.get_actions(torch_obs, greedy=True)

# Single environment step
Expand Down
4 changes: 4 additions & 0 deletions sheeprl/configs/env/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ grayscale: False
clip_rewards: False
capture_video: True
frame_stack_dilation: 1
actions_as_observation:
num_stack: -1
noop: "You MUST define the NOOP"
dilation: 1
max_episode_steps: null
reward_as_observation: False
wrapper: ???
89 changes: 89 additions & 0 deletions sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import time
from collections import deque
Expand Down Expand Up @@ -251,3 +253,90 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
if len(frame.shape) == 3 and frame.shape[-1] == 1:
frame = frame.repeat(3, axis=-1)
return frame


class ActionsAsObservationWrapper(gym.Wrapper):
def __init__(self, env: Env, num_stack: int, noop: float | int | List[int], dilation: int = 1):
super().__init__(env)
if num_stack < 1:
raise ValueError(
"The number of actions to the `action_stack` observation "
f"must be greater or equal than 1, got: {num_stack}"
)
if dilation < 1:
raise ValueError(f"The actions stack dilation argument must be greater than zero, got: {dilation}")
if not isinstance(noop, (int, float, list)):
raise ValueError(f"The noop action must be an integer or float or list, got: {noop} ({type(noop)})")
self._num_stack = num_stack
self._dilation = dilation
self._actions = deque(maxlen=num_stack * dilation)
self._is_continuous = isinstance(self.env.action_space, gym.spaces.Box)
self._is_multidiscrete = isinstance(self.env.action_space, gym.spaces.MultiDiscrete)
self.observation_space = copy.deepcopy(self.env.observation_space)
if self._is_continuous:
self._action_shape = self.env.action_space.shape[0]
low = np.resize(self.env.action_space.low, self._action_shape * num_stack)
high = np.resize(self.env.action_space.high, self._action_shape * num_stack)
elif self._is_multidiscrete:
low = 0
high = 1 # one-hot encoding
# one one-hot for each action
self._action_shape = sum(self.env.action_space.nvec)
else:
low = 0
high = 1 # one-hot encoding
self._action_shape = self.env.action_space.n
self.observation_space["action_stack"] = gym.spaces.Box(
low=low, high=high, shape=(self._action_shape * num_stack,), dtype=np.float32
)
if self._is_continuous:
if isinstance(noop, list):
raise ValueError(f"The noop actions must be a float for continuous action spaces, got: {noop}")
self.noop = np.full((self._action_shape,), noop, dtype=np.float32)
elif self._is_multidiscrete:
if not isinstance(noop, list):
raise ValueError(f"The noop actions must be a list for multi-discrete action spaces, got: {noop}")
if len(self.env.action_space.nvec) != len(noop):
raise RuntimeError(
"The number of noop actions must be equal to the number of actions of the environment. "
f"Got env_action_space = {self.env.action_space.nvec} and {noop =}"
)
noops = []
for act, n in zip(noop, self.env.action_space.nvec):
noops.append(np.zeros((n,), dtype=np.float32))
noops[-1][noop[act]] = 1.0
self.noop = np.concatenate(noops, axis=-1)
else:
if isinstance(noop, (list, float)):
raise ValueError(f"The noop actions must be an integer for discrete action spaces, got: {noop}")
self.noop = np.zeros((self._action_shape,), dtype=np.float32)
self.noop[noop] = 1.0

def step(self, action: Any) -> Tuple[Any | SupportsFloat | bool | Dict[str, Any]]:
if self._is_continuous:
self._actions.append(action)
elif self._is_multidiscrete:
one_hot_actions = []
for act, n in zip(action, self.env.action_space.nvec):
one_hot_actions.append(np.zeros((n,), dtype=np.float32))
one_hot_actions[-1][act] = 1.0
self._actions.append(np.concatenate(one_hot_actions, axis=-1))
else:
one_hot_action = np.zeros((self._action_shape,), dtype=np.float32)
one_hot_action[action] = 1.0
self._actions.append(one_hot_action)
obs, reward, done, truncated, info = super().step(action)
obs["action_stack"] = self._get_actions_stack()
return obs, reward, done, truncated, info

def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None) -> Tuple[Any | Dict[str, Any]]:
obs, info = super().reset(seed=seed, options=options)
self._actions.clear()
[self._actions.append(self.noop) for _ in range(self._num_stack * self._dilation)]
obs["action_stack"] = self._get_actions_stack()
return obs, info

def _get_actions_stack(self) -> np.ndarray:
actions_stack = list(self._actions)[self._dilation - 1 :: self._dilation]
actions = np.concatenate(actions_stack, axis=-1)
return actions.astype(np.float32)
4 changes: 4 additions & 0 deletions sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sheeprl.envs.wrappers import (
ActionRepeat,
ActionsAsObservationWrapper,
FrameStack,
GrayscaleRenderWrapper,
MaskVelocityWrapper,
Expand Down Expand Up @@ -207,6 +208,9 @@ def transform_obs(obs: Dict[str, Any]):
)
env = FrameStack(env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation)

if cfg.env.actions_as_observation.num_stack > 0 and "diambra" not in cfg.env.wrapper._target_:
env = ActionsAsObservationWrapper(env, **cfg.env.actions_as_observation)

if cfg.env.reward_as_observation:
env = RewardAsObservationWrapper(env)

Expand Down
94 changes: 93 additions & 1 deletion tests/test_envs/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,102 @@
import gymnasium as gym
import numpy as np
import pytest

from sheeprl.envs.wrappers import MaskVelocityWrapper
from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv
from sheeprl.envs.wrappers import ActionsAsObservationWrapper, MaskVelocityWrapper

ENVIRONMENTS = {
"discrete_dummy": DiscreteDummyEnv,
"multidiscrete_dummy": MultiDiscreteDummyEnv,
"continuous_dummy": ContinuousDummyEnv,
}


def test_mask_velocities_fail():
with pytest.raises(NotImplementedError):
env = gym.make("CarRacing-v2")
env = MaskVelocityWrapper(env)


@pytest.mark.parametrize("num_stack", [1, 4, 8])
@pytest.mark.parametrize("dilation", [1, 2, 4])
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"])
def test_actions_as_observation_wrapper(env_id: str, num_stack, dilation):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=dilation)

o = env.reset()[0]
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape)
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape):
assert d1 == d2

for _ in range(64):
o = env.step(env.action_space.sample())[0]
assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape)
for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape):
assert d1 == d2


@pytest.mark.parametrize("num_stack", [-1, 0])
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"])
def test_actions_as_observation_wrapper_invalid_num_stack(env_id, num_stack):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
with pytest.raises(ValueError, match="The number of actions to the"):
env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=3)


@pytest.mark.parametrize("dilation", [-1, 0])
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"])
def test_actions_as_observation_wrapper_invalid_dilation(env_id, dilation):
env = ENVIRONMENTS[env_id]()
if isinstance(env.action_space, gym.spaces.MultiDiscrete):
noop = [0, 0]
else:
noop = 0
with pytest.raises(ValueError, match="The actions stack dilation argument must be greater than zero"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=dilation)


@pytest.mark.parametrize("noop", [set([0, 0, 0]), "this is an invalid type", np.array([0, 0, 0])])
@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"])
def test_actions_as_observation_wrapper_invalid_noop_type(env_id, noop):
env = ENVIRONMENTS[env_id]()
with pytest.raises(ValueError, match="The noop action must be an integer or float or list"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)


def test_actions_as_observation_wrapper_invalid_noop_continuous_type():
env = ContinuousDummyEnv()
with pytest.raises(ValueError, match="The noop actions must be a float for continuous action spaces"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=[0, 0, 0], dilation=2)


@pytest.mark.parametrize("noop", [[0, 0, 0], 0.0])
def test_actions_as_observation_wrapper_invalid_noop_discrete_type(noop):
env = DiscreteDummyEnv()
with pytest.raises(ValueError, match="The noop actions must be an integer for discrete action spaces"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)


@pytest.mark.parametrize("noop", [0, 0.0])
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_type(noop):
env = MultiDiscreteDummyEnv()
with pytest.raises(ValueError, match="The noop actions must be a list for multi-discrete action spaces"):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)


@pytest.mark.parametrize("noop", [[0], [0, 0, 0]])
def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_n_actions(noop):
env = MultiDiscreteDummyEnv()
with pytest.raises(
RuntimeError, match="The number of noop actions must be equal to the number of actions of the environment"
):
env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)

0 comments on commit 878620a

Please sign in to comment.