diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 354d6b36..8c2a41e7 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -377,13 +377,24 @@ def __init__( self.init_states() - def init_states(self) -> None: - """ - Initialize the states and the actions for the ended environments. + def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: + """Initialize the states and the actions for the ended environments. + + Args: + reset_envs (Optional[Sequence[int]], optional): which environments' states to reset. + If None, then all environments' states are reset. + Defaults to None. """ - self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device) - self.stochastic_state = torch.zeros(1, self.num_envs, self.stochastic_size, device=self.device) - self.recurrent_state = torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device) + if reset_envs is None or len(reset_envs) == 0: + self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device) + self.recurrent_state = torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device) + self.stochastic_state = torch.zeros( + 1, self.num_envs, self.stochastic_size, device=self.device + ) + else: + self.actions[:, reset_envs] = torch.zeros_like(self.actions[:, reset_envs]) + self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) + self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) def get_exploration_action(self, obs: Tensor, is_continuous: bool) -> Tensor: """ diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 5868d5fe..221e14fa 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -1,3 +1,4 @@ +import copy import os import pathlib import time @@ -23,8 +24,9 @@ from sheeprl.algos.dreamer_v1.agent import Player, WorldModel, build_models from sheeprl.algos.dreamer_v1.args import DreamerV1Args from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss -from sheeprl.algos.dreamer_v1.utils import cnn_forward, make_env, test -from sheeprl.data.buffers import SequentialReplayBuffer +from sheeprl.algos.dreamer_v1.utils import cnn_forward, test +from sheeprl.algos.dreamer_v2.utils import make_env +from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.parser import HfArgumentParser @@ -340,7 +342,6 @@ def train( def main(): parser = HfArgumentParser(DreamerV1Args) args: DreamerV1Args = parser.parse_args_into_dataclasses()[0] - args.num_envs = 1 torch.set_num_threads(1) # Initialize Fabric @@ -396,23 +397,31 @@ def main(): log_dir = data[0] os.makedirs(log_dir, exist_ok=True) - env: gym.Env = make_env( - args.env_id, - args.seed + rank * args.num_envs, - rank, - args, - logger.log_dir if rank == 0 else None, - "train", + # Environment setup + vectorized_env = gym.vector.SyncVectorEnv if args.sync_env else gym.vector.AsyncVectorEnv + envs = vectorized_env( + [ + make_env( + args.env_id, + args.seed + rank * args.num_envs, + rank, + args, + logger.log_dir if rank == 0 else None, + "train", + ) + for i in range(args.num_envs) + ] ) - is_continuous = isinstance(env.action_space, gym.spaces.Box) - is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + action_space = envs.single_action_space + observation_space = envs.single_observation_space + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) actions_dim = ( - env.action_space.shape - if is_continuous - else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - observation_shape = env.observation_space.shape + observation_shape = observation_space["rgb"].shape clip_rewards_fn = lambda r: torch.tanh(r) if args.clip_rewards else r world_model, actor, critic = build_models( @@ -473,23 +482,24 @@ def main(): "Grads/critic": MeanMetric(sync_on_compute=False), } ) - aggregator.to(fabric.device) + aggregator.to(fabric.device) # Local data buffer_size = ( args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 2 ) - rb = SequentialReplayBuffer( + rb = AsyncReplayBuffer( buffer_size, args.num_envs, device=fabric.device if args.memmap_buffer else "cpu", memmap=args.memmap_buffer, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + sequential=True, ) if args.checkpoint_path and args.checkpoint_buffer: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], SequentialReplayBuffer): + elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") @@ -499,11 +509,12 @@ def main(): # Global variables start_time = time.perf_counter() start_step = state["global_step"] // fabric.world_size if args.checkpoint_path else 1 - step_before_training = args.train_every // (fabric.world_size * args.action_repeat) if not args.dry_run else 0 - num_updates = int(args.total_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 1 - learning_starts = (args.learning_starts // (fabric.world_size * args.action_repeat)) if not args.dry_run else 0 + single_global_step = int(args.num_envs * fabric.world_size * args.action_repeat) + step_before_training = args.train_every // single_global_step if not args.dry_run else 0 + num_updates = int(args.total_steps // single_global_step) if not args.dry_run else 1 + learning_starts = (args.learning_starts // single_global_step) if not args.dry_run else 0 if args.checkpoint_path and not args.checkpoint_buffer: - learning_starts = start_step + args.learning_starts // int(fabric.world_size * args.action_repeat) + learning_starts = start_step + args.learning_starts // single_global_step max_step_expl_decay = args.max_step_expl_decay // (args.gradient_steps * fabric.world_size) if args.checkpoint_path: player.expl_amount = polynomial_decay( @@ -514,7 +525,9 @@ def main(): ) # Get the first environment observation and start the optimization - obs = torch.from_numpy(env.reset(seed=args.seed)[0]).view(args.num_envs, *observation_shape) # [N_envs, N_obs] + obs = torch.from_numpy(envs.reset(seed=args.seed)[0]["rgb"]).view( + args.num_envs, *observation_shape + ) # [N_envs, N_obs] step_data["dones"] = torch.zeros(args.num_envs, 1) step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) step_data["rewards"] = torch.zeros(args.num_envs, 1) @@ -524,13 +537,13 @@ def main(): for global_step in range(start_step, num_updates + 1): # Sample an action given the observation received by the environment - if global_step < learning_starts and args.checkpoint_path is None: - real_actions = actions = np.array(env.action_space.sample()) + if global_step <= learning_starts and args.checkpoint_path is None and "minedojo" not in args.env_id: + real_actions = actions = np.array(envs.action_space.sample()) if not is_continuous: actions = np.concatenate( [ F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim)), actions_dim) + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) @@ -543,14 +556,28 @@ def main(): if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions]) - next_obs, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape)) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + next_obs = next_obs["rgb"] dones = np.logical_or(dones, truncated) - if (dones or truncated) and "episode" in infos: - fabric.print(f"Rank-0: global_step={global_step}, reward_env_{0}={infos['episode']['r'][0]}") - aggregator.update("Rewards/rew_avg", infos["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", infos["episode"]["l"][0]) + if "final_info" in infos: + for i, agent_final_info in enumerate(infos["final_info"]): + if agent_final_info is not None and "episode" in agent_final_info: + fabric.print( + f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" + ) + aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) + aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + + # Save the real next observation + real_next_obs = next_obs.copy() + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + if k == "rgb": + real_next_obs[idx] = v next_obs = torch.from_numpy(next_obs).view(args.num_envs, *observation_shape) actions = torch.from_numpy(actions).view(args.num_envs, -1).float() @@ -562,20 +589,25 @@ def main(): step_data["dones"] = dones step_data["actions"] = actions - step_data["observations"] = obs + step_data["observations"] = real_next_obs step_data["rewards"] = clip_rewards_fn(rewards) rb.add(step_data[None, ...]) - if dones or truncated: - obs = torch.from_numpy(env.reset(seed=args.seed)[0]).view( - args.num_envs, *observation_shape - ) # [N_envs, N_obs] - step_data["dones"] = torch.zeros(args.num_envs, 1) - step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) - step_data["rewards"] = torch.zeros(args.num_envs, 1) - step_data["observations"] = obs - rb.add(step_data[None, ...]) - player.init_states() + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + reset_data["observations"] = next_obs[dones_idxes] + reset_data["dones"] = torch.zeros(reset_envs, 1) + reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) + reset_data["rewards"] = torch.zeros(reset_envs, 1) + rb.add(reset_data[None, ...], dones_idxes) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + # Reset internal agent states + player.init_states(dones_idxes) step_before_training -= 1 @@ -642,7 +674,7 @@ def main(): replay_buffer=rb if args.checkpoint_buffer else None, ) - env.close() + envs.close() if fabric.is_global_zero: test(player, fabric, args) diff --git a/sheeprl/algos/dreamer_v1/utils.py b/sheeprl/algos/dreamer_v1/utils.py index 4fcdeb28..f119f596 100644 --- a/sheeprl/algos/dreamer_v1/utils.py +++ b/sheeprl/algos/dreamer_v1/utils.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union import gymnasium as gym import numpy as np @@ -9,131 +9,14 @@ from torch import Tensor, nn from torch.distributions import Distribution, Independent, Normal -from sheeprl.utils.utils import get_dummy_env - if TYPE_CHECKING: from sheeprl.algos.dreamer_v1.agent import Player from sheeprl.algos.dreamer_v1.args import DreamerV1Args +from sheeprl.algos.dreamer_v2.utils import make_env from sheeprl.envs.wrappers import ActionRepeat -def make_env( - env_id: str, - seed: int, - rank: int, - args: DreamerV1Args, - run_name: Optional[str] = None, - prefix: str = "", -) -> gym.Env: - """ - Create the callable function to createenvironment and - force the environment to return only pixels observations. - - Args: - env_id (str): the id of the environment to initialize. - seed (int): the seed to use. - rank (int): the rank of the process. - args (DreamerV1Args): the configs of the experiment. - run_name (str, optional): the name of the run. - Default to None. - prefix (str): the prefix to add to the video folder. - Default to "". - - Returns: - The callable function that initializes the environment. - """ - _env_id = env_id.lower() - if "dummy" in _env_id: - env = get_dummy_env(_env_id) - elif "dmc" in _env_id: - from sheeprl.envs.dmc import DMCWrapper - - _, domain, task = _env_id.split("_") - env = DMCWrapper( - domain, - task, - from_pixels=True, - height=64, - width=64, - frame_skip=args.action_repeat, - seed=seed, - ) - elif "minedojo" in _env_id: - from sheeprl.envs.minedojo import MineDojoWrapper - - task_id = "_".join(env_id.split("_")[1:]) - start_position = ( - { - "x": args.mine_start_position[0], - "y": args.mine_start_position[1], - "z": args.mine_start_position[2], - "pitch": args.mine_start_position[3], - "yaw": args.mine_start_position[4], - } - if args.mine_start_position is not None - else None - ) - env = MineDojoWrapper( - task_id, - height=64, - width=64, - pitch_limits=(args.mine_min_pitch, args.mine_max_pitch), - seed=args.seed, - start_position=start_position, - ) - env = ActionRepeat(env, args.action_repeat) - else: - env_spec = gym.spec(env_id).entry_point - if "mujoco" in env_spec: - try: - env = gym.make(env_id, render_mode="rgb_array", terminate_when_unhealthy=False) - except: - env = gym.make(env_id, render_mode="rgb_array") - env.frame_skip = 0 - else: - env = gym.make(env_id, render_mode="rgb_array") - if "atari" in env_spec: - if args.atari_noop_max < 0: - raise ValueError( - f"Negative value of atart_noop_max parameter ({args.atari_noop_max}), the minimum value allowed is 0" - ) - env = gym.wrappers.AtariPreprocessing( - env, - noop_max=args.atari_noop_max, - frame_skip=args.action_repeat, - screen_size=64, - grayscale_obs=args.grayscale_obs, - scale_obs=False, - terminal_on_life_loss=True, - grayscale_newaxis=True, - ) - else: - env = ActionRepeat(env, args.action_repeat) - if isinstance(env.observation_space, gym.spaces.Box) or len(env.observation_space.shape) < 3: - env = gym.wrappers.PixelObservationWrapper(env) - env = gym.wrappers.TransformObservation(env, lambda obs: obs["pixels"]) - env.observation_space = env.observation_space["pixels"] - env = gym.wrappers.ResizeObservation(env, (64, 64)) - if args.grayscale_obs: - env = gym.wrappers.GrayScaleObservation(env, keep_dim=True) - env = gym.wrappers.TransformObservation(env, lambda obs: obs.transpose(2, 0, 1)) - env.observation_space = gym.spaces.Box( - 0, 255, (env.observation_space.shape[-1], *env.observation_space.shape[:2]), np.uint8 - ) - env.action_space.seed(seed) - env.observation_space.seed(seed) - if args.max_episode_steps > 0: - env = gym.wrappers.TimeLimit(env, max_episode_steps=args.max_episode_steps // args.action_repeat) - env = gym.wrappers.RecordEpisodeStatistics(env) - if args.capture_video and rank == 0 and run_name is not None: - env = gym.experimental.wrappers.RecordVideoV0( - env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True - ) - env.metadata["render_fps"] = env.frames_per_sec - return env - - def compute_stochastic_state( state_information: Tensor, event_shape: Optional[int] = 1, @@ -236,16 +119,19 @@ def test(player: "Player", fabric: Fabric, args: DreamerV1Args, test_name: str = """ env: gym.Env = make_env( args.env_id, args.seed, 0, args, fabric.logger.log_dir, "test" + (f"_{test_name}" if test_name != "" else "") - ) + )() done = False cumulative_rew = 0 - next_obs = torch.tensor(env.reset(seed=args.seed)[0], device=fabric.device).view(1, 1, *env.observation_space.shape) + next_obs = torch.tensor(env.reset(seed=args.seed)[0]["rgb"], device=fabric.device).view( + 1, 1, *env.observation_space["rgb"].shape + ) + player.num_envs = 1 player.init_states() while not done: # Act greedly through the environment action = player.get_greedy_action(next_obs / 255 - 0.5, False) if not player.actor.is_continuous: - action = np.array([act.cpu().argmax() for act in action]) + action = np.array([act.cpu().argmax(dim=-1).numpy() for act in action]) else: action = action[0].cpu().numpy() @@ -253,7 +139,7 @@ def test(player: "Player", fabric: Fabric, args: DreamerV1Args, test_name: str = next_obs, reward, done, truncated, _ = env.step(action.reshape(env.action_space.shape)) done = done or truncated or args.dry_run cumulative_rew += reward - next_obs = torch.tensor(next_obs, device=fabric.device).view(1, 1, *env.observation_space.shape) + next_obs = torch.tensor(next_obs["rgb"], device=fabric.device).view(1, 1, *env.observation_space["rgb"].shape) fabric.print("Test - Reward:", cumulative_rew) fabric.logger.log_metrics({"Test/cumulative_reward": cumulative_rew}, 0) env.close() diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 0e705ff0..398764f9 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -595,18 +595,30 @@ def __init__( self.num_envs = num_envs self.init_states() - def init_states(self) -> None: - """ - Initialize the states and the actions for the ended environments. + def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: + """Initialize the states and the actions for the ended environments. + + Args: + reset_envs (Optional[Sequence[int]], optional): which environments' states to reset. + If None, then all environments' states are reset. + Defaults to None. """ - self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device) - self.stochastic_state = torch.zeros( - 1, self.num_envs, self.stochastic_size * self.discrete_size, device=self.device - ) - self.recurrent_state = torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device) + if reset_envs is None or len(reset_envs) == 0: + self.actions = torch.zeros(1, self.num_envs, np.sum(self.actions_dim), device=self.device) + self.recurrent_state = torch.zeros(1, self.num_envs, self.recurrent_state_size, device=self.device) + self.stochastic_state = torch.zeros( + 1, self.num_envs, self.stochastic_size * self.discrete_size, device=self.device + ) + else: + self.actions[:, reset_envs] = torch.zeros_like(self.actions[:, reset_envs]) + self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) + self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) def get_exploration_action( - self, obs: Dict[str, Tensor], is_continuous: bool, mask: Optional[Dict[str, np.ndarray]] = None + self, + obs: Dict[str, Tensor], + is_continuous: bool, + mask: Optional[Dict[str, np.ndarray]] = None, ) -> Tensor: """ Return the actions with a certain amount of noise for exploration. @@ -635,7 +647,10 @@ def get_exploration_action( return tuple(expl_actions) def get_greedy_action( - self, obs: Dict[str, Tensor], is_training: bool = True, mask: Optional[Dict[str, np.ndarray]] = None + self, + obs: Dict[str, Tensor], + is_training: bool = True, + mask: Optional[Dict[str, np.ndarray]] = None, ) -> Sequence[Tensor]: """ Return the greedy actions. diff --git a/sheeprl/algos/dreamer_v2/args.py b/sheeprl/algos/dreamer_v2/args.py index 9adac713..f98523b7 100644 --- a/sheeprl/algos/dreamer_v2/args.py +++ b/sheeprl/algos/dreamer_v2/args.py @@ -37,7 +37,7 @@ class DreamerV2Args(StandardArgs): horizon: int = Arg(default=15, help="the number of imagination step") gamma: float = Arg(default=0.99, help="the discount factor gamma") lmbda: float = Arg(default=0.95, help="the lambda for the TD lambda values") - use_continues: bool = Arg(default=False, help="wheter or not to use the continue predictor") + use_continues: bool = Arg(default=True, help="wheter or not to use the continue predictor") stochastic_size: int = Arg(default=32, help="the dimension of the stochastic state") discrete_size: int = Arg(default=32, help="the dimension of the discrete state") hidden_size: int = Arg(default=200, help="the hidden size for the transition and representation model") @@ -47,9 +47,6 @@ class DreamerV2Args(StandardArgs): kl_free_avg: bool = Arg(default=True, help="whether to apply free average") kl_regularizer: float = Arg(default=1.0, help="the scale factor for the kl divergence") continue_scale_factor: float = Arg(default=1.0, help="the scale factor for the continue loss") - min_std: float = Arg( - default=0.1, help="the minimum value of the standard deviation for the stochastic state distribution" - ) actor_ent_coef: float = Arg(default=1e-4, help="the entropy coefficient for the actor loss") actor_init_std: float = Arg( default=0.0, help="the amout to sum to the input of the function of the standard deviation of the actions" diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 2711654f..093f9509 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -1,3 +1,4 @@ +import copy import os import pathlib import time @@ -16,6 +17,7 @@ from lightning.fabric.wrappers import _FabricModule from tensordict import TensorDict from tensordict.tensordict import TensorDictBase +from torch import Tensor from torch.distributions import Bernoulli, Distribution, Independent, Normal, OneHotCategorical from torch.optim import Adam, Optimizer from torch.utils.data import BatchSampler @@ -25,7 +27,7 @@ from sheeprl.algos.dreamer_v2.args import DreamerV2Args from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, make_env, test -from sheeprl.data.buffers import EpisodeBuffer, SequentialReplayBuffer +from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.parser import HfArgumentParser @@ -254,9 +256,8 @@ def train( imagined_trajectories[i] = imagined_latent_state # predict values and rewards - with torch.no_grad(): - predicted_target_values = Independent(Normal(target_critic(imagined_trajectories), 1), 1).mean - predicted_rewards = Independent(Normal(world_model.reward_model(imagined_trajectories), 1), 1).mean + predicted_target_values = Independent(Normal(target_critic(imagined_trajectories), 1), 1).mode + predicted_rewards = Independent(Normal(world_model.reward_model(imagined_trajectories), 1), 1).mode if args.use_continues and world_model.continue_model: continues = Independent( Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=False), 1 @@ -320,7 +321,7 @@ def train( dynamics = lambda_values[1:] # Reinforce - advantage = (lambda_values[1:] - predicted_target_values).detach() + advantage = (lambda_values[1:] - predicted_target_values[:-2]).detach() reinforce = ( torch.stack( [ @@ -372,7 +373,6 @@ def train( def main(): parser = HfArgumentParser(DreamerV2Args) args: DreamerV2Args = parser.parse_args_into_dataclasses()[0] - args.num_envs = 1 torch.set_num_threads(1) # Initialize Fabric @@ -428,28 +428,36 @@ def main(): log_dir = data[0] os.makedirs(log_dir, exist_ok=True) - env: gym.Env = make_env( - args.env_id, - args.seed + rank * args.num_envs, - rank, - args, - logger.log_dir if rank == 0 else None, - "train", + # Environment setup + vectorized_env = gym.vector.SyncVectorEnv if args.sync_env else gym.vector.AsyncVectorEnv + envs = vectorized_env( + [ + make_env( + args.env_id, + args.seed + rank * args.num_envs, + rank, + args, + logger.log_dir if rank == 0 else None, + "train", + ) + for i in range(args.num_envs) + ] ) - is_continuous = isinstance(env.action_space, gym.spaces.Box) - is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + action_space = envs.single_action_space + observation_space = envs.single_observation_space + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) actions_dim = ( - env.action_space.shape - if is_continuous - else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if args.clip_rewards else r cnn_keys = [] mlp_keys = [] - if isinstance(env.observation_space, gym.spaces.Dict): + if isinstance(observation_space, gym.spaces.Dict): cnn_keys = [] - for k, v in env.observation_space.spaces.items(): + for k, v in observation_space.spaces.items(): if args.cnn_keys and ( k in args.cnn_keys or (len(args.cnn_keys) == 1 and args.cnn_keys[0].lower() == "all") ): @@ -461,7 +469,7 @@ def main(): "Try to transform the observation from the environment into a 3D image" ) mlp_keys = [] - for k, v in env.observation_space.spaces.items(): + for k, v in observation_space.spaces.items(): if args.mlp_keys and ( k in args.mlp_keys or (len(args.mlp_keys) == 1 and args.mlp_keys[0].lower() == "all") ): @@ -473,18 +481,19 @@ def main(): "Try to flatten the observation from the environment" ) else: - raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {env.observation_space}") + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") if cnn_keys == [] and mlp_keys == []: raise RuntimeError(f"There must be at least one valid observation.") fabric.print("CNN keys:", cnn_keys) fabric.print("MLP keys:", mlp_keys) + obs_keys = cnn_keys + mlp_keys world_model, actor, critic, target_critic = build_models( fabric, actions_dim, is_continuous, args, - env.observation_space, + observation_space, cnn_keys, mlp_keys, state["world_model"] if args.checkpoint_path else None, @@ -540,7 +549,7 @@ def main(): "Grads/critic": MeanMetric(sync_on_compute=False), } ) - aggregator.to(fabric.device) + aggregator.to(fabric.device) # Local data buffer_size = ( @@ -548,12 +557,13 @@ def main(): ) buffer_type = args.buffer_type.lower() if buffer_type == "sequential": - rb = SequentialReplayBuffer( + rb = AsyncReplayBuffer( buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + sequential=True, ) elif buffer_type == "episode": rb = EpisodeBuffer( @@ -568,7 +578,7 @@ def main(): if args.checkpoint_path and args.checkpoint_buffer: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], (SequentialReplayBuffer, EpisodeBuffer)): + elif isinstance(state["rb"], (AsyncReplayBuffer, EpisodeBuffer)): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") @@ -578,11 +588,12 @@ def main(): # Global variables start_time = time.perf_counter() start_step = state["global_step"] // fabric.world_size if args.checkpoint_path else 1 - step_before_training = args.train_every // (fabric.world_size * args.action_repeat) if not args.dry_run else 0 - num_updates = int(args.total_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 1 - learning_starts = args.learning_starts // (fabric.world_size * args.action_repeat) if not args.dry_run else 0 + single_global_step = int(args.num_envs * fabric.world_size * args.action_repeat) + step_before_training = args.train_every // single_global_step if not args.dry_run else 0 + num_updates = args.total_steps // single_global_step if not args.dry_run else 1 + learning_starts = args.learning_starts // single_global_step if not args.dry_run else 0 if args.checkpoint_path and not args.checkpoint_buffer: - learning_starts = start_step + args.learning_starts // int((fabric.world_size * args.action_repeat)) + learning_starts += start_step max_step_expl_decay = args.max_step_expl_decay // (args.gradient_steps * fabric.world_size) if args.checkpoint_path: player.expl_amount = polynomial_decay( @@ -593,13 +604,17 @@ def main(): ) # Get the first environment observation and start the optimization - episode_steps = [] - o = env.reset(seed=args.seed)[0] + episode_steps = [[] for _ in range(args.num_envs)] + o = envs.reset(seed=args.seed)[0] obs = {} for k in o.keys(): - torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape).float() - step_data[k] = torch_obs - obs[k] = torch_obs + if k in obs_keys: + torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape[1:]) + if k in mlp_keys: + # Images stay uint8 to save space + torch_obs = torch_obs.float() + step_data[k] = torch_obs + obs[k] = torch_obs step_data["dones"] = torch.zeros(args.num_envs, 1) step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) step_data["rewards"] = torch.zeros(args.num_envs, 1) @@ -607,19 +622,20 @@ def main(): if buffer_type == "sequential": rb.add(step_data[None, ...]) else: - episode_steps.append(step_data[None, ...]) + for i, env_ep in enumerate(episode_steps): + env_ep.append(step_data[i : i + 1][None, ...]) player.init_states() gradient_steps = 0 for global_step in range(start_step, num_updates + 1): # Sample an action given the observation received by the environment if global_step <= learning_starts and args.checkpoint_path is None and "minedojo" not in args.env_id: - real_actions = actions = np.array(env.action_space.sample()) + real_actions = actions = np.array(envs.action_space.sample()) if not is_continuous: actions = np.concatenate( [ F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim)), actions_dim) + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) @@ -639,61 +655,78 @@ def main(): if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions]) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - o, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape)) + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) if args.dry_run and buffer_type == "episode": dones = np.ones_like(dones) - if (dones or truncated) and "episode" in infos: - fabric.print(f"Rank-0: global_step={global_step}, reward_env_{0}={infos['episode']['r'][0]}") - aggregator.update("Rewards/rew_avg", infos["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", infos["episode"]["l"][0]) - - next_obs = {} - for k in o.keys(): # [N_envs, N_obs] - torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape).float() - step_data[k] = torch_obs - next_obs[k] = torch_obs + if "final_info" in infos: + for i, agent_final_info in enumerate(infos["final_info"]): + if agent_final_info is not None and "episode" in agent_final_info: + fabric.print( + f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" + ) + aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) + aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + + # Save the real next observation + real_next_obs = copy.deepcopy(o) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + + next_obs: Dict[str, Tensor] = {} + for k in real_next_obs.keys(): # [N_envs, N_obs] + if k in obs_keys: + next_obs[k] = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape[1:]) + step_data[k] = torch.from_numpy(real_next_obs[k]).view(args.num_envs, *real_next_obs[k].shape[1:]) + if k in mlp_keys: + next_obs[k] = next_obs[k].float() + step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(args.num_envs, -1).float() - rewards = torch.tensor([rewards]).view(args.num_envs, -1).float() - dones = torch.tensor([bool(dones)]).view(args.num_envs, -1).float() + rewards = torch.from_numpy(rewards).view(args.num_envs, -1).float() + dones = torch.from_numpy(dones).view(args.num_envs, -1).float() # next_obs becomes the new obs obs = next_obs step_data["dones"] = dones - step_data["is_first"] = torch.zeros_like(step_data["dones"]) step_data["actions"] = actions step_data["rewards"] = clip_rewards_fn(rewards) - data_to_add = step_data[None, ...] if buffer_type == "sequential": - rb.add(data_to_add) + rb.add(step_data[None, ...]) else: - episode_steps.append(data_to_add) - - if dones or truncated: - # Add entire episode if needed - if buffer_type == "episode" and len(episode_steps) >= args.per_rank_sequence_length: - rb.add(torch.cat(episode_steps, dim=0)) - episode_steps = [] - o = env.reset(seed=args.seed)[0] - obs = {} - for k in o.keys(): - torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape).float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(args.num_envs, 1) - step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) - step_data["rewards"] = torch.zeros(args.num_envs, 1) - step_data["is_first"] = torch.ones_like(step_data["dones"]) - data_to_add = step_data[None, ...] - if buffer_type == "sequential": - rb.add(data_to_add) + for i, env_ep in enumerate(episode_steps): + env_ep.append(step_data[i : i + 1][None, ...]) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + for k in next_obs.keys(): + reset_data[k] = next_obs[k][dones_idxes] + reset_data["dones"] = torch.zeros(reset_envs, 1) + reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) + reset_data["rewards"] = torch.zeros(reset_envs, 1) + reset_data["is_first"] = torch.ones_like(reset_data["dones"]) + if buffer_type == "episode": + for i, d in enumerate(dones_idxes): + if len(episode_steps[d]) >= args.per_rank_sequence_length: + rb.add(torch.cat(episode_steps[d], dim=0)) + episode_steps[d] = [reset_data[i : i + 1][None, ...]] else: - episode_steps.append(data_to_add) - player.init_states() + rb.add(reset_data[None, ...], dones_idxes) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + # Reset internal agent states + player.init_states(dones_idxes) step_before_training -= 1 @@ -735,7 +768,7 @@ def main(): actions_dim, ) gradient_steps += 1 - step_before_training = args.train_every // (args.num_envs * (fabric.world_size * args.action_repeat)) + step_before_training = args.train_every // single_global_step if args.expl_decay: expl_decay_steps += 1 player.expl_amount = polynomial_decay( @@ -776,7 +809,7 @@ def main(): replay_buffer=rb if args.checkpoint_buffer else None, ) - env.close() + envs.close() if fabric.is_global_zero: test(player, fabric, args, cnn_keys, mlp_keys) diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index f0614b21..f431a24c 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import cv2 import gymnasium as gym @@ -26,7 +26,7 @@ def make_env( args: DreamerV2Args, run_name: Optional[str] = None, prefix: str = "", -) -> gym.Env: +) -> Callable[..., gym.Env]: """ Create the callable function to createenvironment and force the environment to return only pixels observations. @@ -44,144 +44,148 @@ def make_env( Returns: The callable function that initializes the environment. """ - env_spec = "" - _env_id = env_id.lower() - if "dummy" in _env_id: - env = get_dummy_env(_env_id) - elif "dmc" in _env_id: - from sheeprl.envs.dmc import DMCWrapper - - _, domain, task = _env_id.split("_") - env = DMCWrapper( - domain, - task, - from_pixels=True, - height=64, - width=64, - frame_skip=args.action_repeat, - seed=seed, - ) - elif "minedojo" in _env_id: - from sheeprl.envs.minedojo import MineDojoWrapper - - task_id = "_".join(env_id.split("_")[1:]) - start_position = ( - { - "x": float(args.mine_start_position[0]), - "y": float(args.mine_start_position[1]), - "z": float(args.mine_start_position[2]), - "pitch": float(args.mine_start_position[3]), - "yaw": float(args.mine_start_position[4]), - } - if args.mine_start_position is not None - else None - ) - env = MineDojoWrapper( - task_id, - height=64, - width=64, - pitch_limits=(args.mine_min_pitch, args.mine_max_pitch), - seed=args.seed, - start_position=start_position, - ) - args.action_repeat = 1 - elif "minerl" in _env_id: - from sheeprl.envs.minerl import MineRLWrapper - - task_id = "_".join(env_id.split("_")[1:]) - env = MineRLWrapper( - task_id, - height=64, - width=64, - pitch_limits=(args.mine_min_pitch, args.mine_max_pitch), - seed=args.seed, - break_speed_multiplier=args.mine_break_speed, - sticky_attack=args.mine_sticky_attack, - sticky_jump=args.mine_sticky_jump, - dense=args.minerl_dense, - extreme=args.minerl_extreme, - ) - else: - env_spec = gym.spec(env_id).entry_point - env = gym.make(env_id, render_mode="rgb_array") - if "mujoco" in env_spec: - env.frame_skip = 0 - elif "atari" in env_spec: - if args.atari_noop_max < 0: - raise ValueError( - f"Negative value of atart_noop_max parameter ({args.atari_noop_max}), the minimum value allowed is 0" - ) - env = gym.wrappers.AtariPreprocessing( - env, - noop_max=args.atari_noop_max, + + def thunk(): + env_spec = "" + _env_id = env_id.lower() + if "dummy" in _env_id: + env = get_dummy_env(_env_id) + elif "dmc" in _env_id: + from sheeprl.envs.dmc import DMCWrapper + + _, domain, task = _env_id.split("_") + env = DMCWrapper( + domain, + task, + from_pixels=True, + height=64, + width=64, frame_skip=args.action_repeat, - screen_size=64, - grayscale_obs=args.grayscale_obs, - scale_obs=False, - terminal_on_life_loss=False, - grayscale_newaxis=True, + seed=seed, + ) + elif "minedojo" in _env_id: + from sheeprl.envs.minedojo import MineDojoWrapper + + task_id = "_".join(env_id.split("_")[1:]) + start_position = ( + { + "x": float(args.mine_start_position[0]), + "y": float(args.mine_start_position[1]), + "z": float(args.mine_start_position[2]), + "pitch": float(args.mine_start_position[3]), + "yaw": float(args.mine_start_position[4]), + } + if args.mine_start_position is not None + else None + ) + env = MineDojoWrapper( + task_id, + height=64, + width=64, + pitch_limits=(args.mine_min_pitch, args.mine_max_pitch), + seed=args.seed, + start_position=start_position, ) + args.action_repeat = 1 + elif "minerl" in _env_id: + from sheeprl.envs.minerl import MineRLWrapper + + task_id = "_".join(env_id.split("_")[1:]) + env = MineRLWrapper( + task_id, + height=64, + width=64, + pitch_limits=(args.mine_min_pitch, args.mine_max_pitch), + seed=args.seed, + break_speed_multiplier=args.mine_break_speed, + sticky_attack=args.mine_sticky_attack, + sticky_jump=args.mine_sticky_jump, + dense=args.minerl_dense, + extreme=args.minerl_extreme, + ) + else: + env_spec = gym.spec(env_id).entry_point + env = gym.make(env_id, render_mode="rgb_array") + if "mujoco" in env_spec: + env.frame_skip = 0 + elif "atari" in env_spec: + if args.atari_noop_max < 0: + raise ValueError( + f"Negative value of atart_noop_max parameter ({args.atari_noop_max}), the minimum value allowed is 0" + ) + env = gym.wrappers.AtariPreprocessing( + env, + noop_max=args.atari_noop_max, + frame_skip=args.action_repeat, + screen_size=64, + grayscale_obs=args.grayscale_obs, + scale_obs=False, + terminal_on_life_loss=False, + grayscale_newaxis=True, + ) - # action repeat - if "atari" not in env_spec and "dmc" not in env_id: - env = ActionRepeat(env, args.action_repeat) + # action repeat + if "atari" not in env_spec and "dmc" not in env_id: + env = ActionRepeat(env, args.action_repeat) - # create dict - if isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) < 3: - env = gym.wrappers.PixelObservationWrapper( - env, pixels_only=len(env.observation_space.shape) == 2, pixel_keys=("rgb",) - ) - elif isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) == 3: - env = gym.wrappers.TransformObservation(env, lambda obs: {"rgb": obs}) - env.observation_space = gym.spaces.Dict({"rgb": env.observation_space}) - - shape = env.observation_space["rgb"].shape - is_3d = len(shape) == 3 - is_grayscale = not is_3d or shape[0] == 1 or shape[-1] == 1 - channel_first = not is_3d or shape[0] in (1, 3) - - def transform_obs(obs: Dict[str, Any]): - # to 3D image - if not is_3d: - obs.update({"rgb": np.expand_dims(obs["rgb"], axis=0)}) - - # channel last (opencv needs it) - if channel_first: - obs.update({"rgb": obs["rgb"].transpose(1, 2, 0)}) - - # resize - if obs["rgb"].shape[:-1] != (64, 64): - obs.update({"rgb": cv2.resize(obs["rgb"], (64, 64), interpolation=cv2.INTER_AREA)}) - - # to grayscale - if args.grayscale_obs and not is_grayscale: - obs.update({"rgb": cv2.cvtColor(obs["rgb"], cv2.COLOR_RGB2GRAY)}) - - # back to 3D - if len(obs["rgb"].shape) == 2: - obs.update({"rgb": np.expand_dims(obs["rgb"], axis=-1)}) - if not args.grayscale_obs: - obs.update({"rgb": np.repeat(obs["rgb"], 3, axis=-1)}) - - # channel first (PyTorch default) - obs.update({"rgb": obs["rgb"].transpose(2, 0, 1)}) - - return obs - - env = gym.wrappers.TransformObservation(env, transform_obs) - env.observation_space["rgb"] = gym.spaces.Box(0, 255, (1 if args.grayscale_obs else 3, 64, 64), np.uint8) - - env.action_space.seed(seed) - env.observation_space.seed(seed) - if args.max_episode_steps > 0: - env = gym.wrappers.TimeLimit(env, max_episode_steps=args.max_episode_steps // args.action_repeat) - env = gym.wrappers.RecordEpisodeStatistics(env) - if args.capture_video and rank == 0 and run_name is not None: - env = gym.experimental.wrappers.RecordVideoV0( - env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True - ) - env.metadata["render_fps"] = env.frames_per_sec - return env + # create dict + if isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) < 3: + env = gym.wrappers.PixelObservationWrapper( + env, pixels_only=len(env.observation_space.shape) == 2, pixel_keys=("rgb",) + ) + elif isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) == 3: + env = gym.wrappers.TransformObservation(env, lambda obs: {"rgb": obs}) + env.observation_space = gym.spaces.Dict({"rgb": env.observation_space}) + + shape = env.observation_space["rgb"].shape + is_3d = len(shape) == 3 + is_grayscale = not is_3d or shape[0] == 1 or shape[-1] == 1 + channel_first = not is_3d or shape[0] in (1, 3) + + def transform_obs(obs: Dict[str, Any]): + # to 3D image + if not is_3d: + obs.update({"rgb": np.expand_dims(obs["rgb"], axis=0)}) + + # channel last (opencv needs it) + if channel_first: + obs.update({"rgb": obs["rgb"].transpose(1, 2, 0)}) + + # resize + if obs["rgb"].shape[:-1] != (64, 64): + obs.update({"rgb": cv2.resize(obs["rgb"], (64, 64), interpolation=cv2.INTER_AREA)}) + + # to grayscale + if args.grayscale_obs and not is_grayscale: + obs.update({"rgb": cv2.cvtColor(obs["rgb"], cv2.COLOR_RGB2GRAY)}) + + # back to 3D + if len(obs["rgb"].shape) == 2: + obs.update({"rgb": np.expand_dims(obs["rgb"], axis=-1)}) + if not args.grayscale_obs: + obs.update({"rgb": np.repeat(obs["rgb"], 3, axis=-1)}) + + # channel first (PyTorch default) + obs.update({"rgb": obs["rgb"].transpose(2, 0, 1)}) + + return obs + + env = gym.wrappers.TransformObservation(env, transform_obs) + env.observation_space["rgb"] = gym.spaces.Box(0, 255, (1 if args.grayscale_obs else 3, 64, 64), np.uint8) + + env.action_space.seed(seed) + env.observation_space.seed(seed) + if args.max_episode_steps > 0: + env = gym.wrappers.TimeLimit(env, max_episode_steps=args.max_episode_steps // args.action_repeat) + env = gym.wrappers.RecordEpisodeStatistics(env) + if args.capture_video and rank == 0 and run_name is not None: + env = gym.experimental.wrappers.RecordVideoV0( + env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True + ) + env.metadata["render_fps"] = env.frames_per_sec + return env + + return thunk def compute_stochastic_state( @@ -251,15 +255,16 @@ def test( player (Player): the agent which contains all the models needed to play. fabric (Fabric): the fabric instance. """ - env: gym.Env = make_env( + env = make_env( args.env_id, args.seed, 0, args, fabric.logger.log_dir, "test" + (f"_{test_name}" if test_name != "" else "") - ) + )() done = False cumulative_rew = 0 device = fabric.device next_obs = env.reset(seed=args.seed)[0] for k in next_obs.keys(): - next_obs[k] = torch.from_numpy(next_obs[k]).view(args.num_envs, *next_obs[k].shape).float() + next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() + player.num_envs = 1 player.init_states() while not done: # Act greedly through the environment @@ -275,12 +280,12 @@ def test( if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions]) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) # Single environment step next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) for k in next_obs.keys(): - next_obs[k] = torch.from_numpy(next_obs[k]).view(args.num_envs, *next_obs[k].shape).float() + next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() done = done or truncated or args.dry_run cumulative_rew += reward fabric.print("Test - Reward:", cumulative_rew) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 3c56d85b..7d91d661 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -24,10 +24,11 @@ from sheeprl.algos.dreamer_v1.agent import Player, WorldModel from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss -from sheeprl.algos.dreamer_v1.utils import cnn_forward, make_env, test +from sheeprl.algos.dreamer_v1.utils import cnn_forward, test +from sheeprl.algos.dreamer_v2.utils import make_env from sheeprl.algos.p2e_dv1.agent import build_models from sheeprl.algos.p2e_dv1.args import P2EDV1Args -from sheeprl.data.buffers import SequentialReplayBuffer +from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.models.models import MLP from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.metric import MetricAggregator @@ -349,7 +350,6 @@ def train( def main(): parser = HfArgumentParser(P2EDV1Args) args: P2EDV1Args = parser.parse_args_into_dataclasses()[0] - args.num_envs = 1 torch.set_num_threads(1) # Initialize Fabric @@ -405,23 +405,31 @@ def main(): log_dir = data[0] os.makedirs(log_dir, exist_ok=True) - env: gym.Env = make_env( - args.env_id, - args.seed + rank * args.num_envs, - rank, - args, - logger.log_dir if rank == 0 else None, - "train", + # Environment setup + vectorized_env = gym.vector.SyncVectorEnv if args.sync_env else gym.vector.AsyncVectorEnv + envs = vectorized_env( + [ + make_env( + args.env_id, + args.seed + rank * args.num_envs, + rank, + args, + logger.log_dir if rank == 0 else None, + "train", + ) + for i in range(args.num_envs) + ] ) - is_continuous = isinstance(env.action_space, gym.spaces.Box) - is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + action_space = envs.single_action_space + observation_space = envs.single_observation_space + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) actions_dim = ( - env.action_space.shape - if is_continuous - else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - observation_shape = env.observation_space.shape + observation_shape = observation_space["rgb"].shape clip_rewards_fn = lambda r: torch.tanh(r) if args.clip_rewards else r world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_models( @@ -528,22 +536,24 @@ def main(): "Grads/ensemble": MeanMetric(sync_on_compute=False), } ) + aggregator.to(device) # Local data buffer_size = ( args.buffer_size // int(args.num_envs * fabric.world_size * args.action_repeat) if not args.dry_run else 4 ) - rb = SequentialReplayBuffer( + rb = AsyncReplayBuffer( buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + sequential=True, ) if args.checkpoint_path and args.checkpoint_buffer: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], SequentialReplayBuffer): + elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") @@ -553,13 +563,14 @@ def main(): # Global variables start_time = time.perf_counter() start_step = state["global_step"] // fabric.world_size if args.checkpoint_path else 1 - step_before_training = args.train_every // (fabric.world_size * args.action_repeat) if not args.dry_run else 0 - num_updates = int(args.total_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 4 + single_global_step = int(args.num_envs * fabric.world_size * args.action_repeat) + step_before_training = args.train_every // single_global_step if not args.dry_run else 0 + num_updates = int(args.total_steps // single_global_step) if not args.dry_run else 1 + learning_starts = (args.learning_starts // single_global_step) if not args.dry_run else 0 exploration_updates = ( int(args.exploration_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 4 ) exploration_updates = min(num_updates, exploration_updates) - learning_starts = (args.learning_starts // (fabric.world_size * args.action_repeat)) if not args.dry_run else 3 if args.checkpoint_path and not args.checkpoint_buffer: learning_starts = start_step + args.learning_starts // int(fabric.world_size * args.action_repeat) max_step_expl_decay = args.max_step_expl_decay // (args.gradient_steps * fabric.world_size) @@ -572,7 +583,7 @@ def main(): ) # Get the first environment observation and start the optimization - obs = torch.from_numpy(env.reset(seed=args.seed)[0].copy()).view( + obs = torch.from_numpy(envs.reset(seed=args.seed)[0]["rgb"]).view( args.num_envs, *observation_shape ) # [N_envs, N_obs] step_data["dones"] = torch.zeros(args.num_envs, 1) @@ -592,13 +603,13 @@ def main(): test(player, fabric, args, "zero-shot") # Sample an action given the observation received by the environment - if global_step <= learning_starts and args.checkpoint_path is None: - real_actions = actions = np.array(env.action_space.sample()) + if global_step <= learning_starts and args.checkpoint_path is None and "minedojo" not in args.env_id: + real_actions = actions = np.array(envs.action_space.sample()) if not is_continuous: actions = np.concatenate( [ F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim)), actions_dim) + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) @@ -611,16 +622,31 @@ def main(): if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions]) - next_obs, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape)) - dones = np.logical_or(dones, truncated) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - if (dones or truncated) and "episode" in infos: - fabric.print(f"Rank-0: global_step={global_step}, reward_env_{0}={infos['episode']['r'][0]}") - aggregator.update("Rewards/rew_avg", infos["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", infos["episode"]["l"][0]) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + next_obs = next_obs["rgb"] + dones = np.logical_or(dones, truncated) - next_obs = torch.from_numpy(next_obs.copy()).view(args.num_envs, *observation_shape) + if "final_info" in infos: + for i, agent_final_info in enumerate(infos["final_info"]): + if agent_final_info is not None and "episode" in agent_final_info: + fabric.print( + f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" + ) + aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) + aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + + # Save the real next observation + real_next_obs = next_obs.copy() + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + if k == "rgb": + real_next_obs[idx] = v + + next_obs = torch.from_numpy(next_obs).view(args.num_envs, *observation_shape) actions = torch.from_numpy(actions).view(args.num_envs, -1).float() rewards = torch.tensor([rewards]).view(args.num_envs, -1).float() dones = torch.tensor([bool(dones)]).view(args.num_envs, -1).float() @@ -634,16 +660,21 @@ def main(): step_data["rewards"] = clip_rewards_fn(rewards) rb.add(step_data[None, ...]) - if dones or truncated: - obs = torch.from_numpy(env.reset(seed=args.seed)[0].copy()).view( - args.num_envs, *observation_shape - ) # [N_envs, N_obs] - step_data["dones"] = torch.zeros(args.num_envs, 1) - step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) - step_data["rewards"] = torch.zeros(args.num_envs, 1) - step_data["observations"] = obs - rb.add(step_data[None, ...]) - player.init_states() + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + reset_data["observations"] = next_obs[dones_idxes] + reset_data["dones"] = torch.zeros(reset_envs, 1) + reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) + reset_data["rewards"] = torch.zeros(reset_envs, 1) + rb.add(reset_data[None, ...], dones_idxes) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + # Reset internal agent states + player.init_states(dones_idxes) step_before_training -= 1 @@ -723,7 +754,7 @@ def main(): replay_buffer=rb if args.checkpoint_buffer else None, ) - env.close() + envs.close() # task test few-shot if fabric.is_global_zero: player.actor = actor_task.module diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index dd7e3d42..08ee42c5 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -18,7 +18,7 @@ from lightning.pytorch.utilities.seed import isolate_rng from tensordict import TensorDict from tensordict.tensordict import TensorDictBase -from torch import nn +from torch import Tensor, nn from torch.distributions import Bernoulli, Distribution, Independent, Normal, OneHotCategorical from torch.optim import Adam from torch.utils.data import BatchSampler @@ -29,7 +29,7 @@ from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, make_env, test from sheeprl.algos.p2e_dv2.agent import build_models from sheeprl.algos.p2e_dv2.args import P2EDV2Args -from sheeprl.data.buffers import EpisodeBuffer, SequentialReplayBuffer +from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.models.models import MLP from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.metric import MetricAggregator @@ -453,7 +453,6 @@ def train( def main(): parser = HfArgumentParser(P2EDV2Args) args: P2EDV2Args = parser.parse_args_into_dataclasses()[0] - args.num_envs = 1 torch.set_num_threads(1) # Initialize Fabric @@ -509,28 +508,36 @@ def main(): log_dir = data[0] os.makedirs(log_dir, exist_ok=True) - env: gym.Env = make_env( - args.env_id, - args.seed + rank * args.num_envs, - rank, - args, - logger.log_dir if rank == 0 else None, - "train", + # Environment setup + vectorized_env = gym.vector.SyncVectorEnv if args.sync_env else gym.vector.AsyncVectorEnv + envs = vectorized_env( + [ + make_env( + args.env_id, + args.seed + rank * args.num_envs, + rank, + args, + logger.log_dir if rank == 0 else None, + "train", + ) + for i in range(args.num_envs) + ] ) - is_continuous = isinstance(env.action_space, gym.spaces.Box) - is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + action_space = envs.single_action_space + observation_space = envs.single_observation_space + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) actions_dim = ( - env.action_space.shape - if is_continuous - else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if args.clip_rewards else r cnn_keys = [] mlp_keys = [] - if isinstance(env.observation_space, gym.spaces.Dict): + if isinstance(observation_space, gym.spaces.Dict): cnn_keys = [] - for k, v in env.observation_space.spaces.items(): + for k, v in observation_space.spaces.items(): if args.cnn_keys and ( k in args.cnn_keys or (len(args.cnn_keys) == 1 and args.cnn_keys[0].lower() == "all") ): @@ -542,7 +549,7 @@ def main(): "Try to transform the observation from the environment into a 3D image" ) mlp_keys = [] - for k, v in env.observation_space.spaces.items(): + for k, v in observation_space.spaces.items(): if args.mlp_keys and ( k in args.mlp_keys or (len(args.mlp_keys) == 1 and args.mlp_keys[0].lower() == "all") ): @@ -554,11 +561,12 @@ def main(): "Try to flatten the observation from the environment" ) else: - raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {env.observation_space}") + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") if cnn_keys == [] and mlp_keys == []: raise RuntimeError(f"There must be at least one valid observation.") fabric.print("CNN keys:", cnn_keys) fabric.print("MLP keys:", mlp_keys) + obs_keys = cnn_keys + mlp_keys ( world_model, @@ -573,7 +581,7 @@ def main(): actions_dim, is_continuous, args, - env.observation_space, + observation_space, cnn_keys, mlp_keys, state["world_model"] if args.checkpoint_path else None, @@ -691,12 +699,13 @@ def main(): ) buffer_type = args.buffer_type.lower() if buffer_type == "sequential": - rb = SequentialReplayBuffer( + rb = AsyncReplayBuffer( buffer_size, args.num_envs, device="cpu", memmap=args.memmap_buffer, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + sequential=True, ) elif buffer_type == "episode": rb = EpisodeBuffer( @@ -711,7 +720,7 @@ def main(): if args.checkpoint_path and args.checkpoint_buffer: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], SequentialReplayBuffer): + elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") @@ -721,15 +730,12 @@ def main(): # Global variables start_time = time.perf_counter() start_step = state["global_step"] // fabric.world_size if args.checkpoint_path else 1 - step_before_training = args.train_every // (fabric.world_size * args.action_repeat) if not args.dry_run else 0 - num_updates = int(args.total_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 4 - exploration_updates = ( - int(args.exploration_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 4 - ) - exploration_updates = min(num_updates, exploration_updates) - learning_starts = (args.learning_starts // (fabric.world_size * args.action_repeat)) if not args.dry_run else 3 + single_global_step = int(args.num_envs * fabric.world_size * args.action_repeat) + step_before_training = args.train_every // single_global_step if not args.dry_run else 0 + num_updates = args.total_steps // single_global_step if not args.dry_run else 1 + learning_starts = args.learning_starts // single_global_step if not args.dry_run else 0 if args.checkpoint_path and not args.checkpoint_buffer: - learning_starts = start_step + args.learning_starts // int(fabric.world_size * args.action_repeat) + learning_starts += start_step max_step_expl_decay = args.max_step_expl_decay // (args.gradient_steps * fabric.world_size) if args.checkpoint_path: player.expl_amount = polynomial_decay( @@ -739,22 +745,33 @@ def main(): max_decay_steps=max_step_expl_decay, ) + # Exploration + exploration_updates = ( + int(args.exploration_steps // (fabric.world_size * args.action_repeat)) if not args.dry_run else 4 + ) + exploration_updates = min(num_updates, exploration_updates) + # Get the first environment observation and start the optimization - episode_steps = [] - o, infos = env.reset(seed=args.seed) + episode_steps = [[] for _ in range(args.num_envs)] + o = envs.reset(seed=args.seed)[0] obs = {} for k in o.keys(): - torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape).float() - step_data[k] = torch_obs - obs[k] = torch_obs + if k in obs_keys: + torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape[1:]) + if k in mlp_keys: + # Images stay uint8 to save space + torch_obs = torch_obs.float() + step_data[k] = torch_obs + obs[k] = torch_obs step_data["dones"] = torch.zeros(args.num_envs, 1) step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) step_data["rewards"] = torch.zeros(args.num_envs, 1) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) + step_data["is_first"] = torch.ones_like(step_data["dones"]) if buffer_type == "sequential": rb.add(step_data[None, ...]) else: - episode_steps.append(step_data[None, ...]) + for i, env_ep in enumerate(episode_steps): + env_ep.append(step_data[i : i + 1][None, ...]) player.init_states() gradient_steps = 0 @@ -769,12 +786,12 @@ def main(): # Sample an action given the observation received by the environment if global_step <= learning_starts and args.checkpoint_path is None and "minedojo" not in args.env_id: - real_actions = actions = np.array(env.action_space.sample()) + real_actions = actions = np.array(envs.action_space.sample()) if not is_continuous: actions = np.concatenate( [ F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim)), actions_dim) + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) @@ -786,68 +803,86 @@ def main(): preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) - real_actions = actions = player.get_exploration_action( - preprocessed_obs, is_continuous, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - ) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions]) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape)) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) - - if (dones or truncated) and "episode" in infos: - fabric.print(f"Rank-0: global_step={global_step}, reward_env_{0}={infos['episode']['r'][0]}") - aggregator.update("Rewards/rew_avg", infos["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", infos["episode"]["l"][0]) - - next_obs = {} - for k in o.keys(): # [N_envs, N_obs] - torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape).float() - step_data[k] = torch_obs - next_obs[k] = torch_obs - actions = torch.from_numpy(actions).view(args.num_envs, -1).float() - rewards = torch.tensor([rewards]).view(args.num_envs, -1).float() - dones = torch.tensor([bool(dones)]).view(args.num_envs, -1).float() if args.dry_run and buffer_type == "episode": dones = np.ones_like(dones) + if "final_info" in infos: + for i, agent_final_info in enumerate(infos["final_info"]): + if agent_final_info is not None and "episode" in agent_final_info: + fabric.print( + f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" + ) + aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) + aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + + # Save the real next observation + real_next_obs = copy.deepcopy(o) + if "final_observation" in infos: + for idx, final_obs in enumerate(infos["final_observation"]): + if final_obs is not None: + for k, v in final_obs.items(): + real_next_obs[k][idx] = v + + next_obs: Dict[str, Tensor] = {} + for k in real_next_obs.keys(): # [N_envs, N_obs] + if k in obs_keys: + next_obs[k] = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape[1:]) + step_data[k] = torch.from_numpy(real_next_obs[k]).view(args.num_envs, *real_next_obs[k].shape[1:]) + if k in mlp_keys: + next_obs[k] = next_obs[k].float() + step_data[k] = step_data[k].float() + actions = torch.from_numpy(actions).view(args.num_envs, -1).float() + rewards = torch.from_numpy(rewards).view(args.num_envs, -1).float() + dones = torch.from_numpy(dones).view(args.num_envs, -1).float() + # next_obs becomes the new obs obs = next_obs step_data["dones"] = dones step_data["actions"] = actions step_data["rewards"] = clip_rewards_fn(rewards) - data_to_add = step_data[None, ...] if buffer_type == "sequential": - rb.add(data_to_add) + rb.add(step_data[None, ...]) else: - episode_steps.append(data_to_add) - - if dones or truncated: - # Add entire episode if needed - if buffer_type == "episode" and len(episode_steps) >= args.per_rank_sequence_length: - rb.add(torch.cat(episode_steps, dim=0)) - episode_steps = [] - o = env.reset(seed=args.seed)[0] - obs = {} - for k in o.keys(): - torch_obs = torch.from_numpy(o[k]).view(args.num_envs, *o[k].shape).float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(args.num_envs, 1) - step_data["actions"] = torch.zeros(args.num_envs, np.sum(actions_dim)) - step_data["rewards"] = torch.zeros(args.num_envs, 1) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - data_to_add = step_data[None, ...] - if buffer_type == "sequential": - rb.add(data_to_add) + for i, env_ep in enumerate(episode_steps): + env_ep.append(step_data[i : i + 1][None, ...]) + + # Reset and save the observation coming from the automatic reset + dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + reset_envs = len(dones_idxes) + if reset_envs > 0: + reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + for k in next_obs.keys(): + reset_data[k] = next_obs[k][dones_idxes] + reset_data["dones"] = torch.zeros(reset_envs, 1) + reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) + reset_data["rewards"] = torch.zeros(reset_envs, 1) + reset_data["is_first"] = torch.ones_like(reset_data["dones"]) + if buffer_type == "episode": + for i, d in enumerate(dones_idxes): + if len(episode_steps[d]) >= args.per_rank_sequence_length: + rb.add(torch.cat(episode_steps[d], dim=0)) + episode_steps[d] = [reset_data[i : i + 1][None, ...]] else: - episode_steps.append(data_to_add) - player.init_states() + rb.add(reset_data[None, ...], dones_idxes) + # Reset dones so that `is_first` is updated + for d in dones_idxes: + step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + # Reset internal agent states + player.init_states(dones_idxes) step_before_training -= 1 @@ -898,7 +933,7 @@ def main(): cnn_keys=cnn_keys, mlp_keys=mlp_keys, ) - step_before_training = args.train_every // (args.num_envs * fabric.world_size * args.action_repeat) + step_before_training = args.train_every // single_global_step if args.expl_decay: expl_decay_steps += 1 player.expl_amount = polynomial_decay( @@ -945,7 +980,7 @@ def main(): replay_buffer=rb if args.checkpoint_buffer else None, ) - env.close() + envs.close() # task test few-shot if fabric.is_global_zero: player.actor = actor_task.module diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 3f665fad..dabff008 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -3,7 +3,7 @@ import uuid import warnings from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Sequence, Union import numpy as np import torch @@ -66,7 +66,7 @@ def buffer_size(self) -> int: return self._buffer_size @property - def full(self) -> int: + def full(self) -> bool: return self._full @property @@ -148,7 +148,7 @@ def add(self, data: Union["ReplayBuffer", TensorDictBase]) -> None: self._full = True self._pos = next_pos - def sample(self, batch_size: int, sample_next_obs: bool = False, clone: bool = False) -> TensorDictBase: + def sample(self, batch_size: int, sample_next_obs: bool = False, clone: bool = False, **kwargs) -> TensorDictBase: """Sample elements from the replay buffer. Custom sampling when using memory efficient variant, @@ -526,3 +526,168 @@ def sample( if clone: return samples.clone() return samples + + +class AsyncReplayBuffer: + def __init__( + self, + buffer_size: int, + n_envs: int = 1, + device: Union[device, str] = "cpu", + memmap: bool = False, + memmap_dir: Optional[Union[str, os.PathLike]] = None, + sequential: bool = False, + ): + """An async replay buffer which internally uses a TensorDict. This replay buffer + saves a experiences independently for every environment. when new data has to be added, it expects + the TensorDict or the ReplayBuffer to be added to have a single shape dimension, representing the number of independent environments, + while the tensors to be at least 2D tensors, where the second dimension representing the sequence length. + + Args: + buffer_size (int): The buffer size. + n_envs (int, optional): The number of environments. Defaults to 1. + device (Union[torch.device, str], optional): The device where the buffer is created. Defaults to "cpu". + memmap (bool, optional): Whether to memory-mapping the buffer. + """ + if buffer_size <= 0: + raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") + if n_envs <= 0: + raise ValueError(f"The number of environments must be greater than zero, got: {n_envs}") + self._buffer_size = buffer_size + self._n_envs = n_envs + if isinstance(device, str): + device = torch.device(device=device) + self._device = device + self._memmap = memmap + self._memmap_dir = memmap_dir + self._sequential = sequential + self._buf: Optional[Sequence[ReplayBuffer]] = None + if self._memmap_dir is not None: + self._memmap_dir = Path(self._memmap_dir) + if self._memmap: + if memmap_dir is None: + warnings.warn( + "The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" + " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + UserWarning, + ) + else: + self._memmap_dir.mkdir(parents=True, exist_ok=True) + + @property + def buffer(self) -> Optional[Sequence[ReplayBuffer]]: + return tuple(self._buf) + + @property + def buffer_size(self) -> int: + return self._buffer_size + + @property + def full(self) -> Optional[Sequence[bool]]: + if self.buffer is None: + return None + return tuple([b.full for b in self.buffer]) + + @property + def n_envs(self) -> int: + return self._n_envs + + @property + def shape(self) -> Optional[Sequence[Size]]: + if self.buffer is None: + return None + return tuple([b.shape for b in self.buffer]) + + @property + def device(self) -> Optional[Sequence[device]]: + if self.buffer is None: + return None + return self._device + + def __len__(self) -> int: + return self.buffer_size + + def add(self, data: TensorDictBase, indices: Optional[Sequence[int]] = None) -> None: + """Add data to the buffer. + + Args: + data: data to add. + indices (Sequence[int], optional): the indices where to add the data. + If None, then data will be added on every indices. + Defaults to None. + + Raises: + RuntimeError: the number of dimensions (the batch_size of the TensorDictBase) must be 2: + one for the number of environments and one for the sequence length. + """ + if not isinstance(data, TensorDictBase): + raise TypeError("`data` must be a TensorDictBase") + if data is None: + raise RuntimeError("The `data` parameter must be not None") + if len(data.shape) != 2: + raise RuntimeError( + "`data` must have 2 batch dimensions: [sequence_length, n_envs]. " + "`sequence_length` and `n_envs` should be 1. Shape is: {}".format(data.shape) + ) + if self._buf is None: + buf_cls = SequentialReplayBuffer if self._sequential else ReplayBuffer + self._buf = [ + buf_cls( + self.buffer_size, + n_envs=1, + device=self._device, + memmap=self._memmap, + memmap_dir=self._memmap_dir / f"env_{i}" if self._memmap_dir is not None else None, + ) + for i in range(self._n_envs) + ] + if indices is None: + indices = tuple(range(self.n_envs)) + for env_data_idx, env_idx in enumerate(indices): + self._buf[env_idx].add(data[:, env_data_idx : env_data_idx + 1]) + + def sample( + self, + batch_size: int, + sample_next_obs: bool = False, + clone: bool = False, + sequence_length: int = 1, + n_samples: int = 1, + ) -> TensorDictBase: + """Sample elements from the sequential replay buffer, + each one is a sequence of a consecutive items. + + Custom sampling when using memory efficient variant, + as the first element of the sequence cannot be in a position + greater than (pos - sequence_length) % buffer_size. + See comments in the code for more information. + + Args: + batch_size (int): Number of element to sample + sample_next_obs (bool): whether to sample the next observations from the 'observations' key. + Defaults to False. + clone (bool): whether to clone the sampled TensorDict. + sequence_length (int): the length of the sequence of each element. Defaults to 1. + n_samples (int): the number of samples to perform. Defaults to 1. + + Returns: + TensorDictBase: the sampled TensorDictBase with a `batch_size` of [n_samples, sequence_length, batch_size] + """ + if batch_size <= 0 or n_samples <= 0: + raise ValueError(f"`batch_size` ({batch_size}) and `n_samples` ({n_samples}) must be both greater than 0") + if self._buf is None: + raise RuntimeError("The buffer has not been initialized. Try to add some data first.") + + bs_per_buf = torch.bincount(torch.randint(0, self._n_envs, (batch_size,))) + samples = [ + b.sample( + batch_size=bs, + sample_next_obs=sample_next_obs, + clone=clone, + n_samples=n_samples, + sequence_length=sequence_length, + ) + for b, bs in zip(self._buf, bs_per_buf) + if bs > 0 + ] + return torch.cat(samples, dim=2 if self._sequential else 0) diff --git a/sheeprl/utils/callback.py b/sheeprl/utils/callback.py index ee74bb9e..aa055134 100644 --- a/sheeprl/utils/callback.py +++ b/sheeprl/utils/callback.py @@ -4,7 +4,7 @@ from lightning.fabric import Fabric from lightning.fabric.plugins.collectives import TorchCollective -from sheeprl.data.buffers import EpisodeBuffer, ReplayBuffer +from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer, ReplayBuffer class CheckpointCallback: @@ -24,7 +24,7 @@ def on_checkpoint_coupled( fabric: Fabric, ckpt_path: str, state: Dict[str, Any], - replay_buffer: Optional[Union["ReplayBuffer", "EpisodeBuffer"]] = None, + replay_buffer: Optional[Union["AsyncReplayBuffer", "ReplayBuffer", "EpisodeBuffer"]] = None, ): if replay_buffer is not None: if isinstance(replay_buffer, ReplayBuffer): @@ -32,6 +32,11 @@ def on_checkpoint_coupled( true_done = replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :].clone() # substitute the last done with all True values (all the environment are truncated) replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :] = True + elif isinstance(replay_buffer, AsyncReplayBuffer): + true_dones = [] + for b in replay_buffer.buffer: + true_dones.append(b["dones"][(b._pos - 1) % b.buffer_size, :].clone()) + b["dones"][(b._pos - 1) % b.buffer_size, :] = True state["rb"] = replay_buffer if fabric.world_size > 1: # We need to collect the buffers from all the ranks @@ -53,6 +58,9 @@ def on_checkpoint_coupled( if replay_buffer is not None and isinstance(replay_buffer, ReplayBuffer): # reinsert the true dones in the buffer replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :] = true_done + elif isinstance(replay_buffer, AsyncReplayBuffer): + for i, b in enumerate(replay_buffer.buffer): + b["dones"][(b._pos - 1) % b.buffer_size, :] = true_dones[i] def on_checkpoint_player( self, diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 85a2c39b..3233348b 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -479,7 +479,7 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): f"--buffer_size={int(os.environ['LT_DEVICES'])}", "--learning_starts=0", "--gradient_steps=1", - "--horizon=2", + "--horizon=8", "--env_id=" + env_id, "--root_dir=" + root_dir, "--run_name=" + run_name, @@ -535,7 +535,7 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): f"--buffer_size={int(os.environ['LT_DEVICES'])}", "--learning_starts=0", "--gradient_steps=1", - "--horizon=2", + "--horizon=8", "--env_id=" + env_id, "--root_dir=" + root_dir, "--run_name=" + run_name,