Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pytorch lightning #702

Merged
merged 59 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f7afe81
first draft proposal for pytorch lightning integration
dennisbader Dec 17, 2021
c026bf7
further cleanup
dennisbader Dec 20, 2021
dbef703
removed unnecessary file
dennisbader Dec 21, 2021
706fd35
fix for multiple-TS
dennisbader Dec 22, 2021
3ee8873
moved prediction timeseries generation back to TorchForecastingModel
dennisbader Dec 22, 2021
93b4bf2
support for custom trainer in fit()
dennisbader Dec 22, 2021
8c1b231
removed unused methods from TorchForecastingModel
dennisbader Dec 24, 2021
333cd5f
checkpoint loading now correctly resumes training
dennisbader Jan 11, 2022
ed42108
Merge branch 'master' into feat/pytorch_lightning
dennisbader Jan 12, 2022
facba8c
Merge branch 'master' into feat/pytorch_lightning
dennisbader Jan 15, 2022
df6b8d5
rewrote TorchForecastingModel
dennisbader Jan 15, 2022
9939422
rewrote TFTModel
dennisbader Jan 15, 2022
7d3f24d
rewrote rnn models
dennisbader Jan 15, 2022
e3bb9c2
rewrote nbeats models
dennisbader Jan 15, 2022
4cb7ec5
rewrote tcn model
dennisbader Jan 15, 2022
0ec8245
rewrote transformer model
dennisbader Jan 15, 2022
8408273
removed unused import
dennisbader Jan 15, 2022
2b36b84
resolve failing tests part 1
dennisbader Jan 17, 2022
d4db950
resolve failing tests part 2
dennisbader Jan 18, 2022
d3350d3
adapted the way how model parameters are saved
dennisbader Jan 23, 2022
1c16b72
moved TFTModel predict method into TorchForecastingModel subclass
dennisbader Jan 23, 2022
51b38b1
further simplification of model calls
dennisbader Jan 23, 2022
2905024
integrated ProbabilisticTorchForecastingModel into PLForecastingModule
dennisbader Jan 23, 2022
21ee164
integrated _produce_predict_output into PLForecastingModule
dennisbader Jan 23, 2022
8c3e94c
reintegrated original random state handling
dennisbader Jan 23, 2022
75f7194
removed unused pl random state wrapper function
dennisbader Jan 23, 2022
7102c6d
use OrderedDict for savety in model parameter extraction
dennisbader Jan 23, 2022
f48b61c
made TFM and PLFM paramater extraction generic
dennisbader Jan 29, 2022
7e7eccf
added types for variables in TFM init
dennisbader Jan 29, 2022
93ec255
made predictions deterministic for same fit predict process for non-l…
dennisbader Jan 29, 2022
83e67d6
Merge branch 'master' into feat/pytorch_lightning
dennisbader Jan 29, 2022
0774c6e
fix flake8 issues
dennisbader Jan 29, 2022
e196c97
fix flake8 issues part 2
dennisbader Jan 29, 2022
4c983fb
added pytorch-lightning to torch requirements
dennisbader Feb 2, 2022
57d11cf
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 2, 2022
6277292
fixed loading models with wrong precision
dennisbader Feb 5, 2022
84cbd61
fixed is_probabilistic()
dennisbader Feb 5, 2022
4e0c8b5
fixed failing tests for epoch count tracker
dennisbader Feb 6, 2022
1ddd5ec
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 6, 2022
4c3e268
removed input/output_chunk_length from TorchForecastingModel __init__
dennisbader Feb 9, 2022
2eb240f
unit tests save models to temp dir
dennisbader Feb 9, 2022
2dd7841
added documentation for ModelMeta
dennisbader Feb 9, 2022
984ea88
apply suggestions from PR review part 1
dennisbader Feb 9, 2022
771fbdf
deprecated `torch_device_str`
dennisbader Feb 9, 2022
634524f
updated optimizer docs
dennisbader Feb 9, 2022
d6274a8
updated retrain warning
dennisbader Feb 9, 2022
fbf05d2
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 13, 2022
0ba7f65
made PLMixedCovariatesModule more generic
dennisbader Feb 13, 2022
970db09
added docs
dennisbader Feb 13, 2022
11e7681
added PTL trainer unit tests
dennisbader Feb 13, 2022
7a084ac
update model docs
dennisbader Feb 13, 2022
c8b4bff
fixed broken url in TFM and covariates userguide
dennisbader Feb 13, 2022
f0e4e30
removed input/output chunk length from PL modules
dennisbader Feb 13, 2022
3b7f846
relaxed pytorch-lightning requirement
dennisbader Feb 13, 2022
7290775
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 13, 2022
43803b9
isort
dennisbader Feb 13, 2022
fa5a6db
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 15, 2022
73e3ff4
isort part 2
dennisbader Feb 15, 2022
0d6c105
Merge branch 'master' into feat/pytorch_lightning
dennisbader Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 383 additions & 0 deletions darts/models/forecasting/ptl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,383 @@
"""
This file contains abstract classes for deterministic and probabilistic pytorch-lightning modules
"""

