Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added prepare obs to all the algorithms #267

Merged
merged 4 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -715,11 +715,16 @@ where `log_models`, `test` and `normalize_obs` have to be defined in the `my_awe
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

from sheeprl.utils.env import make_env
from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

Expand All @@ -729,43 +734,41 @@ if TYPE_CHECKING:
from mlflow.models.model import ModelInfo


def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k in obs.keys():
torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:])
else:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1)
return normalize_obs(torch_obs, cnn_keys, obs.keys())


@torch.no_grad()
def test(agent: SOTAAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs
obs = env.reset(seed=cfg.seed)[0]

while not done:
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder)

# Act greedly through the environment
if agent.is_continuous:
actions = torch.cat(agent.get_greedy_actions(obs), dim=-1)
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1)
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
49 changes: 24 additions & 25 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -713,59 +713,58 @@ where `log_models`, `test` and `normalize_obs` have to be defined in the `sheepr
```python
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

from sheeprl.algos.sota.agent import SOTAAgentPlayer
from sheeprl.utils.env import make_env
from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

if TYPE_CHECKING:
from mlflow.models.model import ModelInfo


def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k in obs.keys():
torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:])
else:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1)
return normalize_obs(torch_obs, cnn_keys, obs.keys())


@torch.no_grad()
def test(agent: SOTAAgentPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs
obs = env.reset(seed=cfg.seed)[0]

while not done:
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
if agent.is_continuous:
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
8 changes: 4 additions & 4 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sheeprl.algos.a2c.agent import A2CAgent, build_agent
from sheeprl.algos.a2c.loss import policy_loss, value_loss
from sheeprl.algos.a2c.utils import test
from sheeprl.algos.a2c.utils import prepare_obs, test
from sheeprl.data import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -236,7 +236,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
Expand Down Expand Up @@ -272,7 +272,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1)
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
Expand Down Expand Up @@ -304,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand Down
28 changes: 13 additions & 15 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,45 @@

from typing import Any, Dict

import numpy as np
import torch
from lightning import Fabric
from torch import Tensor

from sheeprl.algos.ppo.agent import PPOPlayer
from sheeprl.utils.env import make_env

AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()}
return torch_obs


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs
obs = env.reset(seed=cfg.seed)[0]

while not done:
# Convert observations to tensors
torch_obs = prepare_obs(fabric, obs)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sheeprl.algos.dreamer_v1.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.dreamer_v2.utils import prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -574,16 +574,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from sheeprl.algos.dreamer_v2.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -599,16 +599,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
33 changes: 19 additions & 14 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from lightning import Fabric
Expand Down Expand Up @@ -101,6 +102,20 @@ def compute_lambda_values(
return torch.cat(list(reversed(lv)), dim=0)


def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) / 255 - 0.5
else:
torch_obs[k] = torch_obs[k].view(1, num_envs, -1)

return torch_obs


@torch.no_grad()
def test(
player: "PlayerDV2" | "PlayerDV1",
Expand All @@ -125,32 +140,22 @@ def test(
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
device = fabric.device
next_obs = env.reset(seed=cfg.seed)[0]
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
obs = env.reset(seed=cfg.seed)[0]
player.num_envs = 1
player.init_states()
while not done:
# Act greedly through the environment
preprocessed_obs = {}
for k, v in next_obs.items():
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
done = done or truncated or cfg.dry_run
cumulative_rew += reward
fabric.print("Test - Reward:", cumulative_rew)
Expand Down
Loading
Loading