-
Notifications
You must be signed in to change notification settings - Fork 904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/pytorch lightning #702
Merged
Merged
Changes from 3 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
f7afe81
first draft proposal for pytorch lightning integration
dennisbader c026bf7
further cleanup
dennisbader dbef703
removed unnecessary file
dennisbader 706fd35
fix for multiple-TS
dennisbader 3ee8873
moved prediction timeseries generation back to TorchForecastingModel
dennisbader 93b4bf2
support for custom trainer in fit()
dennisbader 8c1b231
removed unused methods from TorchForecastingModel
dennisbader 333cd5f
checkpoint loading now correctly resumes training
dennisbader ed42108
Merge branch 'master' into feat/pytorch_lightning
dennisbader facba8c
Merge branch 'master' into feat/pytorch_lightning
dennisbader df6b8d5
rewrote TorchForecastingModel
dennisbader 9939422
rewrote TFTModel
dennisbader 7d3f24d
rewrote rnn models
dennisbader e3bb9c2
rewrote nbeats models
dennisbader 4cb7ec5
rewrote tcn model
dennisbader 0ec8245
rewrote transformer model
dennisbader 8408273
removed unused import
dennisbader 2b36b84
resolve failing tests part 1
dennisbader d4db950
resolve failing tests part 2
dennisbader d3350d3
adapted the way how model parameters are saved
dennisbader 1c16b72
moved TFTModel predict method into TorchForecastingModel subclass
dennisbader 51b38b1
further simplification of model calls
dennisbader 2905024
integrated ProbabilisticTorchForecastingModel into PLForecastingModule
dennisbader 21ee164
integrated _produce_predict_output into PLForecastingModule
dennisbader 8c3e94c
reintegrated original random state handling
dennisbader 75f7194
removed unused pl random state wrapper function
dennisbader 7102c6d
use OrderedDict for savety in model parameter extraction
dennisbader f48b61c
made TFM and PLFM paramater extraction generic
dennisbader 7e7eccf
added types for variables in TFM init
dennisbader 93ec255
made predictions deterministic for same fit predict process for non-l…
dennisbader 83e67d6
Merge branch 'master' into feat/pytorch_lightning
dennisbader 0774c6e
fix flake8 issues
dennisbader e196c97
fix flake8 issues part 2
dennisbader 4c983fb
added pytorch-lightning to torch requirements
dennisbader 57d11cf
Merge branch 'master' into feat/pytorch_lightning
dennisbader 6277292
fixed loading models with wrong precision
dennisbader 84cbd61
fixed is_probabilistic()
dennisbader 4e0c8b5
fixed failing tests for epoch count tracker
dennisbader 1ddd5ec
Merge branch 'master' into feat/pytorch_lightning
dennisbader 4c3e268
removed input/output_chunk_length from TorchForecastingModel __init__
dennisbader 2eb240f
unit tests save models to temp dir
dennisbader 2dd7841
added documentation for ModelMeta
dennisbader 984ea88
apply suggestions from PR review part 1
dennisbader 771fbdf
deprecated `torch_device_str`
dennisbader 634524f
updated optimizer docs
dennisbader d6274a8
updated retrain warning
dennisbader fbf05d2
Merge branch 'master' into feat/pytorch_lightning
dennisbader 0ba7f65
made PLMixedCovariatesModule more generic
dennisbader 970db09
added docs
dennisbader 11e7681
added PTL trainer unit tests
dennisbader 7a084ac
update model docs
dennisbader c8b4bff
fixed broken url in TFM and covariates userguide
dennisbader f0e4e30
removed input/output chunk length from PL modules
dennisbader 3b7f846
relaxed pytorch-lightning requirement
dennisbader 7290775
Merge branch 'master' into feat/pytorch_lightning
dennisbader 43803b9
isort
dennisbader fa5a6db
Merge branch 'master' into feat/pytorch_lightning
dennisbader 73e3ff4
isort part 2
dennisbader 0d6c105
Merge branch 'master' into feat/pytorch_lightning
dennisbader File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,383 @@ | ||
""" | ||
This file contains abstract classes for deterministic and probabilistic pytorch-lightning modules | ||
""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from joblib import Parallel, delayed | ||
from typing import Any, Optional, Dict, Tuple, Union, Sequence | ||
from abc import ABC, abstractmethod | ||
import torch | ||
from torch import Tensor | ||
import torch.nn as nn | ||
|
||
from darts.timeseries import TimeSeries | ||
from darts.utils.timeseries_generation import _generate_index | ||
from darts.models.forecasting.ptl_torch_forecasting_model import TorchForecastingModel | ||
|
||
from darts.utils.likelihood_models import Likelihood | ||
from darts.logging import get_logger, raise_log, raise_if | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
# TODO: better names | ||
class PLTorchForecastingModel(pl.LightningModule, ABC): | ||
|
||
@abstractmethod | ||
def __init__(self, | ||
loss_fn: nn.modules.loss._Loss = nn.MSELoss(), | ||
optimizer_cls: torch.optim.Optimizer = torch.optim.Adam, | ||
optimizer_kwargs: Optional[Dict] = None, | ||
lr_scheduler_cls: torch.optim.lr_scheduler._LRScheduler = None, | ||
lr_scheduler_kwargs: Optional[Dict] = None) -> None: | ||
|
||
super(PLTorchForecastingModel, self).__init__() | ||
|
||
# Define the loss function | ||
self.criterion = loss_fn | ||
|
||
# Persist optimiser and LR scheduler parameters | ||
self.optimizer_cls = optimizer_cls | ||
self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs | ||
self.lr_scheduler_cls = lr_scheduler_cls | ||
self.lr_scheduler_kwargs = dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs | ||
|
||
# by default models are deterministic (i.e. not probabilistic) | ||
self.likelihood = None | ||
|
||
# TODO: make better | ||
# initialize prediction settings | ||
self.pred_n: Optional[int] = None | ||
self.pred_num_samples: Optional[int] = None | ||
self.pred_n_jobs: Optional[int] = None | ||
self.pred_roll_size: Optional[int] = None | ||
self.pred_batch_size: Optional[int] = None | ||
|
||
@property | ||
def first_prediction_index(self) -> int: | ||
""" | ||
Returns the index of the first predicted within the output of self.model. | ||
""" | ||
return 0 | ||
|
||
@abstractmethod | ||
def forward(self, *args, **kwargs) -> Any: | ||
super(PLTorchForecastingModel, self).forward(*args, **kwargs) | ||
|
||
def training_step(self, train_batch, batch_idx) -> Any: | ||
output = self._produce_train_output(train_batch[:-1]) | ||
target = train_batch[-1] # By convention target is always the last element returned by datasets | ||
loss = self._compute_loss(output, target) | ||
self.log('train_loss', loss, batch_size=train_batch[0].shape[0]) | ||
return loss | ||
|
||
def validation_step(self, val_batch, batch_idx) -> Any: | ||
output = self._produce_train_output(val_batch[:-1]) | ||
target = val_batch[-1] | ||
loss = self._compute_loss(output, target) | ||
self.log('val_loss', loss, batch_size=val_batch[0].shape[0]) | ||
return loss | ||
|
||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: | ||
input_data_tuple, batch_input_series = batch[:-1], batch[-1] | ||
|
||
# number of individual series to be predicted in current batch | ||
num_series = input_data_tuple[0].shape[0] | ||
|
||
# number of of times the input tensor should be tiled to produce predictions for multiple samples | ||
# this variable is larger than 1 only if the batch_size is at least twice as large as the number | ||
# of individual time series being predicted in current batch (`num_series`) | ||
batch_sample_size = min(max(self.pred_batch_size // num_series, 1), self.pred_num_samples) | ||
|
||
# counts number of produced prediction samples for every series to be predicted in current batch | ||
sample_count = 0 | ||
|
||
# repeat prediction procedure for every needed sample | ||
batch_predictions = [] | ||
while sample_count < self.pred_num_samples: | ||
|
||
# make sure we don't produce too many samples | ||
if sample_count + batch_sample_size > self.pred_num_samples: | ||
batch_sample_size = self.pred_num_samples - sample_count | ||
|
||
# stack multiple copies of the tensors to produce probabilistic forecasts | ||
input_data_tuple_samples = self._sample_tiling(input_data_tuple, batch_sample_size) | ||
|
||
# get predictions for 1 whole batch (can include predictions of multiple series | ||
# and for multiple samples if a probabilistic forecast is produced) | ||
batch_prediction = self._get_batch_prediction(self.pred_n, input_data_tuple_samples, self.pred_roll_size) | ||
|
||
# reshape from 3d tensor (num_series x batch_sample_size, ...) | ||
# into 4d tensor (batch_sample_size, num_series, ...), where dim 0 represents the samples | ||
out_shape = batch_prediction.shape | ||
batch_prediction = batch_prediction.reshape((batch_sample_size, num_series,) + out_shape[1:]) | ||
|
||
# save all predictions and update the `sample_count` variable | ||
batch_predictions.append(batch_prediction) | ||
sample_count += batch_sample_size | ||
|
||
# concatenate the batch of samples, to form self.pred_num_samples samples | ||
batch_predictions = torch.cat(batch_predictions, dim=0) | ||
batch_predictions = batch_predictions.cpu().detach().numpy() | ||
|
||
# create `TimeSeries` objects from prediction tensors | ||
ts_forecasts = Parallel(n_jobs=self.pred_n_jobs)( | ||
delayed(self._build_forecast_series)( | ||
[batch_prediction[batch_idx] for batch_prediction in batch_predictions], input_series | ||
) | ||
for batch_idx, input_series in enumerate(batch_input_series) | ||
) | ||
return ts_forecasts | ||
|
||
def on_predict_end(self) -> None: | ||
self.pred_n = None | ||
self.pred_num_samples = None | ||
self.pred_n_jobs = None | ||
self.pred_roll_size = None | ||
self.pred_batch_size = None | ||
|
||
def _compute_loss(self, output, target): | ||
return self.criterion(output, target) | ||
|
||
def configure_optimizers(self): | ||
"""sets up optimizers""" | ||
|
||
# TODO: i think we can move this to to pl.Trainer(). and could probably be simplified | ||
|
||
# A utility function to create optimizer and lr scheduler from desired classes | ||
def _create_from_cls_and_kwargs(cls, kws): | ||
try: | ||
return cls(**kws) | ||
except (TypeError, ValueError) as e: | ||
raise_log(ValueError('Error when building the optimizer or learning rate scheduler;' | ||
'please check the provided class and arguments' | ||
'\nclass: {}' | ||
'\narguments (kwargs): {}' | ||
'\nerror:\n{}'.format(cls, kws, e)), | ||
logger) | ||
|
||
# Create the optimizer and (optionally) the learning rate scheduler | ||
# we have to create copies because we cannot save model.parameters into object state (not serializable) | ||
optimizer_kws = {k: v for k, v in self.optimizer_kwargs.items()} | ||
optimizer_kws['params'] = self.parameters() | ||
|
||
optimizer = _create_from_cls_and_kwargs(self.optimizer_cls, optimizer_kws) | ||
|
||
if self.lr_scheduler_cls is not None: | ||
lr_sched_kws = {k: v for k, v in self.lr_scheduler_kwargs.items()} | ||
lr_sched_kws['optimizer'] = optimizer | ||
lr_scheduler = _create_from_cls_and_kwargs(self.lr_scheduler_cls, lr_sched_kws) | ||
return [optimizer], [lr_scheduler] | ||
else: | ||
return optimizer | ||
|
||
@abstractmethod | ||
def _produce_train_output(self, input_batch: Tuple) -> Tensor: | ||
pass | ||
|
||
@abstractmethod | ||
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor: | ||
""" | ||
In charge of apply the recurrent logic for non-recurrent models. | ||
Should be overwritten by recurrent models. | ||
""" | ||
pass | ||
|
||
# TODO: had to copy these methods over from ForecastingModel as it is not a parent class. | ||
# Maybe let TorchForecastingModel handle the _build_forecast_series() | ||
def _build_forecast_series(self, | ||
points_preds: Union[np.ndarray, Sequence[np.ndarray]], | ||
input_series: TimeSeries) -> TimeSeries: | ||
""" | ||
Builds a forecast time series starting after the end of the training time series, with the | ||
correct time index (or after the end of the input series, if specified). | ||
""" | ||
|
||
time_index_length = len(points_preds) if isinstance(points_preds, np.ndarray) else len(points_preds[0]) | ||
time_index = self._generate_new_dates(time_index_length, input_series=input_series) | ||
if isinstance(points_preds, np.ndarray): | ||
return TimeSeries.from_times_and_values(time_index, | ||
points_preds, | ||
freq=input_series.freq_str, | ||
columns=input_series.columns) | ||
|
||
return TimeSeries.from_times_and_values(time_index, | ||
np.stack(points_preds, axis=2), | ||
freq=input_series.freq_str, | ||
columns=input_series.columns) | ||
|
||
@staticmethod | ||
def _generate_new_dates(n: int, | ||
input_series: TimeSeries) -> Union[pd.DatetimeIndex, pd.RangeIndex]: | ||
""" | ||
Generates `n` new dates after the end of the specified series | ||
""" | ||
last = input_series.end_time() | ||
start = last + input_series.freq if input_series.has_datetime_index else last + 1 | ||
return _generate_index(start=start, freq=input_series.freq, length=n) | ||
|
||
@staticmethod | ||
def _sample_tiling(input_data_tuple, batch_sample_size): | ||
tiled_input_data = [] | ||
for tensor in input_data_tuple: | ||
if tensor is not None: | ||
tiled_input_data.append(tensor.tile((batch_sample_size, 1, 1))) | ||
else: | ||
tiled_input_data.append(None) | ||
return tuple(tiled_input_data) | ||
|
||
|
||
class PLPastCovariatesTorchModel(PLTorchForecastingModel, ABC): | ||
def _produce_train_output(self, input_batch: Tuple): | ||
past_target, past_covariate = input_batch | ||
# Currently all our PastCovariates models require past target and covariates concatenated | ||
inpt = torch.cat([past_target, past_covariate], dim=2) if past_covariate is not None else past_target | ||
return self.model(inpt) | ||
|
||
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> torch.Tensor: | ||
""" | ||
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset to farecast | ||
the next `n` target values per target variable. | ||
|
||
Parameters: | ||
---------- | ||
n | ||
prediction length | ||
input_batch | ||
(past_target, past_covariates, future_past_covariates) | ||
roll_size | ||
roll input arrays after every sequence by `roll_size`. Initially, `roll_size` is equivalent to | ||
`self.output_chunk_length` | ||
""" | ||
dim_component = 2 | ||
past_target, past_covariates, future_past_covariates = input_batch | ||
|
||
n_targets = past_target.shape[dim_component] | ||
n_past_covs = past_covariates.shape[dim_component] if not past_covariates is None else 0 | ||
|
||
input_past = torch.cat( | ||
[ds for ds in [past_target, past_covariates] if ds is not None], | ||
dim=dim_component | ||
) | ||
|
||
out = self._produce_predict_output(input_past)[:, self.first_prediction_index:, :] | ||
|
||
batch_prediction = [out[:, :roll_size, :]] | ||
prediction_length = roll_size | ||
|
||
while prediction_length < n: | ||
# we want the last prediction to end exactly at `n` into the future. | ||
# this means we may have to truncate the previous prediction and step | ||
# back the roll size for the last chunk | ||
if prediction_length + self.output_chunk_length > n: | ||
spillover_prediction_length = prediction_length + self.output_chunk_length - n | ||
roll_size -= spillover_prediction_length | ||
prediction_length -= spillover_prediction_length | ||
batch_prediction[-1] = batch_prediction[-1][:, :roll_size, :] | ||
|
||
# ==========> PAST INPUT <========== | ||
# roll over input series to contain latest target and covariate | ||
input_past = torch.roll(input_past, -roll_size, 1) | ||
|
||
# update target input to include next `roll_size` predictions | ||
if self.input_chunk_length >= roll_size: | ||
input_past[:, -roll_size:, :n_targets] = out[:, :roll_size, :] | ||
else: | ||
input_past[:, :, :n_targets] = out[:, -self.input_chunk_length:, :] | ||
|
||
# set left and right boundaries for extracting future elements | ||
if self.input_chunk_length >= roll_size: | ||
left_past, right_past = prediction_length - roll_size, prediction_length | ||
else: | ||
left_past, right_past = prediction_length - self.input_chunk_length, prediction_length | ||
|
||
# update past covariates to include next `roll_size` future past covariates elements | ||
if n_past_covs and self.input_chunk_length >= roll_size: | ||
input_past[:, -roll_size:, n_targets:n_targets + n_past_covs] = ( | ||
future_past_covariates[:, left_past:right_past, :] | ||
) | ||
elif n_past_covs: | ||
input_past[:, :, n_targets:n_targets + n_past_covs] = ( | ||
future_past_covariates[:, left_past:right_past, :] | ||
) | ||
|
||
# take only last part of the output sequence where needed | ||
out = self._produce_predict_output(input_past)[:, self.first_prediction_index:, :] | ||
batch_prediction.append(out) | ||
prediction_length += self.output_chunk_length | ||
|
||
# bring predictions into desired format and drop unnecessary values | ||
batch_prediction = torch.cat(batch_prediction, dim=1) | ||
batch_prediction = batch_prediction[:, :n, :] | ||
return batch_prediction | ||
|
||
def _produce_predict_output(self, x): | ||
return self.model(x) | ||
|
||
|
||
class PLFutureCovariatesTorchModel(PLTorchForecastingModel, ABC): | ||
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor: | ||
raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") | ||
|
||
|
||
class PLDualCovariatesTorchModel(PLTorchForecastingModel, ABC): | ||
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor: | ||
raise NotImplementedError("TBD: The only DualCovariatesModel is an RNN with a specific implementation.") | ||
|
||
|
||
class PLMixedCovariatesTorchModel(PLTorchForecastingModel, ABC): | ||
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor: | ||
raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") | ||
|
||
|
||
class PLSplitCovariatesTorchModel(PLTorchForecastingModel, ABC): | ||
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor: | ||
raise NotImplementedError("TBD: Darts doesn't contain such a model yet.") | ||
|
||
|
||
# TODO: I think we could actually integrate probabilistic support already in the parent class and remove it from here? | ||
dennisbader marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class PLTorchParametricProbabilisticForecastingModel(PLTorchForecastingModel, ABC): | ||
def __init__(self, likelihood: Optional[Likelihood] = None, **kwargs): | ||
""" Pytorch Parametric Probabilistic Forecasting Model. | ||
|
||
This is a base class for pytroch parametric probabilistic models. "Parametric" | ||
means that these models are based on some predefined parametric distribution, say Gaussian. | ||
Make sure that subclasses contain the *likelihood* parameter in __init__ method | ||
and it is passed to the superclass via calling super().__init__. If the likelihood is not | ||
provided, the model is considered as deterministic. | ||
|
||
All TorchParametricProbabilisticForecastingModel's must produce outputs of shape | ||
(batch_size, n_timesteps, n_components, n_params). I.e., there's an extra dimension | ||
to store the distribution's parameters. | ||
|
||
Parameters | ||
---------- | ||
likelihood | ||
The likelihood model to be used for probabilistic forecasts. | ||
""" | ||
super().__init__(**kwargs) | ||
self.likelihood = likelihood | ||
|
||
def _is_probabilistic(self): | ||
return self.likelihood is not None | ||
|
||
def _compute_loss(self, output, target): | ||
# output is of shape (batch_size, n_timesteps, n_components, n_params) | ||
if self.likelihood: | ||
return self.likelihood.compute_loss(output, target) | ||
else: | ||
# If there's no likelihood, nr_params=1 and we need to squeeze out the | ||
# last dimension of model output, for properly computing the loss. | ||
return super()._compute_loss(output.squeeze(dim=-1), target) | ||
|
||
@abstractmethod | ||
def _produce_predict_output(self, x): | ||
""" | ||
This method has to be implemented by all children. | ||
""" | ||
pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1