import numpy as np
import pandas as pd

from joblib import Parallel, delayed
from typing import Any, Optional, Dict, Tuple, Union, Sequence
from abc import ABC, abstractmethod
import torch
from torch import Tensor
import torch.nn as nn

from darts.timeseries import TimeSeries
from darts.utils.timeseries_generation import _generate_index
from darts.models.forecasting.ptl_torch_forecasting_model import TorchForecastingModel

from darts.utils.likelihood_models import Likelihood
from darts.logging import get_logger, raise_log, raise_if

import pytorch_lightning as pl


logger = get_logger(__name__)


# TODO: better names
class PLTorchForecastingModel(pl.LightningModule, ABC):

@abstractmethod
def __init__(self,
loss_fn: nn.modules.loss._Loss = nn.MSELoss(),
optimizer_cls: torch.optim.Optimizer = torch.optim.Adam,
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_cls: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler_kwargs: Optional[Dict] = None) -> None:

super(PLTorchForecastingModel, self).__init__()

# Define the loss function
self.criterion = loss_fn

# Persist optimiser and LR scheduler parameters
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs
self.lr_scheduler_cls = lr_scheduler_cls
self.lr_scheduler_kwargs = dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs

# by default models are deterministic (i.e. not probabilistic)
self.likelihood = None

# TODO: make better
# initialize prediction settings
self.pred_n: Optional[int] = None
self.pred_num_samples: Optional[int] = None
self.pred_n_jobs: Optional[int] = None
self.pred_roll_size: Optional[int] = None
self.pred_batch_size: Optional[int] = None

@property
def first_prediction_index(self) -> int:
"""
Returns the index of the first predicted within the output of self.model.
"""
return 0

@abstractmethod
def forward(self, *args, **kwargs) -> Any:
super(PLTorchForecastingModel, self).forward(*args, **kwargs)

def training_step(self, train_batch, batch_idx) -> Any:
output = self._produce_train_output(train_batch[:-1])
target = train_batch[-1] # By convention target is always the last element returned by datasets
loss = self._compute_loss(output, target)
self.log('train_loss', loss, batch_size=train_batch[0].shape[0])
return loss

def validation_step(self, val_batch, batch_idx) -> Any:
output = self._produce_train_output(val_batch[:-1])
target = val_batch[-1]
loss = self._compute_loss(output, target)
self.log('val_loss', loss, batch_size=val_batch[0].shape[0])
return loss

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
input_data_tuple, batch_input_series = batch[:-1], batch[-1]

# number of individual series to be predicted in current batch
num_series = input_data_tuple[0].shape[0]

