Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CNN support for DQN #49

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
from sbx.dqn.policies import DQNPolicy
from sbx.dqn.policies import CNNPolicy, DQNPolicy


class DQN(OffPolicyAlgorithmJax):
policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": DQNPolicy,
"CnnPolicy": CNNPolicy,
}
# Linear schedule will be defined in `_setup_model()`
exploration_schedule: Schedule
Expand All @@ -36,6 +37,7 @@ def __init__(
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
optimize_memory_usage: bool = False, # Note: unused but to match SB3 API
# max_grad_norm: float = 10,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
Expand Down
54 changes: 53 additions & 1 deletion sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return x


# Add CNN policy from DQN paper
class NatureCNN(nn.Module):
n_actions: int
n_units: int = 512
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# Convert from channel-first (PyTorch) to channel-last (Jax)
x = jnp.transpose(x, (0, 2, 3, 1))
# Convert to float and normalize the image
x = x.astype(jnp.float32) / 255.0
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = self.activation_fn(x)
# Flatten
x = x.reshape((x.shape[0], -1))
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_actions)(x)
return x


class DQNPolicy(BaseJaxPolicy):
action_space: spaces.Discrete # type: ignore[assignment]

Expand Down Expand Up @@ -65,7 +91,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:

obs = jnp.array([self.observation_space.sample()])

self.qf = QNetwork(
self.qf: nn.Module = QNetwork(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
Expand Down Expand Up @@ -97,3 +123,29 @@ def select_action(qf_state, observations):

def _predict(self, observation: np.ndarray, deterministic: bool = True) -> np.ndarray: # type: ignore[override]
return DQNPolicy.select_action(self.qf_state, observation)


class CNNPolicy(DQNPolicy):
def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)

obs = jnp.array([self.observation_space.sample()])

self.qf = NatureCNN(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
apply_fn=self.qf.apply,
params=self.qf.init({"params": qf_key}, obs),
target_params=self.qf.init({"params": qf_key}, obs),
tx=self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
),
)
self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign]

return key
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.16.0
0.17.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
## Example

```python
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

model = TQC("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)
Expand All @@ -40,7 +40,7 @@
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.3.0",
"stable_baselines3>=2.4.0a4,<3.0",
"jax",
"jaxlib",
"flax",
Expand Down
44 changes: 44 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
from stable_baselines3.common.envs import FakeImageEnv

from sbx import DQN


@pytest.mark.parametrize("model_class", [DQN])
def test_cnn(tmp_path, model_class):
SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1)
model = model_class(
"CnnPolicy",
env,
buffer_size=250,
policy_kwargs=dict(net_arch=[64]),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=250)

obs, _ = env.reset()

# Test stochastic predict with channel last input
if model_class == DQN:
model.exploration_rate = 0.9

for _ in range(10):
model.predict(obs, deterministic=False)

action, _ = model.predict(obs, deterministic=True)

model.save(tmp_path / SAVE_NAME)
del model

model = model_class.load(tmp_path / SAVE_NAME)

# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])

(tmp_path / SAVE_NAME).unlink()
Loading