Skip to content

Commit

Permalink
fix: tests (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi authored Apr 4, 2024
1 parent 0da0194 commit e25da73
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 71 deletions.
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.output_dim = dense_units

def forward(self, obs: Dict[str, Tensor]) -> Tensor:
x = torch.cat([obs[k] for k in self.keys], -1).type(torch.float32)
x = torch.cat([obs[k] for k in self.keys], -1)
return self.model(x)


Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def build_agent(
mlp_layers=world_model_cfg.encoder.mlp_layers,
dense_units=world_model_cfg.encoder.dense_units,
activation=eval(world_model_cfg.encoder.dense_act),
layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder._mlp_layer_norm.cls),
layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder.mlp_layer_norm.cls),
layer_norm_kw=world_model_cfg.encoder.mlp_layer_norm.kw,
)
if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0
Expand Down
119 changes: 53 additions & 66 deletions sheeprl/envs/dummy.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,24 @@
from typing import List, Tuple
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple

import gymnasium as gym
import numpy as np


class ContinuousDummyEnv(gym.Env):
def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128):
self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,))
self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps

def step(self, action):
done = self._current_step == self._n_steps
self._current_step += 1
return (
np.zeros(self.observation_space.shape, dtype=np.uint8),
np.zeros(1, dtype=np.float32).item(),
done,
False,
{},
class BaseDummyEnv(gym.Env, ABC):
@abstractmethod
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
):
self.observation_space = gym.spaces.Dict(
{
"rgb": gym.spaces.Box(0, 256, shape=image_size, dtype=np.uint8),
"state": gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32),
}
)

def reset(self, seed=None, options=None):
self._current_step = 0
return np.zeros(self.observation_space.shape, dtype=np.uint8), {}

def render(self, mode="human", close=False):
pass

def close(self):
pass

def seed(self, seed=None):
pass


class DiscreteDummyEnv(gym.Env):
def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 4):
self.action_space = gym.spaces.Discrete(action_dim)
self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps
Expand All @@ -49,16 +27,22 @@ def step(self, action):
done = self._current_step == self._n_steps
self._current_step += 1
return (
np.random.randint(0, 256, self.observation_space.shape, dtype=np.uint8),
self.get_obs(),
np.zeros(1, dtype=np.float32).item(),
done,
False,
{},
)

def get_obs(self) -> Dict[str, np.ndarray]:
return {
"rgb": np.zeros(self.observation_space["rgb"].shape, dtype=np.uint8),
"state": np.zeros(self.observation_space["state"].shape, dtype=np.float32),
}

def reset(self, seed=None, options=None):
self._current_step = 0
return np.zeros(self.observation_space.shape, dtype=np.uint8), {}
return self.get_obs(), {}

def render(self, mode="human", close=False):
pass
Expand All @@ -70,34 +54,37 @@ def seed(self, seed=None):
pass


class MultiDiscreteDummyEnv(gym.Env):
def __init__(self, action_dims: List[int] = [2, 2], size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128):
self.action_space = gym.spaces.MultiDiscrete(action_dims)
self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps

def step(self, action):
done = self._current_step == self._n_steps
self._current_step += 1
return (
np.zeros(self.observation_space.shape, dtype=np.uint8),
np.zeros(1, dtype=np.float32).item(),
done,
False,
{},
)
class ContinuousDummyEnv(BaseDummyEnv):
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
action_dim: int = 2,
):
self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,))
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)

def reset(self, seed=None, options=None):
self._current_step = 0
return np.zeros(self.observation_space.shape, dtype=np.uint8), {}

def render(self, mode="human", close=False):
pass
class DiscreteDummyEnv(BaseDummyEnv):
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 4,
vector_shape: Tuple[int] = (10,),
action_dim: int = 2,
):
self.action_space = gym.spaces.Discrete(action_dim)
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)

def close(self):
pass

def seed(self, seed=None):
pass
class MultiDiscreteDummyEnv(BaseDummyEnv):
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
action_dims: List[int] = [2, 2],
):
self.action_space = gym.spaces.MultiDiscrete(action_dims)
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)
22 changes: 20 additions & 2 deletions tests/test_algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_ppo(standard_args, start_time, env_id):
f"run_name={run_name}",
f"env.id={env_id}",
"algo.cnn_keys.encoder=[rgb]",
"algo.mlp_keys.encoder=[]",
"algo.mlp_keys.encoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand All @@ -198,7 +198,7 @@ def test_ppo_decoupled(standard_args, start_time, env_id):
f"run_name={run_name}",
f"env.id={env_id}",
"algo.cnn_keys.encoder=[rgb]",
"algo.mlp_keys.encoder=[]",
"algo.mlp_keys.encoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -249,6 +249,8 @@ def test_dreamer_v1(standard_args, env_id, start_time):
"algo.world_model.recurrent_model.recurrent_state_size=8",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -283,6 +285,8 @@ def test_p2e_dv1(standard_args, env_id, start_time):
"buffer.checkpoint=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
]

Expand Down Expand Up @@ -324,6 +328,8 @@ def test_p2e_dv1(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=2",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]
with mock.patch.object(sys, "argv", args):
run()
Expand Down Expand Up @@ -356,6 +362,8 @@ def test_dreamer_v2(standard_args, env_id, start_time):
"algo.layer_norm=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -390,6 +398,8 @@ def test_p2e_dv2(standard_args, env_id, start_time):
"buffer.checkpoint=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
]

Expand Down Expand Up @@ -431,6 +441,8 @@ def test_p2e_dv2(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=2",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]
with mock.patch.object(sys, "argv", args):
run()
Expand Down Expand Up @@ -460,6 +472,8 @@ def test_dreamer_v3(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=8",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
]
Expand Down Expand Up @@ -496,6 +510,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"buffer.checkpoint=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
Expand Down Expand Up @@ -539,6 +555,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=8",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
]
Expand Down
8 changes: 7 additions & 1 deletion tests/test_algos/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_resume_from_checkpoint():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand All @@ -144,7 +145,9 @@ def test_resume_from_checkpoint():
subprocess.run(
sys.executable
+ f" sheeprl.py exp=dreamer_v3 env=dummy checkpoint.resume_from={ckpt_path} "
+ "root_dir=pytest_resume_ckpt run_name=test_resume metric.log_level=0",
+ "root_dir=pytest_resume_ckpt run_name=test_resume metric.log_level=0 "
+ "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
+ "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state]",
shell=True,
check=True,
)
Expand All @@ -168,6 +171,7 @@ def test_resume_from_checkpoint_env_error():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand Down Expand Up @@ -221,6 +225,7 @@ def test_resume_from_checkpoint_algo_error():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand Down Expand Up @@ -276,6 +281,7 @@ def test_evaluate():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand Down

0 comments on commit e25da73

Please sign in to comment.