Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tests #255

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.output_dim = dense_units

def forward(self, obs: Dict[str, Tensor]) -> Tensor:
x = torch.cat([obs[k] for k in self.keys], -1).type(torch.float32)
x = torch.cat([obs[k] for k in self.keys], -1)
return self.model(x)


Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def build_agent(
mlp_layers=world_model_cfg.encoder.mlp_layers,
dense_units=world_model_cfg.encoder.dense_units,
activation=eval(world_model_cfg.encoder.dense_act),
layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder._mlp_layer_norm.cls),
layer_norm_cls=hydra.utils.get_class(world_model_cfg.encoder.mlp_layer_norm.cls),
layer_norm_kw=world_model_cfg.encoder.mlp_layer_norm.kw,
)
if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0
Expand Down
119 changes: 53 additions & 66 deletions sheeprl/envs/dummy.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,24 @@
from typing import List, Tuple
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple

import gymnasium as gym
import numpy as np


class ContinuousDummyEnv(gym.Env):
def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128):
self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,))
self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps

def step(self, action):
done = self._current_step == self._n_steps
self._current_step += 1
return (
np.zeros(self.observation_space.shape, dtype=np.uint8),
np.zeros(1, dtype=np.float32).item(),
done,
False,
{},
class BaseDummyEnv(gym.Env, ABC):
@abstractmethod
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
):
self.observation_space = gym.spaces.Dict(
{
"rgb": gym.spaces.Box(0, 256, shape=image_size, dtype=np.uint8),
"state": gym.spaces.Box(-20, 20, shape=vector_shape, dtype=np.float32),
}
)

def reset(self, seed=None, options=None):
self._current_step = 0
return np.zeros(self.observation_space.shape, dtype=np.uint8), {}

def render(self, mode="human", close=False):
pass

def close(self):
pass

def seed(self, seed=None):
pass


class DiscreteDummyEnv(gym.Env):
def __init__(self, action_dim: int = 2, size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 4):
self.action_space = gym.spaces.Discrete(action_dim)
self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps
Expand All @@ -49,16 +27,22 @@ def step(self, action):
done = self._current_step == self._n_steps
self._current_step += 1
return (
np.random.randint(0, 256, self.observation_space.shape, dtype=np.uint8),
self.get_obs(),
np.zeros(1, dtype=np.float32).item(),
done,
False,
{},
)

def get_obs(self) -> Dict[str, np.ndarray]:
return {
"rgb": np.zeros(self.observation_space["rgb"].shape, dtype=np.uint8),
"state": np.zeros(self.observation_space["state"].shape, dtype=np.float32),
}

def reset(self, seed=None, options=None):
self._current_step = 0
return np.zeros(self.observation_space.shape, dtype=np.uint8), {}
return self.get_obs(), {}

def render(self, mode="human", close=False):
pass
Expand All @@ -70,34 +54,37 @@ def seed(self, seed=None):
pass


class MultiDiscreteDummyEnv(gym.Env):
def __init__(self, action_dims: List[int] = [2, 2], size: Tuple[int, int, int] = (3, 64, 64), n_steps: int = 128):
self.action_space = gym.spaces.MultiDiscrete(action_dims)
self.observation_space = gym.spaces.Box(0, 256, shape=size, dtype=np.uint8)
self.reward_range = (-np.inf, np.inf)
self._current_step = 0
self._n_steps = n_steps

def step(self, action):
done = self._current_step == self._n_steps
self._current_step += 1
return (
np.zeros(self.observation_space.shape, dtype=np.uint8),
np.zeros(1, dtype=np.float32).item(),
done,
False,
{},
)
class ContinuousDummyEnv(BaseDummyEnv):
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
action_dim: int = 2,
):
self.action_space = gym.spaces.Box(-np.inf, np.inf, shape=(action_dim,))
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)

def reset(self, seed=None, options=None):
self._current_step = 0
return np.zeros(self.observation_space.shape, dtype=np.uint8), {}

