Skip to content

Commit

Permalink
return the new updated key in _train
Browse files Browse the repository at this point in the history
  • Loading branch information
theovincent committed Apr 12, 2024
1 parent fcd647e commit 84feff5
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
update_carry["qf_state"],
update_carry["actor_state"],
update_carry["ent_coef_state"],
key,
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
)
2 changes: 1 addition & 1 deletion sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
update_carry["qf_state"],
update_carry["actor_state"],
update_carry["ent_coef_state"],
key,
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"], update_carry["info"]["ent_coef_loss"]),
)
2 changes: 1 addition & 1 deletion sbx/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,6 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
return (
update_carry["qf_state"],
update_carry["actor_state"],
key,
update_carry["key"],
(update_carry["info"]["actor_loss"], update_carry["info"]["qf_loss"]),
)
2 changes: 1 addition & 1 deletion sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
update_carry["qf2_state"],
update_carry["actor_state"],
update_carry["ent_coef_state"],
key,
update_carry["key"],
(
update_carry["info"]["qf1_loss"],
update_carry["info"]["qf2_loss"],
Expand Down

0 comments on commit 84feff5

Please sign in to comment.