Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix custom module for RNNModel and add tests #2088

Merged
merged 10 commits into from
Dec 10, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

### For users of the library:
**Improved**
- 🔴 Added `CustomRNNModule` and `CustomBlockRNNModule` for defining custom RNN modules that can be used with `RNNModel` and `BlockRNNModel`. The custom `model` must now be a subclass of the custom modules. [#2088](https://github.com/unit8co/darts/pull/2088) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed a bug in historical forecasts, where some `fit/predict_kwargs` were not passed to the underlying model's fit/predict methods. [#2103](https://github.com/unit8co/darts/pull/2103) by [Dennis Bader](https://github.com/dennisbader).
- Fixed an import error when trying to create a `TorchForecastingModel` with PyTorch Lightning v<2.0.0. [#2087](https://github.com/unit8co/darts/pull/2087) by [Eschibli](https://github.com/eschibli).
- Fixed a bug when creating a `RNNModel` with a custom `model`. [#2088](https://github.com/unit8co/darts/pull/2088) by [Dennis Bader](https://github.com/dennisbader).

### For developers of the library:

Expand Down
162 changes: 108 additions & 54 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
-------------------------------
"""

from typing import List, Optional, Tuple, Union
import inspect
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn

from darts.logging import get_logger, raise_if_not
from darts.logging import get_logger, raise_log
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
Expand All @@ -18,11 +20,9 @@
logger = get_logger(__name__)


# TODO add batch norm
class _BlockRNNModule(PLPastCovariatesModule):
class CustomBlockRNNModule(PLPastCovariatesModule, ABC):
def __init__(
self,
name: str,
input_size: int,
hidden_dim: int,
num_layers: int,
Expand All @@ -32,24 +32,22 @@ def __init__(
dropout: float = 0.0,
**kwargs,
):
"""This class allows to create custom block RNN modules that can later be used with Darts'
:class:`BlockRNNModel`. It adds the backbone that is required to be used with Darts'
:class:`TorchForecastingModel` and :class:`BlockRNNModel`.

"""PyTorch module implementing a block RNN to be used in `BlockRNNModel`.
To create a new module, subclass from :class:`CustomBlockRNNModule` and:

PyTorch module implementing a simple block RNN with the specified `name` layer.
This module combines a PyTorch RNN module, together with a fully connected network, which maps the
last hidden layers to output of the desired size `output_chunk_length` and makes it compatible with
`BlockRNNModel`s.
* Define the architecture in the module constructor (`__init__()`)

This module uses an RNN to encode the input sequence, and subsequently uses a fully connected
network as the decoder which takes as input the last hidden state of the encoder RNN.
The final output of the decoder is a sequence of length `output_chunk_length`. In this sense,
the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different
from `_RNNModule` used by the `RNNModel`).
* Add the `forward()` method and define the logic of your module's forward pass

* Use the custom module class when creating a new :class:`BlockRNNModel` with parameter `model`.

You can use `darts.models.forecasting.block_rnn_model._BlockRNNModule` as an example.

Parameters
----------
name
The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM").
input_size
The dimensionality of the input time series.
hidden_dim
Expand Down Expand Up @@ -78,28 +76,83 @@ def __init__(
y of shape `(batch_size, output_chunk_length, target_size, nr_params)`
Tensor containing the prediction at the last time step of the sequence.
"""

super().__init__(**kwargs)

# Defining parameters
self.input_size = input_size
self.hidden_dim = hidden_dim
self.n_layers = num_layers
self.num_layers = num_layers
self.target_size = target_size
self.nr_params = nr_params
num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc
self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc
self.dropout = dropout
self.out_len = self.output_chunk_length

@io_processor
@abstractmethod
def forward(self, x_in: Tuple) -> torch.Tensor:
pass


# TODO add batch norm
class _BlockRNNModule(CustomBlockRNNModule):
def __init__(
self,
name: str,
**kwargs,
):

"""PyTorch module implementing a block RNN to be used in `BlockRNNModel`.

PyTorch module implementing a simple block RNN with the specified `name` layer.
This module combines a PyTorch RNN module, together with a fully connected network, which maps the
last hidden layers to output of the desired size `output_chunk_length` and makes it compatible with
`BlockRNNModel`s.

This module uses an RNN to encode the input sequence, and subsequently uses a fully connected
network as the decoder which takes as input the last hidden state of the encoder RNN.
The final output of the decoder is a sequence of length `output_chunk_length`. In this sense,
the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different
from `_RNNModule` used by the `RNNModel`).

Parameters
----------
name
The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM").
**kwargs
all parameters required for the :class:`darts.model.forecasting_models.CustomBlockRNNModule` base class.

Inputs
------
x of shape `(batch_size, input_chunk_length, input_size, nr_params)`
Tensor containing the features of the input sequence.

Outputs
-------
y of shape `(batch_size, output_chunk_length, target_size, nr_params)`
Tensor containing the prediction at the last time step of the sequence.
"""

super().__init__(**kwargs)

self.name = name

# Defining the RNN module
self.rnn = getattr(nn, name)(
input_size, hidden_dim, num_layers, batch_first=True, dropout=dropout
self.rnn = getattr(nn, self.name)(
self.input_size,
self.hidden_dim,
self.num_layers,
batch_first=True,
dropout=self.dropout,
)

# The RNN module is followed by a fully connected layer, which maps the last hidden layer
# to the output of desired length
last = hidden_dim
last = self.hidden_dim
feats = []
for feature in num_layers_out_fc + [self.out_len * target_size * nr_params]:
for feature in self.num_layers_out_fc + [
self.out_len * self.target_size * self.nr_params
]:
feats.append(nn.Linear(last, feature))
last = feature
self.fc = nn.Sequential(*feats)
Expand Down Expand Up @@ -131,7 +184,7 @@ def __init__(
self,
input_chunk_length: int,
output_chunk_length: int,
model: Union[str, nn.Module] = "RNN",
model: Union[str, Type[CustomBlockRNNModule]] = "RNN",
hidden_dim: int = 25,
n_rnn_layers: int = 1,
hidden_fc_sizes: Optional[List] = None,
Expand Down Expand Up @@ -168,9 +221,8 @@ def __init__(
the model from using future values of past and / or future covariates for prediction (depending on the
model's covariate support).
model
Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"),
or a PyTorch module with the same specifications as
:class:`darts.models.block_rnn_model._BlockRNNModule`.
Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), or a subclass of
:class:`CustomBlockRNNModule` (the class itself, not an object of the class) with a custom logic.
hidden_dim
Size for feature maps for each hidden RNN layer (:math:`h_n`).
In Darts version <= 0.21, hidden_dim was referred as hidden_size.
Expand Down Expand Up @@ -276,7 +328,6 @@ def encode_year(idx):
"devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs``
dict:


- ``{"accelerator": "cpu"}`` for CPU,
- ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer),
- ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS.
Expand Down Expand Up @@ -357,14 +408,16 @@ def encode_year(idx):

# check we got right model type specified:
if model not in ["RNN", "LSTM", "GRU"]:
raise_if_not(
isinstance(model, nn.Module),
'{} is not a valid RNN model.\n Please specify "RNN", "LSTM", '
'"GRU", or give your own PyTorch nn.Module'.format(
model.__class__.__name__
),
logger,
)
if not inspect.isclass(model) or not issubclass(
model, CustomBlockRNNModule
):
raise_log(
ValueError(
"`model` is not a valid RNN model. Please specify 'RNN', 'LSTM', 'GRU', or give a subclass "
"(not an instance) of darts.models.forecasting.rnn_model.CustomBlockRNNModule."
),
logger=logger,
)

self.rnn_type_or_module = model
self.hidden_fc_sizes = hidden_fc_sizes
Expand All @@ -384,21 +437,22 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
output_dim = train_sample[-1].shape[1]
nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

if self.rnn_type_or_module in ["RNN", "LSTM", "GRU"]:
hidden_fc_sizes = (
[] if self.hidden_fc_sizes is None else self.hidden_fc_sizes
)
model = _BlockRNNModule(
name=self.rnn_type_or_module,
input_size=input_dim,
target_size=output_dim,
nr_params=nr_params,
hidden_dim=self.hidden_dim,
num_layers=self.n_rnn_layers,
num_layers_out_fc=hidden_fc_sizes,
dropout=self.dropout,
**self.pl_module_params,
)
hidden_fc_sizes = [] if self.hidden_fc_sizes is None else self.hidden_fc_sizes

kwargs = {}
if isinstance(self.rnn_type_or_module, str):
model_cls = _BlockRNNModule
kwargs["name"] = self.rnn_type_or_module
else:
model = self.rnn_type_or_module
return model
model_cls = self.rnn_type_or_module
return model_cls(
input_size=input_dim,
target_size=output_dim,
nr_params=nr_params,
hidden_dim=self.hidden_dim,
num_layers=self.n_rnn_layers,
num_layers_out_fc=hidden_fc_sizes,
dropout=self.dropout,
**self.pl_module_params,
**kwargs,
)
1 change: 0 additions & 1 deletion darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ def encode_year(idx):
"devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs``
dict:


- ``{"accelerator": "cpu"}`` for CPU,
- ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer),
- ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS.
Expand Down
1 change: 0 additions & 1 deletion darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,6 @@ def encode_year(idx):
"devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs``
dict:


- ``{"accelerator": "cpu"}`` for CPU,
- ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer),
- ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS.
Expand Down
1 change: 0 additions & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def encode_year(idx):
"devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs``
dict:


- ``{"accelerator": "cpu"}`` for CPU,
- ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer),
- ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS.
Expand Down
Loading
Loading