Skip to content

Commit

Permalink
[BugFix] Fix discrete SAC log-prob (#1750)
Browse files Browse the repository at this point in the history
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
  • Loading branch information
vmoens and matteobettini authored Dec 17, 2023
1 parent 08f0bed commit 0e02132
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
8 changes: 2 additions & 6 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,9 +1892,7 @@ def call(data, params):
with params.to_module(training_model):
return training_model(data)

assert vmap(call, (None, 0))(data, params).shape == torch.Size(
(2, 50, 11)
)
assert vmap(call, (None, 0))(data, params).shape == torch.Size((2, 50, 11))


class TestGRUModule:
Expand Down Expand Up @@ -2221,9 +2219,7 @@ def call(data, params):
with params.to_module(training_model):
return training_model(data)

assert vmap(call, (None, 0))(data, params).shape == torch.Size(
(2, 50, 11)
)
assert vmap(call, (None, 0))(data, params).shape == torch.Size((2, 50, 11))


def test_safe_specs():
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ def _actor_loss(
with self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(tensordict.clone(False))
prob = dist.probs
log_prob = prob.clamp_min(torch.finfo(prob.dtype).resolution)
log_prob = dist.logits

td_q = tensordict.select(*self.qvalue_network.in_keys)

Expand Down

0 comments on commit 0e02132

Please sign in to comment.