Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QR-DQN #13

Merged
merged 34 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ad4f445
Add QR-DQN(WIP)
toshikwa Dec 8, 2020
0a84573
Update docstring
toshikwa Dec 8, 2020
5a72eba
Add quantile_huber_loss
toshikwa Dec 8, 2020
671d328
Fix typo
toshikwa Dec 8, 2020
04d0612
Merge branch 'master' into feat/qrdqn
araffin Dec 8, 2020
51a52cf
Remove unnecessary lines
toshikwa Dec 8, 2020
d94d583
Update variable names and comments in quantile_huber_loss
toshikwa Dec 8, 2020
8b36a21
Fix mutable arguments
toshikwa Dec 8, 2020
50e7e8d
Update variable names
toshikwa Dec 8, 2020
d456bc0
Merge branch 'master' into feat/qrdqn
araffin Dec 8, 2020
f55b8ad
Ignore import not used warnings
toshikwa Dec 9, 2020
f4ece75
Fix default parameter of optimizer in QR-DQN
toshikwa Dec 9, 2020
d67d5e8
Update quantile_huber_loss to have more reasonable interface
toshikwa Dec 9, 2020
92c8d10
Merge branch 'feat/qrdqn' of https://github.com/ku2482/stable-baselin…
toshikwa Dec 9, 2020
39d5bc7
update tests
toshikwa Dec 9, 2020
b335b37
Add assertion to quantile_huber_loss
toshikwa Dec 9, 2020
62c336a
Update variable names of quantile regression
toshikwa Dec 9, 2020
2f350e5
Update comments
toshikwa Dec 10, 2020
bedbc80
Reduce the number of quantiles during test
toshikwa Dec 10, 2020
11ae6b0
Update comment
toshikwa Dec 10, 2020
faeda56
Merge branch 'master' into feat/qrdqn
araffin Dec 13, 2020
d2b1ab7
Update quantile_huber_loss
toshikwa Dec 13, 2020
cd419da
Fix isort
toshikwa Dec 13, 2020
5449171
Add document of QR-DQN without results
toshikwa Dec 13, 2020
e0de065
Update docs
toshikwa Dec 13, 2020
147d3e8
Fix bugs
toshikwa Dec 13, 2020
4f31b17
Update doc
araffin Dec 19, 2020
b54b5d6
Add comments about shape
araffin Dec 19, 2020
29d1912
Minor edits
araffin Dec 19, 2020
eac6080
Update comments
toshikwa Dec 19, 2020
b27ed43
Add benchmark
araffin Dec 20, 2020
fe9f015
Doc fixes
araffin Dec 20, 2020
f213f5e
Update doc
araffin Dec 21, 2020
3a53ee1
Bug fix in saving/loading + update tests
araffin Dec 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

# from sb3_contrib.cmaes import CMAES
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC

# Read version from file
Expand Down
31 changes: 31 additions & 0 deletions sb3_contrib/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch as th


def quantile_huber_loss(current_quantile: th.Tensor, target_quantile: th.Tensor) -> th.Tensor:
"""
The quantile-regression loss, as described in the QR-DQN and TQC papers.
Partially taken from https://github.com/bayesgroup/tqc_pytorch

:param current_quantile: current estimate of quantile value
:param target_quantile: target quantile value
:return: the loss
"""
n_quantiles = current_quantile.shape[-1]
# Cumulative probabilities to calculate quantile values.
cum_prob = (th.arange(n_quantiles, device=current_quantile.device).float() + 0.5) / n_quantiles

if current_quantile.ndim == 3:
araffin marked this conversation as resolved.
Show resolved Hide resolved
# For TQC. current_quantile has a shape (batch_size, n_critics, n_quantiles).
cum_prob = cum_prob.view(1, 1, -1, 1)
pairwise_delta = target_quantile[:, None, None, :] - current_quantile[:, :, :, None]
elif current_quantile.ndim == 2:
# For QR-DQN. current_quantile has a shape (batch_size, n_quantiles).
cum_prob = cum_prob.view(1, -1, 1)
pairwise_delta = target_quantile[:, None, :] - current_quantile[:, :, None]
else:
NotImplementedError

abs_pairwise_delta = th.abs(pairwise_delta)
huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5)
loss = (th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss).mean()
return loss
2 changes: 2 additions & 0 deletions sb3_contrib/qrdqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy
araffin marked this conversation as resolved.
Show resolved Hide resolved
from sb3_contrib.qrdqn.qrdqn import QRDQN
248 changes: 248 additions & 0 deletions sb3_contrib/qrdqn/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import Any, Dict, List, Optional, Type

import gym
import torch as th
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import Schedule
from torch import nn


class QuantileNetwork(BasePolicy):
"""
Action-Quantile (Q-Value) network for QR-DQN
:param observation_space: Observation space
:param action_space: Action space
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""

def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor: nn.Module,
features_dim: int,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(QuantileNetwork, self).__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)

if net_arch is None:
net_arch = [64, 64]

self.net_arch = net_arch
self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim
self.n_quantiles = n_quantiles
self.normalize_images = normalize_images
action_dim = self.action_space.n # number of actions
quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn)
self.quantile_net = nn.Sequential(*quantile_net)

def forward(self, obs: th.Tensor) -> th.Tensor:
"""
Predict the quantile-values.
:param obs: Observation
:return: The estimated Quantile-Value for each action.
"""
quantile_values = self.quantile_net(self.extract_features(obs))
return quantile_values.view(-1, self.n_quantiles, self.action_space.n)

def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self.forward(observation).mean(dim=1)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
return action

def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
n_quantiles=self.n_quantiles,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
epsilon=self.epsilon,
)
)
return data


class QRDQNPolicy(BasePolicy):
"""
Policy class with Q-Value Net and target net for DQN
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):

super(QRDQNPolicy, self).__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
)

if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [64, 64]
else:
net_arch = []

self.n_quantiles = n_quantiles
self.net_arch = net_arch
self.activation_fn = activation_fn
self.normalize_images = normalize_images

self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_quantiles": self.n_quantiles,
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}

if optimizer_class is th.optim.Adam and "eps" not in self.optimizer_kwargs:
araffin marked this conversation as resolved.
Show resolved Hide resolved
self.optimizer_kwargs.update(eps=0.01 / 32) # 32 is a minibatch size

self.quantile_net, self.quantile_net_target = None, None
self._build(lr_schedule)

def _build(self, lr_schedule: Schedule) -> None:
"""
Create the network and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self.quantile_net = self.make_quantile_net()
self.quantile_net_target = self.make_quantile_net()
self.quantile_net_target.load_state_dict(self.quantile_net.state_dict())

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

def make_quantile_net(self) -> QuantileNetwork:
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QuantileNetwork(**net_args).to(self.device)

def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)

def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self.quantile_net._predict(obs, deterministic=deterministic)

def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(
dict(
n_quantiles=self.net_args["n_quantiles"],
net_arch=self.net_args["net_arch"],
activation_fn=self.net_args["activation_fn"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data


MlpPolicy = QRDQNPolicy


class CnnPolicy(QRDQNPolicy):
"""
Policy class for QR-DQN when using images as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param n_quantiles: Number of quantiles
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
n_quantiles: int = 200,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(CnnPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
n_quantiles,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)


register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
Loading