diff --git a/howto/actions_as_observation.md b/howto/actions_as_observation.md new file mode 100644 index 00000000..d7f318ae --- /dev/null +++ b/howto/actions_as_observation.md @@ -0,0 +1,27 @@ +# Actions as Observations Wrapper +In this how-to, some indications are given on how to use the Actions as Observations Wrapper. + +When you want to add the last `n` actions to the observations, you must specify three parameters in the [`./configs/env/default.yaml`](../sheeprl/configs/env/default.yaml) file: +- `actions_as_observation.num_stack` (integer greater than 0): The number of actions to add to the observations. +- `actions_as_observation.dilation` (integer greater than 0): The dilation (number of steps) between one action and the next one. +- `actions_as_observation.noop` (integer or float or list of integer): The noop action to use when resetting the environment, the buffer is filled with this action. Every environment has its own NOOP action, it is strongly recommended to use that action for the correct learning of the algorithm. + +## NOOP Parameter +The NOOP parameter must be: +- An integer for discrete action spaces +- A float for continuous action spaces +- A list of integers for multi-discrete action spaces: the length of the list must be equal to the number of actions in the environment. + +Each environment has its own NOOP action, usually it is specified in the documentation. Below we reported the list of noop actions of the environments supported in SheepRL: +- MuJoCo (both gymnasium and DMC) environments: `0.0`. +- Atari environments: `0`. +- Crafter: `0`. +- MineRL: `0`. +- MineDojo: `[0, 0, 0]`. +- Super Mario Bros: `0`. +- Diambra: + - Discrete: `0`. + - Multi-discrete: `[0, 0]`. +- Box2D (gymnasium): + - Discrete: `0`. + - Continuous: `0.0`. \ No newline at end of file diff --git a/howto/configs.md b/howto/configs.md index 3c91fc5d..1db471b7 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -422,10 +422,19 @@ grayscale: False clip_rewards: False capture_video: True frame_stack_dilation: 1 +actions_as_observation: + num_stack: -1 + noop: "You MUST define the NOOP" + dilation: 1 max_episode_steps: null reward_as_observation: False +wrapper: ??? ``` +> [!NOTE] +> +> The actions as observations wrapper is used for adding the last `n` actions to the observations. For more information, check the corresponding [howto file](./actions_as_observation.md). + Every custom environment must then "inherit" from this default config, override the particular parameters, and define the `wrapper` field, which is the one that will be directly instantiated at runtime. The `wrapper` field must define all the specific parameters to be passed to the `_target_` function when the wrapper will be instantiated. Take for example the `atari.yaml` config: ```yaml diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index e3b23ee1..98ba65c0 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -236,7 +236,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment # This calls the `forward` method of the PyTorch module, escaping from Fabric # because we don't want this to be a synchronization point - torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs) + torch_obs = prepare_obs( + fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs + ) actions, _, values = player(torch_obs) if is_continuous: real_actions = torch.stack(actions, -1).cpu().numpy() @@ -304,7 +306,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.inference_mode(): - torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs) + torch_obs = prepare_obs(fabric, next_obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs) next_values = player.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), diff --git a/sheeprl/algos/a2c/utils.py b/sheeprl/algos/a2c/utils.py index c26fbbaf..88fb0099 100644 --- a/sheeprl/algos/a2c/utils.py +++ b/sheeprl/algos/a2c/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Dict, Sequence import numpy as np import torch @@ -13,8 +13,10 @@ AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"} -def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]: - torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()} +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {k: torch.from_numpy(obs[k].copy()).to(fabric.device).float().reshape(num_envs, -1) for k in mlp_keys} return torch_obs @@ -28,7 +30,7 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): while not done: # Convert observations to tensors - torch_obs = prepare_obs(fabric, obs) + torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder) # Act greedly through the environment actions = agent.get_actions(torch_obs, greedy=True) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 9c9a500e..b5cf8c35 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: with torch.inference_mode(): # Sample an action given the observation received by the environment - torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) + torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs) actions = player(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index e8cd9e75..4abe9506 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -256,7 +256,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: # Sample an action given the observation received by the environment with torch.inference_mode(): - torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) + torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs) actions = player(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index ceb49b78..705e1dd6 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -187,7 +187,7 @@ def player( actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment - torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) + torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder, num_envs=cfg.env.num_envs) actions = actor(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index 6912f278..9432db3f 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -28,9 +28,11 @@ MODELS_TO_REGISTER = {"agent"} -def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Tensor: +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, mlp_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Tensor: with fabric.device: - torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1) + torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in mlp_keys], dim=-1) return torch_obs.reshape(num_envs, -1) @@ -43,7 +45,7 @@ def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): obs = env.reset(seed=cfg.seed)[0] while not done: # Act greedly through the environment - torch_obs = prepare_obs(fabric, obs) + torch_obs = prepare_obs(fabric, obs, mlp_keys=cfg.algo.mlp_keys.encoder) action = actor.get_actions(torch_obs, greedy=True) # Single environment step diff --git a/sheeprl/configs/env/default.yaml b/sheeprl/configs/env/default.yaml index d80a6333..459d0cab 100644 --- a/sheeprl/configs/env/default.yaml +++ b/sheeprl/configs/env/default.yaml @@ -8,6 +8,10 @@ grayscale: False clip_rewards: False capture_video: True frame_stack_dilation: 1 +actions_as_observation: + num_stack: -1 + noop: "You MUST define the NOOP" + dilation: 1 max_episode_steps: null reward_as_observation: False wrapper: ??? diff --git a/sheeprl/envs/wrappers.py b/sheeprl/envs/wrappers.py index a5fa5904..cc285b11 100644 --- a/sheeprl/envs/wrappers.py +++ b/sheeprl/envs/wrappers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import time from collections import deque @@ -251,3 +253,90 @@ def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: if len(frame.shape) == 3 and frame.shape[-1] == 1: frame = frame.repeat(3, axis=-1) return frame + + +class ActionsAsObservationWrapper(gym.Wrapper): + def __init__(self, env: Env, num_stack: int, noop: float | int | List[int], dilation: int = 1): + super().__init__(env) + if num_stack < 1: + raise ValueError( + "The number of actions to the `action_stack` observation " + f"must be greater or equal than 1, got: {num_stack}" + ) + if dilation < 1: + raise ValueError(f"The actions stack dilation argument must be greater than zero, got: {dilation}") + if not isinstance(noop, (int, float, list)): + raise ValueError(f"The noop action must be an integer or float or list, got: {noop} ({type(noop)})") + self._num_stack = num_stack + self._dilation = dilation + self._actions = deque(maxlen=num_stack * dilation) + self._is_continuous = isinstance(self.env.action_space, gym.spaces.Box) + self._is_multidiscrete = isinstance(self.env.action_space, gym.spaces.MultiDiscrete) + self.observation_space = copy.deepcopy(self.env.observation_space) + if self._is_continuous: + self._action_shape = self.env.action_space.shape[0] + low = np.resize(self.env.action_space.low, self._action_shape * num_stack) + high = np.resize(self.env.action_space.high, self._action_shape * num_stack) + elif self._is_multidiscrete: + low = 0 + high = 1 # one-hot encoding + # one one-hot for each action + self._action_shape = sum(self.env.action_space.nvec) + else: + low = 0 + high = 1 # one-hot encoding + self._action_shape = self.env.action_space.n + self.observation_space["action_stack"] = gym.spaces.Box( + low=low, high=high, shape=(self._action_shape * num_stack,), dtype=np.float32 + ) + if self._is_continuous: + if isinstance(noop, list): + raise ValueError(f"The noop actions must be a float for continuous action spaces, got: {noop}") + self.noop = np.full((self._action_shape,), noop, dtype=np.float32) + elif self._is_multidiscrete: + if not isinstance(noop, list): + raise ValueError(f"The noop actions must be a list for multi-discrete action spaces, got: {noop}") + if len(self.env.action_space.nvec) != len(noop): + raise RuntimeError( + "The number of noop actions must be equal to the number of actions of the environment. " + f"Got env_action_space = {self.env.action_space.nvec} and {noop =}" + ) + noops = [] + for act, n in zip(noop, self.env.action_space.nvec): + noops.append(np.zeros((n,), dtype=np.float32)) + noops[-1][noop[act]] = 1.0 + self.noop = np.concatenate(noops, axis=-1) + else: + if isinstance(noop, (list, float)): + raise ValueError(f"The noop actions must be an integer for discrete action spaces, got: {noop}") + self.noop = np.zeros((self._action_shape,), dtype=np.float32) + self.noop[noop] = 1.0 + + def step(self, action: Any) -> Tuple[Any | SupportsFloat | bool | Dict[str, Any]]: + if self._is_continuous: + self._actions.append(action) + elif self._is_multidiscrete: + one_hot_actions = [] + for act, n in zip(action, self.env.action_space.nvec): + one_hot_actions.append(np.zeros((n,), dtype=np.float32)) + one_hot_actions[-1][act] = 1.0 + self._actions.append(np.concatenate(one_hot_actions, axis=-1)) + else: + one_hot_action = np.zeros((self._action_shape,), dtype=np.float32) + one_hot_action[action] = 1.0 + self._actions.append(one_hot_action) + obs, reward, done, truncated, info = super().step(action) + obs["action_stack"] = self._get_actions_stack() + return obs, reward, done, truncated, info + + def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None) -> Tuple[Any | Dict[str, Any]]: + obs, info = super().reset(seed=seed, options=options) + self._actions.clear() + [self._actions.append(self.noop) for _ in range(self._num_stack * self._dilation)] + obs["action_stack"] = self._get_actions_stack() + return obs, info + + def _get_actions_stack(self) -> np.ndarray: + actions_stack = list(self._actions)[self._dilation - 1 :: self._dilation] + actions = np.concatenate(actions_stack, axis=-1) + return actions.astype(np.float32) diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 14d57103..750d85ee 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -9,6 +9,7 @@ from sheeprl.envs.wrappers import ( ActionRepeat, + ActionsAsObservationWrapper, FrameStack, GrayscaleRenderWrapper, MaskVelocityWrapper, @@ -207,6 +208,9 @@ def transform_obs(obs: Dict[str, Any]): ) env = FrameStack(env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation) + if cfg.env.actions_as_observation.num_stack > 0 and "diambra" not in cfg.env.wrapper._target_: + env = ActionsAsObservationWrapper(env, **cfg.env.actions_as_observation) + if cfg.env.reward_as_observation: env = RewardAsObservationWrapper(env) diff --git a/tests/test_envs/test_wrappers.py b/tests/test_envs/test_wrappers.py index 2c0b8dbf..651679db 100644 --- a/tests/test_envs/test_wrappers.py +++ b/tests/test_envs/test_wrappers.py @@ -1,10 +1,102 @@ import gymnasium as gym +import numpy as np import pytest -from sheeprl.envs.wrappers import MaskVelocityWrapper +from sheeprl.envs.dummy import ContinuousDummyEnv, DiscreteDummyEnv, MultiDiscreteDummyEnv +from sheeprl.envs.wrappers import ActionsAsObservationWrapper, MaskVelocityWrapper + +ENVIRONMENTS = { + "discrete_dummy": DiscreteDummyEnv, + "multidiscrete_dummy": MultiDiscreteDummyEnv, + "continuous_dummy": ContinuousDummyEnv, +} def test_mask_velocities_fail(): with pytest.raises(NotImplementedError): env = gym.make("CarRacing-v2") env = MaskVelocityWrapper(env) + + +@pytest.mark.parametrize("num_stack", [1, 4, 8]) +@pytest.mark.parametrize("dilation", [1, 2, 4]) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +def test_actions_as_observation_wrapper(env_id: str, num_stack, dilation): + env = ENVIRONMENTS[env_id]() + if isinstance(env.action_space, gym.spaces.MultiDiscrete): + noop = [0, 0] + else: + noop = 0 + env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=dilation) + + o = env.reset()[0] + assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape) + for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape): + assert d1 == d2 + + for _ in range(64): + o = env.step(env.action_space.sample())[0] + assert len(o["action_stack"].shape) == len(env.observation_space["action_stack"].shape) + for d1, d2 in zip(o["action_stack"].shape, env.observation_space["action_stack"].shape): + assert d1 == d2 + + +@pytest.mark.parametrize("num_stack", [-1, 0]) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +def test_actions_as_observation_wrapper_invalid_num_stack(env_id, num_stack): + env = ENVIRONMENTS[env_id]() + if isinstance(env.action_space, gym.spaces.MultiDiscrete): + noop = [0, 0] + else: + noop = 0 + with pytest.raises(ValueError, match="The number of actions to the"): + env = ActionsAsObservationWrapper(env, num_stack=num_stack, noop=noop, dilation=3) + + +@pytest.mark.parametrize("dilation", [-1, 0]) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +def test_actions_as_observation_wrapper_invalid_dilation(env_id, dilation): + env = ENVIRONMENTS[env_id]() + if isinstance(env.action_space, gym.spaces.MultiDiscrete): + noop = [0, 0] + else: + noop = 0 + with pytest.raises(ValueError, match="The actions stack dilation argument must be greater than zero"): + env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=dilation) + + +@pytest.mark.parametrize("noop", [set([0, 0, 0]), "this is an invalid type", np.array([0, 0, 0])]) +@pytest.mark.parametrize("env_id", ["discrete_dummy", "multidiscrete_dummy", "continuous_dummy"]) +def test_actions_as_observation_wrapper_invalid_noop_type(env_id, noop): + env = ENVIRONMENTS[env_id]() + with pytest.raises(ValueError, match="The noop action must be an integer or float or list"): + env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) + + +def test_actions_as_observation_wrapper_invalid_noop_continuous_type(): + env = ContinuousDummyEnv() + with pytest.raises(ValueError, match="The noop actions must be a float for continuous action spaces"): + env = ActionsAsObservationWrapper(env, num_stack=3, noop=[0, 0, 0], dilation=2) + + +@pytest.mark.parametrize("noop", [[0, 0, 0], 0.0]) +def test_actions_as_observation_wrapper_invalid_noop_discrete_type(noop): + env = DiscreteDummyEnv() + with pytest.raises(ValueError, match="The noop actions must be an integer for discrete action spaces"): + env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) + + +@pytest.mark.parametrize("noop", [0, 0.0]) +def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_type(noop): + env = MultiDiscreteDummyEnv() + with pytest.raises(ValueError, match="The noop actions must be a list for multi-discrete action spaces"): + env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2) + + +@pytest.mark.parametrize("noop", [[0], [0, 0, 0]]) +def test_actions_as_observation_wrapper_invalid_noop_multidiscrete_n_actions(noop): + env = MultiDiscreteDummyEnv() + with pytest.raises( + RuntimeError, match="The number of noop actions must be equal to the number of actions of the environment" + ): + env = ActionsAsObservationWrapper(env, num_stack=3, noop=noop, dilation=2)