diff --git a/pyproject.toml b/pyproject.toml index ff03ec6..30fffa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,13 +29,9 @@ disable = [] ignore_missing_imports = true follow_imports = "silent" show_error_codes = true -exclude = """(?x)( - sbx/common/on_policy_algorithm.py$ - | sbx/common/off_policy_algorithm.py$ - | sbx/ppo/ppo.py$ - | sbx/dqn/dqn.py$ - | sbx/common/policies.py$ - )""" +# exclude = """(?x)( +# sbx/common/policies.py$ +# )""" [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. @@ -46,12 +42,6 @@ env = [ filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", - # Gym warnings - "ignore:Parameters to load are deprecated.:DeprecationWarning", - "ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning", - "ignore::UserWarning:gym", - "ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning", - "ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning", ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 0540c18..6069bc9 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union -import gym import jax import numpy as np -from stable_baselines3.common.buffers import ReplayBuffer +from gymnasium import spaces +from stable_baselines3 import HerReplayBuffer +from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy @@ -39,7 +40,7 @@ def __init__( sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, - supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -52,6 +53,8 @@ def __init__( gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, action_noise=action_noise, use_sde=use_sde, sde_sample_freq=sde_sample_freq, @@ -77,7 +80,7 @@ def _excluded_save_params(self) -> List[str]: excluded.remove("policy") return excluded - def set_random_seed(self, seed: int) -> None: + def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: # Sample random seed @@ -85,20 +88,31 @@ def set_random_seed(self, seed: int) -> None: self.key = jax.random.PRNGKey(seed) def _setup_model(self) -> None: + if self.replay_buffer_class is None: # type: ignore[has-type] + if isinstance(self.observation_space, spaces.Dict): + self.replay_buffer_class = DictReplayBuffer + else: + self.replay_buffer_class = ReplayBuffer + self._setup_lr_schedule() # By default qf_learning_rate = pi_learning_rate self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1) self.set_random_seed(self.seed) + # Make a local copy as we should not pickle + # the environment when using HerReplayBuffer + replay_buffer_kwargs = self.replay_buffer_kwargs.copy() + if issubclass(self.replay_buffer_class, HerReplayBuffer): # type: ignore[arg-type] + assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" + replay_buffer_kwargs["env"] = self.env - self.replay_buffer_class = ReplayBuffer - self.replay_buffer = self.replay_buffer_class( + self.replay_buffer = self.replay_buffer_class( # type: ignore[misc] self.buffer_size, self.observation_space, self.action_space, device="cpu", # force cpu device to easy torch -> numpy conversion n_envs=self.n_envs, optimize_memory_usage=self.optimize_memory_usage, - **self.replay_buffer_kwargs, + **replay_buffer_kwargs, ) # Convert train freq parameter to TrainFreq object self._convert_train_freq() diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index b556885..4b75411 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym +import gymnasium as gym import jax import numpy as np import torch as th +from gymnasium import spaces from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm @@ -11,10 +12,16 @@ from stable_baselines3.common.type_aliases import GymEnv, Schedule from stable_baselines3.common.vec_env import VecEnv +from sbx.ppo.policies import Actor, Critic, PPOPolicy + OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithmJax") class OnPolicyAlgorithmJax(OnPolicyAlgorithm): + policy: PPOPolicy # type: ignore[assignment] + actor: Actor + vf: Critic + def __init__( self, policy: Union[str, Type[BasePolicy]], @@ -35,10 +42,10 @@ def __init__( seed: Optional[int] = None, device: str = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, ): super().__init__( - policy=policy, + policy=policy, # type: ignore[arg-type] env=env, learning_rate=learning_rate, n_steps=n_steps, @@ -68,7 +75,7 @@ def _excluded_save_params(self) -> List[str]: excluded.remove("policy") return excluded - def set_random_seed(self, seed: int) -> None: + def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: # Sample random seed @@ -167,22 +174,32 @@ def collect_rollouts( and infos[idx].get("TimeLimit.truncated", False) ): terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] - terminal_value = np.array(self.vf.apply(self.policy.vf_state.params, terminal_obs).flatten()) + terminal_value = np.array( + self.vf.apply( # type: ignore[union-attr] + self.policy.vf_state.params, + terminal_obs, + ).flatten() + ) rewards[idx] += self.gamma * terminal_value rollout_buffer.add( - self._last_obs, # type: ignore[has-type] + self._last_obs, # type: ignore actions, rewards, - self._last_episode_starts, # type: ignore[has-type] + self._last_episode_starts, # type: ignore th.as_tensor(values), th.as_tensor(log_probs), ) - self._last_obs = new_obs + self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones - values = np.array(self.vf.apply(self.policy.vf_state.params, self.policy.prepare_obs(new_obs)[0]).flatten()) + values = np.array( + self.vf.apply( # type: ignore[union-attr] + self.policy.vf_state.params, + self.policy.prepare_obs(new_obs)[0], # type: ignore[arg-type] + ).flatten() + ) rollout_buffer.compute_returns_and_advantage(last_values=th.as_tensor(values), dones=dones) diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 9c7f9ba..66ac1f9 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -1,9 +1,9 @@ # import copy -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, no_type_check -import gym import jax import numpy as np +from gymnasium import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose from stable_baselines3.common.utils import is_vectorized_observation @@ -28,6 +28,7 @@ def sample_action(actor_state, obervations, key): def select_action(actor_state, obervations): return actor_state.apply_fn(actor_state.params, obervations).mode() + @no_type_check def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], @@ -44,7 +45,7 @@ def predict( # Convert to numpy, and reshape to the original action shape actions = np.array(actions).reshape((-1, *self.action_space.shape)) - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): if self.squash_output: # Clip due to numerical instability actions = np.clip(actions, -1, 1) @@ -57,15 +58,24 @@ def predict( # Remove batch dimension if needed if not vectorized_env: - actions = actions.squeeze(axis=0) + actions = actions.squeeze(axis=0) # type: ignore[call-overload] return actions, state def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[np.ndarray, bool]: vectorized_env = False if isinstance(observation, dict): - raise NotImplementedError() - # # need to copy the dict as the dict in VecFrameStack will become a torch tensor + assert isinstance(self.observation_space, spaces.Dict) + # Minimal dict support: flatten + keys = list(self.observation_space.keys()) + vectorized_env = is_vectorized_observation(observation[keys[0]], self.observation_space[keys[0]]) + + # Add batch dim and concatenate + observation = np.concatenate( + [observation[key].reshape(-1, *self.observation_space[key].shape) for key in keys], + axis=1, + ) + # need to copy the dict as the dict in VecFrameStack will become a torch tensor # observation = copy.deepcopy(observation) # for key, obs in observation.items(): # obs_space = self.observation_space.spaces[key] @@ -75,7 +85,7 @@ def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> # obs_ = np.array(obs) # vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) # # Add batch dimension if needed - # observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) + # observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) elif is_image_space(self.observation_space): # Handle the different cases for images @@ -85,11 +95,11 @@ def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> else: observation = np.array(observation) - if not isinstance(observation, dict): - # Dict obs need to be handled separately + if not isinstance(self.observation_space, spaces.Dict): + assert isinstance(observation, np.ndarray) vectorized_env = is_vectorized_observation(observation, self.observation_space) # Add batch dimension if needed - observation = observation.reshape((-1, *self.observation_space.shape)) + observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc] assert isinstance(observation, np.ndarray) return observation, vectorized_env diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 2eea976..47d0e6d 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -1,14 +1,13 @@ import warnings from typing import Any, Dict, Optional, Tuple, Type, Union -import gym +import gymnasium as gym import jax import jax.numpy as jnp import numpy as np import optax -from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation +from stable_baselines3.common.utils import get_linear_fn from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState @@ -21,6 +20,7 @@ class DQN(OffPolicyAlgorithmJax): } # Linear schedule will be defined in `_setup_model()` exploration_schedule: Schedule + policy: DQNPolicy def __init__( self, @@ -62,7 +62,7 @@ def __init__( verbose=verbose, seed=seed, sde_support=False, - supported_action_spaces=(gym.spaces.Discrete), + supported_action_spaces=(gym.spaces.Discrete,), support_multi_env=True, ) @@ -99,14 +99,16 @@ def _setup_model(self) -> None: self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) - if self.policy is None: # type: ignore[has-type] - self.policy = self.policy_class( # pytype:disable=not-instantiable + if not hasattr(self, "policy") or self.policy is None: + # pytype:disable=not-instantiable + self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, - **self.policy_kwargs, # pytype:disable=not-instantiable + **self.policy_kwargs, ) - assert isinstance(self.policy, DQNPolicy) + # pytype:enable=not-instantiable + self.key = self.policy.build(self.key, self.lr_schedule) self.qf = self.policy.qf @@ -244,7 +246,7 @@ def predict( (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: - if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): + if self.policy.is_vectorized_observation(observation): if isinstance(observation, dict): n_batch = observation[list(observation.keys())[0]].shape[0] else: diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index b486a65..1905d5a 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -1,11 +1,11 @@ from typing import Any, Callable, Dict, List, Optional, Union import flax.linen as nn -import gym import jax import jax.numpy as jnp import numpy as np import optax +from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule from sbx.common.policies import BaseJaxPolicy @@ -27,10 +27,12 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: class DQNPolicy(BaseJaxPolicy): + action_space: spaces.Discrete # type: ignore[assignment] + def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, features_extractor_class=None, @@ -54,12 +56,12 @@ def __init__( else: self.n_units = 256 - def build(self, key, lr_schedule: Schedule) -> None: + def build(self, key: jax.random.KeyArray, lr_schedule: Schedule) -> jax.random.KeyArray: key, qf_key = jax.random.split(key, 2) obs = jnp.array([self.observation_space.sample()]) - self.qf = QNetwork(n_actions=self.action_space.n, n_units=self.n_units) + self.qf = QNetwork(n_actions=int(self.action_space.n), n_units=self.n_units) self.qf_state = RLTrainState.create( apply_fn=self.qf.apply, diff --git a/sbx/droq/droq.py b/sbx/droq/droq.py index 49c9a3f..bc2529a 100644 --- a/sbx/droq/droq.py +++ b/sbx/droq/droq.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Optional, Tuple, Type, Union +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, Schedule @@ -31,6 +32,8 @@ def __init__( dropout_rate: float = 0.01, layer_norm: bool = True, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, @@ -56,6 +59,8 @@ def __init__( gradient_steps=gradient_steps, policy_delay=policy_delay, action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 16c9648..850e54a 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import flax.linen as nn -import gym +import gymnasium as gym import jax import jax.numpy as jnp import numpy as np @@ -9,7 +9,7 @@ import tensorflow_probability from flax.linen.initializers import constant from flax.training.train_state import TrainState -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule from sbx.common.policies import BaseJaxPolicy @@ -109,7 +109,7 @@ def __init__( self.key = self.noise_key = jax.random.PRNGKey(0) - def build(self, key, lr_schedule: Schedule, max_grad_norm: float) -> None: + def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, max_grad_norm: float) -> jax.random.KeyArray: key, actor_key, vf_key = jax.random.split(key, 3) # Keep a key for the actor key, self.key = jax.random.split(key, 2) @@ -120,12 +120,12 @@ def build(self, key, lr_schedule: Schedule, max_grad_norm: float) -> None: if isinstance(self.action_space, spaces.Box): actor_kwargs = { - "action_dim": np.prod(self.action_space.shape), + "action_dim": int(np.prod(self.action_space.shape)), "continuous": True, } elif isinstance(self.action_space, spaces.Discrete): actor_kwargs = { - "action_dim": self.action_space.n, + "action_dim": int(self.action_space.n), "continuous": False, } else: @@ -135,7 +135,7 @@ def build(self, key, lr_schedule: Schedule, max_grad_norm: float) -> None: n_units=self.n_units, log_std_init=self.log_std_init, activation_fn=self.activation_fn, - **actor_kwargs, + **actor_kwargs, # type: ignore[arg-type] ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index ed8df59..c44a36e 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpy as np from flax.training.train_state import TrainState -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -73,6 +73,7 @@ class PPO(OnPolicyAlgorithmJax): # "CnnPolicy": ActorCriticCnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, } + policy: PPOPolicy # type: ignore[assignment] def __init__( self, @@ -166,14 +167,15 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - if self.policy is None: # type: ignore[has-type] - self.policy = self.policy_class( # pytype:disable=not-instantiable + if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] + # pytype:disable=not-instantiable + self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, - **self.policy_kwargs, # pytype:disable=not-instantiable + **self.policy_kwargs, ) - assert isinstance(self.policy, PPOPolicy) + # pytype:enable=not-instantiable self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm) @@ -257,7 +259,7 @@ def train(self) -> None: # train for n_epochs epochs for _ in range(self.n_epochs): # JIT only one update - for rollout_data in self.rollout_buffer.get(self.batch_size): + for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to int actions = rollout_data.actions.flatten().numpy().astype(np.int32) @@ -279,7 +281,10 @@ def train(self) -> None: ) self._n_updates += self.n_epochs - explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + explained_var = explained_variance( + self.rollout_buffer.values.flatten(), # type: ignore[attr-defined] + self.rollout_buffer.returns.flatten(), # type: ignore[attr-defined] + ) # Logs # self.logger.record("train/entropy_loss", np.mean(entropy_losses)) diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index 6ab4023..7370b4e 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -1,13 +1,13 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import flax.linen as nn -import gym import jax import jax.numpy as jnp import numpy as np import optax import tensorflow_probability from flax.training.train_state import TrainState +from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule from sbx.common.distributions import TanhTransformedDistribution @@ -19,33 +19,28 @@ class Critic(nn.Module): + net_arch: Sequence[int] use_layer_norm: bool = False dropout_rate: Optional[float] = None - n_units: int = 256 @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: x = jnp.concatenate([x, action], -1) - x = nn.Dense(self.n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.relu(x) - x = nn.Dense(self.n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.relu(x) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) x = nn.Dense(1)(x) return x class VectorCritic(nn.Module): + net_arch: Sequence[int] use_layer_norm: bool = False dropout_rate: Optional[float] = None - n_units: int = 256 n_critics: int = 2 @nn.compact @@ -63,14 +58,14 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): q_values = vmap_critic( use_layer_norm=self.use_layer_norm, dropout_rate=self.dropout_rate, - n_units=self.n_units, + net_arch=self.net_arch, )(obs, action) return q_values class Actor(nn.Module): + net_arch: Sequence[int] action_dim: int - n_units: int = 256 log_std_min: float = -20 log_std_max: float = 2 @@ -80,10 +75,9 @@ def get_std(self): @nn.compact def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] - x = nn.Dense(self.n_units)(x) - x = nn.relu(x) - x = nn.Dense(self.n_units)(x) - x = nn.relu(x) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + x = nn.relu(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) @@ -94,10 +88,12 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def class SACPolicy(BaseJaxPolicy): + action_space: spaces.Box # type: ignore[assignment] + def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, @@ -129,28 +125,34 @@ def __init__( self.dropout_rate = dropout_rate self.layer_norm = layer_norm if net_arch is not None: - assert isinstance(net_arch, list) - self.n_units = net_arch[0] + if isinstance(net_arch, list): + self.net_arch_pi = self.net_arch_qf = net_arch + else: + self.net_arch_pi = net_arch["pi"] + self.net_arch_qf = net_arch["qf"] else: - self.n_units = 256 + self.net_arch_pi = self.net_arch_qf = [256, 256] self.n_critics = n_critics self.use_sde = use_sde self.key = self.noise_key = jax.random.PRNGKey(0) - def build(self, key, lr_schedule: Schedule, qf_learning_rate: float) -> None: + def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rate: float) -> jax.random.KeyArray: key, actor_key, qf_key, dropout_key = jax.random.split(key, 4) # Keep a key for the actor key, self.key = jax.random.split(key, 2) # Initialize noise self.reset_noise() - obs = jnp.array([self.observation_space.sample()]) + if isinstance(self.observation_space, spaces.Dict): + obs = jnp.array([spaces.flatten(self.observation_space, self.observation_space.sample())]) + else: + obs = jnp.array([self.observation_space.sample()]) action = jnp.array([self.action_space.sample()]) self.actor = Actor( - action_dim=np.prod(self.action_space.shape), - n_units=self.n_units, + action_dim=int(np.prod(self.action_space.shape)), + net_arch=self.net_arch_pi, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -167,7 +169,7 @@ def build(self, key, lr_schedule: Schedule, qf_learning_rate: float) -> None: self.qf = VectorCritic( dropout_rate=self.dropout_rate, use_layer_norm=self.layer_norm, - n_units=self.n_units, + net_arch=self.net_arch_qf, n_critics=self.n_critics, ) diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 98157d3..60d67da 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -3,12 +3,13 @@ import flax import flax.linen as nn -import gym import jax import jax.numpy as jnp import numpy as np import optax from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -40,8 +41,13 @@ def __call__(self) -> float: class SAC(OffPolicyAlgorithmJax): policy_aliases: Dict[str, Type[SACPolicy]] = { # type: ignore[assignment] "MlpPolicy": SACPolicy, + # Minimal dict support using flatten() + "MultiInputPolicy": SACPolicy, } + policy: SACPolicy + action_space: spaces.Box # type: ignore[assignment] + def __init__( self, policy, @@ -57,6 +63,8 @@ def __init__( gradient_steps: int = 1, policy_delay: int = 1, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, @@ -81,6 +89,8 @@ def __init__( train_freq=train_freq, gradient_steps=gradient_steps, action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, @@ -88,7 +98,7 @@ def __init__( tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, - supported_action_spaces=(gym.spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) @@ -101,7 +111,7 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - if self.policy is None: # type: ignore[has-type] + if not hasattr(self, "policy") or self.policy is None: # pytype: disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -111,14 +121,14 @@ def _setup_model(self) -> None: ) # pytype: enable=not-instantiable - assert isinstance(self.policy, SACPolicy) + assert isinstance(self.qf_learning_rate, float) self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) self.key, ent_key = jax.random.split(self.key, 2) - self.actor = self.policy.actor - self.qf = self.policy.qf + self.actor = self.policy.actor # type: ignore[assignment] + self.qf = self.policy.qf # type: ignore[assignment] # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section @@ -134,10 +144,11 @@ def _setup_model(self) -> None: # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 self.ent_coef = EntropyCoef(ent_coef_init) else: - # Force conversion to float - # this will throw an error if a malformed string (different from 'auto') - # is passed - self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] self.ent_coef_state = TrainState.create( apply_fn=self.ent_coef.apply, @@ -176,11 +187,20 @@ def train(self, batch_size, gradient_steps): # It will compile once per value of policy_delay_indices policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0} policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) + + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + # Convert to numpy data = ReplayBufferSamplesNp( - data.observations.numpy(), + obs, data.actions.numpy(), - data.next_observations.numpy(), + next_obs, data.dones.numpy().flatten(), data.rewards.numpy().flatten(), ) diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index fa87aec..763ae17 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -1,13 +1,13 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union import flax.linen as nn -import gym import jax import jax.numpy as jnp import numpy as np import optax import tensorflow_probability from flax.training.train_state import TrainState +from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule from sbx.common.distributions import TanhTransformedDistribution @@ -19,33 +19,28 @@ class Critic(nn.Module): + net_arch: Sequence[int] use_layer_norm: bool = False dropout_rate: Optional[float] = None n_quantiles: int = 25 - n_units: int = 256 @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False) -> jnp.ndarray: x = jnp.concatenate([x, a], -1) - x = nn.Dense(self.n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.relu(x) - x = nn.Dense(self.n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = nn.relu(x) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) x = nn.Dense(self.n_quantiles)(x) return x class Actor(nn.Module): + net_arch: Sequence[int] action_dim: int - n_units: int = 256 log_std_min: float = -20 log_std_max: float = 2 @@ -55,10 +50,9 @@ def get_std(self): @nn.compact def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] - x = nn.Dense(self.n_units)(x) - x = nn.relu(x) - x = nn.Dense(self.n_units)(x) - x = nn.relu(x) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + x = nn.relu(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) @@ -69,10 +63,12 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def class TQCPolicy(BaseJaxPolicy): + action_space: spaces.Box # type: ignore[assignment] + def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Box, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, dropout_rate: float = 0.0, @@ -106,10 +102,13 @@ def __init__( self.dropout_rate = dropout_rate self.layer_norm = layer_norm if net_arch is not None: - assert isinstance(net_arch, list) - self.n_units = net_arch[0] + if isinstance(net_arch, list): + self.net_arch_pi = self.net_arch_qf = net_arch + else: + self.net_arch_pi = net_arch["pi"] + self.net_arch_qf = net_arch["qf"] else: - self.n_units = 256 + self.net_arch_pi = self.net_arch_qf = [256, 256] self.n_quantiles = n_quantiles self.n_critics = n_critics self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net @@ -121,18 +120,22 @@ def __init__( self.key = self.noise_key = jax.random.PRNGKey(0) - def build(self, key, lr_schedule: Schedule, qf_learning_rate: float) -> None: + def build(self, key: jax.random.KeyArray, lr_schedule: Schedule, qf_learning_rate: float) -> jax.random.KeyArray: key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) key, dropout_key1, dropout_key2, self.key = jax.random.split(key, 4) # Initialize noise self.reset_noise() - obs = jnp.array([self.observation_space.sample()]) + if isinstance(self.observation_space, spaces.Dict): + obs = jnp.array([spaces.flatten(self.observation_space, self.observation_space.sample())]) + else: + obs = jnp.array([self.observation_space.sample()]) + action = jnp.array([self.action_space.sample()]) self.actor = Actor( - action_dim=np.prod(self.action_space.shape), - n_units=self.n_units, + action_dim=int(np.prod(self.action_space.shape)), + net_arch=self.net_arch_pi, ) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -149,7 +152,7 @@ def build(self, key, lr_schedule: Schedule, qf_learning_rate: float) -> None: self.qf = Critic( dropout_rate=self.dropout_rate, use_layer_norm=self.layer_norm, - n_units=self.n_units, + net_arch=self.net_arch_qf, n_quantiles=self.n_quantiles, ) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 427fa5d..f297811 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -3,12 +3,13 @@ import flax import flax.linen as nn -import gym import jax import jax.numpy as jnp import numpy as np import optax from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -40,8 +41,13 @@ def __call__(self) -> float: class TQC(OffPolicyAlgorithmJax): policy_aliases: Dict[str, Type[TQCPolicy]] = { # type: ignore[assignment] "MlpPolicy": TQCPolicy, + # Minimal dict support using flatten() + "MultiInputPolicy": TQCPolicy, } + policy: TQCPolicy + action_space: spaces.Box # type: ignore[assignment] + def __init__( self, policy, @@ -58,6 +64,8 @@ def __init__( policy_delay: int = 1, top_quantiles_to_drop_per_net: int = 2, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, @@ -82,6 +90,8 @@ def __init__( train_freq=train_freq, gradient_steps=gradient_steps, action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, @@ -89,7 +99,7 @@ def __init__( tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, - supported_action_spaces=(gym.spaces.Box), + supported_action_spaces=(spaces.Box,), support_multi_env=True, ) @@ -103,7 +113,7 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() - if self.policy is None: # type: ignore[has-type] + if not hasattr(self, "policy") or self.policy is None: # pytype: disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, @@ -112,13 +122,13 @@ def _setup_model(self) -> None: **self.policy_kwargs, ) # pytype: enable=not-instantiable - assert isinstance(self.policy, TQCPolicy) + assert isinstance(self.qf_learning_rate, float) self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) self.key, ent_key = jax.random.split(self.key, 2) - self.actor = self.policy.actor + self.actor = self.policy.actor # type: ignore[assignment] self.qf = self.policy.qf # The entropy coefficient or entropy can be learned automatically @@ -135,10 +145,11 @@ def _setup_model(self) -> None: # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 self.ent_coef = EntropyCoef(ent_coef_init) else: - # Force conversion to float - # this will throw an error if a malformed string (different from 'auto') - # is passed - self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) + # This will throw an error if a malformed string (different from 'auto') is passed + assert isinstance( + self.ent_coef_init, float + ), f"Entropy coef must be float when not equal to 'auto', actual: {self.ent_coef_init}" + self.ent_coef = ConstantEntropyCoef(self.ent_coef_init) # type: ignore[assignment] self.ent_coef_state = TrainState.create( apply_fn=self.ent_coef.apply, @@ -178,15 +189,22 @@ def train(self, batch_size, gradient_steps): policy_delay_indices = {i: True for i in range(gradient_steps) if ((self._n_updates + i + 1) % self.policy_delay) == 0} policy_delay_indices = flax.core.FrozenDict(policy_delay_indices) + if isinstance(data.observations, dict): + keys = list(self.observation_space.keys()) + obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) + next_obs = np.concatenate([data.next_observations[key].numpy() for key in keys], axis=1) + else: + obs = data.observations.numpy() + next_obs = data.next_observations.numpy() + # Convert to numpy data = ReplayBufferSamplesNp( - data.observations.numpy(), + obs, data.actions.numpy(), - data.next_observations.numpy(), + next_obs, data.dones.numpy().flatten(), data.rewards.numpy().flatten(), ) - ( self.policy.qf1_state, self.policy.qf2_state, diff --git a/sbx/version.txt b/sbx/version.txt index a918a2a..faef31a 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.6.0 +0.7.0 diff --git a/setup.py b/setup.py index 2705597..35e4bc1 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.8.0a9", + "stable_baselines3>=2.0.0a4", "jax", "jaxlib", "flax", diff --git a/tests/test_run.py b/tests/test_run.py index b89c3b2..bcf2f12 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,10 @@ +from typing import Optional, Type + import numpy as np import pytest +from stable_baselines3 import HerReplayBuffer from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy from sbx import DQN, PPO, SAC, TQC, DroQ @@ -57,7 +61,7 @@ def test_tqc() -> None: @pytest.mark.parametrize("model_class", [SAC]) -def test_sac(model_class): +def test_sac(model_class: Type[SAC]) -> None: model = model_class( "MlpPolicy", "Pendulum-v1", @@ -69,7 +73,7 @@ def test_sac(model_class): @pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"]) -def test_ppo(env_id): +def test_ppo(env_id: str) -> None: model = PPO( "MlpPolicy", env_id, @@ -89,3 +93,11 @@ def test_dqn() -> None: target_update_interval=10, ) model.learn(128) + + +@pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer]) +def test_dict(replay_buffer_class: Optional[Type[HerReplayBuffer]]) -> None: + env = BitFlippingEnv(n_bits=2, continuous=True) + model = SAC("MultiInputPolicy", env, replay_buffer_class=replay_buffer_class) + + model.learn(int(200), progress_bar=True)