Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/replay ratio #247

Merged
merged 38 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3711d05
Decoupled RSSM for DV3 agent
belerico Feb 8, 2024
e80e9d5
Initialize posterior with prior if is_first is True
belerico Feb 8, 2024
b23112a
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 12, 2024
f47b8f9
Fix PlayerDV3 creation in evaluation
belerico Feb 12, 2024
e42c83d
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 26, 2024
2ec4fbb
Fix representation_model
belerico Feb 26, 2024
3a5380b
Fix compute first prior state with a zero posterior
belerico Feb 27, 2024
42d9433
DV3 replay ratio conversion
belerico Feb 29, 2024
750f671
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 29, 2024
b06433b
Removed expl parameters dependent on old per_Rank_gradient_steps
belerico Feb 29, 2024
20cc43e
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Mar 4, 2024
37d0e86
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
michele-milesi Mar 18, 2024
704b0ce
feat: update repeats computation
michele-milesi Mar 18, 2024
20905f0
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
michele-milesi Mar 28, 2024
e1290ee
feat: update learning starts in config
michele-milesi Mar 28, 2024
1f0c0ef
fix: remove files
michele-milesi Mar 28, 2024
cd4a4c4
feat: update repeats
michele-milesi Mar 28, 2024
e8c9049
feat: added replay ratio and update exploration
michele-milesi Mar 28, 2024
88c6968
Fix exploration actions computation on DV1
belerico Mar 28, 2024
a5c957c
Fix naming
belerico Mar 28, 2024
c36577d
Add replay-ratio to SAC
belerico Mar 28, 2024
0bc9f07
feat: added replay ratio to p2e algos
michele-milesi Mar 28, 2024
b5fbe5d
feat: update configs and utils of p2e algos
michele-milesi Mar 28, 2024
24c9352
Add replay-ratio to SAC-AE
belerico Mar 28, 2024
a11b558
Merge branch 'feature/replay-ratio' of https://github.com/Eclectic-Sh…
belerico Mar 28, 2024
32b89b4
Add DrOQ replay ratio
belerico Mar 29, 2024
d057886
Fix tests
belerico Mar 29, 2024
b9044a3
Fix mispelled
belerico Mar 29, 2024
5bd7d75
Fix wrong attribute accesing
belerico Mar 29, 2024
8d94f68
FIx naming and configs
belerico Mar 29, 2024
e361a17
Ratio: account for pretrain steps
belerico Mar 29, 2024
7b143ed
Fix dreamer-vq actor naming
belerico Mar 29, 2024
beff471
feat: added ratio state to checkpoint in sac decoupled
michele-milesi Mar 29, 2024
9be5304
feat: added typing in Ratio class
michele-milesi Mar 29, 2024
082a650
Move ratio.py to examples
belerico Mar 29, 2024
edd06d8
Log dreamer-v1 exploration amount
belerico Mar 29, 2024
fea31fc
Fix DV1 log expl amount
belerico Mar 29, 2024
db62f90
Fix DV2 replay ratios
belerico Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions howto/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ In the `algo` folder one can find all the configurations for every algorithm imp