def render(self, mode="human", close=False):
pass
class DiscreteDummyEnv(BaseDummyEnv):
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 4,
vector_shape: Tuple[int] = (10,),
action_dim: int = 2,
):
self.action_space = gym.spaces.Discrete(action_dim)
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)

def close(self):
pass

def seed(self, seed=None):
pass
class MultiDiscreteDummyEnv(BaseDummyEnv):
def __init__(
self,
image_size: Tuple[int, int, int] = (3, 64, 64),
n_steps: int = 128,
vector_shape: Tuple[int] = (10,),
action_dims: List[int] = [2, 2],
):
self.action_space = gym.spaces.MultiDiscrete(action_dims)
super().__init__(image_size=image_size, n_steps=n_steps, vector_shape=vector_shape)
22 changes: 20 additions & 2 deletions tests/test_algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_ppo(standard_args, start_time, env_id):
f"run_name={run_name}",
f"env.id={env_id}",
"algo.cnn_keys.encoder=[rgb]",
"algo.mlp_keys.encoder=[]",
"algo.mlp_keys.encoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand All @@ -198,7 +198,7 @@ def test_ppo_decoupled(standard_args, start_time, env_id):
f"run_name={run_name}",
f"env.id={env_id}",
"algo.cnn_keys.encoder=[rgb]",
"algo.mlp_keys.encoder=[]",
"algo.mlp_keys.encoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -249,6 +249,8 @@ def test_dreamer_v1(standard_args, env_id, start_time):
"algo.world_model.recurrent_model.recurrent_state_size=8",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -283,6 +285,8 @@ def test_p2e_dv1(standard_args, env_id, start_time):
"buffer.checkpoint=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
]

Expand Down Expand Up @@ -324,6 +328,8 @@ def test_p2e_dv1(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=2",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]
with mock.patch.object(sys, "argv", args):
run()
Expand Down Expand Up @@ -356,6 +362,8 @@ def test_dreamer_v2(standard_args, env_id, start_time):
"algo.layer_norm=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -390,6 +398,8 @@ def test_p2e_dv2(standard_args, env_id, start_time):
"buffer.checkpoint=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
]

Expand Down Expand Up @@ -431,6 +441,8 @@ def test_p2e_dv2(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=2",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
]
with mock.patch.object(sys, "argv", args):
run()
Expand Down Expand Up @@ -460,6 +472,8 @@ def test_dreamer_v3(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=8",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
]
Expand Down Expand Up @@ -496,6 +510,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"buffer.checkpoint=True",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
Expand Down Expand Up @@ -539,6 +555,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"algo.world_model.transition_model.hidden_size=8",
"algo.cnn_keys.encoder=[rgb]",
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
]
Expand Down
8 changes: 7 additions & 1 deletion tests/test_algos/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_resume_from_checkpoint():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand All @@ -144,7 +145,9 @@ def test_resume_from_checkpoint():
subprocess.run(
sys.executable
+ f" sheeprl.py exp=dreamer_v3 env=dummy checkpoint.resume_from={ckpt_path} "
+ "root_dir=pytest_resume_ckpt run_name=test_resume metric.log_level=0",
+ "root_dir=pytest_resume_ckpt run_name=test_resume metric.log_level=0 "
+ "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
+ "algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state]",
shell=True,
check=True,
)
Expand All @@ -168,6 +171,7 @@ def test_resume_from_checkpoint_env_error():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand Down Expand Up @@ -221,6 +225,7 @@ def test_resume_from_checkpoint_algo_error():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand Down Expand Up @@ -276,6 +281,7 @@ def test_evaluate():
sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True "
"env.capture_video=False algo.dense_units=8 algo.horizon=8 "
"algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] "
"algo.mlp_keys.encoder=[state] algo.mlp_keys.decoder=[state] "
"algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 "
"algo.world_model.recurrent_model.recurrent_state_size=8 "
"algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 "
Expand Down
Loading