diff --git a/src/gluonts/torch/model/samformer/estimator.py b/src/gluonts/torch/model/samformer/estimator.py index bf8bd614eb..e154005e9c 100644 --- a/src/gluonts/torch/model/samformer/estimator.py +++ b/src/gluonts/torch/model/samformer/estimator.py @@ -49,7 +49,7 @@ class SamFormerEstimator(PyTorchLightningEstimator): """ An estimator training the SamFormer model for multivariate forecasting - as described in TODO extended to be + as described in https://arxiv.org/abs/2402.10198 extended to be probabilistic. This class uses the model defined in ``SamFormerModel``, @@ -66,10 +66,16 @@ class SamFormerEstimator(PyTorchLightningEstimator): takes as inputs (default: ``10 * prediction_length``). hidden_dim Size of query and key projection (default: ``32``). + projection_dim + Size of the projection dimension (default: ``8``). + sam + Whether to use SAM optimizer (default: ``True``). + rho + Rho parameter for SAM optimizer (default: ``0.5``). lr Learning rate (default: ``1e-3``). weight_decay - Weight decay regularization parameter (default: ``1e-8``). + Weight decay regularization parameter (default: ``1e-5``). scaling Scaling parameter can be "mean", "std" or None. distr_output diff --git a/src/gluonts/torch/model/samformer/lightning_module.py b/src/gluonts/torch/model/samformer/lightning_module.py index 76b26eeb34..17c72f4012 100644 --- a/src/gluonts/torch/model/samformer/lightning_module.py +++ b/src/gluonts/torch/model/samformer/lightning_module.py @@ -39,6 +39,10 @@ class SamFormerLightningModule(pl.LightningModule): Learning rate. weight_decay Weight decay regularization parameter. + rho + Rho parameter for SAM optimizer. + sam + Whether to use SAM optimizer. """ @validated() diff --git a/src/gluonts/torch/model/samformer/module.py b/src/gluonts/torch/model/samformer/module.py index e2ab4a806b..0812f68e5c 100644 --- a/src/gluonts/torch/model/samformer/module.py +++ b/src/gluonts/torch/model/samformer/module.py @@ -20,14 +20,14 @@ from gluonts.core.component import validated from gluonts.model import Input, InputSpec from gluonts.torch.distributions import StudentTOutput -from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler +from gluonts.torch.scaler import MeanScaler, NOPScaler, StdScaler from gluonts.torch.util import weighted_average class SamFormerModel(nn.Module): """ Module implementing the SamFormer model for multivariate forecasting as - described in TODO extended to be probabilistic. + described in https://arxiv.org/abs/2402.10198 extended to be probabilistic. Parameters ---------- @@ -37,6 +37,8 @@ class SamFormerModel(nn.Module): Number of time steps prior to prediction time that the model. hidden_dim Dim of query and key projection. + projection_dim + Dim of projection layer. scaling Whether to scale the input using mean or std or None. distr_output @@ -78,6 +80,7 @@ def __init__( self.scaler = NOPScaler(keepdim=True, dim=1) self.nonnegative_pred_samples = nonnegative_pred_samples + # input is each variate together with the loc and scale self.compute_keys = nn.Linear(context_length + 2, hidden_dim) self.compute_queries = nn.Linear(context_length + 2, hidden_dim) self.compute_values = nn.Linear(context_length + 2, context_length) diff --git a/src/gluonts/torch/model/samformer/sam.py b/src/gluonts/torch/model/samformer/sam.py index 081c0e1de2..43c7dbda55 100644 --- a/src/gluonts/torch/model/samformer/sam.py +++ b/src/gluonts/torch/model/samformer/sam.py @@ -66,9 +66,9 @@ def step(self, closure=None): self.second_step() def _grad_norm(self): - shared_device = ( - self.param_groups[0]["params"][0].device - ) # put everything on the same device, in case of model parallelism + shared_device = self.param_groups[0]["params"][ + 0 + ].device # put everything on the same device, in case of model parallelism norm = torch.norm( torch.stack( [