Skip to content

Commit

Permalink
Feature/split p2e (#151)
Browse files Browse the repository at this point in the history
* p2e dv1 and p2e dv2 split into exploration and finetuning

* fix: exploration amount

* fix: change actor from exploration to task when starting training

* fix: from __future__ import annotations

* fix: exploration amount

* docs: added p2e readme

* Feature/p2e dv3 (#113)

* feat: implemented p2e_dv3

* feat: added the possibility to have more critics for exploration

* tests: added p2e_dv3 test

* docs: update p2e_dv3 docs

* docs: update

* fix: p2e_dv3 refactoring

* fix: checkpoint

* Fix missing 0.5 value

* feat: add validate args to p2e_dv3

* feat: uniform p2e_dv3 with last improvements

* fix: ppo tests

* feat: split exploration and finetuning

* fix: resume from checkpoint controls

* fix: bugs

* tests: added p2e_dv3 and resume from checkpoint tests

* fix: p2e dv3 resume from checkpoint

* tests: update p2e dv3 test

* feat: added p2e_dv3 evaluation

* fix: evaluate and __init__

* fix: cli controls

* fix: added detach() when learning world model in exploration

* fix: checks in cli

* fix: exploration amount

* fix: removed minedojo test cfgs

---------

Co-authored-by: belerico_t <federico.belotti@orobix.com>

* fix: buffer load

---------

Co-authored-by: belerico_t <federico.belotti@orobix.com>
  • Loading branch information
michele-milesi and belerico authored Nov 13, 2023
1 parent 933f4b9 commit 9b68f22
Show file tree
Hide file tree
Showing 47 changed files with 4,286 additions and 387 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ The algorithms sheeped by sheeprl out-of-the-box are:
| Dreamer-V3 | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V1) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V2) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V3) | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |

and more are coming soon! [Open a PR](https://github.com/Eclectic-Sheep/sheeprl/pulls) if you have any particular request :sheep:

Expand All @@ -115,6 +116,7 @@ The actions supported by sheeprl agents are:
| Dreamer-V3 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V1) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V2) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Plan2Explore (Dreamer V3) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |

The environments supported by sheeprl are:
| Algorithm | Installation command | More info | Status |
Expand Down
3 changes: 2 additions & 1 deletion examples/observation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def main(cfg: DictConfig) -> None:
"dreamer_v3",
"p2e_dv1",
"p2e_dv2",
"p2e_dv3",
"sac_ae",
"ppo",
"ppo_decoupled",
Expand All @@ -25,7 +26,7 @@ def main(cfg: DictConfig) -> None:
env: gym.Env = make_env(cfg, cfg.seed, 0)()
else:
raise ValueError(
"Invalid selected agent: check the available agents with the command `python sheeprl.py --sheeprl_help`"
"Invalid selected agent: check the available agents with the command `python sheeprl/available_agents.py`"
)

print()
Expand Down
1 change: 1 addition & 0 deletions howto/learn_in_atari.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The list of selectable algorithms is given below:
* `dreamer_v3`
* `p2e_dv1`
* `p2e_dv2`
* `p2e_dv3`
* `ppo`
* `ppo_decoupled`
* `sac_ae`
Expand Down
2 changes: 2 additions & 0 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ from sheeprl.algos.dreamer_v3 import dreamer_v3 as dreamer_v3
from sheeprl.algos.droq import droq as droq
from sheeprl.algos.p2e_dv1 import p2e_dv1 as p2e_dv1
from sheeprl.algos.p2e_dv2 import p2e_dv2 as p2e_dv2
from sheeprl.algos.p2e_dv3 import p2e_dv3 as p2e_dv3
from sheeprl.algos.ppo import ppo as ppo
from sheeprl.algos.ppo import ppo_decoupled as ppo_decoupled
from sheeprl.algos.ppo_recurrent import ppo_recurrent as ppo_recurrent
Expand Down Expand Up @@ -452,6 +453,7 @@ SheepRL Agents
│ sheeprl.algos.droq │ droq │ main │ False │
│ sheeprl.algos.p2e_dv1 │ p2e_dv1 │ main │ False │
│ sheeprl.algos.p2e_dv2 │ p2e_dv2 │ main │ False │
│ sheeprl.algos.p2e_dv3 │ p2e_dv3 │ main │ False │
│ sheeprl.algos.ppo │ ppo │ main │ False │
│ sheeprl.algos.ppo │ ppo_decoupled │ main │ True │
│ sheeprl.algos.ppo_recurrent │ ppo_recurrent │ main │ False │
Expand Down
3 changes: 2 additions & 1 deletion howto/select_observations.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ The algorithms that can work with both image and vector observations are specifi
* Dreamer-V2
* Dreamer-V3
* Plan2Explore (Dreamer-V1)
* Plan2Explore (Dreamer-V1)
* Plan2Explore (Dreamer-V2)
* Plan2Explore (Dreamer-V3)

