Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic dict obs support #11

Merged
merged 14 commits into from
Apr 13, 2023
16 changes: 3 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,9 @@ disable = []
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
sbx/common/on_policy_algorithm.py$
| sbx/common/off_policy_algorithm.py$
| sbx/ppo/ppo.py$
| sbx/dqn/dqn.py$
| sbx/common/policies.py$
)"""
# exclude = """(?x)(
# sbx/common/policies.py$
# )"""

[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
Expand All @@ -46,12 +42,6 @@ env = [
filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gym warnings
"ignore:Parameters to load are deprecated.:DeprecationWarning",
"ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning",
"ignore::UserWarning:gym",
"ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning",
"ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
Expand Down
28 changes: 21 additions & 7 deletions sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import gym
import jax
import numpy as np
from stable_baselines3.common.buffers import ReplayBuffer
from gymnasium import spaces
from stable_baselines3 import HerReplayBuffer
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
Expand All @@ -52,6 +53,8 @@ def __init__(
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
action_noise=action_noise,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
Expand All @@ -77,28 +80,39 @@ def _excluded_save_params(self) -> List[str]:
excluded.remove("policy")
return excluded

def set_random_seed(self, seed: int) -> None:
def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
# Sample random seed
seed = np.random.randint(2**14)
self.key = jax.random.PRNGKey(seed)

def _setup_model(self) -> None:
if self.replay_buffer_class is None: # type: ignore[has-type]
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer

self._setup_lr_schedule()
# By default qf_learning_rate = pi_learning_rate
self.qf_learning_rate = self.qf_learning_rate or self.lr_schedule(1)
self.set_random_seed(self.seed)
# Make a local copy as we should not pickle
# the environment when using HerReplayBuffer
replay_buffer_kwargs = self.replay_buffer_kwargs.copy()
if issubclass(self.replay_buffer_class, HerReplayBuffer): # type: ignore[arg-type]
assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"
replay_buffer_kwargs["env"] = self.env

self.replay_buffer_class = ReplayBuffer
self.replay_buffer = self.replay_buffer_class(
self.replay_buffer = self.replay_buffer_class( # type: ignore[misc]
self.buffer_size,
self.observation_space,
self.action_space,
device="cpu", # force cpu device to easy torch -> numpy conversion
n_envs=self.n_envs,
optimize_memory_usage=self.optimize_memory_usage,
**self.replay_buffer_kwargs,
**replay_buffer_kwargs,
)
# Convert train freq parameter to TrainFreq object
self._convert_train_freq()
35 changes: 26 additions & 9 deletions sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import gym
import gymnasium as gym
import jax
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.vec_env import VecEnv

from sbx.ppo.policies import Actor, Critic, PPOPolicy

OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithmJax")


class OnPolicyAlgorithmJax(OnPolicyAlgorithm):
policy: PPOPolicy # type: ignore[assignment]
actor: Actor
vf: Critic

def __init__(
self,
policy: Union[str, Type[BasePolicy]],
Expand All @@ -35,10 +42,10 @@ def __init__(
seed: Optional[int] = None,
device: str = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
policy=policy, # type: ignore[arg-type]
env=env,
learning_rate=learning_rate,
n_steps=n_steps,
Expand Down Expand Up @@ -68,7 +75,7 @@ def _excluded_save_params(self) -> List[str]:
excluded.remove("policy")
return excluded

def set_random_seed(self, seed: int) -> None:
def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
# Sample random seed
Expand Down Expand Up @@ -167,22 +174,32 @@ def collect_rollouts(
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0]
terminal_value = np.array(self.vf.apply(self.policy.vf_state.params, terminal_obs).flatten())
terminal_value = np.array(
self.vf.apply( # type: ignore[union-attr]
self.policy.vf_state.params,
terminal_obs,
).flatten()
)

rewards[idx] += self.gamma * terminal_value

rollout_buffer.add(
self._last_obs, # type: ignore[has-type]
self._last_obs, # type: ignore
actions,
rewards,
self._last_episode_starts, # type: ignore[has-type]
self._last_episode_starts, # type: ignore
th.as_tensor(values),
th.as_tensor(log_probs),
)
self._last_obs = new_obs
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones

values = np.array(self.vf.apply(self.policy.vf_state.params, self.policy.prepare_obs(new_obs)[0]).flatten())
values = np.array(
self.vf.apply( # type: ignore[union-attr]
self.policy.vf_state.params,
self.policy.prepare_obs(new_obs)[0], # type: ignore[arg-type]
).flatten()
)

rollout_buffer.compute_returns_and_advantage(last_values=th.as_tensor(values), dones=dones)

Expand Down
30 changes: 20 additions & 10 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# import copy
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union, no_type_check

import gym
import jax
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation
Expand All @@ -28,6 +28,7 @@ def sample_action(actor_state, obervations, key):
def select_action(actor_state, obervations):
return actor_state.apply_fn(actor_state.params, obervations).mode()

@no_type_check
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
Expand All @@ -44,7 +45,7 @@ def predict(
# Convert to numpy, and reshape to the original action shape
actions = np.array(actions).reshape((-1, *self.action_space.shape))

if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Clip due to numerical instability
actions = np.clip(actions, -1, 1)
Expand All @@ -57,15 +58,24 @@ def predict(

# Remove batch dimension if needed
if not vectorized_env:
actions = actions.squeeze(axis=0)
actions = actions.squeeze(axis=0) # type: ignore[call-overload]

return actions, state

def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[np.ndarray, bool]:
vectorized_env = False
if isinstance(observation, dict):
raise NotImplementedError()
# # need to copy the dict as the dict in VecFrameStack will become a torch tensor
assert isinstance(self.observation_space, spaces.Dict)
# Minimal dict support: flatten
keys = list(self.observation_space.keys())
vectorized_env = is_vectorized_observation(observation[keys[0]], self.observation_space[keys[0]])

# Add batch dim and concatenate
observation = np.concatenate(
[observation[key].reshape(-1, *self.observation_space[key].shape) for key in keys],
axis=1,
)
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
# observation = copy.deepcopy(observation)
# for key, obs in observation.items():
# obs_space = self.observation_space.spaces[key]
Expand All @@ -75,7 +85,7 @@ def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) ->
# obs_ = np.array(obs)
# vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# # Add batch dimension if needed
# observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape)
# observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))

elif is_image_space(self.observation_space):
# Handle the different cases for images
Expand All @@ -85,11 +95,11 @@ def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) ->
else:
observation = np.array(observation)

if not isinstance(observation, dict):
# Dict obs need to be handled separately
if not isinstance(self.observation_space, spaces.Dict):
assert isinstance(observation, np.ndarray)
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1, *self.observation_space.shape))
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]

assert isinstance(observation, np.ndarray)
return observation, vectorized_env
Expand Down
20 changes: 11 additions & 9 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import warnings
from typing import Any, Dict, Optional, Tuple, Type, Union

import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation
from stable_baselines3.common.utils import get_linear_fn

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
Expand All @@ -21,6 +20,7 @@ class DQN(OffPolicyAlgorithmJax):
}
# Linear schedule will be defined in `_setup_model()`
exploration_schedule: Schedule
policy: DQNPolicy

def __init__(
self,
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
verbose=verbose,
seed=seed,
sde_support=False,
supported_action_spaces=(gym.spaces.Discrete),
supported_action_spaces=(gym.spaces.Discrete,),
support_multi_env=True,
)

Expand Down Expand Up @@ -99,14 +99,16 @@ def _setup_model(self) -> None:

self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

if self.policy is None: # type: ignore[has-type]
self.policy = self.policy_class( # pytype:disable=not-instantiable
if not hasattr(self, "policy") or self.policy is None:
# pytype:disable=not-instantiable
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs, # pytype:disable=not-instantiable
**self.policy_kwargs,
)
assert isinstance(self.policy, DQNPolicy)
# pytype:enable=not-instantiable

self.key = self.policy.build(self.key, self.lr_schedule)
self.qf = self.policy.qf

Expand Down Expand Up @@ -244,7 +246,7 @@ def predict(
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
Expand Down
12 changes: 7 additions & 5 deletions sbx/dqn/policies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Callable, Dict, List, Optional, Union

import flax.linen as nn
import gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy
Expand All @@ -27,10 +27,12 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:


class DQNPolicy(BaseJaxPolicy):
action_space: spaces.Discrete # type: ignore[assignment]

def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
features_extractor_class=None,
Expand All @@ -54,12 +56,12 @@ def __init__(
else:
self.n_units = 256

def build(self, key, lr_schedule: Schedule) -> None:
def build(self, key: jax.random.KeyArray, lr_schedule: Schedule) -> jax.random.KeyArray:
key, qf_key = jax.random.split(key, 2)

obs = jnp.array([self.observation_space.sample()])

self.qf = QNetwork(n_actions=self.action_space.n, n_units=self.n_units)
self.qf = QNetwork(n_actions=int(self.action_space.n), n_units=self.n_units)

self.qf_state = RLTrainState.create(
apply_fn=self.qf.apply,
Expand Down
Loading