```yaml
# sheeprl/configs/algo/dreamer_v3.yaml
# Dreamer-V3 XL configuration

defaults:
- default
- /optim@world_model.optimizer: adam
Expand All @@ -139,10 +141,8 @@ lmbda: 0.95
horizon: 15

# Training recipe
train_every: 16
learning_starts: 65536
per_rank_pretrain_steps: 1
per_rank_gradient_steps: 1
replay_ratio: 1
learning_starts: 1024
per_rank_sequence_length: ???

# Encoder and decoder keys
Expand All @@ -159,6 +159,7 @@ dense_act: torch.nn.SiLU
cnn_act: torch.nn.SiLU
unimix: 0.01
hafner_initialization: True
decoupled_rssm: False

# World model
world_model:
Expand Down Expand Up @@ -241,10 +242,6 @@ actor:
layer_norm: ${algo.layer_norm}
dense_units: ${algo.dense_units}
clip_gradients: 100.0
expl_amount: 0.0
expl_min: 0.0
expl_decay: False
max_step_expl_decay: 0

# Disttributed percentile model (used to scale the values)
moments:
Expand All @@ -266,7 +263,7 @@ critic:
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
dense_units: ${algo.dense_units}
target_network_update_freq: 1
per_rank_target_network_update_freq: 1
tau: 0.02
bins: 255
clip_gradients: 100.0
Expand Down Expand Up @@ -410,7 +407,7 @@ buffer:
algo:
learning_starts: 1024
total_steps: 100000
train_every: 1

dense_units: 512
mlp_layers: 2
world_model:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@
" mask = {k: v for k, v in preprocessed_obs.items() if k.startswith(\"mask\")}\n",
" if len(mask) == 0:\n",
" mask = None\n",
" real_actions = actions = player.get_exploration_action(preprocessed_obs, mask)\n",
" real_actions = actions = player.get_actions(preprocessed_obs, mask)\n",
" actions = torch.cat(actions, -1).cpu().numpy()\n",
" if is_continuous:\n",
" real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n",
Expand Down
65 changes: 65 additions & 0 deletions ratio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import warnings
belerico marked this conversation as resolved.
Show resolved Hide resolved


class Ratio:
"""Directly taken from Hafner et al. (2023) implementation:
https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/embodied/core/when.py#L26
"""

def __init__(self, ratio: float, pretrain_steps: int = 0):
if pretrain_steps < 0:
raise ValueError(f"'pretrain_steps' must be non-negative, got {pretrain_steps}")
if ratio < 0:
raise ValueError(f"'ratio' must be non-negative, got {ratio}")
self._pretrain_steps = pretrain_steps
self._ratio = ratio
self._prev = None

def __call__(self, step) -> int:
step = int(step)
if self._ratio == 0:
return 0
if self._prev is None:
self._prev = step
if self._pretrain_steps > 0:
if step < self._pretrain_steps:
warnings.warn(
"The number of pretrain steps is greater than the number of current steps. This could lead to "
f"a higher ratio than the one specified ({self._ratio}). Setting the 'pretrain_steps' equal to "
"the number of current steps."
)
self._pretrain_steps = step
return round(self._pretrain_steps * self._ratio)
else:
return 1
repeats = round((step - self._prev) * self._ratio)
self._prev += repeats / self._ratio
return repeats


if __name__ == "__main__":
num_envs = 1
world_size = 1
replay_ratio = 0.5
per_rank_batch_size = 16
per_rank_sequence_length = 64
replayed_steps = world_size * per_rank_batch_size * per_rank_sequence_length
train_steps = 0
gradient_steps = 0
total_policy_steps = 2**10
r = Ratio(ratio=replay_ratio, pretrain_steps=256)
policy_steps = num_envs * world_size
printed = False
for i in range(0, total_policy_steps, policy_steps):
if i >= 128:
per_rank_repeats = r(i / world_size)
if per_rank_repeats > 0 and not printed:
print(
f"Training the agent with {per_rank_repeats} repeats on every rank "
f"({per_rank_repeats * world_size} global repeats) at global iteration {i}"
)
printed = True
gradient_steps += per_rank_repeats * world_size
print("Replay ratio", replay_ratio)
print("Hafner train ratio", replay_ratio * replayed_steps)
print("Final ratio", gradient_steps / total_policy_steps)
27 changes: 18 additions & 9 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __init__(
encoder: nn.Module | _FabricModule,
recurrent_model: nn.Module | _FabricModule,
representation_model: nn.Module | _FabricModule,
actor: nn.Module | _FabricModule,
actor: Actor | _FabricModule,
actions_dim: Sequence[int],
num_envs: int,
stochastic_size: int,
Expand Down Expand Up @@ -288,32 +288,38 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs])
self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs])

def get_exploration_action(self, obs: Tensor, mask: Optional[Dict[str, Tensor]] = None) -> Sequence[Tensor]:
def get_exploration_actions(
self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0
) -> Sequence[Tensor]:
"""Return the actions with a certain amount of noise for exploration.

Args:
obs (Tensor): the current observations.
sample_actions (bool): whether or not to sample the actions.
Default to True.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.
step (int): the step of the training, used for the exploration amount.
Default to 0.

Returns:
The actions the agent has to perform (Sequence[Tensor]).
"""
actions = self.get_greedy_action(obs, mask=mask)
actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask)
expl_actions = None
if self.actor.expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, mask=mask)
if self.actor._expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask)
self.actions = torch.cat(expl_actions, dim=-1)
return expl_actions or actions

def get_greedy_action(
self, obs: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None
def get_actions(
self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None
) -> Sequence[Tensor]:
"""Return the greedy actions.

Args:
obs (Tensor): the current observations.
is_training (bool): whether it is training.
sample_actions (bool): whether or not to sample the actions.
Default to True.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.
Expand All @@ -329,7 +335,7 @@ def get_greedy_action(
self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)),
validate_args=self.validate_args,
)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask)
self.actions = torch.cat(actions, -1)
return actions

Expand Down Expand Up @@ -488,6 +494,9 @@ def build_agent(
activation=eval(actor_cfg.dense_act),
distribution_cfg=cfg.distribution,
layer_norm=False,
expl_amount=actor_cfg.expl_amount,
expl_decay=actor_cfg.expl_decay,
expl_min=actor_cfg.expl_min,
)
critic = MLP(
input_dims=latent_state_size,
Expand Down
79 changes: 35 additions & 44 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import polynomial_decay, save_configs
from sheeprl.utils.utils import Ratio, save_configs

# Decomment the following two lines if you cannot start an experiment with DMC environments
# os.environ["PYOPENGL_PLATFORM"] = ""
Expand Down Expand Up @@ -547,22 +547,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * world_size)
updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0
num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0
expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0
max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
actor.expl_amount = polynomial_decay(
expl_decay_steps,
initial=cfg.algo.actor.expl_amount,
final=cfg.algo.actor.expl_min,
max_decay_steps=max_step_expl_decay,
)
if not cfg.buffer.checkpoint:
learning_starts += start_step

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
if cfg.checkpoint.resume_from:
ratio.load_state_dict(state["ratio"])

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
warnings.warn(
Expand Down Expand Up @@ -592,6 +588,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size

Expand Down Expand Up @@ -624,7 +621,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_action(normalized_obs, mask)
real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down Expand Up @@ -684,46 +681,35 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Reset internal agent states
player.init_states(reset_envs=dones_idxes)

updates_before_training -= 1

# Train the agent
if update > learning_starts and updates_before_training <= 0:
# Start training
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
if update >= learning_starts:
per_rank_gradient_steps = ratio(policy_step / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=1,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
batch = {k: v[0].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.actor.expl_decay:
expl_decay_steps += 1
actor.expl_amount = polynomial_decay(
expl_decay_steps,
initial=cfg.algo.actor.expl_amount,
final=cfg.algo.actor.expl_min,
max_decay_steps=max_step_expl_decay,
)
if aggregator:
aggregator.update("Params/exploration_amount", actor.expl_amount)
belerico marked this conversation as resolved.
Show resolved Hide resolved
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
Expand All @@ -733,6 +719,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.log_dict(metrics_dict, policy_step)
aggregator.reset()

# Log replay ratio
fabric.log(
"Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step
)

# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
Expand Down Expand Up @@ -767,7 +758,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
"world_optimizer": world_optimizer.state_dict(),
"actor_optimizer": actor_optimizer.state_dict(),
"critic_optimizer": critic_optimizer.state_dict(),
"expl_decay_steps": expl_decay_steps,
"ratio": ratio.state_dict(),
"update": update * world_size,
"batch_size": cfg.algo.per_rank_batch_size * world_size,
"last_log": last_log,
Expand Down
1 change: 0 additions & 1 deletion sheeprl/algos/dreamer_v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"State/post_entropy",
"State/prior_entropy",
"State/kl",
"Params/exploration_amount",
belerico marked this conversation as resolved.
Show resolved Hide resolved
"Grads/world_model",
"Grads/actor",
"Grads/critic",
Expand Down
Loading
Loading