Skip to content

Commit

Permalink
fix params to tune
Browse files Browse the repository at this point in the history
  • Loading branch information
malodetz committed Jul 9, 2023
1 parent 13bfa43 commit 617d20f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
20 changes: 13 additions & 7 deletions etna/models/nn/patchts.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
self.hidden_size = hidden_size
self.nhead = nhead
self.stride = stride
self.loss = torch.nn.MSELoss() if loss is None else loss
self.loss = loss

encoder_layers = nn.TransformerEncoderLayer(d_model=self.hidden_size, nhead=self.nhead,
dim_feedforward=self.feedforward_size)
Expand Down Expand Up @@ -132,13 +132,15 @@ def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore
:
forecast with shape (batch_size, decoder_length, 1)
"""
print(x)
encoder_real = x["encoder_real"].float() # (batch_size, encoder_length-1, input_size)
decoder_real = x["decoder_real"].float() # (batch_size, decoder_length, input_size)
decoder_length = decoder_real.shape[1]
outputs = []
x = encoder_real
for i in range(decoder_length):
pred = self._get_prediction(x)
print(pred)
outputs.append(pred)
x = torch.cat((x[:, 1:, :], torch.unsqueeze(pred, dim=1)), dim=1)

Expand Down Expand Up @@ -275,9 +277,9 @@ def __init__(
feedforward_size: int = 256,
nhead: int = 16,
lr: float = 1e-3,
loss: Optional["torch.nn.Module"] = None,
train_batch_size: int = 16,
test_batch_size: int = 16,
loss: "torch.nn.Module" = nn.MSELoss(),
train_batch_size: int = 64,
test_batch_size: int = 64,
optimizer_params: Optional[dict] = None,
trainer_params: Optional[dict] = None,
train_dataloader_params: Optional[dict] = None,
Expand Down Expand Up @@ -334,6 +336,10 @@ def __init__(
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lr = lr
self.patch_len = patch_len
self.stride = stride
self.nhead = nhead
self.feedforward_size = feedforward_size
self.loss = loss
self.optimizer_params = optimizer_params
super().__init__(
Expand All @@ -346,7 +352,7 @@ def __init__(
feedforward_size=feedforward_size,
nhead=nhead,
lr=lr,
loss=nn.MSELoss() if loss is None else loss,
loss=loss,
optimizer_params=optimizer_params,
),
decoder_length=decoder_length,
Expand All @@ -373,7 +379,7 @@ def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""
return {
"num_layers": IntDistribution(low=1, high=3),
"hidden_size": IntDistribution(low=4, high=64, step=4),
"hidden_size": IntDistribution(low=16, high=256, step=self.nhead),
"lr": FloatDistribution(low=1e-5, high=1e-2, log=True),
"encoder_length": IntDistribution(low=1, high=20),
"encoder_length": IntDistribution(low=self.patch_len, high=24)
}
14 changes: 11 additions & 3 deletions tests/test_models/nn/test_patchts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from etna.metrics import MAE
from etna.models.nn import PatchTSModel
from etna.transforms import StandardScalerTransform
from tests.test_models.utils import assert_sampling_is_valid


@pytest.mark.long_2
@pytest.mark.parametrize(
"horizon",
[
8,
13,
15
# 13,
# 15
],
)
def test_patchts_model_run_weekly_overfit_with_scaler_small_patch(ts_dataset_weekly_function_with_horizon, horizon):
Expand Down Expand Up @@ -56,4 +57,11 @@ def test_patchts_model_run_weekly_overfit_with_scaler_medium_patch(ts_dataset_we
future.inverse_transform([std])

mae = MAE("macro")
assert mae(ts_test, future) < 1.3
assert mae(ts_test, future) < 1.3


def test_params_to_tune(example_tsds):
ts = example_tsds
model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1))
assert len(model.params_to_tune()) > 0
assert_sampling_is_valid(model=model, ts=ts)

0 comments on commit 617d20f

Please sign in to comment.