Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jun 16, 2024
1 parent 13ec2da commit 9905d36
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
10 changes: 8 additions & 2 deletions src/gluonts/torch/model/samformer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
class SamFormerEstimator(PyTorchLightningEstimator):
"""
An estimator training the SamFormer model for multivariate forecasting
as described in TODO extended to be
as described in https://arxiv.org/abs/2402.10198 extended to be
probabilistic.
This class uses the model defined in ``SamFormerModel``,
Expand All @@ -66,10 +66,16 @@ class SamFormerEstimator(PyTorchLightningEstimator):
takes as inputs (default: ``10 * prediction_length``).
hidden_dim
Size of query and key projection (default: ``32``).
projection_dim
Size of the projection dimension (default: ``8``).
sam
Whether to use SAM optimizer (default: ``True``).
rho
Rho parameter for SAM optimizer (default: ``0.5``).
lr
Learning rate (default: ``1e-3``).
weight_decay
Weight decay regularization parameter (default: ``1e-8``).
Weight decay regularization parameter (default: ``1e-5``).
scaling
Scaling parameter can be "mean", "std" or None.
distr_output
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/torch/model/samformer/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class SamFormerLightningModule(pl.LightningModule):
Learning rate.
weight_decay
Weight decay regularization parameter.
rho
Rho parameter for SAM optimizer.
sam
Whether to use SAM optimizer.
"""

@validated()
Expand Down
7 changes: 5 additions & 2 deletions src/gluonts/torch/model/samformer/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from gluonts.core.component import validated
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import StudentTOutput
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.scaler import MeanScaler, NOPScaler, StdScaler
from gluonts.torch.util import weighted_average


class SamFormerModel(nn.Module):
"""
Module implementing the SamFormer model for multivariate forecasting as
described in TODO extended to be probabilistic.
described in https://arxiv.org/abs/2402.10198 extended to be probabilistic.
Parameters
----------
Expand All @@ -37,6 +37,8 @@ class SamFormerModel(nn.Module):
Number of time steps prior to prediction time that the model.
hidden_dim
Dim of query and key projection.
projection_dim
Dim of projection layer.
scaling
Whether to scale the input using mean or std or None.
distr_output
Expand Down Expand Up @@ -78,6 +80,7 @@ def __init__(
self.scaler = NOPScaler(keepdim=True, dim=1)
self.nonnegative_pred_samples = nonnegative_pred_samples

# input is each variate together with the loc and scale
self.compute_keys = nn.Linear(context_length + 2, hidden_dim)
self.compute_queries = nn.Linear(context_length + 2, hidden_dim)
self.compute_values = nn.Linear(context_length + 2, context_length)
Expand Down
6 changes: 3 additions & 3 deletions src/gluonts/torch/model/samformer/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def step(self, closure=None):
self.second_step()

def _grad_norm(self):
shared_device = (
self.param_groups[0]["params"][0].device
) # put everything on the same device, in case of model parallelism
shared_device = self.param_groups[0]["params"][
0
].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack(
[
Expand Down

0 comments on commit 9905d36

Please sign in to comment.