Skip to content

Commit

Permalink
Support for setting the target entropy (#43)
Browse files Browse the repository at this point in the history
* Support for setting the target entropy in TQC

* Support for setting the target entropy in CrossQ and SAC

* Made type hints for target_entropy more precise

* Update version and tests

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
jan1854 and araffin authored Apr 8, 2024
1 parent c8db73f commit fcd647e
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
15 changes: 11 additions & 4 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -155,8 +157,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
Expand Down Expand Up @@ -251,7 +259,6 @@ def update_critic(
def mse_loss(
params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict
) -> Tuple[jax.Array, jax.Array]:

# Joint forward pass of obs/next_obs and actions/next_state_actions to have only
# one forward pass with shape (n_critics, 2 * batch_size, 1).
#
Expand Down
14 changes: 11 additions & 3 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -157,8 +159,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
Expand Down
15 changes: 12 additions & 3 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
Expand Down Expand Up @@ -106,6 +107,8 @@ def __init__(

self.policy_delay = policy_delay
self.ent_coef_init = ent_coef
self.target_entropy = target_entropy

self.policy_kwargs["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net

if _init_setup_model:
Expand Down Expand Up @@ -159,8 +162,14 @@ def _setup_model(self) -> None:
),
)

# automatically set target entropy if needed
self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore
else:
# Force conversion
# this will also throw an error for unexpected string
self.target_entropy = float(self.target_entropy)

def learn(
self,
Expand Down
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.13.0
0.14.0
1 change: 1 addition & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_tqc(tmp_path) -> None:
gradient_steps=1,
use_sde=True,
qf_learning_rate=1e-3,
target_entropy=-10,
)
model.learn(200)
check_save_load(model, TQC, tmp_path)
Expand Down

0 comments on commit fcd647e

Please sign in to comment.