Skip to content

Commit

Permalink
[BugFix] Fix missing min/max alpha clamps in losses (#2684)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 9, 2025
1 parent f672c70 commit ed656a1
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
alpha = self.log_alpha.data.exp()
return alpha
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _forward_value_estimator_keys(self, **kwargs):

@property
def alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
with torch.no_grad():
alpha = self.log_alpha.exp()
Expand Down Expand Up @@ -1374,7 +1374,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor:

@property
def _alpha(self):
if self.min_log_alpha is not None:
if self.min_log_alpha is not None or self.max_log_alpha is not None:
self.log_alpha.data = self.log_alpha.data.clamp(
self.min_log_alpha, self.max_log_alpha
)
Expand Down

0 comments on commit ed656a1

Please sign in to comment.