# number of of times the input tensor should be tiled to produce predictions for multiple samples
# this variable is larger than 1 only if the batch_size is at least twice as large as the number
# of individual time series being predicted in current batch (`num_series`)
batch_sample_size = min(max(self.pred_batch_size // num_series, 1), self.pred_num_samples)

# counts number of produced prediction samples for every series to be predicted in current batch
sample_count = 0

# repeat prediction procedure for every needed sample
batch_predictions = []
while sample_count < self.pred_num_samples:

# make sure we don't produce too many samples
if sample_count + batch_sample_size > self.pred_num_samples:
batch_sample_size = self.pred_num_samples - sample_count

# stack multiple copies of the tensors to produce probabilistic forecasts
input_data_tuple_samples = self._sample_tiling(input_data_tuple, batch_sample_size)

# get predictions for 1 whole batch (can include predictions of multiple series
# and for multiple samples if a probabilistic forecast is produced)
batch_prediction = self._get_batch_prediction(self.pred_n, input_data_tuple_samples, self.pred_roll_size)

# reshape from 3d tensor (num_series x batch_sample_size, ...)
# into 4d tensor (batch_sample_size, num_series, ...), where dim 0 represents the samples
out_shape = batch_prediction.shape
batch_prediction = batch_prediction.reshape((batch_sample_size, num_series,) + out_shape[1:])

# save all predictions and update the `sample_count` variable
batch_predictions.append(batch_prediction)
sample_count += batch_sample_size

# concatenate the batch of samples, to form self.pred_num_samples samples
batch_predictions = torch.cat(batch_predictions, dim=0)
batch_predictions = batch_predictions.cpu().detach().numpy()

# create `TimeSeries` objects from prediction tensors
ts_forecasts = Parallel(n_jobs=self.pred_n_jobs)(
delayed(self._build_forecast_series)(
[batch_prediction[batch_idx] for batch_prediction in batch_predictions], input_series
)
for batch_idx, input_series in enumerate(batch_input_series)
)
return ts_forecasts

def on_predict_end(self) -> None:
self.pred_n = None
self.pred_num_samples = None
self.pred_n_jobs = None
self.pred_roll_size = None
self.pred_batch_size = None

def _compute_loss(self, output, target):
return self.criterion(output, target)

def configure_optimizers(self):
"""sets up optimizers"""

# TODO: i think we can move this to to pl.Trainer(). and could probably be simplified
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


# A utility function to create optimizer and lr scheduler from desired classes
def _create_from_cls_and_kwargs(cls, kws):
try:
return cls(**kws)
except (TypeError, ValueError) as e:
raise_log(ValueError('Error when building the optimizer or learning rate scheduler;'
'please check the provided class and arguments'
'\nclass: {}'
'\narguments (kwargs): {}'
'\nerror:\n{}'.format(cls, kws, e)),
logger)

# Create the optimizer and (optionally) the learning rate scheduler
# we have to create copies because we cannot save model.parameters into object state (not serializable)
optimizer_kws = {k: v for k, v in self.optimizer_kwargs.items()}
optimizer_kws['params'] = self.parameters()

optimizer = _create_from_cls_and_kwargs(self.optimizer_cls, optimizer_kws)

if self.lr_scheduler_cls is not None:
lr_sched_kws = {k: v for k, v in self.lr_scheduler_kwargs.items()}
lr_sched_kws['optimizer'] = optimizer
lr_scheduler = _create_from_cls_and_kwargs(self.lr_scheduler_cls, lr_sched_kws)
return [optimizer], [lr_scheduler]
else:
return optimizer

@abstractmethod
def _produce_train_output(self, input_batch: Tuple) -> Tensor:
pass

@abstractmethod
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor:
"""
In charge of apply the recurrent logic for non-recurrent models.
Should be overwritten by recurrent models.
"""
pass

# TODO: had to copy these methods over from ForecastingModel as it is not a parent class.
# Maybe let TorchForecastingModel handle the _build_forecast_series()
def _build_forecast_series(self,
points_preds: Union[np.ndarray, Sequence[np.ndarray]],
input_series: TimeSeries) -> TimeSeries:
"""
Builds a forecast time series starting after the end of the training time series, with the
correct time index (or after the end of the input series, if specified).
"""

time_index_length = len(points_preds) if isinstance(points_preds, np.ndarray) else len(points_preds[0])
time_index = self._generate_new_dates(time_index_length, input_series=input_series)
if isinstance(points_preds, np.ndarray):
return TimeSeries.from_times_and_values(time_index,
points_preds,
freq=input_series.freq_str,
columns=input_series.columns)

return TimeSeries.from_times_and_values(time_index,
np.stack(points_preds, axis=2),
freq=input_series.freq_str,
columns=input_series.columns)

@staticmethod
def _generate_new_dates(n: int,
input_series: TimeSeries) -> Union[pd.DatetimeIndex, pd.RangeIndex]:
"""
Generates `n` new dates after the end of the specified series
"""
last = input_series.end_time()
start = last + input_series.freq if input_series.has_datetime_index else last + 1
return _generate_index(start=start, freq=input_series.freq, length=n)

@staticmethod
def _sample_tiling(input_data_tuple, batch_sample_size):
tiled_input_data = []
for tensor in input_data_tuple:
if tensor is not None:
tiled_input_data.append(tensor.tile((batch_sample_size, 1, 1)))
else:
tiled_input_data.append(None)
return tuple(tiled_input_data)


class PLPastCovariatesTorchModel(PLTorchForecastingModel, ABC):
def _produce_train_output(self, input_batch: Tuple):
past_target, past_covariate = input_batch
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = torch.cat([past_target, past_covariate], dim=2) if past_covariate is not None else past_target
return self.model(inpt)

def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> torch.Tensor:
"""
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset to farecast
the next `n` target values per target variable.

Parameters:
----------
n
prediction length
input_batch
(past_target, past_covariates, future_past_covariates)
roll_size
roll input arrays after every sequence by `roll_size`. Initially, `roll_size` is equivalent to
`self.output_chunk_length`
"""
dim_component = 2
past_target, past_covariates, future_past_covariates = input_batch

n_targets = past_target.shape[dim_component]
n_past_covs = past_covariates.shape[dim_component] if not past_covariates is None else 0

input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=dim_component
)

out = self._produce_predict_output(input_past)[:, self.first_prediction_index:, :]

batch_prediction = [out[:, :roll_size, :]]
prediction_length = roll_size

while prediction_length < n:
# we want the last prediction to end exactly at `n` into the future.
# this means we may have to truncate the previous prediction and step
# back the roll size for the last chunk
if prediction_length + self.output_chunk_length > n:
spillover_prediction_length = prediction_length + self.output_chunk_length - n
roll_size -= spillover_prediction_length
prediction_length -= spillover_prediction_length
batch_prediction[-1] = batch_prediction[-1][:, :roll_size, :]

# ==========> PAST INPUT <==========
# roll over input series to contain latest target and covariate
input_past = torch.roll(input_past, -roll_size, 1)

# update target input to include next `roll_size` predictions
if self.input_chunk_length >= roll_size:
input_past[:, -roll_size:, :n_targets] = out[:, :roll_size, :]
else:
input_past[:, :, :n_targets] = out[:, -self.input_chunk_length:, :]

# set left and right boundaries for extracting future elements
if self.input_chunk_length >= roll_size:
left_past, right_past = prediction_length - roll_size, prediction_length
else:
left_past, right_past = prediction_length - self.input_chunk_length, prediction_length

# update past covariates to include next `roll_size` future past covariates elements
if n_past_covs and self.input_chunk_length >= roll_size:
input_past[:, -roll_size:, n_targets:n_targets + n_past_covs] = (
future_past_covariates[:, left_past:right_past, :]
)
elif n_past_covs:
input_past[:, :, n_targets:n_targets + n_past_covs] = (
future_past_covariates[:, left_past:right_past, :]
)

# take only last part of the output sequence where needed
out = self._produce_predict_output(input_past)[:, self.first_prediction_index:, :]
batch_prediction.append(out)
prediction_length += self.output_chunk_length

# bring predictions into desired format and drop unnecessary values
batch_prediction = torch.cat(batch_prediction, dim=1)
batch_prediction = batch_prediction[:, :n, :]
return batch_prediction

def _produce_predict_output(self, x):
return self.model(x)


class PLFutureCovariatesTorchModel(PLTorchForecastingModel, ABC):
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor:
raise NotImplementedError("TBD: Darts doesn't contain such a model yet.")


class PLDualCovariatesTorchModel(PLTorchForecastingModel, ABC):
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor:
raise NotImplementedError("TBD: The only DualCovariatesModel is an RNN with a specific implementation.")


class PLMixedCovariatesTorchModel(PLTorchForecastingModel, ABC):
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor:
raise NotImplementedError("TBD: Darts doesn't contain such a model yet.")


class PLSplitCovariatesTorchModel(PLTorchForecastingModel, ABC):
def _get_batch_prediction(self, n: int, input_batch: Tuple, roll_size: int) -> Tensor:
raise NotImplementedError("TBD: Darts doesn't contain such a model yet.")


# TODO: I think we could actually integrate probabilistic support already in the parent class and remove it from here?
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
class PLTorchParametricProbabilisticForecastingModel(PLTorchForecastingModel, ABC):
def __init__(self, likelihood: Optional[Likelihood] = None, **kwargs):
""" Pytorch Parametric Probabilistic Forecasting Model.

This is a base class for pytroch parametric probabilistic models. "Parametric"
means that these models are based on some predefined parametric distribution, say Gaussian.
Make sure that subclasses contain the *likelihood* parameter in __init__ method
and it is passed to the superclass via calling super().__init__. If the likelihood is not
provided, the model is considered as deterministic.

All TorchParametricProbabilisticForecastingModel's must produce outputs of shape
(batch_size, n_timesteps, n_components, n_params). I.e., there's an extra dimension
to store the distribution's parameters.

Parameters
----------
likelihood
The likelihood model to be used for probabilistic forecasts.
"""
super().__init__(**kwargs)
self.likelihood = likelihood

def _is_probabilistic(self):
return self.likelihood is not None

def _compute_loss(self, output, target):
# output is of shape (batch_size, n_timesteps, n_components, n_params)
if self.likelihood:
return self.likelihood.compute_loss(output, target)
else:
# If there's no likelihood, nr_params=1 and we need to squeeze out the
# last dimension of model output, for properly computing the loss.
return super()._compute_loss(output.squeeze(dim=-1), target)

@abstractmethod
def _produce_predict_output(self, x):
"""
This method has to be implemented by all children.
"""
pass
Loading