Skip to content

Commit

Permalink
Fix missing KL
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Jul 16, 2023
1 parent ceb3642 commit 3f4be4d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sheeprl/algos/p2e/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def train(
q = Independent(Normal(priors_mean, priors_std), 1)

world_optimizer.zero_grad(set_to_none=True)
rec_loss, state_loss, reward_loss, observation_loss, continue_loss = reconstruction_loss(
rec_loss, kl, state_loss, reward_loss, observation_loss, continue_loss = reconstruction_loss(
qo,
batch_obs,
qr,
Expand All @@ -165,6 +165,7 @@ def train(
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.mean().detach())
aggregator.update("State/p_entropy", p.entropy().mean().detach())
aggregator.update("State/q_entropy", q.entropy().mean().detach())

Expand Down Expand Up @@ -513,6 +514,7 @@ def main():
"Loss/state_loss": MeanMetric(sync_on_compute=False),
"Loss/continue_loss": MeanMetric(sync_on_compute=False),
"Loss/ensemble_loss": MeanMetric(sync_on_compute=False),
"State/kl": MeanMetric(sync_on_compute=False),
"State/p_entropy": MeanMetric(sync_on_compute=False),
"State/q_entropy": MeanMetric(sync_on_compute=False),
"Params/exploration_amout": MeanMetric(sync_on_compute=False),
Expand Down

0 comments on commit 3f4be4d

Please sign in to comment.