To run one of these algorithms, it is necessary to specify which observations to use: it is possible to select all the vector observations or only some of them or none of them. Moreover, you can select all/some/none of the image observations.
You just need to pass the `mlp_keys` and `cnn_keys` of the encoder and the decoder to the script to select the vector observations and the image observations, respectively.
Expand Down
9 changes: 7 additions & 2 deletions sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401
from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401
from sheeprl.algos.droq import droq # noqa: F401
from sheeprl.algos.p2e_dv1 import p2e_dv1 # noqa: F401
from sheeprl.algos.p2e_dv2 import p2e_dv2 # noqa: F401
from sheeprl.algos.p2e_dv1 import p2e_dv1_exploration # noqa: F401
from sheeprl.algos.p2e_dv1 import p2e_dv1_finetuning # noqa: F401
from sheeprl.algos.p2e_dv2 import p2e_dv2_exploration # noqa: F401
from sheeprl.algos.p2e_dv2 import p2e_dv2_finetuning # noqa: F401
from sheeprl.algos.p2e_dv3 import p2e_dv3_exploration # noqa: F401
from sheeprl.algos.p2e_dv3 import p2e_dv3_finetuning # noqa: F401
from sheeprl.algos.ppo import ppo # noqa: F401
from sheeprl.algos.ppo import ppo_decoupled # noqa: F401
from sheeprl.algos.ppo_recurrent import ppo_recurrent # noqa: F401
Expand All @@ -33,6 +37,7 @@
from sheeprl.algos.droq import evaluate as droq_evaluate # noqa: F401, isort:skip
from sheeprl.algos.p2e_dv1 import evaluate as p2e_dv1_evaluate # noqa: F401, isort:skip
from sheeprl.algos.p2e_dv2 import evaluate as p2e_dv2_evaluate # noqa: F401, isort:skip
from sheeprl.algos.p2e_dv3 import evaluate as p2e_dv3_evaluate # noqa: F401, isort:skip
from sheeprl.algos.ppo import evaluate as ppo_evaluate # noqa: F401, isort:skip
from sheeprl.algos.ppo_recurrent import evaluate as ppo_recurrent_evaluate # noqa: F401, isort:skip
from sheeprl.algos.sac import evaluate as sac_evaluate # noqa: F401, isort:skip
Expand Down
6 changes: 6 additions & 0 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Sequence, Tuple

import gymnasium
Expand Down Expand Up @@ -228,6 +230,8 @@ class PlayerDV1(nn.Module):
stochastic_size (int): the size of the stochastic state.
recurrent_state_size (int): the size of the recurrent state.
device (torch.device): the device to work on.
actor_type (str, optional): which actor the player is using ('task' or 'exploration').
Default to None.
"""

def __init__(
Expand All @@ -241,6 +245,7 @@ def __init__(
stochastic_size: int,
recurrent_state_size: int,
device: torch.device,
actor_type: str | None = None,
) -> None:
super().__init__()
self.encoder = encoder
Expand All @@ -254,6 +259,7 @@ def __init__(
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.init_states()
self.actor_type = actor_type

def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
"""Initialize the states and the actions for the ended environments.
Expand Down
8 changes: 7 additions & 1 deletion sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def __init__(
expl_amount: float = 0.0,
) -> None:
super().__init__()
self.distribution_cfg = distribution_cfg
self.distribution = distribution_cfg.pop("type", "auto").lower()
self.distribution_cfg.type = self.distribution
if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "trunc_normal"):
raise ValueError(
"The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. "
Expand Down Expand Up @@ -739,7 +741,7 @@ def __init__(

class PlayerDV2(nn.Module):
"""
The model of the Dreamer_v1 player.
The model of the Dreamer_v2 player.
Args:
encoder (nn.Module): the encoder.
Expand All @@ -754,6 +756,8 @@ class PlayerDV2(nn.Module):
discrete_size (int): the dimension of a single Categorical variable in the
stochastic state (prior or posterior).
Defaults to 32.
actor_type (str, optional): which actor the player is using ('task' or 'exploration').
Default to None.
"""

def __init__(
Expand All @@ -768,6 +772,7 @@ def __init__(
recurrent_state_size: int,
device: torch.device,
discrete_size: int = 32,
actor_type: str | None = None,
) -> None:
super().__init__()
self.encoder = encoder
Expand All @@ -781,6 +786,7 @@ def __init__(
self.recurrent_state_size = recurrent_state_size
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.actor_type = actor_type

def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
"""Initialize the states and the actions for the ended environments.
Expand Down
7 changes: 6 additions & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class PlayerDV3(nn.Module):
discrete_size (int): the dimension of a single Categorical variable in the
stochastic state (prior or posterior).
Defaults to 32.
actor_type (str, optional): which actor the player is using ('task' or 'exploration').
Default to None.
"""

def __init__(
Expand All @@ -488,6 +490,7 @@ def __init__(
recurrent_state_size: int,
device: device = "cpu",
discrete_size: int = 32,
actor_type: str | None = None,
) -> None:
super().__init__()
self.encoder = encoder
Expand All @@ -507,6 +510,7 @@ def __init__(
self.recurrent_state_size = recurrent_state_size
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.actor_type = actor_type

@torch.no_grad()
def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
Expand Down Expand Up @@ -629,7 +633,8 @@ def __init__(
) -> None:
super().__init__()
self.distribution_cfg = distribution_cfg
self.distribution = distribution_cfg.pop("type").lower()
self.distribution = distribution_cfg.pop("type", "auto").lower()
self.distribution_cfg.type = self.distribution
if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "trunc_normal"):
raise ValueError(
"The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. "
Expand Down
20 changes: 10 additions & 10 deletions sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch
from lightning import Fabric
from torch import Tensor
from torch import Tensor, nn

from sheeprl.utils.env import make_env

Expand All @@ -31,7 +31,7 @@
}


class Moments(torch.nn.Module):
class Moments(nn.Module):
def __init__(
self,
fabric: Fabric,
Expand Down Expand Up @@ -133,26 +133,26 @@ def test(

# Adapted from: https://github.com/NM512/dreamerv3-torch/blob/main/tools.py#L929
def init_weights(m):
if isinstance(m, torch.nn.Linear):
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
torch.nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std)
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
torch.nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, torch.nn.LayerNorm):
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
Expand All @@ -161,16 +161,16 @@ def init_weights(m):
# Adapted from: https://github.com/NM512/dreamerv3-torch/blob/main/tools.py#L957
def uniform_init_weights(given_scale):
def f(m):
if isinstance(m, torch.nn.Linear):
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
torch.nn.init.uniform_(m.weight.data, a=-limit, b=limit)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, torch.nn.LayerNorm):
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
Expand Down
49 changes: 49 additions & 0 deletions sheeprl/algos/p2e_dv1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Plan2Explore
## Algorithm Overview

The Plan2Explore algorithm is designed to efficiently learn and exploit the dynamics of the environment for accomplishing multiple tasks. The algorithm employs two actors: one for exploration and one for learning the task. During the exploratory phase, the exploration actor focuses on discovering new states by selecting actions that lead to unexplored regions. Simultaneously, the task actor learns from the experiences gathered by the exploration actor in a zero-shot manner. Following the exploration phase, the agent can be fine-tuned with experiences collected by the task actor in a few-shot fashion, enhancing its performance on specific tasks.

## Implementation Details

### Scripts

The algorithm implementation is organized into two scripts:

1. **Exploration Script (`p2e_dv1_exploration.py`):**
- Used for the exploratory phase to learn the dynamics of the environment.
- Trains the exploration actor to select actions leading to new states.

2. **Fine-tuning Script (`p2e_dv1_finetuning.py`):**
- Utilized for fine-tuning the agent after the exploration phase.
- Starts with a trained agent and refines its performance or learns new tasks.

### Configuration Constraints

To ensure the proper functioning of the algorithm, the following constraints must be observed:

- **Environment Configuration:** The fine-tuning must be executed with the same environment configurations used during exploration.

- **Hyper-parameter Consistency:** Hyper-parameters of the agent should remain consistent between the exploration and fine-tuning phases.

### Experience Collection

The implementation supports flexibility in experience collection during fine-tuning:

- **Buffer Options:** Fine-tuning can start from the buffer collected during exploration or a new one (`buffer.load_from_exploration` parameter).

- **Initial Experiences:** If using a new buffer, users can decide whether to collect initial experiences (until `learning_start`) with the `actor_exploration` or the `actor_task`. After `learning_start`, only the `actor_task` collects experiences. (`player.actor_type` parameter, can be either `exploration` or `task`).

> **Note**
>
> When exploring, the only valid choice of the `player.actor_type` parameter is `exploration`.
## Usage

To use the Plan2Explore framework, follow these steps:

1. Run the exploration script to learn the dynamics of the environment.
2. Execute the fine-tuning script with the same environment configurations and consistent hyper-parameters.

> **Note**
>
> Choose whether to start fine-tuning from the exploration buffer or create a new buffer, and specify the actor for initial experience collection accordingly.
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sheeprl.utils.registry import register_evaluation


@register_evaluation(algorithms="p2e_dv1")
@register_evaluation(algorithms=["p2e_dv1_exploration", "p2e_dv1_finetuning"])
def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
logger = create_tensorboard_logger(fabric, cfg)
if logger and fabric.is_global_zero:
Expand Down
Loading

0 comments on commit 9b68f22

Please sign in to comment.