From 7cdc9b12f726e2608a869d5e1a74472a63bdde99 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Thu, 23 Nov 2023 11:58:49 +0100 Subject: [PATCH 1/8] fix custom module for RNNModel and add tests --- darts/models/forecasting/rnn_model.py | 11 +- darts/tests/models/forecasting/test_RNN.py | 119 +++++++++++++++++++++ 2 files changed, 120 insertions(+), 10 deletions(-) create mode 100644 darts/tests/models/forecasting/test_RNN.py diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index 22a4f25cec..4d16ff8435 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -478,16 +478,7 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: **self.pl_module_params, ) else: - model = self.rnn_type_or_module( - name="custom_module", - input_size=input_dim, - target_size=output_dim, - nr_params=nr_params, - hidden_dim=self.hidden_dim, - dropout=self.dropout, - num_layers=self.n_rnn_layers, - **self.pl_module_params, - ) + model = self.rnn_type_or_module return model def _build_train_dataset( diff --git a/darts/tests/models/forecasting/test_RNN.py b/darts/tests/models/forecasting/test_RNN.py new file mode 100644 index 0000000000..83cb3a5a8a --- /dev/null +++ b/darts/tests/models/forecasting/test_RNN.py @@ -0,0 +1,119 @@ +import numpy as np +import pandas as pd +import pytest + +from darts import TimeSeries +from darts.logging import get_logger +from darts.tests.conftest import tfm_kwargs + +logger = get_logger(__name__) + +try: + from darts.models.forecasting.rnn_model import RNNModel, _RNNModule + + TORCH_AVAILABLE = True +except ImportError: + logger.warning("Torch not available. RNN tests will be skipped.") + TORCH_AVAILABLE = False + + +if TORCH_AVAILABLE: + + class TestRNNModel: + times = pd.date_range("20130101", "20130410") + pd_series = pd.Series(range(100), index=times) + series: TimeSeries = TimeSeries.from_series(pd_series) + module = _RNNModule( + name="RNN", + input_chunk_length=1, + output_chunk_length=1, + input_size=1, + hidden_dim=25, + num_layers=1, + target_size=1, + nr_params=1, + dropout=0, + ) + + def test_creation(self): + with pytest.raises(ValueError): + # cannot choose any string + RNNModel( + input_chunk_length=1, output_chunk_length=1, model="UnknownRNN?" + ) + # can give a custom module + model1 = RNNModel( + input_chunk_length=1, output_chunk_length=1, model=self.module + ) + model2 = RNNModel(input_chunk_length=1, output_chunk_length=1, model="RNN") + assert model1.model.__repr__() == model2.model.__repr__() + + def test_fit(self, tmpdir_module): + # Test basic fit() + model = RNNModel( + input_chunk_length=1, output_chunk_length=1, n_epochs=2, **tfm_kwargs + ) + model.fit(self.series) + + # Test fit-save-load cycle + model2 = RNNModel( + input_chunk_length=1, + output_chunk_length=1, + model="LSTM", + n_epochs=1, + model_name="unittest-model-lstm", + work_dir=tmpdir_module, + save_checkpoints=True, + force_reset=True, + **tfm_kwargs + ) + model2.fit(self.series) + model_loaded = model2.load_from_checkpoint( + model_name="unittest-model-lstm", + work_dir=tmpdir_module, + best=False, + map_location="cpu", + ) + pred1 = model2.predict(n=6) + pred2 = model_loaded.predict(n=6) + + # Two models with the same parameters should deterministically yield the same output + np.testing.assert_array_equal(pred1.values(), pred2.values()) + + # Another random model should not + model3 = RNNModel( + input_chunk_length=1, + output_chunk_length=1, + model="RNN", + n_epochs=2, + **tfm_kwargs + ) + model3.fit(self.series) + pred3 = model3.predict(n=6) + assert not np.array_equal(pred1.values(), pred3.values()) + + # test short predict + pred4 = model3.predict(n=1) + assert len(pred4) == 1 + + # test validation series input + model3.fit(self.series[:60], val_series=self.series[60:]) + pred4 = model3.predict(n=6) + assert len(pred4) == 6 + + def helper_test_pred_length(self, pytorch_model, series): + model = pytorch_model( + input_chunk_length=1, output_chunk_length=3, n_epochs=1, **tfm_kwargs + ) + model.fit(series) + pred = model.predict(7) + assert len(pred) == 7 + pred = model.predict(2) + assert len(pred) == 2 + assert pred.width == 1 + pred = model.predict(4) + assert len(pred) == 4 + assert pred.width == 1 + + def test_pred_length(self): + self.helper_test_pred_length(RNNModel, self.series) From a248ef10eefc404bb2ddf03bb8b4ef3cc24da095 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Thu, 23 Nov 2023 12:03:50 +0100 Subject: [PATCH 2/8] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6646ee60f3..d0729e0af6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** **Fixed** +- Fixed a bug when creating a `RNNModel` with a custom `model`. [#2088](https://github.com/unit8co/darts/pull/2088) by [Dennis Bader](https://github.com/dennisbader). ### For developers of the library: - ## [0.27.0](https://github.com/unit8co/darts/tree/0.27.0) (2023-11-18) ### For users of the library: **Improved** From b2de60ea4bf5c4f6ecb3a73e2f340fd09a54b6c3 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Fri, 1 Dec 2023 11:36:22 +0100 Subject: [PATCH 3/8] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index da547f029a..d2d90f1374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For developers of the library: + ## [0.27.0](https://github.com/unit8co/darts/tree/0.27.0) (2023-11-18) ### For users of the library: **Improved** From ec7e4ff4d16f3d90cf89d98742308f11557f0c9a Mon Sep 17 00:00:00 2001 From: dennisbader Date: Fri, 1 Dec 2023 15:26:12 +0100 Subject: [PATCH 4/8] make custom rnn module --- darts/models/forecasting/rnn_model.py | 161 ++++++++++++++------- darts/tests/models/forecasting/test_RNN.py | 77 +++++++++- 2 files changed, 175 insertions(+), 63 deletions(-) diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index 4d16ff8435..eeae3fd2c4 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -3,12 +3,14 @@ ------------------------- """ -from typing import Optional, Sequence, Tuple, Union +import inspect +from abc import ABC, abstractmethod +from typing import Optional, Sequence, Tuple, Type, Union import torch import torch.nn as nn -from darts.logging import get_logger, raise_if_not +from darts.logging import get_logger, raise_if_not, raise_log from darts.models.forecasting.pl_forecasting_module import ( PLDualCovariatesModule, io_processor, @@ -20,11 +22,9 @@ logger = get_logger(__name__) -# TODO add batch norm -class _RNNModule(PLDualCovariatesModule): +class CustomRNNModule(PLDualCovariatesModule, ABC): def __init__( self, - name: str, input_size: int, hidden_dim: int, num_layers: int, @@ -33,7 +33,6 @@ def __init__( dropout: float = 0.0, **kwargs, ): - """PyTorch module implementing an RNN to be used in `RNNModel`. PyTorch module implementing a simple RNN with the specified `name` type. @@ -43,8 +42,6 @@ def __init__( Parameters ---------- - name - The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). input_size The dimensionality of the input time series. hidden_dim @@ -72,42 +69,23 @@ def __init__( During training the whole tensor is used as output, whereas during prediction we only use y[:, -1, :]. However, this module always returns the whole Tensor. """ - # RNNModule doesn't really need input and output_chunk_length for PLModule super().__init__(**kwargs) # Defining parameters + self.input_size = input_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers self.target_size = target_size self.nr_params = nr_params - self.name = name - - # Defining the RNN module - self.rnn = getattr(nn, name)( - input_size, hidden_dim, num_layers, batch_first=True, dropout=dropout - ) - - # The RNN module needs a linear layer V that transforms hidden states into outputs, individually - self.V = nn.Linear(hidden_dim, target_size * nr_params) + self.dropout = dropout @io_processor + @abstractmethod def forward( self, x_in: Tuple, h: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - x, _ = x_in - # data is of size (batch_size, input_length, input_size) - batch_size = x.shape[0] - - # out is of size (batch_size, input_length, hidden_dim) - out, last_hidden_state = self.rnn(x) if h is None else self.rnn(x, h) - - # Here, we apply the V matrix to every hidden state to produce the outputs - predictions = self.V(out) - - # predictions is of size (batch_size, input_length, target_size) - predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) - - # returns outputs for all inputs, only the last one is needed for prediction time - return predictions, last_hidden_state + pass def _produce_train_output(self, input_batch: Tuple) -> torch.Tensor: ( @@ -204,11 +182,82 @@ def _get_batch_prediction( return batch_prediction +# TODO add batch norm +class _RNNModule(CustomRNNModule): + def __init__( + self, + name: str, + **kwargs, + ): + """PyTorch module implementing an RNN to be used in `RNNModel`. + + PyTorch module implementing a simple RNN with the specified `name` type. + This module combines a PyTorch RNN module, together with one fully connected layer which + maps the hidden state of the RNN at each step to the output value of the model at that + time step. + + Parameters + ---------- + name + The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). + **kwargs + all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class. + + Inputs + ------ + x of shape `(batch_size, input_length, input_size)` + Tensor containing the features of the input sequence. The `input_length` is not fixed. + + Outputs + ------- + y of shape `(batch_size, output_chunk_length, target_size, nr_params)` + Tensor containing the outputs of the RNN at every time step of the input sequence. + During training the whole tensor is used as output, whereas during prediction we only use y[:, -1, :]. + However, this module always returns the whole Tensor. + """ + + # RNNModule doesn't really need input and output_chunk_length for PLModule + super().__init__(**kwargs) + self.name = name + + # Defining the RNN module + self.rnn = getattr(nn, name)( + self.input_size, + self.hidden_dim, + self.num_layers, + batch_first=True, + dropout=self.dropout, + ) + + # The RNN module needs a linear layer V that transforms hidden states into outputs, individually + self.V = nn.Linear(self.hidden_dim, self.target_size * self.nr_params) + + @io_processor + def forward( + self, x_in: Tuple, h: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x, _ = x_in + # data is of size (batch_size, input_length, input_size) + batch_size = x.shape[0] + + # out is of size (batch_size, input_length, hidden_dim) + out, last_hidden_state = self.rnn(x) if h is None else self.rnn(x, h) + + # Here, we apply the V matrix to every hidden state to produce the outputs + predictions = self.V(out) + + # predictions is of size (batch_size, input_length, target_size) + predictions = predictions.view(batch_size, -1, self.target_size, self.nr_params) + + # returns outputs for all inputs, only the last one is needed for prediction time + return predictions, last_hidden_state + + class RNNModel(DualCovariatesTorchModel): def __init__( self, input_chunk_length: int, - model: Union[str, nn.Module] = "RNN", + model: Union[str, Type[CustomRNNModule]] = "RNN", hidden_dim: int = 25, n_rnn_layers: int = 1, dropout: float = 0.0, @@ -442,14 +491,14 @@ def encode_year(idx): # check we got right model type specified: if model not in ["RNN", "LSTM", "GRU"]: - raise_if_not( - isinstance(model, nn.Module), - '{} is not a valid RNN model.\n Please specify "RNN", "LSTM", ' - '"GRU", or give your own PyTorch nn.Module'.format( - model.__class__.__name__ - ), - logger, - ) + if not inspect.isclass(model) or not issubclass(model, CustomRNNModule): + raise_log( + ValueError( + "`model` is not a valid RNN model. Please specify RNN', 'LSTM', 'GRU', or give a subclass " + "(not an instance) of darts.models.forecasting.rnn_model.CustomRNNModule." + ), + logger=logger, + ) self.rnn_type_or_module = model self.dropout = dropout @@ -466,20 +515,22 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: output_dim = train_sample[-1].shape[1] nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters - if self.rnn_type_or_module in ["RNN", "LSTM", "GRU"]: - model = _RNNModule( - name=self.rnn_type_or_module, - input_size=input_dim, - target_size=output_dim, - nr_params=nr_params, - hidden_dim=self.hidden_dim, - dropout=self.dropout, - num_layers=self.n_rnn_layers, - **self.pl_module_params, - ) + kwargs = {} + if isinstance(self.rnn_type_or_module, str): + model_cls = _RNNModule + kwargs["name"] = self.rnn_type_or_module else: - model = self.rnn_type_or_module - return model + model_cls = self.rnn_type_or_module + return model_cls( + input_size=input_dim, + target_size=output_dim, + nr_params=nr_params, + hidden_dim=self.hidden_dim, + dropout=self.dropout, + num_layers=self.n_rnn_layers, + **self.pl_module_params, + **kwargs, + ) def _build_train_dataset( self, diff --git a/darts/tests/models/forecasting/test_RNN.py b/darts/tests/models/forecasting/test_RNN.py index 83cb3a5a8a..3508cb9e3d 100644 --- a/darts/tests/models/forecasting/test_RNN.py +++ b/darts/tests/models/forecasting/test_RNN.py @@ -9,7 +9,9 @@ logger = get_logger(__name__) try: - from darts.models.forecasting.rnn_model import RNNModel, _RNNModule + import torch.nn as nn + + from darts.models.forecasting.rnn_model import CustomRNNModule, RNNModel, _RNNModule TORCH_AVAILABLE = True except ImportError: @@ -19,11 +21,28 @@ if TORCH_AVAILABLE: + class ModuleValid1(_RNNModule): + """Wrapper around the _RNNModule""" + + def __init__(self, **kwargs): + super().__init__(name="RNN", **kwargs) + + class ModuleValid2(CustomRNNModule): + """Just a linear layer.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.linear = nn.Linear(self.input_size, self.target_size) + + def forward(self, x_in, h=None): + x = self.linear(x_in[0]) + return x.view(len(x), -1, self.target_size, self.nr_params) + class TestRNNModel: times = pd.date_range("20130101", "20130410") pd_series = pd.Series(range(100), index=times) series: TimeSeries = TimeSeries.from_series(pd_series) - module = _RNNModule( + module_invalid = _RNNModule( name="RNN", input_chunk_length=1, output_chunk_length=1, @@ -36,17 +55,59 @@ class TestRNNModel: ) def test_creation(self): - with pytest.raises(ValueError): - # cannot choose any string + # cannot choose any string + with pytest.raises(ValueError) as msg: RNNModel( input_chunk_length=1, output_chunk_length=1, model="UnknownRNN?" ) - # can give a custom module + assert str(msg.value).startswith("`model` is not a valid RNN model.") + + # cannot create from a class instance + with pytest.raises(ValueError) as msg: + _ = RNNModel( + input_chunk_length=1, + output_chunk_length=1, + model=self.module_invalid, + ) + assert str(msg.value).startswith("`model` is not a valid RNN model.") + + # can create from valid module name model1 = RNNModel( - input_chunk_length=1, output_chunk_length=1, model=self.module + input_chunk_length=1, + output_chunk_length=1, + model="RNN", + n_epochs=1, + random_state=42, + **tfm_kwargs ) - model2 = RNNModel(input_chunk_length=1, output_chunk_length=1, model="RNN") - assert model1.model.__repr__() == model2.model.__repr__() + model1.fit(self.series) + preds1 = model1.predict(n=3) + + # can create from a custom class itself + model2 = RNNModel( + input_chunk_length=1, + output_chunk_length=1, + model=ModuleValid1, + n_epochs=1, + random_state=42, + **tfm_kwargs + ) + model2.fit(self.series) + preds2 = model2.predict(n=3) + np.testing.assert_array_equal(preds1.all_values(), preds2.all_values()) + + model3 = RNNModel( + input_chunk_length=1, + output_chunk_length=1, + model=ModuleValid2, + n_epochs=1, + random_state=42, + **tfm_kwargs + ) + model3.fit(self.series) + preds3 = model2.predict(n=3) + assert preds3.all_values().shape == preds2.all_values().shape + assert preds3.time_index.equals(preds2.time_index) def test_fit(self, tmpdir_module): # Test basic fit() From fd84f51128a24bc2142d58dd16afb1aaf983b5e6 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Fri, 1 Dec 2023 15:40:40 +0100 Subject: [PATCH 5/8] update rnn docs --- darts/models/forecasting/block_rnn_model.py | 1 - darts/models/forecasting/dlinear.py | 1 - darts/models/forecasting/nbeats.py | 1 - darts/models/forecasting/nhits.py | 1 - darts/models/forecasting/rnn_model.py | 20 ++++++++++--------- darts/models/forecasting/tcn_model.py | 1 - darts/models/forecasting/tft_model.py | 1 - darts/models/forecasting/tide_model.py | 1 - .../forecasting/torch_forecasting_model.py | 1 - darts/models/forecasting/transformer_model.py | 1 - 10 files changed, 11 insertions(+), 18 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index afd3cbf503..489a80453f 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -276,7 +276,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/dlinear.py b/darts/models/forecasting/dlinear.py index 4673550a41..f4f013ad6e 100644 --- a/darts/models/forecasting/dlinear.py +++ b/darts/models/forecasting/dlinear.py @@ -373,7 +373,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/nbeats.py b/darts/models/forecasting/nbeats.py index 76daceac7a..662fc3c9a5 100644 --- a/darts/models/forecasting/nbeats.py +++ b/darts/models/forecasting/nbeats.py @@ -698,7 +698,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index 94adc5f0ef..7b3ce2daa3 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -634,7 +634,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index eeae3fd2c4..cea93c2cbd 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -33,12 +33,16 @@ def __init__( dropout: float = 0.0, **kwargs, ): - """PyTorch module implementing an RNN to be used in `RNNModel`. + """This class allows to create custom RNN modules that can later be used with Darts' `RNNModel`. + It adds the backbone that is required to be used with Darts' `TorchForecastingModel` and `RNNModel`. - PyTorch module implementing a simple RNN with the specified `name` type. - This module combines a PyTorch RNN module, together with one fully connected layer which - maps the hidden state of the RNN at each step to the output value of the model at that - time step. + To create a new module, create a subclass and: + + * Define the architecture in the module constructor (`__init__()`) + + * Add the `forward()` method and define the logic of your module's forward pass + + You can use `darts.models.forecasting.rnn_model._RNNModule` as an example. Parameters ---------- @@ -295,9 +299,8 @@ def __init__( input_chunk_length Number of past time steps that are fed to the forecasting module at prediction time. model - Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), - or a PyTorch module with the same specifications as - `darts.models.rnn_model._RNNModule`. + Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), or a subclass of + :class:`CustomRNNModule` (the class itself, not an object of the class) with a custom logic. hidden_dim Size for feature maps for each hidden RNN layer (:math:`h_n`). n_rnn_layers @@ -402,7 +405,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index 7f53b8781c..69119500f6 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -400,7 +400,6 @@ def encode_year(idx): dict:rgs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index e6e503f974..2a1cada3e8 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -849,7 +849,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/tide_model.py b/darts/models/forecasting/tide_model.py index 3005020268..604aef7eba 100644 --- a/darts/models/forecasting/tide_model.py +++ b/darts/models/forecasting/tide_model.py @@ -523,7 +523,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index c37cc2e7e6..fc7f92dd4f 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -250,7 +250,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 775c2e05ca..7113b44336 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -487,7 +487,6 @@ def encode_year(idx): "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. From 69ad82b59d9943e5de424293c7ff104b51cb8bd5 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sun, 3 Dec 2023 14:08:10 +0100 Subject: [PATCH 6/8] add custom block rnn module --- darts/models/forecasting/block_rnn_model.py | 161 ++++++++++++------ darts/models/forecasting/rnn_model.py | 13 +- .../models/forecasting/test_block_RNN.py | 79 ++++++++- 3 files changed, 187 insertions(+), 66 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index 489a80453f..923f882f65 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -3,12 +3,14 @@ ------------------------------- """ -from typing import List, Optional, Tuple, Union +import inspect +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn -from darts.logging import get_logger, raise_if_not +from darts.logging import get_logger, raise_log from darts.models.forecasting.pl_forecasting_module import ( PLPastCovariatesModule, io_processor, @@ -18,11 +20,9 @@ logger = get_logger(__name__) -# TODO add batch norm -class _BlockRNNModule(PLPastCovariatesModule): +class CustomBlockRNNModule(PLPastCovariatesModule, ABC): def __init__( self, - name: str, input_size: int, hidden_dim: int, num_layers: int, @@ -32,24 +32,22 @@ def __init__( dropout: float = 0.0, **kwargs, ): + """This class allows to create custom block RNN modules that can later be used with Darts' + :class:`BlockRNNModel`. It adds the backbone that is required to be used with Darts' + :class:`TorchForecastingModel` and :class:`BlockRNNModel`. - """PyTorch module implementing a block RNN to be used in `BlockRNNModel`. + To create a new module, subclass from :class:`CustomBlockRNNModule` and: - PyTorch module implementing a simple block RNN with the specified `name` layer. - This module combines a PyTorch RNN module, together with a fully connected network, which maps the - last hidden layers to output of the desired size `output_chunk_length` and makes it compatible with - `BlockRNNModel`s. + * Define the architecture in the module constructor (`__init__()`) - This module uses an RNN to encode the input sequence, and subsequently uses a fully connected - network as the decoder which takes as input the last hidden state of the encoder RNN. - The final output of the decoder is a sequence of length `output_chunk_length`. In this sense, - the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different - from `_RNNModule` used by the `RNNModel`). + * Add the `forward()` method and define the logic of your module's forward pass + + * Use the custom module class when creating a new :class:`BlockRNNModel` with parameter `model`. + + You can use `darts.models.forecasting.rnn_model._RNNModule` as an example. Parameters ---------- - name - The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). input_size The dimensionality of the input time series. hidden_dim @@ -78,28 +76,83 @@ def __init__( y of shape `(batch_size, output_chunk_length, target_size, nr_params)` Tensor containing the prediction at the last time step of the sequence. """ - super().__init__(**kwargs) # Defining parameters + self.input_size = input_size self.hidden_dim = hidden_dim - self.n_layers = num_layers + self.num_layers = num_layers self.target_size = target_size self.nr_params = nr_params - num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc + self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc + self.dropout = dropout self.out_len = self.output_chunk_length + + @io_processor + @abstractmethod + def forward(self, x_in: Tuple) -> torch.Tensor: + pass + + +# TODO add batch norm +class _BlockRNNModule(CustomBlockRNNModule): + def __init__( + self, + name: str, + **kwargs, + ): + + """PyTorch module implementing a block RNN to be used in `BlockRNNModel`. + + PyTorch module implementing a simple block RNN with the specified `name` layer. + This module combines a PyTorch RNN module, together with a fully connected network, which maps the + last hidden layers to output of the desired size `output_chunk_length` and makes it compatible with + `BlockRNNModel`s. + + This module uses an RNN to encode the input sequence, and subsequently uses a fully connected + network as the decoder which takes as input the last hidden state of the encoder RNN. + The final output of the decoder is a sequence of length `output_chunk_length`. In this sense, + the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different + from `_RNNModule` used by the `RNNModel`). + + Parameters + ---------- + name + The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). + **kwargs + all parameters required for the :class:`darts.model.forecasting_models.CustomBlockRNNModule` base class. + + Inputs + ------ + x of shape `(batch_size, input_chunk_length, input_size, nr_params)` + Tensor containing the features of the input sequence. + + Outputs + ------- + y of shape `(batch_size, output_chunk_length, target_size, nr_params)` + Tensor containing the prediction at the last time step of the sequence. + """ + + super().__init__(**kwargs) + self.name = name # Defining the RNN module - self.rnn = getattr(nn, name)( - input_size, hidden_dim, num_layers, batch_first=True, dropout=dropout + self.rnn = getattr(nn, self.name)( + self.input_size, + self.hidden_dim, + self.num_layers, + batch_first=True, + dropout=self.dropout, ) # The RNN module is followed by a fully connected layer, which maps the last hidden layer # to the output of desired length - last = hidden_dim + last = self.hidden_dim feats = [] - for feature in num_layers_out_fc + [self.out_len * target_size * nr_params]: + for feature in self.num_layers_out_fc + [ + self.out_len * self.target_size * self.nr_params + ]: feats.append(nn.Linear(last, feature)) last = feature self.fc = nn.Sequential(*feats) @@ -131,7 +184,7 @@ def __init__( self, input_chunk_length: int, output_chunk_length: int, - model: Union[str, nn.Module] = "RNN", + model: Union[str, Type[CustomBlockRNNModule]] = "RNN", hidden_dim: int = 25, n_rnn_layers: int = 1, hidden_fc_sizes: Optional[List] = None, @@ -168,9 +221,8 @@ def __init__( the model from using future values of past and / or future covariates for prediction (depending on the model's covariate support). model - Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), - or a PyTorch module with the same specifications as - :class:`darts.models.block_rnn_model._BlockRNNModule`. + Either a string specifying the RNN module type ("RNN", "LSTM" or "GRU"), or a subclass of + :class:`CustomBlockRNNModule` (the class itself, not an object of the class) with a custom logic. hidden_dim Size for feature maps for each hidden RNN layer (:math:`h_n`). In Darts version <= 0.21, hidden_dim was referred as hidden_size. @@ -356,14 +408,16 @@ def encode_year(idx): # check we got right model type specified: if model not in ["RNN", "LSTM", "GRU"]: - raise_if_not( - isinstance(model, nn.Module), - '{} is not a valid RNN model.\n Please specify "RNN", "LSTM", ' - '"GRU", or give your own PyTorch nn.Module'.format( - model.__class__.__name__ - ), - logger, - ) + if not inspect.isclass(model) or not issubclass( + model, CustomBlockRNNModule + ): + raise_log( + ValueError( + "`model` is not a valid RNN model. Please specify 'RNN', 'LSTM', 'GRU', or give a subclass " + "(not an instance) of darts.models.forecasting.rnn_model.CustomBlockRNNModule." + ), + logger=logger, + ) self.rnn_type_or_module = model self.hidden_fc_sizes = hidden_fc_sizes @@ -383,21 +437,22 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: output_dim = train_sample[-1].shape[1] nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters - if self.rnn_type_or_module in ["RNN", "LSTM", "GRU"]: - hidden_fc_sizes = ( - [] if self.hidden_fc_sizes is None else self.hidden_fc_sizes - ) - model = _BlockRNNModule( - name=self.rnn_type_or_module, - input_size=input_dim, - target_size=output_dim, - nr_params=nr_params, - hidden_dim=self.hidden_dim, - num_layers=self.n_rnn_layers, - num_layers_out_fc=hidden_fc_sizes, - dropout=self.dropout, - **self.pl_module_params, - ) + hidden_fc_sizes = [] if self.hidden_fc_sizes is None else self.hidden_fc_sizes + + kwargs = {} + if isinstance(self.rnn_type_or_module, str): + model_cls = _BlockRNNModule + kwargs["name"] = self.rnn_type_or_module else: - model = self.rnn_type_or_module - return model + model_cls = self.rnn_type_or_module + return model_cls( + input_size=input_dim, + target_size=output_dim, + nr_params=nr_params, + hidden_dim=self.hidden_dim, + num_layers=self.n_rnn_layers, + num_layers_out_fc=hidden_fc_sizes, + dropout=self.dropout, + **self.pl_module_params, + **kwargs, + ) diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index cea93c2cbd..d4d90a2bc1 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -33,15 +33,18 @@ def __init__( dropout: float = 0.0, **kwargs, ): - """This class allows to create custom RNN modules that can later be used with Darts' `RNNModel`. - It adds the backbone that is required to be used with Darts' `TorchForecastingModel` and `RNNModel`. + """This class allows to create custom RNN modules that can later be used with Darts' :class:`RNNModel`. + It adds the backbone that is required to be used with Darts' :class:`TorchForecastingModel` and + :class:`RNNModel`. - To create a new module, create a subclass and: + To create a new module, subclass from :class:`CustomRNNModule` and: * Define the architecture in the module constructor (`__init__()`) * Add the `forward()` method and define the logic of your module's forward pass + * Use the custom module class when creating a new :class:`RNNModel` with parameter `model`. + You can use `darts.models.forecasting.rnn_model._RNNModule` as an example. Parameters @@ -205,7 +208,7 @@ def __init__( name The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM"). **kwargs - all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class. + all parameters required for the :class:`darts.model.forecasting_models.CustomRNNModule` base class. Inputs ------ @@ -496,7 +499,7 @@ def encode_year(idx): if not inspect.isclass(model) or not issubclass(model, CustomRNNModule): raise_log( ValueError( - "`model` is not a valid RNN model. Please specify RNN', 'LSTM', 'GRU', or give a subclass " + "`model` is not a valid RNN model. Please specify 'RNN', 'LSTM', 'GRU', or give a subclass " "(not an instance) of darts.models.forecasting.rnn_model.CustomRNNModule." ), logger=logger, diff --git a/darts/tests/models/forecasting/test_block_RNN.py b/darts/tests/models/forecasting/test_block_RNN.py index 676252d18c..1aa8a6ff2d 100644 --- a/darts/tests/models/forecasting/test_block_RNN.py +++ b/darts/tests/models/forecasting/test_block_RNN.py @@ -9,7 +9,13 @@ logger = get_logger(__name__) try: - from darts.models.forecasting.block_rnn_model import BlockRNNModel, _BlockRNNModule + import torch.nn as nn + + from darts.models.forecasting.block_rnn_model import ( + BlockRNNModel, + CustomBlockRNNModule, + _BlockRNNModule, + ) TORCH_AVAILABLE = True except ImportError: @@ -19,11 +25,28 @@ if TORCH_AVAILABLE: + class ModuleValid1(_BlockRNNModule): + """Wrapper around the _BlockRNNModule""" + + def __init__(self, **kwargs): + super().__init__(name="RNN", **kwargs) + + class ModuleValid2(CustomBlockRNNModule): + """Just a linear layer.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.linear = nn.Linear(self.input_size, self.target_size) + + def forward(self, x_in): + x = self.linear(x_in[0]) + return x.view(len(x), -1, self.target_size, self.nr_params) + class TestBlockRNNModel: times = pd.date_range("20130101", "20130410") pd_series = pd.Series(range(100), index=times) series: TimeSeries = TimeSeries.from_series(pd_series) - module = _BlockRNNModule( + module_invalid = _BlockRNNModule( "RNN", input_size=1, input_chunk_length=1, @@ -37,19 +60,59 @@ class TestBlockRNNModel: ) def test_creation(self): - with pytest.raises(ValueError): - # cannot choose any string + # cannot choose any string + with pytest.raises(ValueError) as msg: BlockRNNModel( input_chunk_length=1, output_chunk_length=1, model="UnknownRNN?" ) - # can give a custom module + assert str(msg.value).startswith("`model` is not a valid RNN model.") + + # cannot create from a class instance + with pytest.raises(ValueError) as msg: + _ = BlockRNNModel( + input_chunk_length=1, + output_chunk_length=1, + model=self.module_invalid, + ) + assert str(msg.value).startswith("`model` is not a valid RNN model.") + + # can create from valid module name model1 = BlockRNNModel( - input_chunk_length=1, output_chunk_length=1, model=self.module + input_chunk_length=1, + output_chunk_length=1, + model="RNN", + n_epochs=1, + random_state=42, + **tfm_kwargs ) + model1.fit(self.series) + preds1 = model1.predict(n=3) + + # can create from a custom class itself model2 = BlockRNNModel( - input_chunk_length=1, output_chunk_length=1, model="RNN" + input_chunk_length=1, + output_chunk_length=1, + model=ModuleValid1, + n_epochs=1, + random_state=42, + **tfm_kwargs ) - assert model1.model.__repr__() == model2.model.__repr__() + model2.fit(self.series) + preds2 = model2.predict(n=3) + np.testing.assert_array_equal(preds1.all_values(), preds2.all_values()) + + model3 = BlockRNNModel( + input_chunk_length=1, + output_chunk_length=1, + model=ModuleValid2, + n_epochs=1, + random_state=42, + **tfm_kwargs + ) + model3.fit(self.series) + preds3 = model2.predict(n=3) + assert preds3.all_values().shape == preds2.all_values().shape + assert preds3.time_index.equals(preds2.time_index) def test_fit(self, tmpdir_module): # Test basic fit() From facbc4e9a5561bcc069c5c576def77002d7eeedf Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sun, 3 Dec 2023 14:23:31 +0100 Subject: [PATCH 7/8] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2d90f1374..fdf92e0e88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For users of the library: **Improved** +- 🔴 Added `CustomRNNModule` and `CustomBlockRNNModule` for defining custom RNN modules that can be used with `RNNModel` and `BlockRNNModel`. The custom `model` must now be a subclass of the custom modules. [#2088](https://github.com/unit8co/darts/pull/2088) by [Dennis Bader](https://github.com/dennisbader). **Fixed** - Fixed an import error when trying to create a `TorchForecastingModel` with PyTorch Lightning v<2.0.0. [#2087](https://github.com/unit8co/darts/pull/2087) by [Eschibli](https://github.com/eschibli). From 4c7ed2e130b3dc31cd31659d25681270cd08943e Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sun, 10 Dec 2023 13:42:43 +0100 Subject: [PATCH 8/8] fix docs for BlockRNNModel --- darts/models/forecasting/block_rnn_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index edb5c248ea..db41dd9455 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -44,7 +44,7 @@ def __init__( * Use the custom module class when creating a new :class:`BlockRNNModel` with parameter `model`. - You can use `darts.models.forecasting.rnn_model._RNNModule` as an example. + You can use `darts.models.forecasting.block_rnn_model._BlockRNNModule` as an example. Parameters ----------