diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index c0f7a369..4b693395 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -122,7 +122,7 @@ def __init__( self.output_dim = dense_units def forward(self, obs: Dict[str, Tensor]) -> Tensor: - x = torch.cat([obs[k] for k in self.keys], -1).type(torch.float32) + x = torch.cat([obs[k] for k in self.keys], -1) return self.model(x) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 03733eea..adcbf195 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -995,7 +995,7 @@ def build_agent( mlp_layers=world_model_cfg.encoder.mlp_layers, dense_units=world_model_cfg.encoder.dense_units, activation=eval(world_model_cfg.encoder.dense_act), - layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder._mlp_layer_norm.cls), + layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder.mlp_layer_norm.cls), layer_norm_kw=world_model_cfg.encoder.mlp_layer_norm.kw, ) if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 diff --git a/sheeprl/envs/dummy.py b/sheeprl/envs/dummy.py index 3e6c77f8..8c6ccac0 100644 --- a/sheeprl/envs/dummy.py +++ b/sheeprl/envs/dummy.py @@ -1,46 +1,24 @@ -from typing import List, Tuple +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple import gymnasium as gym import numpy as np -class ContinuousDummyEnv(gym.Env): - def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128): - self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,)) - self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8) - self.reward_range = (-np.inf, np.inf) - self._current_step = 0 - self._n_steps = n_steps - - def step(self, action): - done = self._current_step == self._n_steps - self._current_step += 1 - return ( - np.zeros(self.observation_space.shape, dtype=np.uint8), - np.zeros(1, dtype=np.float32).item(), - done, - False, - {}, +class BaseDummyEnv(gym.Env, ABC): + @abstractmethod + def __init__( + self, + image_size: Tuple[int, int, int] = (3, 64, 64), + n_steps: int = 128, + vector_shape: Tuple[int] = (10,), + ): + self.observation_space = gym.spaces.Dict( + { + "rgb": gym.spaces.Box(0, 256, shape=image_size, dtype=np.uint8), + "state": gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32), + } ) - - def reset(self, seed=None, options=None): - self._current_step = 0 - return np.zeros(self.observation_space.shape, dtype=np.uint8), {} - - def render(self, mode="human", close=False): - pass - - def close(self): - pass - - def seed(self, seed=None): - pass - - -class DiscreteDummyEnv(gym.Env): - def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 4): - self.action_space = gym.spaces.Discrete(action_dim) - self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8) self.reward_range = (-np.inf, np.inf) self._current_step = 0 self._n_steps = n_steps @@ -49,16 +27,22 @@ def step(self, action): done = self._current_step == self._n_steps self._current_step += 1 return ( - np.random.randint(0, 256, self.observation_space.shape, dtype=np.uint8), + self.get_obs(), np.zeros(1, dtype=np.float32).item(), done, False, {}, ) + def get_obs(self) -> Dict[str, np.ndarray]: + return { + "rgb": np.zeros(self.observation_space["rgb"].shape, dtype=np.uint8), + "state": np.zeros(self.observation_space["state"].shape, dtype=np.float32), + } + def reset(self, seed=None, options=None): self._current_step = 0 - return np.zeros(self.observation_space.shape, dtype=np.uint8), {} + return self.get_obs(), {} def render(self, mode="human", close=False): pass @@ -70,34 +54,37 @@ def seed(self, seed=None): pass -class MultiDiscreteDummyEnv(gym.Env): - def __init__(self, action_dims: List[int] = [2, 2], size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128): - self.action_space = gym.spaces.MultiDiscrete(action_dims) - self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8) - self.reward_range = (-np.inf, np.inf) - self._current_step = 0 - self._n_steps = n_steps - - def step(self, action): - done = self._current_step == self._n_steps - self._current_step += 1 - return ( - np.zeros(self.observation_space.shape, dtype=np.uint8), - np.zeros(1, dtype=np.float32).item(), - done, - False, - {}, - ) +class ContinuousDummyEnv(BaseDummyEnv): + def __init__( + self, + image_size: Tuple[int, int, int] = (3, 64, 64), + n_steps: int = 128, + vector_shape: Tuple[int] = (10,), + action_dim: int = 2, + ): + self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,)) + super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape) - def reset(self, seed=None, options=None): - self._current_step = 0 - return np.zeros(self.observation_space.shape, dtype=np.uint8), {} - def render(self, mode="human", close=False): - pass +class DiscreteDummyEnv(BaseDummyEnv): + def __init__( + self, + image_size: Tuple[int, int, int] = (3, 64, 64), + n_steps: int = 4, + vector_shape: Tuple[int] = (10,), + action_dim: int = 2, + ): + self.action_space = gym.spaces.Discrete(action_dim) + super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape) - def close(self): - pass - def seed(self, seed=None): - pass +class MultiDiscreteDummyEnv(BaseDummyEnv): + def __init__( + self, + image_size: Tuple[int, int, int] = (3, 64, 64), + n_steps: int = 128, + vector_shape: Tuple[int] = (10,), + action_dims: List[int] = [2, 2], + ): + self.action_space = gym.spaces.MultiDiscrete(action_dims) + super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index db6f22a5..a15b6834 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -175,7 +175,7 @@ def test_ppo(standard_args, start_time, env_id): f"run_name={run_name}", f"env.id={env_id}", "algo.cnn_keys.encoder=[rgb]", - "algo.mlp_keys.encoder=[]", + "algo.mlp_keys.encoder=[state]", ] with mock.patch.object(sys, "argv", args): @@ -198,7 +198,7 @@ def test_ppo_decoupled(standard_args, start_time, env_id): f"run_name={run_name}", f"env.id={env_id}", "algo.cnn_keys.encoder=[rgb]", - "algo.mlp_keys.encoder=[]", + "algo.mlp_keys.encoder=[state]", ] with mock.patch.object(sys, "argv", args): @@ -249,6 +249,8 @@ def test_dreamer_v1(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", ] with mock.patch.object(sys, "argv", args): @@ -283,6 +285,8 @@ def test_p2e_dv1(standard_args, env_id, start_time): "buffer.checkpoint=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", "checkpoint.save_last=True", ] @@ -324,6 +328,8 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.world_model.transition_model.hidden_size=2", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", ] with mock.patch.object(sys, "argv", args): run() @@ -356,6 +362,8 @@ def test_dreamer_v2(standard_args, env_id, start_time): "algo.layer_norm=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", ] with mock.patch.object(sys, "argv", args): @@ -390,6 +398,8 @@ def test_p2e_dv2(standard_args, env_id, start_time): "buffer.checkpoint=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", "checkpoint.save_last=True", ] @@ -431,6 +441,8 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.world_model.transition_model.hidden_size=2", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", ] with mock.patch.object(sys, "argv", args): run() @@ -460,6 +472,8 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.world_model.transition_model.hidden_size=8", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", ] @@ -496,6 +510,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): "buffer.checkpoint=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", "checkpoint.save_last=True", "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", @@ -539,6 +555,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.transition_model.hidden_size=8", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.mlp_keys.decoder=[state]", "algo.mlp_layer_norm.cls=torch.nn.LayerNorm", "algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast", ] diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 0e465d24..275ac9b4 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -125,6 +125,7 @@ def test_resume_from_checkpoint(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " + "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " @@ -144,7 +145,9 @@ def test_resume_from_checkpoint(): subprocess.run( sys.executable + f" sheeprl.py exp=dreamer_v3 env=dummy checkpoint.resume_from={ckpt_path} " - + "root_dir=pytest_resume_ckpt run_name=test_resume metric.log_level=0", + + "root_dir=pytest_resume_ckpt run_name=test_resume metric.log_level=0 " + + "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " + + "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state]", shell=True, check=True, ) @@ -168,6 +171,7 @@ def test_resume_from_checkpoint_env_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " + "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " @@ -221,6 +225,7 @@ def test_resume_from_checkpoint_algo_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " + "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " @@ -276,6 +281,7 @@ def test_evaluate(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " + "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "