Skip to content

Commit

Permalink
add proj dim and optimizer flag
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jun 16, 2024
1 parent 3f15795 commit f2616b4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 27 deletions.
11 changes: 8 additions & 3 deletions src/gluonts/torch/model/samformer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def __init__(
prediction_length: int,
context_length: Optional[int] = None,
hidden_dim: int = 32,
projection_dim: int = 8,
lr: float = 1e-3,
weight_decay: float = 1e-5,
rho: float = 0.5,
sam: bool = True,
scaling: Optional[str] = "mean",
distr_output: Output = StudentTOutput(),
num_parallel_samples: int = 100,
Expand All @@ -114,9 +116,8 @@ def __init__(
validation_sampler: Optional[InstanceSampler] = None,
nonnegative_pred_samples: bool = False,
) -> None:
default_trainer_kwargs = {
"max_epochs": 100,
}
default_trainer_kwargs = {"max_epochs": 100}

if trainer_kwargs is not None:
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)
Expand All @@ -132,6 +133,8 @@ def __init__(
self.num_parallel_samples = num_parallel_samples
self.scaling = scaling
self.hidden_dim = hidden_dim
self.projection_dim = projection_dim
self.sam = sam
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.nonnegative_pred_samples = nonnegative_pred_samples
Expand Down Expand Up @@ -167,10 +170,12 @@ def create_lightning_module(self) -> pl.LightningModule:
weight_decay=self.weight_decay,
rho=self.rho,
num_parallel_samples=self.num_parallel_samples,
sam=self.sam,
model_kwargs={
"prediction_length": self.prediction_length,
"context_length": self.context_length,
"hidden_dim": self.hidden_dim,
"projection_dim": self.projection_dim,
"distr_output": self.distr_output,
"scaling": self.scaling,
"nonnegative_pred_samples": self.nonnegative_pred_samples,
Expand Down
52 changes: 33 additions & 19 deletions src/gluonts/torch/model/samformer/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
lr: float = 1e-3,
weight_decay: float = 1e-5,
rho: float = 0.5,
sam: bool = True,
):
super().__init__()
self.save_hyperparameters()
Expand All @@ -57,6 +58,7 @@ def __init__(
self.lr = lr
self.weight_decay = weight_decay
self.rho = rho
self.sam = sam

self.automatic_optimization = False

Expand All @@ -83,18 +85,23 @@ def training_step(self, batch, batch_idx: int): # type: ignore
future_observed_values=batch["future_observed_values"],
).mean()

# Ascent Step
self.manual_backward(train_loss)
opt.first_step(zero_grad=True)

# Descent Step
train_loss = self.model.loss(
**select(self.inputs, batch),
future_target=batch["future_target"],
future_observed_values=batch["future_observed_values"],
).mean()
self.manual_backward(train_loss)
opt.second_step(zero_grad=True)
if self.sam:
# Ascent Step
self.manual_backward(train_loss)
opt.first_step(zero_grad=True)

# Descent Step
train_loss_2 = self.model.loss(
**select(self.inputs, batch),
future_target=batch["future_target"],
future_observed_values=batch["future_observed_values"],
).mean()
self.manual_backward(train_loss_2)
opt.second_step(zero_grad=True)
else:
opt.zero_grad()
self.manual_backward(train_loss)
opt.step()

self.log(
"train_loss",
Expand Down Expand Up @@ -124,10 +131,17 @@ def configure_optimizers(self):
"""
Returns the optimizer to use.
"""
return SAM(
self.model.parameters(),
base_optimizer=torch.optim.Adam,
lr=self.lr,
rho=self.rho,
weight_decay=self.weight_decay,
)
if self.sam:
return SAM(
self.model.parameters(),
base_optimizer=torch.optim.Adam,
lr=self.lr,
rho=self.rho,
weight_decay=self.weight_decay,
)
else:
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
11 changes: 6 additions & 5 deletions src/gluonts/torch/model/samformer/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
prediction_length: int,
context_length: int,
hidden_dim: int,
projection_dim: int,
scaling: Optional[str],
distr_output=StudentTOutput(),
nonnegative_pred_samples: bool = False,
Expand All @@ -65,7 +66,7 @@ def __init__(

self.prediction_length = prediction_length
self.context_length = context_length
self.hidden_dim = hidden_dim
self.projection_dim = projection_dim

self.distr_output = distr_output

Expand All @@ -83,11 +84,11 @@ def __init__(

# project each variate to prediction length number of latent variables
self.projection = nn.Linear(
context_length, prediction_length * hidden_dim
context_length, prediction_length * projection_dim
)

# project each prediction length latent to distribution parameters
self.args_proj = self.distr_output.get_args_proj(hidden_dim)
self.args_proj = self.distr_output.get_args_proj(projection_dim)

def describe_inputs(self, batch_size=1) -> InputSpec:
return InputSpec(
Expand Down Expand Up @@ -132,12 +133,12 @@ def forward(
att_score = F.scaled_dot_product_attention(queries, keys, values)
out = past_target_scaled + att_score

# project to prediction length * hidden_dim and reshape
# project to prediction length * projection_dim and reshape
projection_out = self.projection(out).reshape(
-1,
past_target.shape[2],
self.prediction_length,
self.hidden_dim,
self.projection_dim,
)

# transpose to prediction length first
Expand Down
12 changes: 12 additions & 0 deletions test/torch/model/test_multivariate_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from gluonts.model.predictor import Predictor
from gluonts.torch.model.forecast import DistributionForecast
from gluonts.torch.model.i_transformer import ITransformerEstimator
from gluonts.torch.model.samformer import SamFormerEstimator


@pytest.mark.parametrize(
Expand All @@ -35,6 +36,12 @@
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: SamFormerEstimator(
prediction_length=dataset.metadata.prediction_length,
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
],
)
@pytest.mark.parametrize("use_validation_data", [False, True])
Expand Down Expand Up @@ -83,6 +90,11 @@ def test_multivariate_estimator_constant_dataset(
batch_size=4,
trainer_kwargs=dict(max_epochs=2),
),
lambda freq, prediction_length: SamFormerEstimator(
prediction_length=prediction_length,
batch_size=4,
trainer_kwargs=dict(max_epochs=2),
),
],
)
def test_multivariate_estimator_with_features(estimator_constructor):
Expand Down

0 comments on commit f2616b4

Please sign in to comment.