diff --git a/README.md b/README.md index 81349fd0..703bcd5f 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ See documentation for the full list of included features. **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) +- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 1d5ce800..f1b0ed55 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions, Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ TQC ✔️ ❌ ❌ ❌ ❌ +QR-DQN ️❌ ️✔️ ❌ ❌ ❌ ============ =========== ============ ================= =============== ================ diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index c535bfd4..12b7d718 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -16,6 +16,21 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment. model.learn(total_timesteps=10000, log_interval=4) model.save("tqc_pendulum") +QR-DQN +------ + +Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. + +.. code-block:: python + + from sb3_contrib import QRDQN + + policy_kwargs = dict(n_quantiles=50) + model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("qrdqn_cartpole") + + .. PyBullet: Normalizing input features .. ------------------------------------ .. diff --git a/docs/index.rst b/docs/index.rst index c198d4d5..f86c47ee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :caption: RL Algorithms modules/tqc + modules/qrdqn .. toctree:: :maxdepth: 1 diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1584caa1..6989a3fa 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -13,9 +13,11 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``TimeFeatureWrapper`` to the wrappers +- Added ``QR-DQN`` algorithm (`@ku2482`_) Bug Fixes: ^^^^^^^^^^ +- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles Deprecations: ^^^^^^^^^^^^^ @@ -24,6 +26,7 @@ Others: ^^^^^^^ - Updated ``TQC`` to match new SB3 version - Updated SB3 min version +- Moved ``quantile_huber_loss`` to ``common/utils.py`` (@ku2482) Documentation: ^^^^^^^^^^^^^^ @@ -62,13 +65,19 @@ Maintainers ----------- Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a), -`Maximilian Ernestus`_ (aka @erniejunior), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_). +`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_). .. _Ashley Hill: https://github.com/hill-a .. _Antonin Raffin: https://araffin.github.io/ -.. _Maximilian Ernestus: https://github.com/erniejunior +.. _Maximilian Ernestus: https://github.com/ernestum .. _Adam Gleave: https://gleave.me/ .. _@araffin: https://github.com/araffin .. _@AdamGleave: https://github.com/adamgleave .. _Anssi Kanervisto: https://github.com/Miffyli .. _@Miffyli: https://github.com/Miffyli +.. _@ku2482: https://github.com/ku2482 + +Contributors: +------------- + +@ku2482 diff --git a/docs/modules/qrdqn.rst b/docs/modules/qrdqn.rst new file mode 100644 index 00000000..f0944295 --- /dev/null +++ b/docs/modules/qrdqn.rst @@ -0,0 +1,150 @@ +.. _qrdqn: + +.. automodule:: sb3_contrib.qrdqn + + +QR-DQN +====== + +`Quantile Regression DQN (QR-DQN) `_ builds on `Deep Q-Network (DQN) `_ +and make use of quantile regression to explicitly model the `distribution over returns `_, +instead of predicting the mean return (DQN). + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1710.100442 +- Distributional RL (C51): https://arxiv.org/abs/1707.06887 + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ❌ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔ ✔ +Box ❌ ✔ +MultiDiscrete ❌ ✔ +MultiBinary ❌ ✔ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + + from sb3_contrib import QRDQN + + env = gym.make("CartPole-v1") + + policy_kwargs = dict(n_quantiles=50) + model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("qrdqn_cartpole") + + del model # remove to demonstrate saving and loading + + model = QRDQN.load("qrdqn_cartpole") + + obs = env.reset() + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render() + if done: + obs = env.reset() + + +Results +------- + +Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds. + +The complete learning curves are available in the `associated PR `_. + + +.. note:: + + QR-DQN implementation was validated against `Intel Coach `_ one + which roughly compare to the original paper results (we trained the agent with a smaller budget). + + +============ ========== =========== +Environments QR-DQN DQN +============ ========== =========== +Breakout 413 +/- 21 ~300 +Pong 20 +/- 0 ~20 +CartPole 386 +/- 64 500 +/- 0 +MountainCar -111 +/- 4 -107 +/- 4 +LunarLander 168 +/- 39 195 +/- 28 +Acrobot -73 +/- 2 -74 +/- 2 +============ ========== =========== + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone RL-Zoo fork and checkout the branch ``feat/qrdqn``: + +.. code-block:: bash + + git clone https://github.com/ku2482/rl-baselines3-zoo/ + cd rl-baselines3-zoo/ + git checkout feat/qrdqn + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results + python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN + + + +Parameters +---------- + +.. autoclass:: QRDQN + :members: + :inherited-members: + +.. _qrdqn_policies: + +QR-DQN Policies +--------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: sb3_contrib.qrdqn.policies.QRDQNPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index c66cfec5..8f253e12 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,6 +1,6 @@ import os -# from sb3_contrib.cmaes import CMAES +from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC # Read version from file diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py new file mode 100644 index 00000000..4a9e522d --- /dev/null +++ b/sb3_contrib/common/utils.py @@ -0,0 +1,69 @@ +from typing import Optional + +import torch as th + + +def quantile_huber_loss( + current_quantiles: th.Tensor, + target_quantiles: th.Tensor, + cum_prob: Optional[th.Tensor] = None, + sum_over_quantiles: bool = True, +) -> th.Tensor: + """ + The quantile-regression loss, as described in the QR-DQN and TQC papers. + Partially taken from https://github.com/bayesgroup/tqc_pytorch. + + :param current_quantiles: current estimate of quantiles, must be either + (batch_size, n_quantiles) or (batch_size, n_critics, n_quantiles) + :param target_quantiles: target of quantiles, must be either (batch_size, n_target_quantiles), + (batch_size, 1, n_target_quantiles), or (batch_size, n_critics, n_target_quantiles) + :param cum_prob: cumulative probabilities to calculate quantiles (also called midpoints in QR-DQN paper), + must be either (batch_size, n_quantiles), (batch_size, 1, n_quantiles), or (batch_size, n_critics, n_quantiles). + (if None, calculating unit quantiles) + :param sum_over_quantiles: if summing over the quantile dimension or not + :return: the loss + """ + if current_quantiles.ndim != target_quantiles.ndim: + raise ValueError( + f"Error: The dimension of curremt_quantile ({current_quantiles.ndim}) needs to match " + f"the dimension of target_quantiles ({target_quantiles.ndim})." + ) + if current_quantiles.shape[0] != target_quantiles.shape[0]: + raise ValueError( + f"Error: The batch size of curremt_quantile ({current_quantiles.shape[0]}) needs to match " + f"the batch size of target_quantiles ({target_quantiles.shape[0]})." + ) + if current_quantiles.ndim not in (2, 3): + raise ValueError(f"Error: The dimension of current_quantiles ({current_quantiles.ndim}) needs to be either 2 or 3.") + + if cum_prob is None: + n_quantiles = current_quantiles.shape[-1] + # Cumulative probabilities to calculate quantiles. + cum_prob = (th.arange(n_quantiles, device=current_quantiles.device, dtype=th.float) + 0.5) / n_quantiles + if current_quantiles.ndim == 2: + # For QR-DQN, current_quantiles have a shape (batch_size, n_quantiles), and make cum_prob + # broadcastable to (batch_size, n_quantiles, n_target_quantiles) + cum_prob = cum_prob.view(1, -1, 1) + elif current_quantiles.ndim == 3: + # For TQC, current_quantiles have a shape (batch_size, n_critics, n_quantiles), and make cum_prob + # broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles) + cum_prob = cum_prob.view(1, 1, -1, 1) + + # QR-DQN + # target_quantiles: (batch_size, n_target_quantiles) -> (batch_size, 1, n_target_quantiles) + # current_quantiles: (batch_size, n_quantiles) -> (batch_size, n_quantiles, 1) + # pairwise_delta: (batch_size, n_target_quantiles, n_quantiles) + # TQC + # target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles) + # current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1) + # pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles) + # Note: in both cases, the loss has the same shape as pairwise_delta + pairwise_delta = target_quantiles.unsqueeze(-2) - current_quantiles.unsqueeze(-1) + abs_pairwise_delta = th.abs(pairwise_delta) + huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5) + loss = th.abs(cum_prob - (pairwise_delta.detach() < 0).float()) * huber_loss + if sum_over_quantiles: + loss = loss.sum(dim=-2).mean() + else: + loss = loss.mean() + return loss diff --git a/sb3_contrib/qrdqn/__init__.py b/sb3_contrib/qrdqn/__init__.py new file mode 100644 index 00000000..1d16fae3 --- /dev/null +++ b/sb3_contrib/qrdqn/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy +from sb3_contrib.qrdqn.qrdqn import QRDQN diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py new file mode 100644 index 00000000..db95299b --- /dev/null +++ b/sb3_contrib/qrdqn/policies.py @@ -0,0 +1,249 @@ +from typing import Any, Dict, List, Optional, Type + +import gym +import torch as th +from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp +from stable_baselines3.common.type_aliases import Schedule +from torch import nn + + +class QuantileNetwork(BasePolicy): + """ + Quantile network for QR-DQN + + :param observation_space: Observation space + :param action_space: Action space + :param n_quantiles: Number of quantiles + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor: nn.Module, + features_dim: int, + n_quantiles: int = 200, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super(QuantileNetwork, self).__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + ) + + if net_arch is None: + net_arch = [64, 64] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.features_extractor = features_extractor + self.features_dim = features_dim + self.n_quantiles = n_quantiles + self.normalize_images = normalize_images + action_dim = self.action_space.n # number of actions + quantile_net = create_mlp(self.features_dim, action_dim * self.n_quantiles, self.net_arch, self.activation_fn) + self.quantile_net = nn.Sequential(*quantile_net) + + def forward(self, obs: th.Tensor) -> th.Tensor: + """ + Predict the quantiles. + + :param obs: Observation + :return: The estimated quantiles for each action. + """ + quantiles = self.quantile_net(self.extract_features(obs)) + return quantiles.view(-1, self.n_quantiles, self.action_space.n) + + def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor: + q_values = self.forward(observation).mean(dim=1) + # Greedy action + action = q_values.argmax(dim=1).reshape(-1) + return action + + def _get_data(self) -> Dict[str, Any]: + data = super()._get_data() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + n_quantiles=self.n_quantiles, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor, + ) + ) + return data + + +class QRDQNPolicy(BasePolicy): + """ + Policy class with quantile and target networks for QR-DQN. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_quantiles: Number of quantiles + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + n_quantiles: int = 200, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + + super(QRDQNPolicy, self).__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) + + if net_arch is None: + if features_extractor_class == FlattenExtractor: + net_arch = [64, 64] + else: + net_arch = [] + + self.n_quantiles = n_quantiles + self.net_arch = net_arch + self.activation_fn = activation_fn + self.normalize_images = normalize_images + + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "n_quantiles": self.n_quantiles, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + + self.quantile_net, self.quantile_net_target = None, None + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the network and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self.quantile_net = self.make_quantile_net() + self.quantile_net_target = self.make_quantile_net() + self.quantile_net_target.load_state_dict(self.quantile_net.state_dict()) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def make_quantile_net(self) -> QuantileNetwork: + # Make sure we always have separate networks for features extractors etc + net_args = self._update_features_extractor(self.net_args, features_extractor=None) + return QuantileNetwork(**net_args).to(self.device) + + def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: + return self.quantile_net._predict(obs, deterministic=deterministic) + + def _get_data(self) -> Dict[str, Any]: + data = super()._get_data() + + data.update( + dict( + n_quantiles=self.net_args["n_quantiles"], + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + +MlpPolicy = QRDQNPolicy + + +class CnnPolicy(QRDQNPolicy): + """ + Policy class for QR-DQN when using images as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param n_quantiles: Number of quantiles + :param net_arch: The specification of the network architecture. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + n_quantiles: int = 200, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(CnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + n_quantiles, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +register_policy("MlpPolicy", MlpPolicy) +register_policy("CnnPolicy", CnnPolicy) diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py new file mode 100644 index 00000000..155199a4 --- /dev/null +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -0,0 +1,253 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from stable_baselines3.common import logger +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update + +from sb3_contrib.common.utils import quantile_huber_loss +from sb3_contrib.qrdqn.policies import QRDQNPolicy + + +class QRDQN(OffPolicyAlgorithm): + """ + Quantile Regression Deep Q-Network (QR-DQN) + Paper: https://arxiv.org/abs/1710.10044 + Default hyperparameters are taken from the paper and are tuned for Atari games. + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Set to `-1` to disable. + :param gradient_steps: How many gradient steps to do after each rollout + (see ``train_freq`` and ``n_episodes_rollout``) + Set to ``-1`` means to do as many gradient steps as steps done in the environment + during the rollout. + :param n_episodes_rollout: Update the model every ``n_episodes_rollout`` episodes. + Note that this cannot be used at the same time as ``train_freq``. Set to `-1` to disable. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param target_update_interval: update the target network every ``target_update_interval`` + environment steps. + :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced + :param exploration_initial_eps: initial value of random action probability + :param exploration_final_eps: final value of random action probability + :param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping) + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[QRDQNPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 5e-5, + buffer_size: int = 1000000, + learning_starts: int = 50000, + batch_size: Optional[int] = 32, + tau: float = 1.0, + gamma: float = 0.99, + train_freq: int = 4, + gradient_steps: int = 1, + n_episodes_rollout: int = -1, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.005, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.01, + max_grad_norm: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(QRDQN, self).__init__( + policy, + env, + QRDQNPolicy, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + n_episodes_rollout, + action_noise=None, # No action noise + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + sde_support=False, + optimize_memory_usage=optimize_memory_usage, + supported_action_spaces=(gym.spaces.Discrete,), + ) + + self.exploration_initial_eps = exploration_initial_eps + self.exploration_final_eps = exploration_final_eps + self.exploration_fraction = exploration_fraction + self.target_update_interval = target_update_interval + self.max_grad_norm = max_grad_norm + # "epsilon" for the epsilon-greedy exploration + self.exploration_rate = 0.0 + # Linear schedule will be defined in `_setup_model()` + self.exploration_schedule = None + self.quantile_net, self.quantile_net_target = None, None + + if "optimizer_class" not in self.policy_kwargs: + self.policy_kwargs["optimizer_class"] = th.optim.Adam + # Proposed in the QR-DQN paper where `batch_size = 32` + self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size) + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super(QRDQN, self)._setup_model() + self._create_aliases() + self.exploration_schedule = get_linear_fn( + self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction + ) + + def _create_aliases(self) -> None: + self.quantile_net = self.policy.quantile_net + self.quantile_net_target = self.policy.quantile_net_target + self.n_quantiles = self.policy.n_quantiles + + def _on_step(self) -> None: + """ + Update the exploration rate and target network if needed. + This method is called in ``collect_rollouts()`` after each step in the environment. + """ + if self.num_timesteps % self.target_update_interval == 0: + polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) + + self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) + logger.record("rollout/exploration rate", self.exploration_rate) + + def train(self, gradient_steps: int, batch_size: int = 100) -> None: + # Update learning rate according to schedule + self._update_learning_rate(self.policy.optimizer) + + losses = [] + for gradient_step in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + with th.no_grad(): + # Compute the quantiles of next observation + next_quantiles = self.quantile_net_target(replay_data.next_observations) + # Follow greedy policy: use the one with the highest value + next_quantiles, _ = next_quantiles.max(dim=2) + # 1-step TD target + target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles + + # Get current quantile estimates + current_quantiles = self.quantile_net(replay_data.observations) + + # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1). + actions = replay_data.actions[..., None].long().expand(batch_size, self.n_quantiles, 1) + # Retrieve the quantiles for the actions from the replay buffer + current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2) + + # Compute Quantile Huber loss, summing over a quantile dimension as in the paper. + loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True) + losses.append(loss.item()) + + # Optimize the policy + self.policy.optimizer.zero_grad() + loss.backward() + # Clip gradient norm + if self.max_grad_norm is not None: + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + # Increase update counter + self._n_updates += gradient_steps + + logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + logger.record("train/loss", np.mean(losses)) + + def predict( + self, + observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Overrides the base_class predict function to include epsilon-greedy exploration. + + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state + (used in recurrent policies) + """ + if not deterministic and np.random.rand() < self.exploration_rate: + if is_vectorized_observation(observation, self.observation_space): + n_batch = observation.shape[0] + action = np.array([self.action_space.sample() for _ in range(n_batch)]) + else: + action = np.array(self.action_space.sample()) + else: + action, state = self.policy.predict(observation, state, mask, deterministic) + return action, state + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "QRDQN", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super(QRDQN, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) + + def _excluded_save_params(self) -> List[str]: + return super(QRDQN, self)._excluded_save_params() + ["quantile_net", "quantile_net_target"] + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy", "policy.optimizer"] + + return state_dicts, [] diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 4ad4b15c..3807d44c 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -391,6 +391,8 @@ def _get_data(self) -> Dict[str, Any]: optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs, + n_quantiles=self.critic_kwargs["n_quantiles"], + n_critics=self.critic_kwargs["n_critics"], ) ) return data diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index d6d1de9c..5bab2bd5 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -9,6 +9,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import polyak_update +from sb3_contrib.common.utils import quantile_huber_loss from sb3_contrib.tqc.policies import TQCPolicy @@ -171,26 +172,6 @@ def _create_aliases(self) -> None: self.critic = self.policy.critic self.critic_target = self.policy.critic_target - @staticmethod - def quantile_huber_loss(quantiles: th.Tensor, samples: th.Tensor) -> th.Tensor: - """ - The quantile-regression loss, as described in the QR-DQN and TQC papers. - Taken from https://github.com/bayesgroup/tqc_pytorch - - :param quantiles: - :param samples: - :return: the loss - """ - # batch x nets x quantiles x samples - pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None] - abs_pairwise_delta = th.abs(pairwise_delta) - huber_loss = th.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta ** 2 * 0.5) - - n_quantiles = quantiles.shape[2] - tau = th.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles - loss = (th.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() - return loss - def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Update optimizers learning rate optimizers = [self.actor.optimizer, self.critic.optimizer] @@ -237,24 +218,27 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self.ent_coef_optimizer.step() with th.no_grad(): - top_quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.critic.n_critics # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) # Compute and cut quantiles at the next state # batch x nets x quantiles - next_z = self.critic_target(replay_data.next_observations, next_actions) - sorted_z, _ = th.sort(next_z.reshape(batch_size, -1)) - sorted_z_part = sorted_z[:, : self.critic.quantiles_total - top_quantiles_to_drop] + next_quantiles = self.critic_target(replay_data.next_observations, next_actions) - target_q = sorted_z_part - ent_coef * next_log_prob.reshape(-1, 1) - # td error + entropy term - q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + # Sort and drop top k quantiles to control overestimation. + n_target_quantiles = self.critic.quantiles_total - self.top_quantiles_to_drop_per_net * self.critic.n_critics + next_quantiles, _ = th.sort(next_quantiles.reshape(batch_size, -1)) + next_quantiles = next_quantiles[:, :n_target_quantiles] - # Get current Q estimates - # using action from the replay buffer - current_z = self.critic(replay_data.observations, replay_data.actions) - # Compute critic loss - critic_loss = self.quantile_huber_loss(current_z, q_backup) + # td error + entropy term + target_quantiles = next_quantiles - ent_coef * next_log_prob.reshape(-1, 1) + target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_quantiles + # Make target_quantiles broadcastable to (batch_size, n_critics, n_target_quantiles). + target_quantiles.unsqueeze_(dim=1) + + # Get current Quantile estimates using action from the replay buffer + current_quantiles = self.critic(replay_data.observations, replay_data.actions) + # Compute critic loss, not summing over the quantile dimension as in the paper. + critic_loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=False) critic_losses.append(critic_loss.item()) # Optimize the critic diff --git a/setup.cfg b/setup.cfg index 08320b37..1f3fd5a7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ ignore = W503,W504,E203,E231 # line breaks before and after binary operators # Ignore import not used when aliases are defined per-file-ignores = ./sb3_contrib/__init__.py:F401 + ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 ./sb3_contrib/common/wrappers/__init__.py:F401 exclude = diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 0e390120..91b6c80f 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -7,10 +7,10 @@ from stable_baselines3.common.identity_env import FakeImageEnv from stable_baselines3.common.utils import zip_strict -from sb3_contrib import TQC +from sb3_contrib import QRDQN, TQC -@pytest.mark.parametrize("model_class", [TQC]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip @@ -18,10 +18,13 @@ def test_cnn(tmp_path, model_class): # to check that the network handle it automatically env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) kwargs = {} - if model_class in {TQC}: + if model_class in {TQC, QRDQN}: # Avoid memory error when using replay buffer - # Reduce the size of the features - kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) + # Reduce the size of the features and the number of quantiles + kwargs = dict( + buffer_size=250, + policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)), + ) model = model_class("CnnPolicy", env, **kwargs).learn(250) obs = env.reset() @@ -39,6 +42,13 @@ def test_cnn(tmp_path, model_class): os.remove(str(tmp_path / SAVE_NAME)) +def patch_qrdqn_names_(model): + # Small hack to make the test work with QRDQN + if isinstance(model, QRDQN): + model.critic = model.quantile_net + model.critic_target = model.quantile_net_target + + def params_should_match(params, other_params): for param, other_param in zip_strict(params, other_params): assert th.allclose(param, other_param) @@ -49,28 +59,36 @@ def params_should_differ(params, other_params): assert not th.allclose(param, other_param) -@pytest.mark.parametrize("model_class", [TQC]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN]) @pytest.mark.parametrize("share_features_extractor", [True, False]) def test_feature_extractor_target_net(model_class, share_features_extractor): + if model_class == QRDQN and share_features_extractor: + pytest.skip() + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) - # Avoid memory error when using replay buffer - # Reduce the size of the features - kwargs = dict( - buffer_size=250, - learning_starts=100, - policy_kwargs=dict( - features_extractor_kwargs=dict(features_dim=32), - share_features_extractor=share_features_extractor, - ), - ) + + if model_class in {TQC, QRDQN}: + # Avoid memory error when using replay buffer + # Reduce the size of the features and the number of quantiles + kwargs = dict( + buffer_size=250, + learning_starts=100, + policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)), + ) + if model_class != QRDQN: + kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor + model = model_class("CnnPolicy", env, seed=0, **kwargs) + patch_qrdqn_names_(model) + if share_features_extractor: # Check that the objects are the same and not just copied assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor) else: # Check that the objects differ - assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor) + if model_class != QRDQN: + assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor) # Critic and target should be equal at the begginning of training params_should_match(model.critic.parameters(), model.critic_target.parameters()) @@ -83,6 +101,8 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): # Re-initialize and collect some random data (without doing gradient steps) model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10) + patch_qrdqn_names_(model) + original_param = deepcopy(list(model.critic.parameters())) original_target_param = deepcopy(list(model.critic_target.parameters())) @@ -103,6 +123,11 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): model.lr_schedule = lambda _: 0.0 # Re-activate polyak update model.tau = 0.01 + # Special case for QRDQN: target net is updated in the `collect_rollouts()` + # not the `train()` method + if model_class == QRDQN: + model.target_update_interval = 1 + model._on_step() model.train(gradient_steps=1) diff --git a/tests/test_run.py b/tests/test_run.py index 8d599764..195d0114 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,6 @@ import pytest -from sb3_contrib import TQC +from sb3_contrib import QRDQN, TQC @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) @@ -21,7 +21,11 @@ def test_tqc(ent_coef): def test_n_critics(n_critics): # Test TQC with different number of critics model = TQC( - "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64], n_critics=n_critics), learning_starts=100, verbose=1 + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64], n_critics=n_critics), + learning_starts=100, + verbose=1, ) model.learn(total_timesteps=300) @@ -38,3 +42,17 @@ def test_sde(): model.learn(total_timesteps=300) model.policy.reset_noise() model.policy.actor.get_std() + + +def test_qrdqn(): + model = QRDQN( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(n_quantiles=25, net_arch=[64, 64]), + learning_starts=100, + buffer_size=500, + learning_rate=3e-4, + verbose=1, + create_eval_env=True, + ) + model.learn(total_timesteps=500, eval_freq=250) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 396bc624..bc16f16a 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -7,22 +7,21 @@ import numpy as np import pytest import torch as th -from stable_baselines3 import DQN from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import TQC +from sb3_contrib import QRDQN, TQC -MODEL_LIST = [TQC] +MODEL_LIST = [TQC, QRDQN] def select_env(model_class: BaseAlgorithm) -> gym.Env: """ - Selects an environment with the correct action space as DQN only supports discrete action space + Selects an environment with the correct action space as QRDQN only supports discrete action space """ - if model_class == DQN: + if model_class == QRDQN: return IdentityEnv(10) else: return IdentityEnvBox(10) @@ -41,8 +40,13 @@ def test_save_load(tmp_path, model_class): env = DummyVecEnv([lambda: select_env(model_class)]) + policy_kwargs = dict(net_arch=[16]) + + if model_class in {QRDQN, TQC}: + policy_kwargs.update(dict(n_quantiles=20)) + # create model - model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), verbose=1) + model = model_class("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs) model.learn(total_timesteps=300) env.reset() @@ -167,13 +171,18 @@ def test_set_env(model_class): :param model_class: (BaseAlgorithm) A RL model """ - # use discrete for DQN + # use discrete for QRDQN env = DummyVecEnv([lambda: select_env(model_class)]) env2 = DummyVecEnv([lambda: select_env(model_class)]) env3 = select_env(model_class) + kwargs = dict(policy_kwargs=dict(net_arch=[16])) + if model_class in {TQC, QRDQN}: + kwargs.update(dict(learning_starts=100)) + kwargs["policy_kwargs"].update(dict(n_quantiles=20)) + # create model - model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16])) + model = model_class("MlpPolicy", env, **kwargs) # learn model.learn(total_timesteps=300) @@ -219,7 +228,7 @@ def test_exclude_include_saved_params(tmp_path, model_class): os.remove(tmp_path / "test_save.zip") -@pytest.mark.parametrize("model_class", [TQC]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN]) def test_save_load_replay_buffer(tmp_path, model_class): path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning @@ -254,20 +263,28 @@ def test_save_load_policy(tmp_path, model_class, policy_str): :param model_class: (BaseAlgorithm) A RL model :param policy_str: (str) Name of the policy. """ - kwargs = {} + kwargs = dict(policy_kwargs=dict(net_arch=[16])) if policy_str == "MlpPolicy": env = select_env(model_class) else: - if model_class in [TQC]: + if model_class in [TQC, QRDQN]: # Avoid memory error when using replay buffer # Reduce the size of the features - kwargs = dict(buffer_size=250) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) + kwargs = dict( + buffer_size=250, + learning_starts=100, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) + + # Reduce number of quantiles for faster tests + if model_class in [TQC, QRDQN]: + kwargs["policy_kwargs"].update(dict(n_quantiles=20)) env = DummyVecEnv([lambda: env]) # create model - model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) + model = model_class(policy_str, env, verbose=1, **kwargs) model.learn(total_timesteps=300) env.reset() @@ -334,3 +351,83 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "policy.pkl") if actor_class is not None: os.remove(tmp_path / "actor.pkl") + + +@pytest.mark.parametrize("model_class", [QRDQN]) +@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) +def test_save_load_q_net(tmp_path, model_class, policy_str): + """ + Test saving and loading q-network/quantile net only. + + :param model_class: (BaseAlgorithm) A RL model + :param policy_str: (str) Name of the policy. + """ + kwargs = dict(policy_kwargs=dict(net_arch=[16])) + if policy_str == "MlpPolicy": + env = select_env(model_class) + else: + if model_class in [QRDQN]: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict( + buffer_size=250, + learning_starts=100, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) + + # Reduce number of quantiles for faster tests + if model_class in [QRDQN]: + kwargs["policy_kwargs"].update(dict(n_quantiles=20)) + + env = DummyVecEnv([lambda: env]) + + # create model + model = model_class(policy_str, env, verbose=1, **kwargs) + model.learn(total_timesteps=300) + + env.reset() + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + + q_net = model.quantile_net + q_net_class = q_net.__class__ + + # Get dictionary of current parameters + params = deepcopy(q_net.state_dict()) + + # Modify all parameters to be random values + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + + # Update model parameters with the new random values + q_net.load_state_dict(random_params) + + new_params = q_net.state_dict() + # Check that all params are different now + for k in params: + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." + + params = new_params + + # get selected actions + selected_actions, _ = q_net.predict(observations, deterministic=True) + + # Save and load q_net + q_net.save(tmp_path / "q_net.pkl") + + del q_net + + q_net = q_net_class.load(tmp_path / "q_net.pkl") + + # check if params are still the same after load + new_params = q_net.state_dict() + + # Check that all params are the same as before save load procedure now + for key in params: + assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." + + # check if model still selects the same actions + new_selected_actions, _ = q_net.predict(observations, deterministic=True) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) + + # clear file from os + os.remove(tmp_path / "q_net.pkl") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..97b740f4 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,19 @@ +import numpy as np +import pytest +import torch as th + +from sb3_contrib.common.utils import quantile_huber_loss + + +def test_quantile_huber_loss(): + assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10)), 2.5) + assert np.isclose(quantile_huber_loss(th.zeros(1, 10), th.ones(1, 10), sum_over_quantiles=False), 0.25) + + with pytest.raises(ValueError): + quantile_huber_loss(th.zeros(1, 4, 4), th.zeros(1, 4)) + with pytest.raises(ValueError): + quantile_huber_loss(th.zeros(1, 4), th.zeros(1, 1, 4)) + with pytest.raises(ValueError): + quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4)) + with pytest.raises(ValueError): + quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4))