diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index c3d656f..01a2d77 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union import flax import flax.linen as nn @@ -66,6 +66,7 @@ def __init__( replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", + target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, @@ -103,6 +104,7 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef + self.target_entropy = target_entropy if _init_setup_model: self._setup_model() @@ -155,8 +157,14 @@ def _setup_model(self) -> None: ), ) - # automatically set target entropy if needed - self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) def learn( self, @@ -251,7 +259,6 @@ def update_critic( def mse_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict ) -> Tuple[jax.Array, jax.Array]: - # Joint forward pass of obs/next_obs and actions/next_state_actions to have only # one forward pass with shape (n_critics, 2 * batch_size, 1). # diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index aab5e59..2639f2b 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union import flax import flax.linen as nn @@ -67,6 +67,7 @@ def __init__( replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", + target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, @@ -105,6 +106,7 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef + self.target_entropy = target_entropy if _init_setup_model: self._setup_model() @@ -157,8 +159,14 @@ def _setup_model(self) -> None: ), ) - # automatically set target entropy if needed - self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) def learn( self, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 4c7b9e6..1c6f92a 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union import flax import flax.linen as nn @@ -68,6 +68,7 @@ def __init__( replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", + target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, @@ -106,6 +107,8 @@ def __init__( self.policy_delay = policy_delay self.ent_coef_init = ent_coef + self.target_entropy = target_entropy + self.policy_kwargs["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net if _init_setup_model: @@ -159,8 +162,14 @@ def _setup_model(self) -> None: ), ) - # automatically set target entropy if needed - self.target_entropy = -np.prod(self.action_space.shape).astype(np.float32) + # Target entropy is used when learning the entropy coefficient + if self.target_entropy == "auto": + # automatically set target entropy if needed + self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) # type: ignore + else: + # Force conversion + # this will also throw an error for unexpected string + self.target_entropy = float(self.target_entropy) def learn( self, diff --git a/sbx/version.txt b/sbx/version.txt index 54d1a4f..a803cc2 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.13.0 +0.14.0 diff --git a/tests/test_run.py b/tests/test_run.py index f4253e5..3ef8e9a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -69,6 +69,7 @@ def test_tqc(tmp_path) -> None: gradient_steps=1, use_sde=True, qf_learning_rate=1e-3, + target_entropy=-10, ) model.learn(200) check_save_load(model, TQC, tmp_path)