From 9fde26ae6593ff91a524c3f1ec1ecdb136abe875 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Thu, 23 Nov 2023 18:08:33 +0100 Subject: [PATCH 1/7] Moved per_rank_batch_size, per_rank_sequence_length and cnn/mlp_keys to algo config --- howto/configs.md | 4 +- howto/register_new_algorithm.md | 10 +- sheeprl/algos/dreamer_v1/agent.py | 28 ++--- sheeprl/algos/dreamer_v1/dreamer_v1.py | 53 ++++----- sheeprl/algos/dreamer_v1/evaluate.py | 10 +- sheeprl/algos/dreamer_v2/agent.py | 28 ++--- sheeprl/algos/dreamer_v2/dreamer_v2.py | 60 +++++----- sheeprl/algos/dreamer_v2/evaluate.py | 10 +- sheeprl/algos/dreamer_v2/utils.py | 4 +- sheeprl/algos/dreamer_v3/agent.py | 28 ++--- sheeprl/algos/dreamer_v3/dreamer_v3.py | 61 +++++----- sheeprl/algos/dreamer_v3/evaluate.py | 10 +- sheeprl/algos/dreamer_v3/utils.py | 4 +- sheeprl/algos/droq/droq.py | 32 +++--- sheeprl/algos/droq/evaluate.py | 8 +- sheeprl/algos/p2e_dv1/evaluate.py | 10 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 54 +++++---- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 48 ++++---- sheeprl/algos/p2e_dv2/evaluate.py | 10 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 60 +++++----- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 54 +++++---- sheeprl/algos/p2e_dv3/evaluate.py | 10 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 61 +++++----- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 50 ++++----- sheeprl/algos/ppo/evaluate.py | 10 +- sheeprl/algos/ppo/ppo.py | 38 +++---- sheeprl/algos/ppo/ppo_decoupled.py | 41 +++---- sheeprl/algos/ppo/utils.py | 12 +- sheeprl/algos/ppo_recurrent/evaluate.py | 10 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 49 +++++---- sheeprl/algos/ppo_recurrent/utils.py | 8 +- sheeprl/algos/sac/evaluate.py | 8 +- sheeprl/algos/sac/sac.py | 30 ++--- sheeprl/algos/sac/sac_decoupled.py | 37 ++++--- sheeprl/algos/sac/utils.py | 8 +- sheeprl/algos/sac_ae/evaluate.py | 22 ++-- sheeprl/algos/sac_ae/sac_ae.py | 84 +++++++------- sheeprl/configs/algo/default.yaml | 8 ++ sheeprl/configs/algo/dreamer_v1.yaml | 9 +- sheeprl/configs/algo/dreamer_v2.yaml | 11 +- sheeprl/configs/algo/dreamer_v3.yaml | 11 +- sheeprl/configs/algo/ppo.yaml | 2 +- sheeprl/configs/algo/ppo_recurrent.yaml | 1 + sheeprl/configs/algo/sac_ae.yaml | 6 + sheeprl/configs/config.yaml | 9 -- sheeprl/configs/env_config.yaml | 4 +- sheeprl/configs/exp/dreamer_v1.yaml | 18 +-- sheeprl/configs/exp/dreamer_v2.yaml | 44 ++++---- sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml | 8 +- sheeprl/configs/exp/dreamer_v3.yaml | 44 ++++---- .../configs/exp/dreamer_v3_100k_boxing.yaml | 6 +- .../exp/dreamer_v3_100k_ms_pacman.yaml | 2 +- sheeprl/configs/exp/dreamer_v3_L_doapp.yaml | 34 +++--- ..._v3_L_doapp_128px_gray_combo_discrete.yaml | 64 ++++++----- .../configs/exp/dreamer_v3_L_navigate.yaml | 32 +++--- .../configs/exp/dreamer_v3_XL_crafter.yaml | 20 ++-- .../exp/dreamer_v3_dmc_walker_walk.yaml | 14 +-- sheeprl/configs/exp/p2e_dv1_finetuning.yaml | 7 +- sheeprl/configs/exp/p2e_dv2_finetuning.yaml | 35 +++--- ...x_gray_combo_discrete_15Mexpl_20Mstps.yaml | 68 ++++++------ sheeprl/configs/exp/p2e_dv3_finetuning.yaml | 3 +- ...doapp_64px_gray_combo_discrete_5Mstps.yaml | 68 ++++++------ sheeprl/configs/exp/ppo.yaml | 12 +- sheeprl/configs/exp/ppo_recurrent.yaml | 8 +- sheeprl/configs/exp/sac.yaml | 14 +-- sheeprl/configs/exp/sac_ae.yaml | 9 +- sheeprl/utils/env.py | 77 ++++++++----- tests/run_tests.py | 2 +- tests/test_algos/test_algos.py | 104 +++++++++--------- tests/test_algos/test_cli.py | 22 ++-- 70 files changed, 908 insertions(+), 932 deletions(-) diff --git a/howto/configs.md b/howto/configs.md index de6031ed..2f37c559 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -123,10 +123,10 @@ root_dir: ${algo.name}/${env.id} # Encoder and decoder keys cnn_keys: encoder: [] - decoder: ${cnn_keys.encoder} + decoder: ${algo.cnn_keys.encoder} mlp_keys: encoder: [] - decoder: ${mlp_keys.encoder} + decoder: ${algo.mlp_keys.encoder} ``` ### Algorithms diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index f90f1c7e..533efb6f 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -146,7 +146,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): policy_step = 0 last_checkpoint = 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: @@ -170,9 +170,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): for k in o.keys(): if k in obs_keys: torch_obs = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs next_obs[k] = torch_obs @@ -212,9 +212,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): for k in o.keys(): if k in obs_keys: torch_obs = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs obs[k] = torch_obs diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index db9f8db8..9b2ea9c0 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -365,26 +365,26 @@ def build_models( # Define models cnn_encoder = ( CNNEncoder( - keys=cfg.cnn_keys.encoder, - input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.cnn_keys.encoder], - image_size=obs_space[cfg.cnn_keys.encoder[0]].shape[-2:], + keys=cfg.algo.cnn_keys.encoder, + input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.encoder], + image_size=obs_space[cfg.algo.cnn_keys.encoder[0]].shape[-2:], channels_multiplier=world_model_cfg.encoder.cnn_channels_multiplier, layer_norm=False, activation=eval(world_model_cfg.encoder.cnn_act), ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 else None ) mlp_encoder = ( MLPEncoder( - keys=cfg.mlp_keys.encoder, - input_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.encoder], + keys=cfg.algo.mlp_keys.encoder, + input_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder], mlp_layers=world_model_cfg.encoder.mlp_layers, dense_units=world_model_cfg.encoder.dense_units, activation=eval(world_model_cfg.encoder.dense_act), layer_norm=False, ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) @@ -418,29 +418,29 @@ def build_models( ) cnn_decoder = ( CNNDecoder( - keys=cfg.cnn_keys.decoder, - output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.cnn_keys.decoder], + keys=cfg.algo.cnn_keys.decoder, + output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.decoder], channels_multiplier=world_model_cfg.observation_model.cnn_channels_multiplier, latent_state_size=latent_state_size, cnn_encoder_output_dim=cnn_encoder.output_dim, - image_size=obs_space[cfg.cnn_keys.decoder[0]].shape[-2:], + image_size=obs_space[cfg.algo.cnn_keys.decoder[0]].shape[-2:], activation=eval(world_model_cfg.observation_model.cnn_act), layer_norm=False, ) - if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 + if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 else None ) mlp_decoder = ( MLPDecoder( - keys=cfg.mlp_keys.decoder, - output_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.decoder], + keys=cfg.algo.mlp_keys.decoder, + output_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.decoder], latent_state_size=latent_state_size, mlp_layers=world_model_cfg.observation_model.mlp_layers, dense_units=world_model_cfg.observation_model.dense_units, activation=eval(world_model_cfg.observation_model.dense_act), layer_norm=False, ) - if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0 + if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 else None ) observation_model = MultiDecoder(cnn_decoder, mlp_decoder) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 39723ae5..1b382852 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -100,14 +100,14 @@ def train( aggregator (MetricAggregator, optional): the aggregator to print the metrics. cfg (DictConfig): the configs. """ - batch_size = cfg.per_rank_batch_size - sequence_length = cfg.per_rank_sequence_length + batch_size = cfg.algo.per_rank_batch_size + sequence_length = cfg.algo.per_rank_sequence_length validate_args = cfg.distribution.validate_args recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size device = fabric.device - batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.cnn_keys.encoder} - batch_obs.update({k: data[k] for k in cfg.mlp_keys.encoder}) + batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} + batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) # Dynamic Learning # initialize the recurrent_state that must be a tuple of tensors (one for GRU or RNN). @@ -408,7 +408,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size # These arguments cannot be changed cfg.env.screen_size = 64 @@ -448,32 +448,27 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder world_model, actor, critic = build_models( fabric, @@ -543,7 +538,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -577,7 +572,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs obs[k] = torch_obs @@ -612,7 +607,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -652,7 +647,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k in obs_keys: # [N_envs, N_obs] next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() @@ -689,8 +684,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update > learning_starts and updates_before_training <= 0: local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) @@ -705,7 +700,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, ) @@ -766,7 +761,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 6e0829ad..1b4a30cc 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -33,13 +33,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index cb01db30..34a94bd0 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -908,26 +908,26 @@ def build_models( # Define models cnn_encoder = ( CNNEncoder( - keys=cfg.cnn_keys.encoder, - input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.cnn_keys.encoder], - image_size=obs_space[cfg.cnn_keys.encoder[0]].shape[-2:], + keys=cfg.algo.cnn_keys.encoder, + input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.encoder], + image_size=obs_space[cfg.algo.cnn_keys.encoder[0]].shape[-2:], channels_multiplier=world_model_cfg.encoder.cnn_channels_multiplier, layer_norm=world_model_cfg.encoder.layer_norm, activation=eval(world_model_cfg.encoder.cnn_act), ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 else None ) mlp_encoder = ( MLPEncoder( - keys=cfg.mlp_keys.encoder, - input_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.encoder], + keys=cfg.algo.mlp_keys.encoder, + input_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder], mlp_layers=world_model_cfg.encoder.mlp_layers, dense_units=world_model_cfg.encoder.dense_units, activation=eval(world_model_cfg.encoder.dense_act), layer_norm=world_model_cfg.encoder.layer_norm, ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) @@ -968,29 +968,29 @@ def build_models( ) cnn_decoder = ( CNNDecoder( - keys=cfg.cnn_keys.decoder, - output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.cnn_keys.decoder], + keys=cfg.algo.cnn_keys.decoder, + output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.decoder], channels_multiplier=world_model_cfg.observation_model.cnn_channels_multiplier, latent_state_size=latent_state_size, cnn_encoder_output_dim=cnn_encoder.output_dim, - image_size=obs_space[cfg.cnn_keys.decoder[0]].shape[-2:], + image_size=obs_space[cfg.algo.cnn_keys.decoder[0]].shape[-2:], activation=eval(world_model_cfg.observation_model.cnn_act), layer_norm=world_model_cfg.observation_model.layer_norm, ) - if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 + if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 else None ) mlp_decoder = ( MLPDecoder( - keys=cfg.mlp_keys.decoder, - output_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.decoder], + keys=cfg.algo.mlp_keys.decoder, + output_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.decoder], latent_state_size=latent_state_size, mlp_layers=world_model_cfg.observation_model.mlp_layers, dense_units=world_model_cfg.observation_model.dense_units, activation=eval(world_model_cfg.observation_model.dense_act), layer_norm=world_model_cfg.observation_model.layer_norm, ) - if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0 + if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 else None ) observation_model = MultiDecoder(cnn_decoder, mlp_decoder) diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index deea6b4e..326a27db 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -113,15 +113,15 @@ def train( # Dones: 0 d1 d2 d3 # Is-first 1 i1 i2 i3 - batch_size = cfg.per_rank_batch_size - sequence_length = cfg.per_rank_sequence_length + batch_size = cfg.algo.per_rank_batch_size + sequence_length = cfg.algo.per_rank_sequence_length validate_args = cfg.distribution.validate_args recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size discrete_size = cfg.algo.world_model.discrete_size device = fabric.device - batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.cnn_keys.encoder} - batch_obs.update({k: data[k] for k in cfg.mlp_keys.encoder}) + batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} + batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) # Given how the environment interaction works, we assume that the first element in a sequence # is the first one, as if the environment has been reset @@ -432,7 +432,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size # These arguments cannot be changed cfg.env.screen_size = 64 @@ -472,32 +472,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder world_model, actor, critic, target_critic = build_models( fabric, @@ -555,7 +551,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): elif buffer_type == "episode": rb = EpisodeBuffer( buffer_size, - sequence_length=cfg.per_rank_sequence_length, + sequence_length=cfg.algo.per_rank_sequence_length, device="cpu", memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), @@ -581,7 +577,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -616,7 +612,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: # Images stay uint8 to save space torch_obs = torch_obs.float() step_data[k] = torch_obs @@ -658,7 +654,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -703,7 +699,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in obs_keys: next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() @@ -735,7 +731,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data["is_first"] = torch.ones_like(reset_data["dones"]) if buffer_type == "episode": for i, d in enumerate(dones_idxes): - if len(episode_steps[d]) >= cfg.per_rank_sequence_length: + if len(episode_steps[d]) >= cfg.algo.per_rank_sequence_length: rb.add(torch.cat(episode_steps[d], dim=0)) episode_steps[d] = [reset_data[i : i + 1][None, ...]] else: @@ -752,15 +748,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if update >= learning_starts and updates_before_training <= 0: if buffer_type == "sequential": local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, ).to(device) else: local_data = rb.sample( - cfg.per_rank_batch_size, + cfg.algo.per_rank_batch_size, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, @@ -781,7 +777,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, actions_dim, @@ -845,7 +841,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index f32da474..a320324e 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -33,13 +33,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 910197a5..29153658 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -133,9 +133,9 @@ def test( # Act greedly through the environment preprocessed_obs = {} for k, v in next_obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) real_actions = player.get_greedy_action( preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 0d712b2e..4174eccf 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -945,27 +945,27 @@ def build_models( cnn_stages = int(np.log2(cfg.env.screen_size) - np.log2(4)) cnn_encoder = ( CNNEncoder( - keys=cfg.cnn_keys.encoder, - input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.cnn_keys.encoder], - image_size=obs_space[cfg.cnn_keys.encoder[0]].shape[-2:], + keys=cfg.algo.cnn_keys.encoder, + input_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.encoder], + image_size=obs_space[cfg.algo.cnn_keys.encoder[0]].shape[-2:], channels_multiplier=world_model_cfg.encoder.cnn_channels_multiplier, layer_norm=world_model_cfg.encoder.layer_norm, activation=eval(world_model_cfg.encoder.cnn_act), stages=cnn_stages, ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 else None ) mlp_encoder = ( MLPEncoder( - keys=cfg.mlp_keys.encoder, - input_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.encoder], + keys=cfg.algo.mlp_keys.encoder, + input_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder], mlp_layers=world_model_cfg.encoder.mlp_layers, dense_units=world_model_cfg.encoder.dense_units, activation=eval(world_model_cfg.encoder.dense_act), layer_norm=world_model_cfg.encoder.layer_norm, ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) @@ -1007,30 +1007,30 @@ def build_models( ) cnn_decoder = ( CNNDecoder( - keys=cfg.cnn_keys.decoder, - output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.cnn_keys.decoder], + keys=cfg.algo.cnn_keys.decoder, + output_channels=[int(np.prod(obs_space[k].shape[:-2])) for k in cfg.algo.cnn_keys.decoder], channels_multiplier=world_model_cfg.observation_model.cnn_channels_multiplier, latent_state_size=latent_state_size, cnn_encoder_output_dim=cnn_encoder.output_dim, - image_size=obs_space[cfg.cnn_keys.decoder[0]].shape[-2:], + image_size=obs_space[cfg.algo.cnn_keys.decoder[0]].shape[-2:], activation=eval(world_model_cfg.observation_model.cnn_act), layer_norm=world_model_cfg.observation_model.layer_norm, stages=cnn_stages, ) - if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 + if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 else None ) mlp_decoder = ( MLPDecoder( - keys=cfg.mlp_keys.decoder, - output_dims=[obs_space[k].shape[0] for k in cfg.mlp_keys.decoder], + keys=cfg.algo.mlp_keys.decoder, + output_dims=[obs_space[k].shape[0] for k in cfg.algo.mlp_keys.decoder], latent_state_size=latent_state_size, mlp_layers=world_model_cfg.observation_model.mlp_layers, dense_units=world_model_cfg.observation_model.dense_units, activation=eval(world_model_cfg.observation_model.dense_act), layer_norm=world_model_cfg.observation_model.layer_norm, ) - if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0 + if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 else None ) observation_model = MultiDecoder(cnn_decoder, mlp_decoder) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index d99dcd5e..a631ae1f 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -90,15 +90,15 @@ def train( # Dones: 0 d1 d2 d3 # Is-first 1 i1 i2 i3 - batch_size = cfg.per_rank_batch_size - sequence_length = cfg.per_rank_sequence_length + batch_size = cfg.algo.per_rank_batch_size + sequence_length = cfg.algo.per_rank_sequence_length validate_args = cfg.distribution.validate_args recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size discrete_size = cfg.algo.world_model.discrete_size device = fabric.device - batch_obs = {k: data[k] / 255.0 for k in cfg.cnn_keys.encoder} - batch_obs.update({k: data[k] for k in cfg.mlp_keys.encoder}) + batch_obs = {k: data[k] / 255.0 for k in cfg.algo.cnn_keys.encoder} + batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :]) # Given how the environment interaction works, we remove the last actions @@ -132,12 +132,13 @@ def train( # Compute the distribution over the reconstructed observations po = { - k: MSEDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) for k in cfg.cnn_keys.decoder + k: MSEDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) + for k in cfg.algo.cnn_keys.decoder } po.update( { k: SymlogDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) - for k in cfg.mlp_keys.decoder + for k in cfg.algo.mlp_keys.decoder } ) @@ -362,7 +363,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # These arguments cannot be changed cfg.env.frame_stack = -1 @@ -406,32 +407,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder world_model, actor, critic, target_critic = build_models( fabric, @@ -511,7 +508,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -545,7 +542,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: # Images stay uint8 to save space torch_obs = torch_obs.float() step_data[k] = torch_obs @@ -581,7 +578,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255.0 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -637,7 +634,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in obs_keys: next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = next_obs[k] - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() @@ -655,7 +652,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") for k in obs_keys: reset_data[k] = real_next_obs[k][dones_idxes] - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: reset_data[k] = reset_data[k].float() reset_data["dones"] = torch.ones(reset_envs, 1).float() reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() @@ -674,8 +671,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, @@ -696,7 +693,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, is_continuous, @@ -763,7 +760,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "expl_decay_steps": expl_decay_steps, "moments": moments.state_dict(), "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index e9f298d9..63a84dd5 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -33,13 +33,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 4a589e0d..7c31278f 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -107,9 +107,9 @@ def test( # Act greedly through the environment preprocessed_obs = {} for k, v in next_obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) real_actions = player.get_greedy_action( preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 609de45d..f7d8fcee 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -43,7 +43,7 @@ def train( # Sample a minibatch in a distributed way: Line 5 - Algorithm 2 # We sample one time to reduce the communications between processes sample = rb.sample( - cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs + cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs ) critic_data = fabric.all_gather(sample.to_dict()) critic_data = make_tensordict(critic_data).view(-1) @@ -57,15 +57,15 @@ def train( drop_last=False, ) critic_sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) else: critic_sampler = BatchSampler( - sampler=range(len(critic_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + sampler=range(len(critic_data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) # Sample a different minibatch in a distributed way to update actor and alpha parameter - sample = rb.sample(cfg.per_rank_batch_size) + sample = rb.sample(cfg.algo.per_rank_batch_size) actor_data = fabric.all_gather(sample.to_dict()) actor_data = make_tensordict(actor_data).view(-1) if fabric.world_size > 1: @@ -146,11 +146,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - if len(cfg.cnn_keys.encoder) > 0: + if len(cfg.algo.cnn_keys.encoder) > 0: warnings.warn("DroQ algorithm cannot allow to use images as observations, the CNN keys will be ignored") - cfg.cnn_keys.encoder = [] + cfg.algo.cnn_keys.encoder = [] # Create TensorBoardLogger. This will create the logger only on the # rank-0 process @@ -181,20 +181,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): raise ValueError("Only continuous action space is supported for the DroQ agent") if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if len(cfg.mlp_keys.encoder) == 0: + if len(cfg.algo.mlp_keys.encoder) == 0: raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") - for k in cfg.mlp_keys.encoder: + for k in cfg.algo.mlp_keys.encoder: if len(observation_space[k].shape) > 1: raise ValueError( "Only environments with vector-only observations are supported by the DroQ agent. " f"Provided environment: {cfg.env.id}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Define the agent and the optimizer and setup them with Fabric act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( observation_dim=obs_dim, action_dim=act_dim, @@ -264,7 +264,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -289,7 +289,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Get the first environment observation and start the optimization o = envs.reset(seed=cfg.seed)[0] obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] for update in range(start_step, num_updates + 1): @@ -325,10 +325,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with device: next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] real_next_obs = torch.cat( - [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -391,7 +391,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py index c192345a..5738bafc 100644 --- a/sheeprl/algos/droq/evaluate.py +++ b/sheeprl/algos/droq/evaluate.py @@ -36,19 +36,19 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): raise ValueError("Only continuous action space is supported for the DroQ agent") if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if len(cfg.mlp_keys.encoder) == 0: + if len(cfg.algo.mlp_keys.encoder) == 0: raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") - for k in cfg.mlp_keys.encoder: + for k in cfg.algo.mlp_keys.encoder: if len(observation_space[k].shape) > 1: raise ValueError( "Only environments with vector-only observations are supported by the DroQ agent. " f"Provided environment: {cfg.env.id}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( observation_dim=obs_dim, action_dim=act_dim, diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index 8c6bf7ba..27b1396a 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -34,13 +34,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index b2ec8e5f..edc4e8d2 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -99,14 +99,14 @@ def train( actor_exploration_optimizer (_FabricOptimizer): the optimizer of the actor for exploration. critic_exploration_optimizer (_FabricOptimizer): the optimizer of the critic for exploration. """ - batch_size = cfg.per_rank_batch_size - sequence_length = cfg.per_rank_sequence_length + batch_size = cfg.algo.per_rank_batch_size + sequence_length = cfg.algo.per_rank_sequence_length validate_args = cfg.distribution.validate_args recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size device = fabric.device - batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.cnn_keys.encoder} - batch_obs.update({k: data[k] for k in cfg.mlp_keys.encoder}) + batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} + batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) # Dynamic Learning recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device) @@ -419,7 +419,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size # These arguments cannot be changed cfg.env.screen_size = 64 @@ -460,32 +460,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_models( fabric, @@ -603,7 +599,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -643,7 +639,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs obs[k] = torch_obs @@ -678,7 +674,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -720,7 +716,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k in obs_keys: # [N_envs, N_obs] next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() @@ -757,8 +753,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) @@ -773,7 +769,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, ensembles=ensembles, @@ -849,7 +845,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), "actor_exploration_optimizer": actor_exploration_optimizer.state_dict(), diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index ca4a96f6..b595789e 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -44,7 +44,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Finetuning that was interrupted for some reason if resume_from_checkpoint: state = fabric.load(pathlib.Path(cfg.checkpoint.resume_from)) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size else: state = fabric.load(ckpt_path) @@ -69,7 +69,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if cfg.buffer.load_from_exploration and exploration_cfg.buffer.checkpoint: cfg.env.num_envs = exploration_cfg.env.num_envs # There must be the same cnn and mlp keys during exploration and finetuning - cfg.cnn_keys = exploration_cfg.cnn_keys + cfg.algo.cnn_keys = exploration_cfg.algo.cnn_keys cfg.mlp_keys = exploration_cfg.mlp_keys # These arguments cannot be changed @@ -110,32 +110,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder world_model, actor_task, critic_task, actor_exploration, _ = build_models( fabric, @@ -207,7 +203,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if resume_from_checkpoint and not cfg.buffer.checkpoint: learning_starts += start_step @@ -247,7 +243,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs obs[k] = torch_obs @@ -266,7 +262,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -308,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): for k in obs_keys: # [N_envs, N_obs] next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() @@ -348,8 +344,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) @@ -364,7 +360,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, ) @@ -432,7 +428,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "critic_task_optimizer": critic_task_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index e3b1509f..c2ccf666 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -34,13 +34,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 41057dca..ab2599f5 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -108,15 +108,15 @@ def train( actions_dim (Sequence[int]): the actions dimension. is_exploring (bool): whether the agent is exploring. """ - batch_size = cfg.per_rank_batch_size - sequence_length = cfg.per_rank_sequence_length + batch_size = cfg.algo.per_rank_batch_size + sequence_length = cfg.algo.per_rank_sequence_length validate_args = cfg.distribution.validate_args recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size discrete_size = cfg.algo.world_model.discrete_size device = fabric.device - batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.cnn_keys.encoder} - batch_obs.update({k: data[k] for k in cfg.mlp_keys.encoder}) + batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} + batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :]) # Dynamic Learning @@ -533,7 +533,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size # These arguments cannot be changed cfg.env.screen_size = 64 @@ -574,32 +574,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder ( world_model, @@ -727,7 +723,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): elif buffer_type == "episode": rb = EpisodeBuffer( buffer_size, - sequence_length=cfg.per_rank_sequence_length, + sequence_length=cfg.algo.per_rank_sequence_length, device="cpu", memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), @@ -753,7 +749,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -794,7 +790,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: # Images stay uint8 to save space torch_obs = torch_obs.float() step_data[k] = torch_obs @@ -836,7 +832,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -880,7 +876,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k in obs_keys: # [N_envs, N_obs] next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() @@ -912,7 +908,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data["is_first"] = torch.ones_like(reset_data["dones"]) if buffer_type == "episode": for i, d in enumerate(dones_idxes): - if len(episode_steps[d]) >= cfg.per_rank_sequence_length: + if len(episode_steps[d]) >= cfg.algo.per_rank_sequence_length: rb.add(torch.cat(episode_steps[d], dim=0)) episode_steps[d] = [reset_data[i : i + 1][None, ...]] else: @@ -929,15 +925,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if update >= learning_starts and updates_before_training <= 0: if buffer_type == "sequential": local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, ).to(device) else: local_data = rb.sample( - cfg.per_rank_batch_size, + cfg.algo.per_rank_batch_size, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, @@ -963,7 +959,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, ensembles=ensembles, @@ -1043,7 +1039,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), "target_critic_exploration": target_critic_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 6af7996c..72aff80f 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -46,7 +46,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Finetuning that was interrupted for some reason if resume_from_checkpoint: state = fabric.load(pathlib.Path(cfg.checkpoint.resume_from)) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size else: state = fabric.load(ckpt_path) @@ -73,7 +73,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if cfg.buffer.load_from_exploration and exploration_cfg.buffer.checkpoint: cfg.env.num_envs = exploration_cfg.env.num_envs # There must be the same cnn and mlp keys during exploration and finetuning - cfg.cnn_keys = exploration_cfg.cnn_keys + cfg.algo.cnn_keys = exploration_cfg.algo.cnn_keys cfg.mlp_keys = exploration_cfg.mlp_keys # These arguments cannot be changed @@ -114,32 +114,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_models( fabric, @@ -199,7 +195,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): elif buffer_type == "episode": rb = EpisodeBuffer( buffer_size, - sequence_length=cfg.per_rank_sequence_length, + sequence_length=cfg.algo.per_rank_sequence_length, device="cpu", memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), @@ -225,7 +221,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if resume_from_checkpoint and not cfg.buffer.checkpoint: learning_starts += start_step @@ -266,7 +262,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: # Images stay uint8 to save space torch_obs = torch_obs.float() step_data[k] = torch_obs @@ -292,7 +288,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -336,7 +332,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): for k in obs_keys: # [N_envs, N_obs] next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() @@ -368,7 +364,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data["is_first"] = torch.ones_like(reset_data["dones"]) if buffer_type == "episode": for i, d in enumerate(dones_idxes): - if len(episode_steps[d]) >= cfg.per_rank_sequence_length: + if len(episode_steps[d]) >= cfg.algo.per_rank_sequence_length: rb.add(torch.cat(episode_steps[d], dim=0)) episode_steps[d] = [reset_data[i : i + 1][None, ...]] else: @@ -388,15 +384,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.actor_type = "task" if buffer_type == "sequential": local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, ).to(device) else: local_data = rb.sample( - cfg.per_rank_batch_size, + cfg.algo.per_rank_batch_size, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, @@ -418,7 +414,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, actions_dim=actions_dim, @@ -488,7 +484,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "critic_task_optimizer": critic_task_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index 910f57db..97b86112 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -34,13 +34,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 10db1464..a5f1b1a6 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -109,15 +109,15 @@ def train( is_continuous (bool): whether or not are continuous actions. actions_dim (Sequence[int]): the actions dimension. """ - batch_size = cfg.per_rank_batch_size - sequence_length = cfg.per_rank_sequence_length + batch_size = cfg.algo.per_rank_batch_size + sequence_length = cfg.algo.per_rank_sequence_length validate_args = cfg.distribution.validate_args recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size discrete_size = cfg.algo.world_model.discrete_size device = fabric.device - batch_obs = {k: data[k] / 255.0 for k in cfg.cnn_keys.encoder} - batch_obs.update({k: data[k] for k in cfg.mlp_keys.encoder}) + batch_obs = {k: data[k] / 255.0 for k in cfg.algo.cnn_keys.encoder} + batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :]) # Given how the environment interaction works, we remove the last actions @@ -151,12 +151,13 @@ def train( # compute the distribution over the reconstructed observations po = { - k: MSEDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) for k in cfg.cnn_keys.decoder + k: MSEDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) + for k in cfg.algo.cnn_keys.decoder } po.update( { k: SymlogDistribution(reconstructed_obs[k], dims=len(reconstructed_obs[k].shape[2:])) - for k in cfg.mlp_keys.decoder + for k in cfg.algo.mlp_keys.decoder } ) # Compute the distribution over the rewards @@ -563,7 +564,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size # These arguments cannot be changed cfg.env.frame_stack = 1 @@ -603,32 +604,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder ( world_model, @@ -813,7 +810,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -853,7 +850,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: # Images stay uint8 to save space torch_obs = torch_obs.float() step_data[k] = torch_obs @@ -889,7 +886,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255.0 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -945,7 +942,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in obs_keys: next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = next_obs[k] - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() @@ -964,7 +961,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k in real_next_obs.keys(): if k in obs_keys: reset_data[k] = real_next_obs[k][dones_idxes] - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: reset_data[k] = reset_data[k].float() reset_data["dones"] = torch.ones(reset_envs, 1).float() reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() @@ -983,8 +980,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, @@ -1012,7 +1009,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, ensembles=ensembles, @@ -1100,7 +1097,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "actor_exploration_optimizer": actor_exploration_optimizer.state_dict(), "last_log": last_log, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index a28b4eab..633556de 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -41,7 +41,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Finetuning that was interrupted for some reason if resume_from_checkpoint: state = fabric.load(pathlib.Path(cfg.checkpoint.resume_from)) - cfg.per_rank_batch_size = state["batch_size"] // world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // world_size else: state = fabric.load(ckpt_path) @@ -68,7 +68,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if cfg.buffer.load_from_exploration and exploration_cfg.buffer.checkpoint: cfg.env.num_envs = exploration_cfg.env.num_envs # There must be the same cnn and mlp keys during exploration and finetuning - cfg.cnn_keys = exploration_cfg.cnn_keys + cfg.algo.cnn_keys = exploration_cfg.algo.cnn_keys cfg.mlp_keys = exploration_cfg.mlp_keys # These arguments cannot be changed @@ -108,32 +108,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjointed") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder ( world_model, @@ -223,7 +219,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if resume_from_checkpoint and not cfg.buffer.checkpoint: learning_starts += start_step @@ -263,7 +259,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} for k in obs_keys: torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: # Images stay uint8 to save space torch_obs = torch_obs.float() step_data[k] = torch_obs @@ -283,7 +279,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) / 255.0 else: preprocessed_obs[k] = v[None, ...].to(device) @@ -339,7 +335,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if k in obs_keys: next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) step_data[k] = next_obs[k] - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = step_data[k].float() @@ -358,7 +354,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): for k in real_next_obs.keys(): if k in obs_keys: reset_data[k] = real_next_obs[k][dones_idxes] - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: reset_data[k] = reset_data[k].float() reset_data["dones"] = torch.ones(reset_envs, 1).float() reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() @@ -380,8 +376,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" local_data = rb.sample( - cfg.per_rank_batch_size, - sequence_length=cfg.per_rank_sequence_length, + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, n_samples=cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps, @@ -403,7 +399,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), aggregator, cfg, is_continuous=is_continuous, @@ -475,7 +471,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "critic_task_optimizer": critic_task_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * world_size, + "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index bf69fa4a..35220f80 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -32,13 +32,13 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: raise RuntimeError( "You should specify at least one CNN keys or MLP keys from the cli: " "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(env.action_space, gym.spaces.Box) is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) @@ -54,8 +54,8 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): encoder_cfg=cfg.algo.encoder, actor_cfg=cfg.algo.actor, critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, screen_size=cfg.env.screen_size, distribution_cfg=cfg.distribution, is_continuous=is_continuous, diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 3d247695..49e62ece 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -49,7 +49,7 @@ def train( ) else: sampler = RandomSampler(indexes) - sampler = BatchSampler(sampler, batch_size=cfg.per_rank_batch_size, drop_last=False) + sampler = BatchSampler(sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False) for epoch in range(cfg.algo.update_epochs): if cfg.buffer.share_data: @@ -57,8 +57,8 @@ def train( for batch_idxes in sampler: batch = data[batch_idxes] normalized_obs = { - k: batch[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else batch[k] - for k in cfg.mlp_keys.encoder + cfg.cnn_keys.encoder + k: batch[k] / 255 - 0.5 if k in cfg.algo.cnn_keys.encoder else batch[k] + for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder } _, logprobs, entropy, new_values = agent( normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) @@ -128,7 +128,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Create TensorBoardLogger. This will create the logger only on the # rank-0 process @@ -157,15 +157,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: raise RuntimeError( "You should specify at least one CNN keys or MLP keys from the cli: " "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) @@ -181,8 +181,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): encoder_cfg=cfg.algo.encoder, actor_cfg=cfg.algo.actor, critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, screen_size=cfg.env.screen_size, distribution_cfg=cfg.distribution, is_continuous=is_continuous, @@ -229,7 +229,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: @@ -260,9 +260,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): next_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs next_obs[k] = torch_obs @@ -277,7 +277,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + k: next_obs[k] / 255 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] for k in obs_keys } actions, logprobs, _, values = agent.module(normalized_obs) if is_continuous: @@ -302,7 +302,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for i, truncated_env in enumerate(truncated_envs): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): @@ -328,10 +328,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Update the observation and dones next_obs = {} for k in obs_keys: - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch.as_tensor(obs[k], device=device) torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch.as_tensor(obs[k], device=device, dtype=torch.float32) step_data[k] = torch_obs next_obs[k] = torch_obs @@ -350,7 +350,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + k: next_obs[k] / 255 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] for k in obs_keys } next_values = agent.module.get_value(normalized_obs) returns, advantages = gae( @@ -442,7 +442,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 9af077b6..98883e91 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -43,7 +43,7 @@ def player( # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // (world_collective.world_size - 1) + cfg.algo.per_rank_batch_size = state["batch_size"] // (world_collective.world_size - 1) # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -64,15 +64,15 @@ def player( if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: raise RuntimeError( "You should specify at least one CNN keys or MLP keys from the cli: " "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) @@ -93,8 +93,8 @@ def player( "encoder_cfg": cfg.algo.encoder, "actor_cfg": cfg.algo.actor, "critic_cfg": cfg.algo.critic, - "cnn_keys": cfg.cnn_keys.encoder, - "mlp_keys": cfg.mlp_keys.encoder, + "cnn_keys": cfg.algo.cnn_keys.encoder, + "mlp_keys": cfg.algo.mlp_keys.encoder, "screen_size": cfg.env.screen_size, "distribution_cfg": cfg.distribution, "is_continuous": is_continuous, @@ -135,7 +135,7 @@ def player( last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps) - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: @@ -173,9 +173,9 @@ def player( next_obs = {} for k in obs_keys: torch_obs = torch.as_tensor(obs[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs next_obs[k] = torch_obs @@ -192,7 +192,8 @@ def player( with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { - k: next_obs[k] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + k: next_obs[k] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] + for k in obs_keys } actions, logprobs, _, values = agent(normalized_obs) if is_continuous: @@ -217,7 +218,7 @@ def player( for i, truncated_env in enumerate(truncated_envs): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): @@ -243,10 +244,10 @@ def player( # Update the observation and dones next_obs = {} for k in obs_keys: - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch.as_tensor(obs[k], device=device) torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch.as_tensor(obs[k], device=device, dtype=torch.float32) step_data[k] = torch_obs next_obs[k] = torch_obs @@ -262,7 +263,9 @@ def player( fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - normalized_obs = {k: next_obs[k] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys} + normalized_obs = { + k: next_obs[k] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] for k in obs_keys + } next_values = agent.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], @@ -443,7 +446,7 @@ def trainer( "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, "update": update, - "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, } @@ -462,7 +465,7 @@ def trainer( # Prepare sampler indexes = list(range(data.shape[0])) - sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) + sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.algo.per_rank_batch_size, drop_last=False) # Start training with timer( @@ -476,7 +479,7 @@ def trainer( batch = data[batch_idxes] normalized_obs = { k: batch[k] / 255.0 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] - for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + for k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder } _, logprobs, entropy, new_values = agent( normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) @@ -578,7 +581,7 @@ def trainer( "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, "update": update, - "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index 7f52a669..a923a808 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -18,11 +18,11 @@ def test(agent: PPOAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): o = env.reset(seed=cfg.seed)[0] obs = {} for k in o.keys(): - if k in cfg.mlp_keys.encoder + cfg.cnn_keys.encoder: + if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() obs[k] = torch_obs @@ -39,11 +39,11 @@ def test(agent: PPOAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): cumulative_rew += reward obs = {} for k in o.keys(): - if k in cfg.mlp_keys.encoder + cfg.cnn_keys.encoder: + if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() obs[k] = torch_obs diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py index 9ea34a01..43321fd0 100644 --- a/sheeprl/algos/ppo_recurrent/evaluate.py +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -32,13 +32,13 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: raise RuntimeError( "You should specify at least one CNN keys or MLP keys from the cli: " "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) is_continuous = isinstance(env.action_space, gym.spaces.Box) is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) @@ -55,8 +55,8 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): rnn_cfg=cfg.algo.rnn, actor_cfg=cfg.algo.actor, critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, is_continuous=is_continuous, distribution_cfg=cfg.distribution, num_envs=cfg.env.num_envs, diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 646e894f..e270c673 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -39,8 +39,8 @@ def train( cfg: Dict[str, Any], ): num_sequences = data.shape[1] - if cfg.per_rank_num_batches > 0: - batch_size = num_sequences // cfg.per_rank_num_batches + if cfg.algo.per_rank_num_batches > 0: + batch_size = num_sequences // cfg.algo.per_rank_num_batches batch_size = batch_size if batch_size > 0 else num_sequences else: batch_size = 1 @@ -54,11 +54,11 @@ def train( for idxes in sampler: batch = data[:, idxes] mask = batch["mask"].unsqueeze(-1) - for k in cfg.cnn_keys.encoder: + for k in cfg.algo.cnn_keys.encoder: batch[k] = batch[k] / 255.0 - 0.5 _, logprobs, entropies, values, _ = agent( - {k: batch[k] for k in set(cfg.cnn_keys.encoder + cfg.mlp_keys.encoder)}, + {k: batch[k] for k in set(cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder)}, prev_actions=batch["prev_actions"], prev_states=(batch["prev_hx"][:1], batch["prev_cx"][:1]), actions=torch.split(batch["actions"], agent.actions_dim, dim=-1), @@ -134,7 +134,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Create TensorBoardLogger. This will create the logger only on the # rank-0 process @@ -163,15 +163,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: raise RuntimeError( "You should specify at least one CNN keys or MLP keys from the cli: " "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) @@ -189,8 +189,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rnn_cfg=cfg.algo.rnn, actor_cfg=cfg.algo.actor, critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, is_continuous=is_continuous, distribution_cfg=cfg.distribution, num_envs=cfg.env.num_envs, @@ -224,7 +224,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data = TensorDict({}, batch_size=[1, cfg.env.num_envs], device=device) # Check that `rollout_steps` = k * `per_rank_sequence_length` - if cfg.algo.rollout_steps % cfg.per_rank_sequence_length != 0: + if cfg.algo.rollout_steps % cfg.algo.per_rank_sequence_length != 0: pass # Global variables @@ -235,7 +235,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: @@ -266,9 +266,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = {} for k in obs_keys: torch_obs = torch.as_tensor(o[k], device=fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs[None] # [Seq_len, Batch_size, D] --> [1, num_envs, D] obs[k] = torch_obs @@ -287,7 +287,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { - k: obs[k][None] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys + k: obs[k][None] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else obs[k][None] + for k in obs_keys } # [Seq_len, Batch_size, D] --> [1, num_envs, D] actions, logprobs, _, values, states = agent.module( normalized_obs, prev_actions=prev_actions, prev_states=prev_states @@ -315,7 +316,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for i, truncated_env in enumerate(truncated_envs): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_v = torch_v.view(1, len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][0, i] = torch_v with torch.no_grad(): @@ -351,10 +352,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Update the observation obs = {} for k in obs_keys: - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch.as_tensor(next_obs[k], device=device) torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.mlp_keys.encoder: + elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch.as_tensor(next_obs[k], device=device, dtype=torch.float32) step_data[k] = torch_obs[None] obs[k] = torch_obs @@ -378,7 +379,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = { - k: obs[k][None] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys + k: obs[k][None] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else obs[k][None] for k in obs_keys } feat = agent.module.feature_extractor(normalized_obs) rnn_out, _ = agent.module.rnn(torch.cat((feat, actions), dim=-1), states) @@ -418,8 +419,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): episodes.append(episode) start = stop + 1 # 2. Split every episode into sequences of length `per_rank_sequence_length` - if cfg.per_rank_sequence_length is not None and cfg.per_rank_sequence_length > 0: - sequences = list(itertools.chain.from_iterable([ep.split(cfg.per_rank_sequence_length) for ep in episodes])) + if cfg.algo.per_rank_sequence_length is not None and cfg.algo.per_rank_sequence_length > 0: + sequences = list( + itertools.chain.from_iterable([ep.split(cfg.algo.per_rank_sequence_length) for ep in episodes]) + ) else: sequences = episodes padded_sequences = pad_sequence(sequences, batch_first=False, return_mask=True) # [Seq_len, Num_seq, *] @@ -487,7 +490,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, "update": update * world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index d363d297..18b0cc4e 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -24,12 +24,12 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di o = env.reset(seed=cfg.seed)[0] next_obs = { k: torch.tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1, *o[k].shape[-2:]) / 255 - for k in cfg.cnn_keys.encoder + for k in cfg.algo.cnn_keys.encoder } next_obs.update( { k: torch.tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1) - for k in cfg.mlp_keys.encoder + for k in cfg.algo.mlp_keys.encoder } ) state = ( @@ -54,10 +54,10 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di with fabric.device: next_obs = { k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1, *o[k].shape[-2:]) / 255 - for k in cfg.cnn_keys.encoder + for k in cfg.algo.cnn_keys.encoder } next_obs.update( - {k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1) for k in cfg.mlp_keys.encoder} + {k: torch.as_tensor(o[k], dtype=torch.float32).view(1, 1, -1) for k in cfg.algo.mlp_keys.encoder} ) if cfg.dry_run: diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py index 4e7a5861..3fbbdc35 100644 --- a/sheeprl/algos/sac/evaluate.py +++ b/sheeprl/algos/sac/evaluate.py @@ -35,18 +35,18 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): raise ValueError("Only continuous action space is supported for the SAC agent") if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if len(cfg.mlp_keys.encoder) == 0: + if len(cfg.algo.mlp_keys.encoder) == 0: raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") - for k in cfg.mlp_keys.encoder: + for k in cfg.algo.mlp_keys.encoder: if len(observation_space[k].shape) > 1: raise ValueError( "Only environments with vector-only observations are supported by the SAC agent. " f"Provided environment: {cfg.env.id}" ) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( observation_dim=obs_dim, action_dim=act_dim, diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 9b08fd85..96894664 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -98,11 +98,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - if len(cfg.cnn_keys.encoder) > 0: + if len(cfg.algo.cnn_keys.encoder) > 0: warnings.warn("SAC algorithm cannot allow to use images as observations, the CNN keys will be ignored") - cfg.cnn_keys.encoder = [] + cfg.algo.cnn_keys.encoder = [] # Create TensorBoardLogger. This will create the logger only on the # rank-0 process @@ -133,20 +133,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): raise ValueError("Only continuous action space is supported for the SAC agent") if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if len(cfg.mlp_keys.encoder) == 0: + if len(cfg.algo.mlp_keys.encoder) == 0: raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") - for k in cfg.mlp_keys.encoder: + for k in cfg.algo.mlp_keys.encoder: if len(observation_space[k].shape) > 1: raise ValueError( "Only environments with vector-only observations are supported by the SAC agent. " f"Provided environment: {cfg.env.id}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Define the agent and the optimizer and setup sthem with Fabric act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( observation_dim=obs_dim, action_dim=act_dim, @@ -210,7 +210,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -235,7 +235,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Get the first environment observation and start the optimization o = envs.reset(seed=cfg.seed)[0] obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] for update in range(start_step, num_updates + 1): @@ -274,10 +274,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with device: next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] real_next_obs = torch.cat( - [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) @@ -300,7 +300,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # We sample one time to reduce the communications between processes sample = rb.sample( - training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, ) # [G*B, 1] gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] @@ -315,11 +315,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): drop_last=False, ) sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) else: sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + sampler=range(len(gathered_data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) # Start training @@ -380,7 +380,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index ae565c45..2752057f 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -42,11 +42,11 @@ def player( # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - if len(cfg.cnn_keys.encoder) > 0: + if len(cfg.algo.cnn_keys.encoder) > 0: warnings.warn("SAC algorithm cannot allow to use images as observations, the CNN keys will be ignored") - cfg.cnn_keys.encoder = [] + cfg.algo.cnn_keys.encoder = [] # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -69,16 +69,16 @@ def player( raise ValueError("Only continuous action space is supported for the SAC agent") if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if len(cfg.mlp_keys.encoder) == 0: + if len(cfg.algo.mlp_keys.encoder) == 0: raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") - for k in cfg.mlp_keys.encoder: + for k in cfg.algo.mlp_keys.encoder: if len(observation_space[k].shape) > 1: raise ValueError( "Only environments with vector-only observations are supported by the SAC agent. " f"Provided environment: {cfg.env.id}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Send (possibly updated, by the make_env method for example) cfg to the trainers cfg.checkpoint.log_dir = log_dir @@ -86,7 +86,7 @@ def player( # Define the agent and the optimizer and setup them with Fabric act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( observation_dim=obs_dim, action_dim=act_dim, @@ -134,7 +134,7 @@ def player( last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs) - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -159,7 +159,7 @@ def player( # Get the first environment observation and start the optimization o = envs.reset(seed=cfg.seed)[0] obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] for update in range(start_step, num_updates + 1): @@ -198,10 +198,10 @@ def player( with device: next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] real_next_obs = torch.cat( - [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1 + [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) # [N_envs, N_obs] actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -230,9 +230,12 @@ def player( # Sample data to be sent to the trainers training_steps = learning_starts if update == learning_starts else 1 chunks = rb.sample( - training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size * (fabric.world_size - 1), + training_steps + * cfg.algo.per_rank_gradient_steps + * cfg.algo.per_rank_batch_size + * (fabric.world_size - 1), sample_next_obs=cfg.buffer.sample_next_obs, - ).split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size) + ).split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size) world_collective.scatter_object_list([None], [None] + chunks, src=0) # Wait the trainers to finish @@ -338,7 +341,7 @@ def trainer( # Define the agent and the optimizer and setup them with Fabric act_dim = prod(envs.single_action_space.shape) - obs_dim = sum([prod(envs.single_observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + obs_dim = sum([prod(envs.single_observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( observation_dim=obs_dim, @@ -414,7 +417,7 @@ def trainer( "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "update": update, - "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, } @@ -428,7 +431,7 @@ def trainer( ) return data = make_tensordict(data, device=device) - sampler = BatchSampler(range(len(data)), batch_size=cfg.per_rank_batch_size, drop_last=False) + sampler = BatchSampler(range(len(data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False) # Start training with timer( @@ -486,7 +489,7 @@ def trainer( "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "update": update, - "batch_size": cfg.per_rank_batch_size * (world_collective.world_size - 1), + "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index ca11198b..d10122e2 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -23,7 +23,9 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): cumulative_rew = 0 with fabric.device: o = env.reset(seed=cfg.seed)[0] - next_obs = torch.cat([torch.tensor(o[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1).unsqueeze( + next_obs = torch.cat( + [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 + ).unsqueeze( 0 ) # [N_envs, N_obs] while not done: @@ -35,7 +37,9 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): done = done or truncated cumulative_rew += reward with fabric.device: - next_obs = torch.cat([torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.mlp_keys.encoder], dim=-1) + next_obs = torch.cat( + [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 + ) if cfg.dry_run: done = True diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py index f2acb58f..fc87e50c 100644 --- a/sheeprl/algos/sac_ae/evaluate.py +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -43,13 +43,9 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space = env.action_space if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) act_dim = prod(action_space.shape) target_entropy = -act_dim @@ -57,29 +53,29 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): # Define the encoder and decoder and setup them with fabric. # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: # we do not need it wrapped with the strategy inside actor and critic - cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] - mlp_dims = [observation_space[k].shape[0] for k in cfg.mlp_keys.encoder] + cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.algo.cnn_keys.encoder] + mlp_dims = [observation_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder] cnn_encoder = ( CNNEncoder( in_channels=sum(cnn_channels), features_dim=cfg.algo.encoder.features_dim, - keys=cfg.cnn_keys.encoder, + keys=cfg.algo.cnn_keys.encoder, screen_size=cfg.env.screen_size, cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 else None ) mlp_encoder = ( MLPEncoder( sum(mlp_dims), - cfg.mlp_keys.encoder, + cfg.algo.mlp_keys.encoder, cfg.algo.encoder.dense_units, cfg.algo.encoder.mlp_layers, eval(cfg.algo.encoder.dense_act), cfg.algo.encoder.layer_norm, ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 5b8a471b..9e18abbe 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -66,8 +66,8 @@ def train( data = data.to(fabric.device) normalized_obs = {} normalized_next_obs = {} - for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder: - if k in cfg.cnn_keys.encoder: + for k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: normalized_obs[k] = data[k] / 255.0 normalized_next_obs[k] = data[f"next_{k}"] / 255.0 else: @@ -117,8 +117,8 @@ def train( hidden = encoder(normalized_obs) reconstruction = decoder(hidden) reconstruction_loss = 0 - for k in cfg.cnn_keys.decoder + cfg.mlp_keys.decoder: - target = preprocess_obs(data[k], bits=5) if k in cfg.cnn_keys.decoder else data[k] + for k in cfg.algo.cnn_keys.decoder + cfg.algo.mlp_keys.decoder: + target = preprocess_obs(data[k], bits=5) if k in cfg.algo.cnn_keys.decoder else data[k] reconstruction_loss += ( F.mse_loss(target, reconstruction[k]) # Reconstruction + cfg.algo.decoder.l2_lambda * (0.5 * hidden.pow(2).sum(1)).mean() # L2 penalty on the hidden state @@ -151,7 +151,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Resume from checkpoint if cfg.checkpoint.resume_from: state = fabric.load(cfg.checkpoint.resume_from) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # These arguments cannot be changed cfg.env.screen_size = 64 @@ -183,31 +183,27 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") - if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: - raise RuntimeError( - "You should specify at least one CNN keys or MLP keys from the cli: " - "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" - ) + if ( - len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0 - and len(set(cfg.mlp_keys.encoder).intersection(set(cfg.mlp_keys.decoder))) == 0 + len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 + and len(set(cfg.algo.mlp_keys.encoder).intersection(set(cfg.algo.mlp_keys.decoder))) == 0 ): raise RuntimeError("The CNN keys or the MLP keys of the encoder and decoder must not be disjoint") - if len(set(cfg.cnn_keys.decoder) - set(cfg.cnn_keys.encoder)) > 0: + if len(set(cfg.algo.cnn_keys.decoder) - set(cfg.algo.cnn_keys.encoder)) > 0: raise RuntimeError( "The CNN keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.cnn_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.cnn_keys.decoder))}" ) - if len(set(cfg.mlp_keys.decoder) - set(cfg.mlp_keys.encoder)) > 0: + if len(set(cfg.algo.mlp_keys.decoder) - set(cfg.algo.mlp_keys.encoder)) > 0: raise RuntimeError( "The MLP keys of the decoder must be contained in the encoder ones. " - f"Those keys are decoded without being encoded: {list(set(cfg.mlp_keys.decoder))}" + f"Those keys are decoded without being encoded: {list(set(cfg.algo.mlp_keys.decoder))}" ) if cfg.metric.log_level > 0: - fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) - fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - fabric.print("Decoder CNN keys:", cfg.cnn_keys.decoder) - fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) + fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) # Define the agent and the optimizer and setup them with Fabric act_dim = prod(envs.single_action_space.shape) @@ -216,29 +212,29 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Define the encoder and decoder and setup them with fabric. # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: # we do not need it wrapped with the strategy inside actor and critic - cnn_channels = [prod(envs.single_observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] - mlp_dims = [envs.single_observation_space[k].shape[0] for k in cfg.mlp_keys.encoder] + cnn_channels = [prod(envs.single_observation_space[k].shape[:-2]) for k in cfg.algo.cnn_keys.encoder] + mlp_dims = [envs.single_observation_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder] cnn_encoder = ( CNNEncoder( in_channels=sum(cnn_channels), features_dim=cfg.algo.encoder.features_dim, - keys=cfg.cnn_keys.encoder, + keys=cfg.algo.cnn_keys.encoder, screen_size=cfg.env.screen_size, cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 else None ) mlp_encoder = ( MLPEncoder( sum(mlp_dims), - cfg.mlp_keys.encoder, + cfg.algo.mlp_keys.encoder, cfg.algo.encoder.dense_units, cfg.algo.encoder.mlp_layers, eval(cfg.algo.encoder.dense_act), cfg.algo.encoder.layer_norm, ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 else None ) encoder = MultiEncoder(cnn_encoder, mlp_encoder) @@ -246,25 +242,25 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): CNNDecoder( cnn_encoder.conv_output_shape, features_dim=encoder.output_dim, - keys=cfg.cnn_keys.decoder, + keys=cfg.algo.cnn_keys.decoder, channels=cnn_channels, screen_size=cfg.env.screen_size, cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier, ) - if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 + if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 else None ) mlp_decoder = ( MLPDecoder( encoder.output_dim, mlp_dims, - cfg.mlp_keys.decoder, + cfg.algo.mlp_keys.decoder, cfg.algo.decoder.dense_units, cfg.algo.decoder.mlp_layers, eval(cfg.algo.decoder.dense_act), cfg.algo.decoder.layer_norm, ) - if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0 + if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 else None ) decoder = MultiDecoder(cnn_decoder, mlp_decoder) @@ -338,7 +334,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device=fabric.device if cfg.buffer.memmap else "cpu", memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - obs_keys=cfg.cnn_keys.encoder + cfg.mlp_keys.encoder, + obs_keys=cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): @@ -358,7 +354,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 time.time() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 + num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step @@ -383,11 +379,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] obs = {} for k in o.keys(): - if k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder: + if k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder: torch_obs = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() obs[k] = torch_obs @@ -401,7 +397,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = envs.action_space.sample() else: with torch.no_grad(): - normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} + normalized_obs = {k: v / 255 if k in cfg.algo.cnn_keys.encoder else v for k, v in obs.items()} actions, _ = agent.actor.module(normalized_obs) actions = actions.cpu().numpy() o, rewards, dones, truncated, infos = envs.step(actions) @@ -428,19 +424,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): next_obs = {} for k in real_next_obs.keys(): next_obs[k] = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: next_obs[k] = next_obs[k].view(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: next_obs[k] = next_obs[k].float() step_data[k] = obs[k] if not cfg.buffer.sample_next_obs: step_data[f"next_{k}"] = torch.from_numpy(real_next_obs[k]).to(fabric.device) - if k in cfg.cnn_keys.encoder: + if k in cfg.algo.cnn_keys.encoder: step_data[f"next_{k}"] = step_data[f"next_{k}"].view( cfg.env.num_envs, -1, *step_data[f"next_{k}"].shape[-2:] ) - if k in cfg.mlp_keys.encoder: + if k in cfg.algo.mlp_keys.encoder: step_data[f"next_{k}"] = step_data[f"next_{k}"].float() actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float().to(fabric.device) rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float().to(fabric.device) @@ -460,7 +456,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # We sample one time to reduce the communications between processes sample = rb.sample( - training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, ) # [G*B, 1] gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] @@ -475,11 +471,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): drop_last=False, ) sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) else: sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + sampler=range(len(gathered_data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) # Start training @@ -548,7 +544,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "encoder_optimizer": encoder_optimizer.state_dict(), "decoder_optimizer": decoder_optimizer.state_dict(), "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/configs/algo/default.yaml b/sheeprl/configs/algo/default.yaml index 35298376..b4b1c476 100644 --- a/sheeprl/configs/algo/default.yaml +++ b/sheeprl/configs/algo/default.yaml @@ -1 +1,9 @@ name: ??? +total_steps: ??? +per_rank_batch_size: ??? + +# Encoder and decoder keys +cnn_keys: + encoder: [] +mlp_keys: + encoder: [] diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index 06cd789f..212668f5 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -11,9 +11,16 @@ horizon: 15 name: dreamer_v1 # Training recipe +train_every: 1000 learning_starts: 5000 per_rank_gradient_steps: 100 -train_every: 1000 +per_rank_sequence_length: ??? + +# Encoder and decoder keys +cnn_keys: + decoder: ${algo.cnn_keys.encoder} +mlp_keys: + decoder: ${algo.mlp_keys.encoder} # Model related parameters dense_units: 400 diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index dd8f2a9c..719d91ec 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -11,10 +11,17 @@ lmbda: 0.95 horizon: 15 # Training recipe +train_every: 5 learning_starts: 1000 -per_rank_pretrain_steps: 100 per_rank_gradient_steps: 1 -train_every: 5 +per_rank_pretrain_steps: 100 +per_rank_sequence_length: ??? + +# Encoder and decoder keys +cnn_keys: + decoder: ${algo.cnn_keys.encoder} +mlp_keys: + decoder: ${algo.mlp_keys.encoder} # Model related parameters layer_norm: False diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 0a37436e..2fc0bace 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -11,10 +11,17 @@ lmbda: 0.95 horizon: 15 # Training recipe +train_every: 16 learning_starts: 65536 per_rank_pretrain_steps: 1 per_rank_gradient_steps: 1 -train_every: 16 +per_rank_sequence_length: ??? + +# Encoder and decoder keys +cnn_keys: + decoder: ${algo.cnn_keys.encoder} +mlp_keys: + decoder: ${algo.mlp_keys.encoder} # Model related parameters layer_norm: True @@ -110,7 +117,7 @@ actor: expl_min: 0.0 expl_decay: False max_step_expl_decay: 0 - + # Disttributed percentile model (used to scale the values) moments: decay: 0.99 diff --git a/sheeprl/configs/algo/ppo.yaml b/sheeprl/configs/algo/ppo.yaml index 9a8a317f..bb3de2b3 100644 --- a/sheeprl/configs/algo/ppo.yaml +++ b/sheeprl/configs/algo/ppo.yaml @@ -50,4 +50,4 @@ critic: # Single optimizer for both actor and critic optimizer: lr: 1e-3 - eps: 1e-4 \ No newline at end of file + eps: 1e-4 diff --git a/sheeprl/configs/algo/ppo_recurrent.yaml b/sheeprl/configs/algo/ppo_recurrent.yaml index cdd7da90..d4bac2cd 100644 --- a/sheeprl/configs/algo/ppo_recurrent.yaml +++ b/sheeprl/configs/algo/ppo_recurrent.yaml @@ -13,6 +13,7 @@ max_grad_norm: 0.5 anneal_ent_coef: True normalize_advantages: True reset_recurrent_state_on_done: True +per_rank_sequence_length: ??? # Model related parameters mlp_layers: 1 diff --git a/sheeprl/configs/algo/sac_ae.yaml b/sheeprl/configs/algo/sac_ae.yaml index 3d38156f..e7dfd94b 100644 --- a/sheeprl/configs/algo/sac_ae.yaml +++ b/sheeprl/configs/algo/sac_ae.yaml @@ -16,6 +16,12 @@ mlp_layers: 2 dense_act: torch.nn.ReLU layer_norm: False +# Encoder and decoder keys +cnn_keys: + decoder: ${algo.cnn_keys.encoder} +mlp_keys: + decoder: ${algo.mlp_keys.encoder} + # Encoder encoder: tau: 0.05 diff --git a/sheeprl/configs/config.yaml b/sheeprl/configs/config.yaml index 7b1afa16..437fdfb7 100644 --- a/sheeprl/configs/config.yaml +++ b/sheeprl/configs/config.yaml @@ -14,7 +14,6 @@ defaults: - exp: ??? num_threads: 1 -total_steps: ??? # Set it to True to run a single optimization step dry_run: False @@ -27,11 +26,3 @@ torch_deterministic: False exp_name: "default" run_name: ${now:%Y-%m-%d_%H-%M-%S}_${exp_name}_${seed} root_dir: ${algo.name}/${env.id} - -# Encoder and decoder keys -cnn_keys: - encoder: [] - decoder: ${cnn_keys.encoder} -mlp_keys: - encoder: [] - decoder: ${mlp_keys.encoder} diff --git a/sheeprl/configs/env_config.yaml b/sheeprl/configs/env_config.yaml index 5f7fdb73..d36ba66b 100644 --- a/sheeprl/configs/env_config.yaml +++ b/sheeprl/configs/env_config.yaml @@ -16,7 +16,7 @@ run_name: ${env.id} agent: ??? cnn_keys: encoder: [] - decoder: ${cnn_keys.encoder} + decoder: ${algo.cnn_keys.encoder} mlp_keys: encoder: [] - decoder: ${mlp_keys.encoder} + decoder: ${algo.mlp_keys.encoder} diff --git a/sheeprl/configs/exp/dreamer_v1.yaml b/sheeprl/configs/exp/dreamer_v1.yaml index 0b91c8e5..75723073 100644 --- a/sheeprl/configs/exp/dreamer_v1.yaml +++ b/sheeprl/configs/exp/dreamer_v1.yaml @@ -5,10 +5,14 @@ defaults: - override /env: atari - _self_ -# Experiment -total_steps: 5000000 -per_rank_batch_size: 50 -per_rank_sequence_length: 50 +# Algorithm +algo: + total_steps: 5000000 + per_rank_batch_size: 50 + per_rank_sequence_length: 50 + cnn_keys: + encoder: [rgb] + decoder: [rgb] # Checkpoint checkpoint: @@ -25,7 +29,7 @@ distribution: metric: aggregator: - metrics: + metrics: Loss/world_model_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} @@ -68,7 +72,3 @@ metric: Grads/critic: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - -cnn_keys: - encoder: [rgb] - decoder: [rgb] \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v2.yaml b/sheeprl/configs/exp/dreamer_v2.yaml index ab61df1d..c06b01ad 100644 --- a/sheeprl/configs/exp/dreamer_v2.yaml +++ b/sheeprl/configs/exp/dreamer_v2.yaml @@ -5,10 +5,14 @@ defaults: - override /env: atari - _self_ -# Experiment -total_steps: 5000000 -per_rank_batch_size: 16 -per_rank_sequence_length: 50 +# Algorithm +algo: + total_steps: 5000000 + per_rank_batch_size: 16 + per_rank_sequence_length: 50 + cnn_keys: + encoder: [rgb] + decoder: [rgb] # Checkpoint checkpoint: @@ -28,49 +32,45 @@ distribution: metric: aggregator: metrics: - Loss/world_model_loss: + Loss/world_model_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/value_loss: + Loss/value_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/policy_loss: + Loss/policy_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/observation_loss: + Loss/observation_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/reward_loss: + Loss/reward_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/state_loss: + Loss/state_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/continue_loss: + Loss/continue_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/post_entropy: + State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/prior_entropy: + State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/kl: + State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount: + Params/exploration_amount: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/world_model: + Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/actor: + Grads/actor: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/critic: + Grads/critic: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - -cnn_keys: - encoder: [rgb] - decoder: [rgb] \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml index e784b2e3..dc8b146b 100644 --- a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml @@ -7,10 +7,6 @@ defaults: seed: 5 -# Experiment -total_steps: 200000000 -per_rank_batch_size: 32 - # Environment env: max_episode_steps: 27000 @@ -31,9 +27,11 @@ buffer: algo: gamma: 0.995 train_every: 16 + total_steps: 200000000 + learning_starts: 200000 + per_rank_batch_size: 32 per_rank_pretrain_steps: 1 per_rank_gradient_steps: 1 - learning_starts: 200000 world_model: use_continues: True kl_free_nats: 0.0 diff --git a/sheeprl/configs/exp/dreamer_v3.yaml b/sheeprl/configs/exp/dreamer_v3.yaml index 91b5cd94..ea108d6c 100644 --- a/sheeprl/configs/exp/dreamer_v3.yaml +++ b/sheeprl/configs/exp/dreamer_v3.yaml @@ -5,10 +5,14 @@ defaults: - override /env: atari - _self_ -# Experiment -total_steps: 5000000 -per_rank_batch_size: 16 -per_rank_sequence_length: 64 +# Algorithm +algo: + total_steps: 5000000 + per_rank_batch_size: 16 + per_rank_sequence_length: 64 + cnn_keys: + encoder: [rgb] + decoder: [rgb] # Checkpoint checkpoint: @@ -26,49 +30,45 @@ distribution: metric: aggregator: metrics: - Loss/world_model_loss: + Loss/world_model_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/value_loss: + Loss/value_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/policy_loss: + Loss/policy_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/observation_loss: + Loss/observation_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/reward_loss: + Loss/reward_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/state_loss: + Loss/state_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/continue_loss: + Loss/continue_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/kl: + State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/post_entropy: + State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/prior_entropy: + State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount: + Params/exploration_amount: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/world_model: + Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/actor: + Grads/actor: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/critic: + Grads/critic: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - -cnn_keys: - encoder: [rgb] - decoder: [rgb] \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml index b090ce42..1ac25653 100644 --- a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml +++ b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml @@ -7,7 +7,6 @@ defaults: # Experiment seed: 5 -total_steps: 100000 # Environment env: @@ -30,10 +29,11 @@ buffer: # Algorithm algo: - learning_starts: 1024 + mlp_layers: 2 train_every: 1 dense_units: 512 - mlp_layers: 2 + total_steps: 100000 + learning_starts: 1024 world_model: encoder: cnn_channels_multiplier: 32 diff --git a/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml index 852a2356..651cd9b7 100644 --- a/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml @@ -7,7 +7,6 @@ defaults: # Experiment seed: 5 -total_steps: 100000 # Environment env: @@ -27,6 +26,7 @@ buffer: # Algorithm algo: learning_starts: 1024 + total_steps: 100000 train_every: 1 dense_units: 512 mlp_layers: 2 diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml index ed40f8fe..26566b6c 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml @@ -26,24 +26,6 @@ checkpoint: buffer: checkpoint: True -# The CNN and MLP keys of the decoder are the same as those of the encoder by default -cnn_keys: - encoder: - - frame -mlp_keys: - encoder: - - P1_actions_attack - - P1_actions_move - - P1_oppChar - - P1_oppHealth - - P1_oppSide - - P1_oppWins - - P1_ownChar - - P1_ownHealth - - P1_ownSide - - P1_ownWins - - stage - # Algorithm algo: learning_starts: 65536 @@ -59,3 +41,19 @@ algo: hidden_size: 768 representation_model: hidden_size: 768 + cnn_keys: + encoder: + - frame + mlp_keys: + encoder: + - P1_actions_attack + - P1_actions_move + - P1_oppChar + - P1_oppHealth + - P1_oppSide + - P1_oppWins + - P1_ownChar + - P1_ownHealth + - P1_ownSide + - P1_ownWins + - stage diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml index 53e6bf86..68bf3b41 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml @@ -7,8 +7,6 @@ defaults: # Experiment seed: 0 -total_steps: 10000000 -per_rank_batch_size: 8 # Environment env: @@ -34,39 +32,10 @@ checkpoint: buffer: checkpoint: True -# The CNN and MLP keys of the decoder are the same as those of the encoder by default -cnn_keys: - encoder: - - frame -mlp_keys: - encoder: - - own_character - - own_health - - own_side - - own_wins - - opp_character - - opp_health - - opp_side - - opp_wins - - stage - - timer - - action - - reward - decoder: - - own_character - - own_health - - own_side - - own_wins - - opp_character - - opp_health - - opp_side - - opp_wins - - stage - - timer - - action - # Algorithm algo: + total_steps: 10000000 + per_rank_batch_size: 8 learning_starts: 65536 train_every: 8 dense_units: 768 @@ -80,6 +49,35 @@ algo: hidden_size: 768 representation_model: hidden_size: 768 + cnn_keys: + encoder: + - frame + mlp_keys: + encoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action + - reward + decoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action # Metric metric: diff --git a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml index 7680f147..d10c9619 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml @@ -25,23 +25,6 @@ checkpoint: buffer: checkpoint: True -# The CNN and MLP keys of the decoder are the same as those of the encoder by default -cnn_keys: - encoder: - - rgb -mlp_keys: - encoder: - - life_stats - - inventory - - max_inventory - - compass - - reward - decoder: - - life_stats - - inventory - - max_inventory - - compass - # Algorithm algo: train_every: 16 @@ -57,3 +40,18 @@ algo: hidden_size: 768 representation_model: hidden_size: 768 + cnn_keys: + encoder: + - rgb + mlp_keys: + encoder: + - life_stats + - inventory + - max_inventory + - compass + - reward + decoder: + - life_stats + - inventory + - max_inventory + - compass diff --git a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml index 7cdfd560..3b6967c3 100644 --- a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml @@ -22,17 +22,6 @@ checkpoint: buffer: checkpoint: True -# The CNN and MLP keys of the decoder are the same as those of the encoder by default -cnn_keys: - encoder: - - rgb - decoder: - - rgb -mlp_keys: - encoder: - - reward - decoder: [] - # Algorithm algo: train_every: 2 @@ -48,3 +37,12 @@ algo: hidden_size: 1024 representation_model: hidden_size: 1024 + cnn_keys: + encoder: + - rgb + decoder: + - rgb + mlp_keys: + encoder: + - reward + decoder: [] diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index 66bcc776..b97262b9 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -7,13 +7,6 @@ defaults: # Experiment seed: 5 -total_steps: 1000000 -cnn_keys: - encoder: - - rgb -mlp_keys: - encoder: - - state # Environment env: @@ -36,6 +29,13 @@ buffer: # Algorithm algo: + total_steps: 1000000 + cnn_keys: + encoder: + - rgb + mlp_keys: + encoder: + - state learning_starts: 8000 train_every: 2 dense_units: 512 diff --git a/sheeprl/configs/exp/p2e_dv1_finetuning.yaml b/sheeprl/configs/exp/p2e_dv1_finetuning.yaml index 054c9aa5..70ca4e61 100644 --- a/sheeprl/configs/exp/p2e_dv1_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv1_finetuning.yaml @@ -5,11 +5,10 @@ defaults: - override /algo: p2e_dv1 - _self_ -total_steps: 1000000 - algo: name: p2e_dv1_finetuning learning_starts: 5000 + total_steps: 1000000 player: actor_type: exploration @@ -21,7 +20,7 @@ checkpoint: metric: aggregator: - metrics: + metrics: Loss/world_model_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} @@ -66,4 +65,4 @@ metric: sync_on_compute: ${metric.sync_on_compute} Grads/critic: _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} \ No newline at end of file + sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml index 7635f3c0..f3ff1d46 100644 --- a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml @@ -5,11 +5,10 @@ defaults: - override /algo: p2e_dv2 - _self_ -total_steps: 1000000 - algo: name: p2e_dv2_finetuning learning_starts: 10000 + total_steps: 1000000 player: actor_type: exploration @@ -22,48 +21,48 @@ checkpoint: metric: aggregator: metrics: - Loss/world_model_loss: + Loss/world_model_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/value_loss: + Loss/value_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/policy_loss: + Loss/policy_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/observation_loss: + Loss/observation_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/reward_loss: + Loss/reward_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/state_loss: + Loss/state_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Loss/continue_loss: + Loss/continue_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/post_entropy: + State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/prior_entropy: + State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - State/kl: + State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: + Params/exploration_amount_task: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: + Params/exploration_amount_exploration: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/world_model: + Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/actor: + Grads/actor: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Grads/critic: + Grads/critic: _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} \ No newline at end of file + sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml index 2c8e3b80..df1f356b 100644 --- a/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml @@ -7,9 +7,6 @@ defaults: # Experiment seed: 0 -total_steps: 20000000 -per_rank_batch_size: 4 -per_rank_sequence_length: 64 # Environment env: @@ -35,37 +32,6 @@ checkpoint: buffer: checkpoint: True -# The CNN and MLP keys of the decoder are the same as those of the encoder by default -cnn_keys: - encoder: [frame] - decoder: [frame] -mlp_keys: - encoder: - - own_character - - own_health - - own_side - - own_wins - - opp_character - - opp_health - - opp_side - - opp_wins - - stage - - timer - - action - - reward - decoder: - - own_character - - own_health - - own_side - - own_wins - - opp_character - - opp_health - - opp_side - - opp_wins - - stage - - timer - - action - # Algorithm algo: learning_starts: 131072 @@ -81,6 +47,38 @@ algo: hidden_size: 768 representation_model: hidden_size: 768 + total_steps: 20000000 + per_rank_batch_size: 4 + per_rank_sequence_length: 64 + cnn_keys: + encoder: [frame] + decoder: [frame] + mlp_keys: + encoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action + - reward + decoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action # Metric metric: @@ -88,4 +86,4 @@ metric: fabric: precision: bf16 - accelerator: gpu \ No newline at end of file + accelerator: gpu diff --git a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml index 996e1ab9..502f8fcd 100644 --- a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml @@ -5,11 +5,10 @@ defaults: - override /algo: p2e_dv3 - _self_ -total_steps: 1000000 - algo: name: p2e_dv3_finetuning learning_starts: 65536 + total_steps: 1000000 player: actor_type: exploration diff --git a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml index a0f736f6..c35a9a98 100644 --- a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml @@ -7,9 +7,6 @@ defaults: # Experiment seed: 0 -total_steps: 5000000 -per_rank_batch_size: 16 -per_rank_sequence_length: 64 # Environment env: @@ -37,37 +34,6 @@ checkpoint: buffer: checkpoint: True -# The CNN and MLP keys of the decoder are the same as those of the encoder by default -cnn_keys: - encoder: [frame] - decoder: [frame] -mlp_keys: - encoder: - - own_character - - own_health - - own_side - - own_wins - - opp_character - - opp_health - - opp_side - - opp_wins - - stage - - timer - - action - - reward - decoder: - - own_character - - own_health - - own_side - - own_wins - - opp_character - - opp_health - - opp_side - - opp_wins - - stage - - timer - - action - # Algorithm algo: learning_starts: 65536 @@ -83,6 +49,38 @@ algo: hidden_size: 768 representation_model: hidden_size: 768 + total_steps: 5000000 + per_rank_batch_size: 16 + per_rank_sequence_length: 64 + cnn_keys: + encoder: [frame] + decoder: [frame] + mlp_keys: + encoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action + - reward + decoder: + - own_character + - own_health + - own_side + - own_wins + - opp_character + - opp_health + - opp_side + - opp_wins + - stage + - timer + - action # Metric metric: @@ -90,4 +88,4 @@ metric: fabric: precision: bf16 - accelerator: gpu \ No newline at end of file + accelerator: gpu diff --git a/sheeprl/configs/exp/ppo.yaml b/sheeprl/configs/exp/ppo.yaml index 37819413..c671bb31 100644 --- a/sheeprl/configs/exp/ppo.yaml +++ b/sheeprl/configs/exp/ppo.yaml @@ -5,9 +5,12 @@ defaults: - override /env: gym - _self_ -# Experiment -total_steps: 65536 -per_rank_batch_size: 64 +# Algorithm +algo: + total_steps: 65536 + per_rank_batch_size: 64 + mlp_keys: + encoder: [state] # Buffer buffer: @@ -26,6 +29,3 @@ metric: Loss/entropy_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - -mlp_keys: - encoder: [state] \ No newline at end of file diff --git a/sheeprl/configs/exp/ppo_recurrent.yaml b/sheeprl/configs/exp/ppo_recurrent.yaml index 8ee55be7..56a3720a 100644 --- a/sheeprl/configs/exp/ppo_recurrent.yaml +++ b/sheeprl/configs/exp/ppo_recurrent.yaml @@ -6,14 +6,12 @@ defaults: - _self_ algo: + per_rank_num_batches: 4 + per_rank_sequence_length: 8 + total_steps: 409600 rollout_steps: 256 update_epochs: 4 -# Experiment -per_rank_num_batches: 4 -per_rank_sequence_length: 8 -total_steps: 409600 - # Environment env: id: CartPole-v1 diff --git a/sheeprl/configs/exp/sac.yaml b/sheeprl/configs/exp/sac.yaml index 55713c26..065612e0 100644 --- a/sheeprl/configs/exp/sac.yaml +++ b/sheeprl/configs/exp/sac.yaml @@ -5,9 +5,13 @@ defaults: - override /env: gym - _self_ -# Experiment -total_steps: 1000000 -per_rank_batch_size: 256 +# Algorithm +algo: + total_steps: 1000000 + per_rank_batch_size: 256 + mlp_keys: + encoder: [state] + decoder: [state] # Checkpoint checkpoint: @@ -35,7 +39,3 @@ metric: Loss/alpha_loss: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - -mlp_keys: - encoder: [state] - decoder: [state] \ No newline at end of file diff --git a/sheeprl/configs/exp/sac_ae.yaml b/sheeprl/configs/exp/sac_ae.yaml index d54aeddd..de7900ba 100644 --- a/sheeprl/configs/exp/sac_ae.yaml +++ b/sheeprl/configs/exp/sac_ae.yaml @@ -5,8 +5,9 @@ defaults: - override /algo: sac_ae - _self_ -# Experiment -per_rank_batch_size: 128 +# Algorithm +algo: + per_rank_batch_size: 128 # Environmment env: @@ -15,6 +16,6 @@ env: metric: aggregator: metrics: - Loss/reconstruction_loss: + Loss/reconstruction_loss: _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} \ No newline at end of file + sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 82a96a3c..4d5e5318 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -82,56 +82,77 @@ def thunk() -> gym.Env: if "mask_velocities" in cfg.env and cfg.env.mask_velocities: env = MaskVelocityWrapper(env) + if not ( + isinstance(cfg.algo.mlp_keys.encoder, list) + and isinstance(cfg.algo.cnn_keys.encoder, list) + and len(cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder) > 0 + ): + raise ValueError( + "`cnn_keys.encoder` and `mlp_keys.encoder` must be lists of strings, got: " + f"cnn encoder keys `{cfg.algo.mlp_keys.encoder}` and mlp encoder keys `{cfg.algo.cnn_keys.encoder}`. " + "Both lists must not be empty." + ) + # Create observation dict if isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) < 2: - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0: - if len(cfg.cnn_keys.encoder) > 1: + # Vector only observation + if len(cfg.algo.cnn_keys.encoder) > 0: + if len(cfg.algo.cnn_keys.encoder) > 1: warnings.warn( "Multiple cnn keys have been specified and only one pixel observation " f"is allowed in {cfg.env.id}, " - f"only the first one is kept: {cfg.cnn_keys.encoder[0]}" + f"only the first one is kept: {cfg.algo.cnn_keys.encoder[0]}" ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0: - gym.wrappers.pixel_observation.STATE_KEY = cfg.mlp_keys.encoder[0] + if len(cfg.algo.mlp_keys.encoder) > 0: + gym.wrappers.pixel_observation.STATE_KEY = cfg.algo.mlp_keys.encoder[0] env = gym.wrappers.PixelObservationWrapper( - env, pixels_only=len(cfg.mlp_keys.encoder) == 0, pixel_keys=(cfg.cnn_keys.encoder[0],) + env, pixels_only=len(cfg.algo.mlp_keys.encoder) == 0, pixel_keys=(cfg.algo.cnn_keys.encoder[0],) ) else: - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0: - if len(cfg.mlp_keys.encoder) > 1: - warnings.warn( - "Multiple mlp keys have been specified and only one pixel observation " - f"is allowed in {cfg.env.id}, " - f"only the first one is kept: {cfg.mlp_keys.encoder[0]}" - ) - mlp_key = cfg.mlp_keys.encoder[0] - else: - mlp_key = "state" - cfg.mlp_keys.encoder = [mlp_key] + if len(cfg.algo.mlp_keys.encoder) > 1: + warnings.warn( + "Multiple mlp keys have been specified and only one pixel observation " + f"is allowed in {cfg.env.id}, " + f"only the first one is kept: {cfg.algo.mlp_keys.encoder[0]}" + ) + mlp_key = cfg.algo.mlp_keys.encoder[0] env = gym.wrappers.TransformObservation(env, lambda obs: {mlp_key: obs}) env.observation_space = gym.spaces.Dict({mlp_key: env.observation_space}) elif isinstance(env.observation_space, gym.spaces.Box) and 2 <= len(env.observation_space.shape) <= 3: - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 1: + # Pixel only observation + if len(cfg.algo.cnn_keys.encoder) > 1: warnings.warn( "Multiple cnn keys have been specified and only one pixel observation " f"is allowed in {cfg.env.id}, " - f"only the first one is kept: {cfg.cnn_keys.encoder[0]}" + f"only the first one is kept: {cfg.algo.cnn_keys.encoder[0]}" ) - cnn_key = cfg.cnn_keys.encoder[0] - else: - cnn_key = "rgb" - cfg.cnn_keys.encoder = [cnn_key] + elif len(cfg.algo.cnn_keys.encoder) == 0: + raise ValueError( + "You have selected a pixel observation but no cnn key has been specified. " + "Please set at least one cnn key in the config file: `algo.cnn_keys.encoder=[your_cnn_key]`" + ) + cnn_key = cfg.algo.cnn_keys.encoder[0] env = gym.wrappers.TransformObservation(env, lambda obs: {cnn_key: obs}) env.observation_space = gym.spaces.Dict({cnn_key: env.observation_space}) + if ( + len( + set(k for k in env.observation_space.keys()).intersection( + set(cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder) + ) + ) + == 0 + ): + raise ValueError( + f"The user specified keys `{cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder}` " + "are not a subset of the " + f"environment `{env.observation_space.keys()}` observation keys. Please check your config file." + ) + env_cnn_keys = set( [k for k in env.observation_space.spaces.keys() if len(env.observation_space[k].shape) in {2, 3}] ) - if cfg.cnn_keys.encoder is None: - user_cnn_keys = set() - else: - user_cnn_keys = set(cfg.cnn_keys.encoder) - cnn_keys = env_cnn_keys.intersection(user_cnn_keys) + cnn_keys = env_cnn_keys.intersection(set(cfg.algo.cnn_keys.encoder)) def transform_obs(obs: Dict[str, Any]): for k in cnn_keys: diff --git a/tests/run_tests.py b/tests/run_tests.py index 3daed75a..e1098964 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -3,4 +3,4 @@ import pytest if __name__ == "__main__": - sys.exit(pytest.main(["-s", "--cov=sheeprl", "-vv"])) + sys.exit(pytest.main(["-s", "--cov=sheeprl", "-vv", "tests/test_algos/test_cli.py"])) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 8a8d1105..fb57d814 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -63,7 +63,7 @@ def test_droq(standard_args, start_time): run_name = "test_droq" args = standard_args + [ "exp=droq", - "per_rank_batch_size=1", + "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -82,7 +82,7 @@ def test_sac(standard_args, start_time): run_name = "test_sac" args = standard_args + [ "exp=sac", - "per_rank_batch_size=1", + "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -101,14 +101,14 @@ def test_sac_ae(standard_args, start_time): run_name = "test_sac_ae" args = standard_args + [ "exp=sac_ae", - "per_rank_batch_size=1", + "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", f"root_dir={root_dir}", f"run_name={run_name}", - "mlp_keys.encoder=[state]", - "cnn_keys.encoder=[rgb]", + "algo.mlp_keys.encoder=[state]", + "algo.cnn_keys.encoder=[rgb]", "env.screen_size=64", "algo.hidden_size=4", "algo.dense_units=4", @@ -128,7 +128,7 @@ def test_sac_decoupled(standard_args, start_time): run_name = "test_sac_decoupled" args = standard_args + [ "exp=sac_decoupled", - "per_rank_batch_size=1", + "algo.per_rank_batch_size=1", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", f"fabric.devices={os.environ['LT_DEVICES']}", @@ -153,12 +153,12 @@ def test_ppo(standard_args, start_time, env_id): "exp=ppo", "env=dummy", f"algo.rollout_steps={os.environ['LT_DEVICES']}", - "per_rank_batch_size=1", + "algo.per_rank_batch_size=1", f"root_dir={root_dir}", f"run_name={run_name}", f"env.id={env_id}", - "cnn_keys.encoder=[rgb]", - "mlp_keys.encoder=[]", + "algo.cnn_keys.encoder=[rgb]", + "algo.mlp_keys.encoder=[]", ] with mock.patch.object(sys, "argv", args): @@ -176,13 +176,13 @@ def test_ppo_decoupled(standard_args, start_time, env_id): "env=dummy", f"fabric.devices={os.environ['LT_DEVICES']}", f"algo.rollout_steps={os.environ['LT_DEVICES']}", - "per_rank_batch_size=1", + "algo.per_rank_batch_size=1", "algo.update_epochs=1", f"root_dir={root_dir}", f"run_name={run_name}", f"env.id={env_id}", - "cnn_keys.encoder=[rgb]", - "mlp_keys.encoder=[]", + "algo.cnn_keys.encoder=[rgb]", + "algo.mlp_keys.encoder=[]", ] with mock.patch.object(sys, "argv", args): @@ -200,8 +200,8 @@ def test_ppo_recurrent(standard_args, start_time): args = standard_args + [ "exp=ppo_recurrent", "algo.rollout_steps=2", - "per_rank_batch_size=1", - "per_rank_sequence_length=2", + "algo.per_rank_batch_size=1", + "algo.per_rank_sequence_length=2", "algo.update_epochs=2", f"root_dir={root_dir}", f"run_name={run_name}", @@ -220,8 +220,8 @@ def test_dreamer_v1(standard_args, env_id, start_time): args = standard_args + [ "exp=dreamer_v1", "env=dummy", - "per_rank_batch_size=1", - "per_rank_sequence_length=1", + "algo.per_rank_batch_size=1", + "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -232,8 +232,8 @@ def test_dreamer_v1(standard_args, env_id, start_time): "algo.dense_units=8", "algo.world_model.encoder.cnn_channels_multiplier=2", "algo.world_model.recurrent_model.recurrent_state_size=8", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", ] with mock.patch.object(sys, "argv", args): @@ -252,8 +252,8 @@ def test_p2e_dv1(standard_args, env_id, start_time): args = standard_args + [ "exp=p2e_dv1_exploration", "env=dummy", - "per_rank_batch_size=2", - "per_rank_sequence_length=2", + "algo.per_rank_batch_size=2", + "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -267,8 +267,8 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=2", "algo.world_model.transition_model.hidden_size=2", "buffer.checkpoint=True", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", "checkpoint.save_last=True", ] @@ -293,8 +293,8 @@ def test_p2e_dv1(standard_args, env_id, start_time): args = standard_args + [ "exp=p2e_dv1_finetuning", f"checkpoint.exploration_ckpt_path={ckpt_path}", - "per_rank_batch_size=2", - "per_rank_sequence_length=2", + "algo.per_rank_batch_size=2", + "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -308,8 +308,8 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=2", "algo.world_model.representation_model.hidden_size=2", "algo.world_model.transition_model.hidden_size=2", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", ] with mock.patch.object(sys, "argv", args): run() @@ -324,8 +324,8 @@ def test_dreamer_v2(standard_args, env_id, start_time): args = standard_args + [ "exp=dreamer_v2", "env=dummy", - "per_rank_batch_size=1", - "per_rank_sequence_length=1", + "algo.per_rank_batch_size=1", + "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -338,11 +338,11 @@ def test_dreamer_v2(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", - "cnn_keys.encoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", "algo.per_rank_pretrain_steps=1", "algo.layer_norm=True", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", ] with mock.patch.object(sys, "argv", args): @@ -361,8 +361,8 @@ def test_p2e_dv2(standard_args, env_id, start_time): args = standard_args + [ "exp=p2e_dv2_exploration", "env=dummy", - "per_rank_batch_size=2", - "per_rank_sequence_length=2", + "algo.per_rank_batch_size=2", + "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -376,8 +376,8 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=2", "algo.world_model.transition_model.hidden_size=2", "buffer.checkpoint=True", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", "checkpoint.save_last=True", ] @@ -402,8 +402,8 @@ def test_p2e_dv2(standard_args, env_id, start_time): args = standard_args + [ "exp=p2e_dv2_finetuning", f"checkpoint.exploration_ckpt_path={ckpt_path}", - "per_rank_batch_size=2", - "per_rank_sequence_length=2", + "algo.per_rank_batch_size=2", + "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -417,8 +417,8 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=2", "algo.world_model.representation_model.hidden_size=2", "algo.world_model.transition_model.hidden_size=2", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", ] with mock.patch.object(sys, "argv", args): run() @@ -433,8 +433,8 @@ def test_dreamer_v3(standard_args, env_id, start_time): args = standard_args + [ "exp=dreamer_v3", "env=dummy", - "per_rank_batch_size=1", - "per_rank_sequence_length=1", + "algo.per_rank_batch_size=1", + "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -447,11 +447,11 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", - "cnn_keys.encoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", "algo.layer_norm=True", "algo.train_every=1", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", ] with mock.patch.object(sys, "argv", args): @@ -470,8 +470,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): args = standard_args + [ "exp=p2e_dv3_exploration", "env=dummy", - "per_rank_batch_size=1", - "per_rank_sequence_length=1", + "algo.per_rank_batch_size=1", + "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -487,8 +487,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.layer_norm=True", "algo.train_every=1", "buffer.checkpoint=True", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", "checkpoint.save_last=True", ] @@ -513,8 +513,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): args = standard_args + [ "exp=p2e_dv3_finetuning", f"checkpoint.exploration_ckpt_path={ckpt_path}", - "per_rank_batch_size=1", - "per_rank_sequence_length=1", + "algo.per_rank_batch_size=1", + "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", "algo.per_rank_gradient_steps=1", @@ -530,8 +530,8 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", "algo.train_every=1", - "cnn_keys.encoder=[rgb]", - "cnn_keys.decoder=[rgb]", + "algo.cnn_keys.encoder=[rgb]", + "algo.cnn_keys.decoder=[rgb]", ] with mock.patch.object(sys, "argv", args): run() diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 780b8976..a1dc9a23 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -99,7 +99,7 @@ def test_strategy_warning(): def test_run_decoupled_algo(): subprocess.run( sys.executable + " sheeprl.py exp=ppo_decoupled fabric.strategy=ddp fabric.devices=2 " - "dry_run=True algo.rollout_steps=1 cnn_keys.encoder=[rgb] mlp_keys.encoder=[state] " + "dry_run=True algo.rollout_steps=1 algo.cnn_keys.encoder=[rgb] algo.mlp_keys.encoder=[state] " "env.capture_video=False checkpoint.save_last=False metric.log_level=0 " "metric.disable_timer=True", shell=True, @@ -109,8 +109,8 @@ def test_run_decoupled_algo(): def test_run_algo(): subprocess.run( - sys.executable - + " sheeprl.py exp=ppo dry_run=True algo.rollout_steps=1 cnn_keys.encoder=[rgb] mlp_keys.encoder=[state] " + sys.executable + " sheeprl.py exp=ppo dry_run=True algo.rollout_steps=1 " + "algo.cnn_keys.encoder=[rgb] algo.mlp_keys.encoder=[state] " "env.capture_video=False checkpoint.save_last=False metric.log_level=0 " "metric.disable_timer=True", shell=True, @@ -124,12 +124,12 @@ def test_resume_from_checkpoint(): subprocess.run( sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " - "cnn_keys.encoder=[rgb] cnn_keys.decoder=[rgb] " + "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True per_rank_batch_size=1 per_rank_sequence_length=1 " + "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, @@ -167,12 +167,12 @@ def test_resume_from_checkpoint_env_error(): subprocess.run( sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " - "cnn_keys.encoder=[rgb] cnn_keys.decoder=[rgb] " + "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True per_rank_batch_size=1 per_rank_sequence_length=1 " + "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, @@ -220,12 +220,12 @@ def test_resume_from_checkpoint_algo_error(): subprocess.run( sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " - "cnn_keys.encoder=[rgb] cnn_keys.decoder=[rgb] " + "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True per_rank_batch_size=1 per_rank_sequence_length=1 " + "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, @@ -275,12 +275,12 @@ def test_evaluate(): subprocess.run( sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " - "cnn_keys.encoder=[rgb] cnn_keys.decoder=[rgb] " + "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " - "algo.layer_norm=True per_rank_batch_size=1 per_rank_sequence_length=1 " + "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, From 0ad40a20f5241e81ef575b09dba99b67b062acc1 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Thu, 23 Nov 2023 18:21:10 +0100 Subject: [PATCH 2/7] Fix cnn_keys and mlp_keys when loading from checkpoint --- sheeprl/cli.py | 1 + sheeprl/utils/env.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 16d17e42..3ca73e5f 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -24,6 +24,7 @@ def resume_from_checkpoint(cfg: DictConfig) -> Dict[str, Any]: run_name = cfg.run_name ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) old_cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") + old_cfg = dotdict(OmegaConf.to_container(old_cfg, resolve=True, throw_on_missing=True)) if old_cfg.env.id != cfg.env.id: raise ValueError( "This experiment is run with a different environment from the one of the experiment you want to restart. " diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 4d5e5318..cf0d21f1 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -89,7 +89,7 @@ def thunk() -> gym.Env: ): raise ValueError( "`cnn_keys.encoder` and `mlp_keys.encoder` must be lists of strings, got: " - f"cnn encoder keys `{cfg.algo.mlp_keys.encoder}` and mlp encoder keys `{cfg.algo.cnn_keys.encoder}`. " + f"cnn encoder keys `{cfg.algo.cnn_keys.encoder}` and mlp encoder keys `{cfg.algo.mlp_keys.encoder}`. " "Both lists must not be empty." ) From fd8f6d0a47bd1787050da985dc5a4e6ce33880f1 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 24 Nov 2023 09:20:06 +0100 Subject: [PATCH 3/7] Update tests --- tests/run_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/run_tests.py b/tests/run_tests.py index e1098964..3daed75a 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -3,4 +3,4 @@ import pytest if __name__ == "__main__": - sys.exit(pytest.main(["-s", "--cov=sheeprl", "-vv", "tests/test_algos/test_cli.py"])) + sys.exit(pytest.main(["-s", "--cov=sheeprl", "-vv"])) From 8956ce78779f1b851fe865ff67f10fceb7cc64ae Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 24 Nov 2023 10:08:01 +0100 Subject: [PATCH 4/7] Update howtos --- howto/configs.md | 31 +++++++++---------- howto/learn_in_atari.md | 2 +- howto/learn_in_diambra.md | 6 ++-- howto/learn_in_dmc.md | 4 +-- howto/learn_in_minedojo.md | 4 +-- howto/learn_in_minerl.md | 2 +- howto/register_new_algorithm.md | 6 ++-- howto/select_observations.md | 14 ++++----- .../work_with_multi-encoder_multi-decoder.md | 4 +-- sheeprl/utils/env.py | 23 ++++++++------ 10 files changed, 49 insertions(+), 47 deletions(-) diff --git a/howto/configs.md b/howto/configs.md index 2f37c559..88693b18 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -106,7 +106,6 @@ defaults: - exp: ??? num_threads: 1 -total_steps: ??? # Set it to True to run a single optimization step dry_run: False @@ -119,14 +118,6 @@ torch_deterministic: False exp_name: "default" run_name: ${now:%Y-%m-%d_%H-%M-%S}_${exp_name}_${seed} root_dir: ${algo.name}/${env.id} - -# Encoder and decoder keys -cnn_keys: - encoder: [] - decoder: ${algo.cnn_keys.encoder} -mlp_keys: - encoder: [] - decoder: ${algo.mlp_keys.encoder} ``` ### Algorithms @@ -148,10 +139,17 @@ lmbda: 0.95 horizon: 15 # Training recipe +train_every: 16 learning_starts: 65536 per_rank_pretrain_steps: 1 per_rank_gradient_steps: 1 -train_every: 16 +per_rank_sequence_length: ??? + +# Encoder and decoder keys +cnn_keys: + decoder: ${algo.cnn_keys.encoder} +mlp_keys: + decoder: ${algo.mlp_keys.encoder} # Model related parameters layer_norm: True @@ -237,14 +235,17 @@ actor: ent_coef: 3e-4 min_std: 0.1 init_std: 0.0 - distribution: "auto" objective_mix: 1.0 dense_act: ${algo.dense_act} mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 - + expl_amount: 0.0 + expl_min: 0.0 + expl_decay: False + max_step_expl_decay: 0 + # Disttributed percentile model (used to scale the values) moments: decay: 0.99 @@ -278,10 +279,6 @@ critic: # Player agent (it interacts with the environment) player: - expl_min: 0.0 - expl_amount: 0.0 - expl_decay: False - max_step_expl_decay: 0 discrete_size: ${algo.world_model.discrete_size} ``` @@ -379,7 +376,6 @@ defaults: # Experiment seed: 5 -total_steps: 100000 # Environment env: @@ -399,6 +395,7 @@ buffer: # Algorithm algo: learning_starts: 1024 + total_steps: 100000 train_every: 1 dense_units: 512 mlp_layers: 2 diff --git a/howto/learn_in_atari.md b/howto/learn_in_atari.md index 74cb4c23..79ec33f7 100644 --- a/howto/learn_in_atari.md +++ b/howto/learn_in_atari.md @@ -24,5 +24,5 @@ The list of selectable algorithms is given below: Once you have chosen the algorithm you want to train, you can start the train, for instance, of the ppo agent by running: ```bash -python sheeprl.py exp=ppo env=atari env.id=PongNoFrameskip-v4 cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 +python sheeprl.py exp=ppo env=atari env.id=PongNoFrameskip-v4 algo.cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 ``` \ No newline at end of file diff --git a/howto/learn_in_diambra.md b/howto/learn_in_diambra.md index bea5a704..5cea01e7 100644 --- a/howto/learn_in_diambra.md +++ b/howto/learn_in_diambra.md @@ -54,14 +54,14 @@ The observation space is slightly modified to be compatible with our algorithms, ## Multi-environments / Distributed training In order to train your agent with multiple environments or to perform distributed training, you have to specify to the `diambra run` command the number of environments you want to instantiate (through the `-s` cli argument). So, you have to multiply the number of environments per single process and the number of processes you want to launch (the number of *player* processes for decoupled algorithms). Thus, in the case of coupled algorithms (e.g., `dreamer_v2`), if you want to distribute your training among $2$ processes each one containing $4$ environments, the total number of environments will be: $2 \cdot 4 = 8$. The command will be: ```bash -diambra run -s=8 python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.num_envs=4 env.sync_env=True cnn_keys.encoder=[frame] fabric.devices=2 +diambra run -s=8 python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.num_envs=4 env.sync_env=True algo.cnn_keys.encoder=[frame] fabric.devices=2 ``` ## Args The IDs of the DIAMBRA environments are specified [here](https://docs.diambra.ai/envs/games/). To train your agent on a DIAMBRA environment you have to select the DIAMBRA configs with the argument `env=diambra`, then set the `env.id` argument to the environment ID, e.g., to train your agent on the *Dead Or Alive ++* game, you have to set the `env.id` argument to `doapp` (i.e., `env.id=doapp`). ```bash -diambra run -s=4 python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.num_envs=4 cnn_keys.encoder=[frame] +diambra run -s=4 python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.num_envs=4 algo.cnn_keys.encoder=[frame] ``` Another possibility is to create a new config file in the `sheeprl/configs/exp` folder, where you specify all the configs you want to use in your experiment. An example of a custom configuration file is available [here](../sheeprl/configs/exp/dreamer_v3_L_doapp.yaml). @@ -120,5 +120,5 @@ diambra run -s=4 python sheeprl.py exp=custom_exp env.num_envs=4 ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: -1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the training command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on a headless machine, you need to run the following command: `xvfb-run diambra run python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.sync_env=True env.num_envs=1 cnn_keys.encoder=[frame] fabric.devices=1` +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the training command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on a headless machine, you need to run the following command: `xvfb-run diambra run python sheeprl.py exp=dreamer_v3 env=diambra env.id=doapp env.sync_env=True env.num_envs=1 algo.cnn_keys.encoder=[frame] fabric.devices=1` 2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/howto/learn_in_dmc.md b/howto/learn_in_dmc.md index 3ec39cbe..65624070 100644 --- a/howto/learn_in_dmc.md +++ b/howto/learn_in_dmc.md @@ -23,12 +23,12 @@ For more information: [https://github.com/deepmind/dm_control](https://github.co In order to train your agents on the [MuJoCo environments](https://gymnasium.farama.org/environments/mujoco/) provided by Gymnasium, it is sufficient to select the *GYM* environment (`env=gym`) and set the `env.id` to the name of the environment you want to use. For instance, `"Walker2d-v4"` if you want to train your agent in the *walker walk* environment. ```bash -python sheeprl.py exp=dreamer_v3 env=gym env.id=Walker2d-v4 cnn_keys.encoder=[rgb] +python sheeprl.py exp=dreamer_v3 env=gym env.id=Walker2d-v4 algo.cnn_keys.encoder=[rgb] ``` ## DeepMind Control In order to train your agents on the [DeepMind control suite](https://github.com/deepmind/dm_control/blob/main/dm_control/suite/README.md), you have to select the *DMC* environment (`env=dmc`) and to set the id of the environment you want to use. A list of the available environments can be found [here](https://arxiv.org/abs/1801.00690). For instance, if you want to train your agent on the *walker walk* environment, you need to set the `env.id` to `"walker_walk"`. ```bash -python sheeprl.py exp=dreamer_v3 env=dmc env.id=walker_walk cnn_keys.encoder=[rgb] +python sheeprl.py exp=dreamer_v3 env=dmc env.id=walker_walk algo.cnn_keys.encoder=[rgb] ``` \ No newline at end of file diff --git a/howto/learn_in_minedojo.md b/howto/learn_in_minedojo.md index 57118260..8a93c7a6 100644 --- a/howto/learn_in_minedojo.md +++ b/howto/learn_in_minedojo.md @@ -34,7 +34,7 @@ It is possible to train your agents on all the tasks provided by MineDojo. You n For instance, you can use the following command to select the MineDojo open-ended environment. ```bash -python sheeprl.py exp=p2e_dv2 env=minedojo env.id=open-ended algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor cnn_keys.encoder=[rgb] +python sheeprl.py exp=p2e_dv2 env=minedojo env.id=open-ended algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor algo.cnn_keys.encoder=[rgb] ``` ### Observation Space @@ -73,5 +73,5 @@ For more information about the MineDojo action space, check [here](https://docs. If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: -1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the training command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on a headless machine, you need to run the following command: `xvfb-run python sheeprl.py exp=p2e_dv2 fabric.devices=1 env=minedojo env.id=open-ended cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`, or `MINEDOJO_HEADLESS=1 python sheeprl.py exp=p2e_dv2 fabric.devices=1 env=minedojo env.id=open-ended cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`. +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the training command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on a headless machine, you need to run the following command: `xvfb-run python sheeprl.py exp=p2e_dv2 fabric.devices=1 env=minedojo env.id=open-ended algo.cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`, or `MINEDOJO_HEADLESS=1 python sheeprl.py exp=p2e_dv2 fabric.devices=1 env=minedojo env.id=open-ended algo.cnn_keys.encoder=[rgb] algo.actor.cls=sheeprl.algos.p2e_dv2.agent.MinedojoActor`. 2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. diff --git a/howto/learn_in_minerl.md b/howto/learn_in_minerl.md index c9bc2082..03889607 100644 --- a/howto/learn_in_minerl.md +++ b/howto/learn_in_minerl.md @@ -54,5 +54,5 @@ Finally, we added sticky actions for the `jump` and `attack` actions. You can se ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: -1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the training command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on a headless machine, you need to run the following command: `xvfb-run python sheeprl.py exp=dreamer_v3 fabric.devices=1 env=minerl env.id=custom_navigate cnn_keys.encoder=[rgb]`. +1. Install the `xvfb` software with the `sudo apt install xvfb` command and prefix the training command with `xvfb-run`. For instance, to train DreamerV2 on the navigate task on a headless machine, you need to run the following command: `xvfb-run python sheeprl.py exp=dreamer_v3 fabric.devices=1 env=minerl env.id=custom_navigate algo.cnn_keys.encoder=[rgb]`. 2. Exploit the [PyVirtualDisplay](https://github.com/ponty/PyVirtualDisplay) package. \ No newline at end of file diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 533efb6f..cc6ee326 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -382,8 +382,10 @@ defaults: - override /env: atari - _self_ -total_steps: 65536 -per_rank_batch_size: 64 +algo: + total_steps: 65536 + per_rank_batch_size: 64 + buffer: share_data: False diff --git a/howto/select_observations.md b/howto/select_observations.md index 227014b1..69c24ef8 100644 --- a/howto/select_observations.md +++ b/howto/select_observations.md @@ -30,7 +30,7 @@ You just need to pass the `mlp_keys` and `cnn_keys` of the encoder and the decod For instance, to train the ppo algorithm on the *doapp* task provided by *DIAMBRA* using image observations and only the `opp_health` and `own_health` as vector observation, you have to run the following command: ```bash -diambra run python sheeprl.py exp=ppo env=diambra env.id=doapp env.num_envs=1 cnn_keys.encoder=[frame] mlp_keys.encoder=[opp_health,own_health] +diambra run python sheeprl.py exp=ppo env=diambra env.id=doapp env.num_envs=1 algo.cnn_keys.encoder=[frame] algo.mlp_keys.encoder=[opp_health,own_health] ``` > **Note** @@ -41,13 +41,13 @@ It is important to know the observations the environment provides, for instance, > **Note** > > For some environments provided by Gymnasium, e.g. `LunarLander-v2` or `CartPole-v1`, only vector observations are returned, but it is possible to extract the image observation from the render. To do this, it is sufficient to specify the `rgb` key to the `cnn_keys` args: -> `python sheeprl.py exp=... cnn_keys.encoder=[rgb]` +> `python sheeprl.py exp=... algo.cnn_keys.encoder=[rgb]` #### Frame Stack For image observations, it is possible to stack the last $n$ observations with the argument `frame_stack`. All the observations specified in the `cnn_keys` argument are stacked. ```bash -python sheeprl.py exp=... env=dmc cnn_keys.encoder=[rgb] env.frame_stack=3 +python sheeprl.py exp=... env=dmc algo.cnn_keys.encoder=[rgb] env.frame_stack=3 ``` #### How to choose the correct keys @@ -71,7 +71,7 @@ You can specify different observations for the encoder and the decoder, but ther You can specify the *mlp* and *cnn* keys of the decoder as follows: ```bash -python sheeprl.py exp=dreamer_v3 env=minerl env.id=custom_navigate mlp_keys.encoder=[life_stats,inventory,max_inventory] mlp_keys.decoder=[life_stats,inventory] +python sheeprl.py exp=dreamer_v3 env=minerl env.id=custom_navigate algo.mlp_keys.encoder=[life_stats,inventory,max_inventory] algo.mlp_keys.decoder=[life_stats,inventory] ``` ### Vector observations algorithms @@ -85,7 +85,7 @@ For these algorithms, you have to specify the *mlp* keys you want to encode. As For instance, you can train a SAC agent on the `LunarLanderContinuous-v2` with the following command: ```bash -python sheeprl.py exp=sac env=gym env.id=LunarLanderContinuous-v2 mlp_keys.encoder=[state] +python sheeprl.py exp=sac env=gym env.id=LunarLanderContinuous-v2 algo.mlp_keys.encoder=[state] ``` @@ -111,9 +111,9 @@ python examples/observation_space.py env=atari agent=dreamer_v3 env.id=MsPacmanN > **Note** > -> You can try to override some *cnn* or *mlp* keys by specifying the `cnn_keys.encoder` and the `mlp_keys.encoder` arguments. **Not all** environments allow it. +> You can try to override some *cnn* or *mlp* keys by specifying the `algo.cnn_keys.encoder` and the `algo.mlp_keys.encoder` arguments. **Not all** environments allow it. > -> For instance, the `python examples/observation_space.py env=gym agent=dreamer_v3 env.id=LunarLander-v2 cnn_keys.encoder=[custom_cnn_key] mlp_keys.encoder=[custom_mlp_key]` command will return the following observation space: +> For instance, the `python examples/observation_space.py env=gym agent=dreamer_v3 env.id=LunarLander-v2 algo.cnn_keys.encoder=[custom_cnn_key] algo.mlp_keys.encoder=[custom_mlp_key]` command will return the following observation space: >``` > Observation space of `LunarLander-v2` environment for `dreamer_v3` agent: > Dict( diff --git a/howto/work_with_multi-encoder_multi-decoder.md b/howto/work_with_multi-encoder_multi-decoder.md index f46a88cf..2f885873 100644 --- a/howto/work_with_multi-encoder_multi-decoder.md +++ b/howto/work_with_multi-encoder_multi-decoder.md @@ -9,14 +9,14 @@ In order to work with both image observations and vector observations, it is nec Another **mandatory** attribute of the *cnn* and *mpl* encoders/decoders is the attribute `keys`. This attribute indicates the *cnn*/*mlp* keys that the encoder/decoder encodes/reconstructs. ### Multi-Encoder -The observations are encoded by the `cnn_encoder` and `mlp_encoder` and then the embeddings are concatenated on the last dimension (the *cnn* encoder encodes the observations defined by the `cnn_keys.encoder` and the *mlp* encoder encodes the observations defined by the `mlp_keys.encoder`). If one between the *cnn* or *mlp* encoder is not present, then the output will be equal to the output of the *mlp* or *cnn* encoder, respectively. So the `cnn_encoder` and the `mlp_encoder` must return a `Tensor`. +The observations are encoded by the `cnn_encoder` and `mlp_encoder` and then the embeddings are concatenated on the last dimension (the *cnn* encoder encodes the observations defined by the `algo.cnn_keys.encoder` and the *mlp* encoder encodes the observations defined by the `algo.mlp_keys.encoder`). If one between the *cnn* or *mlp* encoder is not present, then the output will be equal to the output of the *mlp* or *cnn* encoder, respectively. So the `cnn_encoder` and the `mlp_encoder` must return a `Tensor`. > **Note** > > From our experience, we prefer to concatenate the images on the channel dimension and the vectors on the last dimension and then compute the embeddings with the `cnn_encoder` and the `mlp_encoder`, which take in input the concatenated images and the concatenated vectors, respectively. ### Multi-Decoder -The Multi-Decoder takes in input the features/states and tries to reconstruct the observations. The same features are passed in input to both the `cnn_decoder` and the `mlp_decoder`. Each of them outputs the reconstructed observations defined by the `cnn_keys.decoder` and `mlp_keys.decoder`, respectively. So the two decoders must return a python dictionary in the form: `Dict[key, rec_obs]`, where `key` is either a *cnn* or *mlp* key. +The Multi-Decoder takes in input the features/states and tries to reconstruct the observations. The same features are passed in input to both the `cnn_decoder` and the `mlp_decoder`. Each of them outputs the reconstructed observations defined by the `algo.cnn_keys.decoder` and `algo.mlp_keys.decoder`, respectively. So the two decoders must return a python dictionary in the form: `Dict[key, rec_obs]`, where `key` is either a *cnn* or *mlp* key. > **Note** > diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index cf0d21f1..741d2545 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -88,28 +88,31 @@ def thunk() -> gym.Env: and len(cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder) > 0 ): raise ValueError( - "`cnn_keys.encoder` and `mlp_keys.encoder` must be lists of strings, got: " - f"cnn encoder keys `{cfg.algo.cnn_keys.encoder}` and mlp encoder keys `{cfg.algo.mlp_keys.encoder}`. " - "Both lists must not be empty." + "`algo.cnn_keys.encoder` and `algo.mlp_keys.encoder` must be lists of strings, got: " + f"cnn encoder keys `{cfg.algo.cnn_keys.encoder}` of type `{type(cfg.algo.cnn_keys.encoder)}` " + f"and mlp encoder keys `{cfg.algo.mlp_keys.encoder}` of type `{type(cfg.algo.mlp_keys.encoder)}`. " + "Both must be non-empty lists." ) # Create observation dict + encoder_cnn_keys_length = len(cfg.algo.cnn_keys.encoder) + encoder_mlp_keys_length = len(cfg.algo.mlp_keys.encoder) if isinstance(env.observation_space, gym.spaces.Box) and len(env.observation_space.shape) < 2: # Vector only observation - if len(cfg.algo.cnn_keys.encoder) > 0: - if len(cfg.algo.cnn_keys.encoder) > 1: + if encoder_cnn_keys_length > 0: + if encoder_cnn_keys_length > 1: warnings.warn( "Multiple cnn keys have been specified and only one pixel observation " f"is allowed in {cfg.env.id}, " f"only the first one is kept: {cfg.algo.cnn_keys.encoder[0]}" ) - if len(cfg.algo.mlp_keys.encoder) > 0: + if encoder_mlp_keys_length > 0: gym.wrappers.pixel_observation.STATE_KEY = cfg.algo.mlp_keys.encoder[0] env = gym.wrappers.PixelObservationWrapper( - env, pixels_only=len(cfg.algo.mlp_keys.encoder) == 0, pixel_keys=(cfg.algo.cnn_keys.encoder[0],) + env, pixels_only=encoder_mlp_keys_length == 0, pixel_keys=(cfg.algo.cnn_keys.encoder[0],) ) else: - if len(cfg.algo.mlp_keys.encoder) > 1: + if encoder_mlp_keys_length > 1: warnings.warn( "Multiple mlp keys have been specified and only one pixel observation " f"is allowed in {cfg.env.id}, " @@ -120,13 +123,13 @@ def thunk() -> gym.Env: env.observation_space = gym.spaces.Dict({mlp_key: env.observation_space}) elif isinstance(env.observation_space, gym.spaces.Box) and 2 <= len(env.observation_space.shape) <= 3: # Pixel only observation - if len(cfg.algo.cnn_keys.encoder) > 1: + if encoder_cnn_keys_length > 1: warnings.warn( "Multiple cnn keys have been specified and only one pixel observation " f"is allowed in {cfg.env.id}, " f"only the first one is kept: {cfg.algo.cnn_keys.encoder[0]}" ) - elif len(cfg.algo.cnn_keys.encoder) == 0: + elif encoder_cnn_keys_length == 0: raise ValueError( "You have selected a pixel observation but no cnn key has been specified. " "Please set at least one cnn key in the config file: `algo.cnn_keys.encoder=[your_cnn_key]`" From aa221f2ebf12a60e9f3236efcc74e60479fe388b Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 24 Nov 2023 10:28:05 +0100 Subject: [PATCH 5/7] Update howto/configs.md --- howto/configs.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/howto/configs.md b/howto/configs.md index 88693b18..280f90e4 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -298,9 +298,23 @@ is: * the content of the `sheeprl/configs/algo/default.yaml` config will be inserted in the current config and whenever a naming collision happens, for example when the same field is defined in both configurations, those will be resolved by keeping the value defined in the current config. This behaviour is specified by letting the `_self_` keyword be the last one in the `defaults` list * `/optim@world_model.optimizer: adam` (and similar) means that the `adam` config, found in the `sheeprl/configs/optim` folder, will be inserted in this config under the `world_model.optimizer` field, so that one can access it at runtime as `cfg.algo.world_model.optimizer`. As in the previous point, the fields `lr`, `eps`, and `weight_decay` will be overwritten by the one specified in this config +The default configuration for all the algorithms is the following: + +```yaml +name: ??? +total_steps: ??? +per_rank_batch_size: ??? + +# Encoder and decoder keys +cnn_keys: + encoder: [] +mlp_keys: + encoder: [] +``` + > **Warning** > -> Every algorithm config **must** contain the field `name` +> Every algorithm config **must** contain the field `name`, the total number of steps `total_steps` and the batch size `per_rank_batch_size` ### Environment From 8fc25ee2f7c12312a8422bfaa90a5a973f80d762 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 24 Nov 2023 16:05:39 +0100 Subject: [PATCH 6/7] Update after conversation on gh --- examples/observation_space.py | 4 +++- howto/select_observations.md | 26 ++++++++++----------- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 +- sheeprl/configs/env_config.yaml | 14 ++++++----- 6 files changed, 27 insertions(+), 23 deletions(-) diff --git a/examples/observation_space.py b/examples/observation_space.py index 9a96b7a2..ae070fcc 100644 --- a/examples/observation_space.py +++ b/examples/observation_space.py @@ -1,8 +1,9 @@ import gymnasium as gym import hydra -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from sheeprl.utils.env import make_env +from sheeprl.utils.utils import dotdict @hydra.main(version_base="1.3", config_path="../sheeprl/configs", config_name="env_config") @@ -23,6 +24,7 @@ def main(cfg: DictConfig) -> None: "droq", "ppo_recurrent", }: + cfg = dotdict(OmegaConf.to_container(cfg, resolve=True)) env: gym.Env = make_env(cfg, cfg.seed, 0)() else: raise ValueError( diff --git a/howto/select_observations.md b/howto/select_observations.md index 69c24ef8..6eb2a5b7 100644 --- a/howto/select_observations.md +++ b/howto/select_observations.md @@ -19,10 +19,10 @@ The algorithms that can work with both image and vector observations are specifi * Plan2Explore (Dreamer-V3) To run one of these algorithms, it is necessary to specify which observations to use: it is possible to select all the vector observations or only some of them or none of them. Moreover, you can select all/some/none of the image observations. -You just need to pass the `mlp_keys` and `cnn_keys` of the encoder and the decoder to the script to select the vector observations and the image observations, respectively. +You just need to pass the `algo.mlp_keys` and `algo.cnn_keys` of the encoder and the decoder to the script to select the vector observations and the image observations, respectively. > **Note** > -> The `mlp_keys` and the `cnn_keys` specified for the encoder are used by default as `mlp_keys` and `cnn_keys` of the decoder, respectively. +> The `algo.mlp_keys` and the `algo.cnn_keys` specified for the encoder are used by default as `algo.mlp_keys` and `algo.cnn_keys` of the decoder, respectively. > **Recommended** > @@ -35,34 +35,34 @@ diambra run python sheeprl.py exp=ppo env=diambra env.id=doapp env.num_envs=1 al > **Note** > -> By default the `mlp_keys` and `cnn_keys` arguments are set to `[]` (empty list), so no observations are selected for the training. This might raise an exception: the agent tries to automatically set the *mlp* or *cnn* keys, but it is not always possible, so it is **strongly recommended to properly set them**. +> By default the `algo.mlp_keys` and `algo.cnn_keys` arguments are set to `[]` (empty list), so no observations are selected for the training. This will raise an exception: if fact, **every algorithm must specify at least one of them**. It is important to know the observations the environment provides, for instance, the *DIAMBRA* environments provide both vector observations and image observations, whereas all the atari environments provide only the image observations. > **Note** > -> For some environments provided by Gymnasium, e.g. `LunarLander-v2` or `CartPole-v1`, only vector observations are returned, but it is possible to extract the image observation from the render. To do this, it is sufficient to specify the `rgb` key to the `cnn_keys` args: +> For some environments provided by Gymnasium, e.g. `LunarLander-v2` or `CartPole-v1`, only vector observations are returned, but it is possible to extract the image observation from the render. To do this, it is sufficient to specify the `rgb` key to the `algo.cnn_keys` args: > `python sheeprl.py exp=... algo.cnn_keys.encoder=[rgb]` #### Frame Stack -For image observations, it is possible to stack the last $n$ observations with the argument `frame_stack`. All the observations specified in the `cnn_keys` argument are stacked. +For image observations, it is possible to stack the last $n$ observations with the argument `frame_stack`. All the observations specified in the `algo.cnn_keys` argument are stacked. ```bash python sheeprl.py exp=... env=dmc algo.cnn_keys.encoder=[rgb] env.frame_stack=3 ``` #### How to choose the correct keys -When the environment provides both the vector and image observations, you just need to specify which observations you want to use with the `mlp_keys` and `cnn_keys`, respectively. +When the environment provides both the vector and image observations, you just need to specify which observations you want to use with the `algo.mlp_keys` and `algo.cnn_keys`, respectively. Instead, for those environments that natively do not support both types of observations, we provide a method to obtain the **image observations from the vector observations (NOT VICE VERSA)**. It means that if you choose an environment with only vector observations, you can get also the image observations, but if you choose an environment with only image observations, you **cannot** get the vector observations. There can be three possible scenarios: -1. You do **not** want to **use** the **image** observations: you don't have to specify any `cnn_keys` while you have to select the `mlp_keys`: - 1. if the environment provides more than one vector observation, then you have to choose between them; - 2. if the environment provides only one vector observation, you can choose the name of the *mlp key* or use the default one (`state`, used when you do not specify any *mlp keys*). -2. You want to **use only** the **image** observation: you don't have to specify any `mlp_keys` while **you must specify the name of the *cnn key*** (if the image observation has to be created from the vector one, the `make_env` function will automatically bind the observation with the specified key, otherwise you must choose a valid one). -3. You want to **use both** the **vector** and **image** observations: You must specify the *cnn key* (as point 2). Instead, for the vector observations, you have two possibilities: - 1. if the environment provides more than one vector observation, then you **must choose between them**; - 2. if the environment provides only one vector observation, you **must specify** the default vector observation key, i.e., **`state`**. +1. You do **not** want to **use** the **image** observations: you don't have to specify any `algo.cnn_keys` while you have to select the `algo.mlp_keys`: + 1. if the environment provides more than one vector observation, then you **must choose between them**; + 2. if the environment provides only one vector observation, you can choose the name of the *mlp key*. +2. You want to **use only** the **image** observation: you don't have to specify any `algo.mlp_keys` while **you must specify the name of the *cnn key*** (if the image observation has to be created from the vector one, the `make_env` function will automatically bind the observation with the specified key, otherwise you must choose a valid one). +3. You want to **use both** the **vector** and **image** observations: you must specify the *cnn key* (as point 2). Instead, for the vector observations, you have two possibilities: + 1. if the environment provides more than one vector observation, then you **must choose between them**; + 2. if the environment provides only one vector observation, you can choose the name of the *mlp key*. #### Different observations for the Encoder and the Decoder You can specify different observations for the encoder and the decoder, but there are some constraints: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index b595789e..109f9c83 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -70,7 +70,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg.env.num_envs = exploration_cfg.env.num_envs # There must be the same cnn and mlp keys during exploration and finetuning cfg.algo.cnn_keys = exploration_cfg.algo.cnn_keys - cfg.mlp_keys = exploration_cfg.mlp_keys + cfg.algo.mlp_keys = exploration_cfg.algo.mlp_keys # These arguments cannot be changed cfg.env.screen_size = 64 diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 72aff80f..54d5fec1 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -74,7 +74,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg.env.num_envs = exploration_cfg.env.num_envs # There must be the same cnn and mlp keys during exploration and finetuning cfg.algo.cnn_keys = exploration_cfg.algo.cnn_keys - cfg.mlp_keys = exploration_cfg.mlp_keys + cfg.algo.mlp_keys = exploration_cfg.algo.mlp_keys # These arguments cannot be changed cfg.env.screen_size = 64 diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 633556de..83debfe3 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -69,7 +69,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg.env.num_envs = exploration_cfg.env.num_envs # There must be the same cnn and mlp keys during exploration and finetuning cfg.algo.cnn_keys = exploration_cfg.algo.cnn_keys - cfg.mlp_keys = exploration_cfg.mlp_keys + cfg.algo.mlp_keys = exploration_cfg.algo.mlp_keys # These arguments cannot be changed cfg.env.frame_stack = 1 diff --git a/sheeprl/configs/env_config.yaml b/sheeprl/configs/env_config.yaml index d36ba66b..8f17821b 100644 --- a/sheeprl/configs/env_config.yaml +++ b/sheeprl/configs/env_config.yaml @@ -14,9 +14,11 @@ exp_name: "default" root_dir: $env_logs run_name: ${env.id} agent: ??? -cnn_keys: - encoder: [] - decoder: ${algo.cnn_keys.encoder} -mlp_keys: - encoder: [] - decoder: ${algo.mlp_keys.encoder} + +algo: + cnn_keys: + encoder: [] + decoder: ${algo.cnn_keys.encoder} + mlp_keys: + encoder: [] + decoder: ${algo.mlp_keys.encoder} From 1b15465c7ab4ec85f6f0b833ab3b22b014a6ef1c Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 24 Nov 2023 16:15:26 +0100 Subject: [PATCH 7/7] Update howto/select_observations.md --- howto/select_observations.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/howto/select_observations.md b/howto/select_observations.md index 6eb2a5b7..3c12c316 100644 --- a/howto/select_observations.md +++ b/howto/select_observations.md @@ -93,7 +93,7 @@ python sheeprl.py exp=sac env=gym env.id=LunarLanderContinuous-v2 algo.mlp_keys. It is possible to retrieve the observation space of a specific environment to easily select the observation keys you want to use in your training. ```bash -python examples/observation_space.py env=... env.id=... agent=dreamer_v3 +python examples/observation_space.py env=... env.id=... agent=dreamer_v3 algo.cnn_keys.encoder=[...] algo.mlp_keys.encoder=[...] ``` or for *DIAMBRA* environments: @@ -102,11 +102,11 @@ or for *DIAMBRA* environments: diambra run python examples/observation_space.py env=diambra agent=dreamer_v3 env.id=doapp ``` -The env argument is the same one you use for training your agent, so it refers to the config folder `sheeprl/configs/env`, more over you can override the environment id and modify its parameters, such as the frame stack or whether or not to use grayscale observations. +The env argument is the same one you use for training your agent, so it refers to the config folder `sheeprl/configs/env`, moreover you can override the environment id and modify its parameters, such as the frame stack or whether or not to use grayscale observations. You can modify the parameters as usual by specifying them as cli arguments: ```bash -python examples/observation_space.py env=atari agent=dreamer_v3 env.id=MsPacmanNoFrameskip-v4 env.frame_stack=5 env.grayscale=True +python examples/observation_space.py env=atari agent=dreamer_v3 env.id=MsPacmanNoFrameskip-v4 env.frame_stack=5 env.grayscale=True algo.cnn_keys.encoder=[frame] ``` > **Note**