diff --git a/howto/register_external_algorithm.md b/howto/register_external_algorithm.md index 92643399..b01c7163 100644 --- a/howto/register_external_algorithm.md +++ b/howto/register_external_algorithm.md @@ -715,11 +715,16 @@ where `log_models`, `test` and `normalize_obs` have to be defined in the `my_awe from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Sequence +import gymnasium as gym +import numpy as np import torch from lightning import Fabric from lightning.fabric.wrappers import _FabricModule +from torch import Tensor + +from sheeprl.utils.env import make_env from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric @@ -729,43 +734,41 @@ if TYPE_CHECKING: from mlflow.models.model import ModelInfo +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + for k in obs.keys(): + torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float() + if k in cnn_keys: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:]) + else: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1) + return normalize_obs(torch_obs, cnn_keys, obs.keys()) + + @torch.no_grad() def test(agent: SOTAAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 - o = env.reset(seed=cfg.seed)[0] - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs + obs = env.reset(seed=cfg.seed)[0] while not done: + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) + # Act greedly through the environment - if agent.is_continuous: - actions = torch.cat(agent.get_greedy_actions(obs), dim=-1) + actions = agent.get_actions(torch_obs, greedy=True) + if agent.actor.is_continuous: + actions = torch.cat(actions, dim=-1) else: - actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1) + actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step - o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs if cfg.dry_run: done = True diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index ddf2d54c..3ed52fd0 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -713,14 +713,16 @@ where `log_models`, `test` and `normalize_obs` have to be defined in the `sheepr ```python from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Sequence +import gymnasium as gym +import numpy as np import torch from lightning import Fabric from lightning.fabric.wrappers import _FabricModule +from torch import Tensor -from sheeprl.algos.sota.agent import SOTAAgentPlayer +from sheeprl.utils.env import make_env from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric @@ -728,44 +730,41 @@ if TYPE_CHECKING: from mlflow.models.model import ModelInfo +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + for k in obs.keys(): + torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float() + if k in cnn_keys: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:]) + else: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1) + return normalize_obs(torch_obs, cnn_keys, obs.keys()) + + @torch.no_grad() def test(agent: SOTAAgentPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 - o = env.reset(seed=cfg.seed)[0] - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs + obs = env.reset(seed=cfg.seed)[0] while not done: + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) + # Act greedly through the environment - actions = agent.get_actions(obs, greedy=True) - if agent.is_continuous: + actions = agent.get_actions(torch_obs, greedy=True) + if agent.actor.is_continuous: actions = torch.cat(actions, dim=-1) else: actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step - o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs if cfg.dry_run: done = True diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index 07241c80..bc1b4cfb 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -12,7 +12,7 @@ from sheeprl.algos.a2c.agent import A2CAgent, build_agent from sheeprl.algos.a2c.loss import policy_loss, value_loss -from sheeprl.algos.a2c.utils import test +from sheeprl.algos.a2c.utils import prepare_obs, test from sheeprl.data import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -236,7 +236,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment # This calls the `forward` method of the PyTorch module, escaping from Fabric # because we don't want this to be a synchronization point - torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs) actions, _, values = player(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() @@ -272,7 +272,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Update the step data step_data["dones"] = dones[np.newaxis] step_data["values"] = values.cpu().numpy()[np.newaxis] - step_data["actions"] = actions[np.newaxis] + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) step_data["rewards"] = rewards[np.newaxis] if cfg.buffer.memmap: step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) @@ -304,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.inference_mode(): - torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs) next_values = player.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), diff --git a/sheeprl/algos/a2c/utils.py b/sheeprl/algos/a2c/utils.py index 39fb3e93..c26fbbaf 100644 --- a/sheeprl/algos/a2c/utils.py +++ b/sheeprl/algos/a2c/utils.py @@ -2,8 +2,10 @@ from typing import Any, Dict +import numpy as np import torch from lightning import Fabric +from torch import Tensor from sheeprl.algos.ppo.agent import PPOPlayer from sheeprl.utils.env import make_env @@ -11,38 +13,34 @@ AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"} +def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]: + torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()} + return torch_obs + + @torch.no_grad() def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 - o = env.reset(seed=cfg.seed)[0] - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - torch_obs = torch_obs.float() - obs[k] = torch_obs + obs = env.reset(seed=cfg.seed)[0] while not done: + # Convert observations to tensors + torch_obs = prepare_obs(fabric, obs) + # Act greedly through the environment - actions = agent.get_actions(obs, greedy=True) + actions = agent.get_actions(torch_obs, greedy=True) if agent.actor.is_continuous: actions = torch.cat(actions, dim=-1) else: actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step - o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - torch_obs = torch_obs.float() - obs[k] = torch_obs if cfg.dry_run: done = True diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 664769b6..d6a1bb4d 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -20,7 +20,7 @@ from sheeprl.algos.dreamer_v1.agent import WorldModel, build_agent from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values -from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.dreamer_v2.utils import prepare_obs, test from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -574,16 +574,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): axis=-1, ) else: - normalized_obs = {} - for k in obs_keys: - torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs / 255 - 0.5 - normalized_obs[k] = torch_obs - mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index ab50f618..ebfaf83c 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -24,7 +24,7 @@ from sheeprl.algos.dreamer_v2.agent import WorldModel, build_agent from sheeprl.algos.dreamer_v2.loss import reconstruction_loss -from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test +from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, prepare_obs, test from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -599,16 +599,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): axis=-1, ) else: - normalized_obs = {} - for k in obs_keys: - torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs / 255 - 0.5 - normalized_obs[k] = torch_obs - mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(normalized_obs, mask=mask) + real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index ed9debb0..4abfa58f 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence import gymnasium as gym +import numpy as np import torch import torch.nn as nn from lightning import Fabric @@ -101,6 +102,20 @@ def compute_lambda_values( return torch.cat(list(reversed(lv)), dim=0) +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + for k, v in obs.items(): + torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).float() + if k in cnn_keys: + torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) / 255 - 0.5 + else: + torch_obs[k] = torch_obs[k].view(1, num_envs, -1) + + return torch_obs + + @torch.no_grad() def test( player: "PlayerDV2" | "PlayerDV1", @@ -125,22 +140,14 @@ def test( env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False cumulative_rew = 0 - device = fabric.device - next_obs = env.reset(seed=cfg.seed)[0] - for k in next_obs.keys(): - next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + obs = env.reset(seed=cfg.seed)[0] player.num_envs = 1 player.init_states() while not done: # Act greedly through the environment - preprocessed_obs = {} - for k, v in next_obs.items(): - if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - elif k in cfg.algo.mlp_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) real_actions = player.get_actions( - preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -148,9 +155,7 @@ def test( real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() # Single environment step - next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) - for k in next_obs.keys(): - next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) done = done or truncated or cfg.dry_run cumulative_rew += reward fabric.print("Test - Reward:", cumulative_rew) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 4d058ce2..97706af9 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -24,7 +24,7 @@ from sheeprl.algos.dreamer_v3.agent import WorldModel, build_agent from sheeprl.algos.dreamer_v3.loss import reconstruction_loss -from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test +from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, prepare_obs, test from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.envs.wrappers import RestartOnException from sheeprl.utils.distribution import ( @@ -566,15 +566,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): axis=-1, ) else: - preprocessed_obs = {} - for k, v in obs.items(): - preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = preprocessed_obs[k] / 255.0 - 0.5 - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(preprocessed_obs, mask=mask) + real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 42d79060..1b73e60c 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -77,6 +77,20 @@ def compute_lambda_values( return ret +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + for k, v in obs.items(): + torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).float() + if k in cnn_keys: + torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) / 255 - 0.5 + else: + torch_obs[k] = torch_obs[k].view(1, num_envs, -1) + + return torch_obs + + @torch.no_grad() def test( player: "PlayerDV3", @@ -101,22 +115,14 @@ def test( env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False cumulative_rew = 0 - device = fabric.device - next_obs = env.reset(seed=cfg.seed)[0] - for k in next_obs.keys(): - next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + obs = env.reset(seed=cfg.seed)[0] player.num_envs = 1 player.init_states() while not done: # Act greedly through the environment - preprocessed_obs = {} - for k, v in next_obs.items(): - if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - elif k in cfg.algo.mlp_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) real_actions = player.get_actions( - preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -124,9 +130,7 @@ def test( real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() # Single environment step - next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) - for k in next_obs.keys(): - next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) done = done or truncated or cfg.dry_run cumulative_rew += reward fabric.print("Test - Reward:", cumulative_rew) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index a917684b..22ee6b7b 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -18,7 +18,7 @@ from sheeprl.algos.droq.agent import DROQAgent, build_agent from sheeprl.algos.sac.loss import entropy_loss, policy_loss -from sheeprl.algos.sac.sac import test +from sheeprl.algos.sac.utils import prepare_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -289,8 +289,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data = {} # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = np.concatenate([o[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32) + obs = envs.reset(seed=cfg.seed)[0] per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 @@ -305,7 +304,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: with torch.inference_mode(): # Sample an action given the observation received by the environment - actions = player(torch.from_numpy(obs).to(device)) + torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) + actions = player(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) @@ -326,13 +326,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - - next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32) real_next_obs = np.concatenate([real_next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype( np.float32 ) - step_data["observations"] = obs[np.newaxis] + step_data["observations"] = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype( + np.float32 + )[np.newaxis] if not cfg.buffer.sample_next_obs: step_data["next_observations"] = real_next_obs[np.newaxis] step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 4f31249a..8ae7c5ae 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -20,7 +20,7 @@ from sheeprl.algos.dreamer_v1.agent import WorldModel from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values -from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.dreamer_v2.utils import prepare_obs, test from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env @@ -598,16 +598,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): axis=-1, ) else: - normalized_obs = {} - for k in obs_keys: - torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs / 255 - 0.5 - normalized_obs[k] = torch_obs - mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 68e6e75d..827f5b4b 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -14,7 +14,7 @@ from torchmetrics import SumMetric from sheeprl.algos.dreamer_v1.dreamer_v1 import train -from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.dreamer_v2.utils import prepare_obs, test from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env @@ -253,16 +253,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - normalized_obs = {} - for k in obs_keys: - torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs / 255 - 0.5 - normalized_obs[k] = torch_obs - mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step) + real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index e0f5a121..61ab6ccc 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -19,7 +19,7 @@ from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss -from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test +from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, prepare_obs, test from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env @@ -735,16 +735,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): axis=-1, ) else: - normalized_obs = {} - for k in obs_keys: - torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs / 255 - 0.5 - normalized_obs[k] = torch_obs - mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(normalized_obs, mask=mask) + real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index bd9c548e..b2c4fafe 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -14,7 +14,7 @@ from torchmetrics import SumMetric from sheeprl.algos.dreamer_v2.dreamer_v2 import train -from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.dreamer_v2.utils import prepare_obs, test from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env @@ -273,16 +273,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - normalized_obs = {} - for k in obs_keys: - torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs / 255 - 0.5 - normalized_obs[k] = torch_obs - mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(normalized_obs, mask=mask) + real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index da0ea8fe..c3406aa9 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -17,7 +17,7 @@ from sheeprl.algos.dreamer_v3.agent import WorldModel from sheeprl.algos.dreamer_v3.loss import reconstruction_loss -from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test +from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, prepare_obs, test from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import ( @@ -807,15 +807,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): axis=-1, ) else: - preprocessed_obs = {} - for k, v in obs.items(): - preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = preprocessed_obs[k] / 255.0 - 0.5 - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(preprocessed_obs, mask=mask) + real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 2f7b272f..97db2c3c 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -12,7 +12,7 @@ from torchmetrics import SumMetric from sheeprl.algos.dreamer_v3.dreamer_v3 import train -from sheeprl.algos.dreamer_v3.utils import Moments, test +from sheeprl.algos.dreamer_v3.utils import Moments, prepare_obs, test from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env @@ -255,15 +255,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - preprocessed_obs = {} - for k, v in obs.items(): - preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) - if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = preprocessed_obs[k] / 255.0 - 0.5 - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) + mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_actions(preprocessed_obs, mask=mask) + real_actions = actions = player.get_actions(torch_obs, mask=mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index d2757486..23d9b436 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -17,7 +17,7 @@ from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo.utils import normalize_obs, test +from sheeprl.algos.ppo.utils import normalize_obs, prepare_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -271,10 +271,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment - normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = { - k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys - } + torch_obs = prepare_obs( + fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs + ) actions, logprobs, values = player(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() @@ -345,8 +344,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.inference_mode(): - normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) next_values = player.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index a1d8bb1c..7ee1ecc0 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -18,7 +18,7 @@ from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo.utils import normalize_obs, test +from sheeprl.algos.ppo.utils import normalize_obs, prepare_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.fabric import get_single_device_fabric @@ -204,10 +204,7 @@ def player( # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment - normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = { - k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys - } + torch_obs = prepare_obs(fabric, next_obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) actions, logprobs, values = agent(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() @@ -243,7 +240,7 @@ def player( # Update the step data step_data["dones"] = dones[np.newaxis] step_data["values"] = values.cpu().numpy()[np.newaxis] - step_data["actions"] = actions[np.newaxis] + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis] step_data["rewards"] = rewards[np.newaxis] if cfg.buffer.memmap: @@ -276,8 +273,7 @@ def player( local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy) # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) next_values = agent.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index e3b55340..4b5e4634 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -22,44 +22,41 @@ MODELS_TO_REGISTER = {"agent"} +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + for k in obs.keys(): + torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float() + if k in cnn_keys: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:]) + else: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1) + return normalize_obs(torch_obs, cnn_keys, obs.keys()) + + @torch.no_grad() def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 - o = env.reset(seed=cfg.seed)[0] - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs + obs = env.reset(seed=cfg.seed)[0] while not done: + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) + # Act greedly through the environment - actions = agent.get_actions(obs, greedy=True) + actions = agent.get_actions(torch_obs, greedy=True) if agent.actor.is_continuous: actions = torch.cat(actions, dim=-1) else: actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step - o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - obs = {} - for k in o.keys(): - if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs if cfg.dry_run: done = True diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index dd2a33a7..bebc8a7b 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -17,9 +17,8 @@ from torchmetrics import SumMetric from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo.utils import normalize_obs from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent, build_agent -from sheeprl.algos.ppo_recurrent.utils import test +from sheeprl.algos.ppo_recurrent.utils import prepare_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -295,8 +294,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment # [Seq_len, Batch_size, D] --> [1, num_envs, D] - normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) actions, logprobs, values, states = player( torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states ) @@ -387,8 +385,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.inference_mode(): - normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) - torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) next_values, _ = player.get_values(torch_obs, torch_actions, states) returns, advantages = gae( local_data["rewards"].to(torch.float64), diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 4e643d6d..90ccdaed 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -3,11 +3,14 @@ from typing import TYPE_CHECKING, Any, Dict, Sequence import gymnasium as gym +import numpy as np import torch from lightning import Fabric +from torch import Tensor from sheeprl.algos.ppo.utils import AGGREGATOR_KEYS as ppo_aggregator_keys from sheeprl.algos.ppo.utils import MODELS_TO_REGISTER as ppo_models_to_register +from sheeprl.algos.ppo.utils import normalize_obs from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOPlayer, build_agent from sheeprl.utils.env import make_env from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE @@ -21,6 +24,20 @@ MODELS_TO_REGISTER = ppo_models_to_register +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + with fabric.device: + for k, v in obs.items(): + torch_obs[k] = torch.as_tensor(v.copy(), dtype=torch.float32, device=fabric.device) + if k in cnn_keys: + torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) + else: + torch_obs[k] = torch_obs[k].view(1, num_envs, -1) + return normalize_obs(torch_obs, cnn_keys, torch_obs.keys()) + + @torch.no_grad() def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() @@ -28,26 +45,17 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d done = False cumulative_rew = 0 agent.num_envs = 1 + obs = env.reset(seed=cfg.seed)[0] with fabric.device: - o = env.reset(seed=cfg.seed)[0] - next_obs = { - k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1, *o[k].shape[-2:]) / 255 - for k in cfg.algo.cnn_keys.encoder - } - next_obs.update( - { - k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1) - for k in cfg.algo.mlp_keys.encoder - } - ) state = ( torch.zeros(1, 1, agent.rnn_hidden_size, device=fabric.device), torch.zeros(1, 1, agent.rnn_hidden_size, device=fabric.device), ) actions = torch.zeros(1, 1, sum(agent.actions_dim), device=fabric.device) while not done: + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) # Act greedly through the environment - actions, state = agent.get_actions(next_obs, actions, state, greedy=True) + actions, state = agent.get_actions(torch_obs, actions, state, greedy=True) if agent.actor.is_continuous: real_actions = torch.cat(actions, -1) actions = torch.cat(actions, dim=-1).view(1, 1, -1) @@ -56,17 +64,9 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d actions = torch.cat([act for act in actions], dim=-1).view(1, 1, -1) # Single environment step - o, reward, done, truncated, info = env.step(real_actions.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, info = env.step(real_actions.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - with fabric.device: - next_obs = { - k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1, *o[k].shape[-2:]) / 255 - for k in cfg.algo.cnn_keys.encoder - } - next_obs.update( - {k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1) for k in cfg.algo.mlp_keys.encoder} - ) if cfg.dry_run: done = True diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 774754fc..72f623c5 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -19,7 +19,7 @@ from sheeprl.algos.sac.agent import SACAgent, build_agent from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss -from sheeprl.algos.sac.utils import test +from sheeprl.algos.sac.utils import prepare_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -242,7 +242,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data = {} # Get the first environment observation and start the optimization obs = envs.reset(seed=cfg.seed)[0] - obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 @@ -257,11 +256,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: # Sample an action given the observation received by the environment with torch.inference_mode(): - torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) + torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) actions = player(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, terminated, truncated, infos = envs.step(actions) - next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -279,14 +277,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: - real_next_obs[idx] = np.concatenate( - [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 - ) + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + real_next_obs = np.concatenate([real_next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype( + np.float32 + ) step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) - step_data["actions"] = actions[np.newaxis] - step_data["observations"] = obs[np.newaxis] + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) + step_data["observations"] = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1)[np.newaxis] if not cfg.buffer.sample_next_obs: step_data["next_observations"] = real_next_obs[np.newaxis] step_data["rewards"] = rewards[np.newaxis] diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index b76337db..ee0b81e8 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -18,7 +18,7 @@ from sheeprl.algos.sac.agent import SACAgent, SACCritic, build_agent from sheeprl.algos.sac.sac import train -from sheeprl.algos.sac.utils import test +from sheeprl.algos.sac.utils import prepare_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.fabric import get_single_device_fabric @@ -172,7 +172,6 @@ def player( step_data = {} # Get the first environment observation and start the optimization obs = envs.reset(seed=cfg.seed)[0] - obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 @@ -186,11 +185,10 @@ def player( actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment - torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) + torch_obs = prepare_obs(fabric, obs, num_envs=cfg.env.num_envs) actions = actor(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions) - next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -208,14 +206,16 @@ def player( if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: - real_next_obs[idx] = np.concatenate( - [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 - ) + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + real_next_obs = np.concatenate([real_next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype( + np.float32 + ) step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) - step_data["actions"] = actions[np.newaxis] - step_data["observations"] = obs[np.newaxis] + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) + step_data["observations"] = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1)[np.newaxis] if not cfg.buffer.sample_next_obs: step_data["next_observations"] = real_next_obs[np.newaxis] step_data["rewards"] = rewards[np.newaxis] diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index ae624cf1..6912f278 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -4,9 +4,11 @@ from typing import TYPE_CHECKING, Any, Dict, Sequence import gymnasium as gym +import numpy as np import torch from lightning import Fabric from lightning.fabric.wrappers import _FabricModule +from torch import Tensor from sheeprl.algos.sac.agent import SACPlayer, build_agent from sheeprl.utils.env import make_env @@ -26,31 +28,28 @@ MODELS_TO_REGISTER = {"agent"} +def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Tensor: + with fabric.device: + torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1) + return torch_obs.reshape(num_envs, -1) + + @torch.no_grad() def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() actor.eval() done = False cumulative_rew = 0 - with fabric.device: - o = env.reset(seed=cfg.seed)[0] - next_obs = torch.cat( - [torch.as_tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ).unsqueeze( - 0 - ) # [N_envs, N_obs] + obs = env.reset(seed=cfg.seed)[0] while not done: # Act greedly through the environment - action = actor.get_actions(next_obs, greedy=True) + torch_obs = prepare_obs(fabric, obs) + action = actor.get_actions(torch_obs, greedy=True) # Single environment step - next_obs, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - with fabric.device: - next_obs = torch.cat( - [torch.as_tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) if cfg.dry_run: done = True diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 183b5297..23aad98d 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -21,7 +21,7 @@ from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent -from sheeprl.algos.sac_ae.utils import preprocess_obs, test +from sheeprl.algos.sac_ae.utils import prepare_obs, preprocess_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder from sheeprl.utils.env import make_env @@ -329,8 +329,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = envs.action_space.sample() else: with torch.inference_mode(): - normalized_obs = {k: v / 255 if k in cfg.algo.cnn_keys.encoder else v for k, v in obs.items()} - torch_obs = {k: torch.from_numpy(v).to(device).float() for k, v in normalized_obs.items()} + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) actions = player(torch_obs).cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) diff --git a/sheeprl/algos/sac_ae/utils.py b/sheeprl/algos/sac_ae/utils.py index 07891f02..680f5ee1 100644 --- a/sheeprl/algos/sac_ae/utils.py +++ b/sheeprl/algos/sac_ae/utils.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, Sequence import gymnasium as gym +import numpy as np import torch import torch.nn as nn from lightning import Fabric @@ -24,44 +25,38 @@ MODELS_TO_REGISTER = {"agent", "encoder", "decoder"} +def prepare_obs( + fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs +) -> Dict[str, Tensor]: + torch_obs = {} + for k in obs.keys(): + torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float() + if k in cnn_keys: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:]) / 255 + else: + torch_obs[k] = torch_obs[k].reshape(num_envs, -1) + + return torch_obs + + @torch.no_grad() def test(actor: "SACAEPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)() - cnn_keys = actor.encoder.cnn_keys - mlp_keys = actor.encoder.mlp_keys actor.eval() done = False cumulative_rew = 0 - next_obs = {} - o = env.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - for k in o.keys(): - if k in mlp_keys + cnn_keys: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cnn_keys: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - if k in mlp_keys: - torch_obs = torch_obs.float() - next_obs[k] = torch_obs + obs = env.reset(seed=cfg.seed)[0] # [N_envs, N_obs] while not done: + torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder) # Act greedly through the environment - action = actor.get_actions(next_obs, greedy=True) + action = actor.get_actions(torch_obs, greedy=True) # Single environment step - o, reward, done, truncated, _ = env.step(action.cpu().numpy().reshape(env.action_space.shape)) + obs, reward, done, truncated, _ = env.step(action.cpu().numpy().reshape(env.action_space.shape)) done = done or truncated cumulative_rew += reward - next_obs = {} - for k in o.keys(): - if k in mlp_keys + cnn_keys: - torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cnn_keys: - torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - if k in mlp_keys: - torch_obs = torch_obs.float() - next_obs[k] = torch_obs - if cfg.dry_run: done = True fabric.print("Test - Reward:", cumulative_rew)