Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added prepare obs to all the algorithms #267

Merged
merged 4 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,34 @@

from typing import Any, Dict

import numpy as np
import torch
from lightning import Fabric
from torch import Tensor

from sheeprl.algos.ppo.agent import PPOPlayer
from sheeprl.utils.env import make_env

AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *args, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v[np.newaxis]).to(fabric.device).float() for k, v in obs.items()}
return torch_obs


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs

while not done:
# Convert observations to tensors
obs = prepare_obs(fabric, o)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
if agent.actor.is_continuous:
Expand All @@ -37,12 +41,6 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
31 changes: 17 additions & 14 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from lightning import Fabric
Expand Down Expand Up @@ -101,6 +102,18 @@ def compute_lambda_values(
return torch.cat(list(reversed(lv)), dim=0)


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str] = []) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).view(1, *v.shape).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k][None, ...] / 255 - 0.5
else:
torch_obs[k] = torch_obs[k][None, ...]

return torch_obs


@torch.no_grad()
def test(
player: "PlayerDV2" | "PlayerDV1",
Expand All @@ -125,32 +138,22 @@ def test(
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
device = fabric.device
next_obs = env.reset(seed=cfg.seed)[0]
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o = env.reset(seed=cfg.seed)[0]
player.num_envs = 1
player.init_states()
while not done:
# Act greedly through the environment
preprocessed_obs = {}
for k, v in next_obs.items():
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
belerico marked this conversation as resolved.
Show resolved Hide resolved
done = done or truncated or cfg.dry_run
cumulative_rew += reward
fabric.print("Test - Reward:", cumulative_rew)
Expand Down
30 changes: 16 additions & 14 deletions sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def compute_lambda_values(
return ret


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str] = []) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).view(1, *v.shape).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k][None, ...] / 255 - 0.5
else:
torch_obs[k] = torch_obs[k][None, ...]

return torch_obs


@torch.no_grad()
def test(
player: "PlayerDV3",
Expand All @@ -101,32 +113,22 @@ def test(
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
device = fabric.device
next_obs = env.reset(seed=cfg.seed)[0]
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o = env.reset(seed=cfg.seed)[0]
player.num_envs = 1
player.init_states()
while not done:
# Act greedly through the environment
preprocessed_obs = {}
for k, v in next_obs.items():
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
o, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
done = done or truncated or cfg.dry_run
cumulative_rew += reward
fabric.print("Test - Reward:", cumulative_rew)
Expand Down
31 changes: 12 additions & 19 deletions sheeprl/algos/ppo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,28 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str]) -> Dict[str, Tensor]:
torch_obs = {}
for k in obs.keys():
torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).unsqueeze(0).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].reshape(1, -1, *torch_obs[k].shape[-2:]) / 255 - 0.5
return torch_obs


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

while not done:
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
Expand All @@ -51,15 +53,6 @@ def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
37 changes: 17 additions & 20 deletions sheeprl/algos/ppo_recurrent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from torch import Tensor

from sheeprl.algos.ppo.utils import AGGREGATOR_KEYS as ppo_aggregator_keys
from sheeprl.algos.ppo.utils import MODELS_TO_REGISTER as ppo_models_to_register
Expand All @@ -21,33 +23,36 @@
MODELS_TO_REGISTER = ppo_models_to_register


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str]) -> Dict[str, Tensor]:
torch_obs = {}
with fabric.device:
for k, v in obs.items():
torch_obs[k] = torch.as_tensor(v.copy(), dtype=torch.float32, device=fabric.device)
if k in cnn_keys:
torch_obs[k] = torch_obs[k].view(1, 1, -1, *v.shape[-2:]) / 255 - 0.5
else:
torch_obs[k] = torch_obs[k].view(1, 1, -1)
return torch_obs


@torch.no_grad()
def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
agent.num_envs = 1
o = env.reset(seed=cfg.seed)[0]
with fabric.device:
o = env.reset(seed=cfg.seed)[0]
next_obs = {
k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1, *o[k].shape[-2:]) / 255
for k in cfg.algo.cnn_keys.encoder
}
next_obs.update(
{
k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1)
for k in cfg.algo.mlp_keys.encoder
}
)
state = (
torch.zeros(1, 1, agent.rnn_hidden_size, device=fabric.device),
torch.zeros(1, 1, agent.rnn_hidden_size, device=fabric.device),
)
actions = torch.zeros(1, 1, sum(agent.actions_dim), device=fabric.device)
while not done:
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
# Act greedly through the environment
actions, state = agent.get_actions(next_obs, actions, state, greedy=True)
actions, state = agent.get_actions(torch_obs, actions, state, greedy=True)
if agent.actor.is_continuous:
real_actions = torch.cat(actions, -1)
actions = torch.cat(actions, dim=-1).view(1, 1, -1)
Expand All @@ -59,14 +64,6 @@ def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_d
o, reward, done, truncated, info = env.step(real_actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
with fabric.device:
next_obs = {
k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1, *o[k].shape[-2:]) / 255
for k in cfg.algo.cnn_keys.encoder
}
next_obs.update(
{k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1) for k in cfg.algo.mlp_keys.encoder}
)

if cfg.dry_run:
done = True
Expand Down
25 changes: 12 additions & 13 deletions sheeprl/algos/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

from sheeprl.algos.sac.agent import SACPlayer, build_agent
from sheeprl.utils.env import make_env
Expand All @@ -26,31 +28,28 @@
MODELS_TO_REGISTER = {"agent"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *args, **kwargs) -> Tensor:
with fabric.device:
torch_obs = torch.cat([torch.as_tensor(obs[k].copy(), dtype=torch.float32) for k in obs.keys()], dim=-1)
return torch_obs.unsqueeze(0)


@torch.no_grad()
def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
actor.eval()
done = False
cumulative_rew = 0
with fabric.device:
o = env.reset(seed=cfg.seed)[0]
next_obs = torch.cat(
[torch.as_tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1
).unsqueeze(
0
) # [N_envs, N_obs]
o = env.reset(seed=cfg.seed)[0]
while not done:
# Act greedly through the environment
action = actor.get_actions(next_obs, greedy=True)
torch_obs = prepare_obs(fabric, o)
action = actor.get_actions(torch_obs, greedy=True)

# Single environment step
next_obs, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape))
o, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
with fabric.device:
next_obs = torch.cat(
[torch.as_tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1
)

if cfg.dry_run:
done = True
Expand Down
Loading
Loading