Skip to content

Commit

Permalink
Add make_samples tests
Browse files Browse the repository at this point in the history
  • Loading branch information
malodetz committed Jul 9, 2023
1 parent 7bf70f0 commit 1c21459
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/test_models/nn/test_patchts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import MagicMock

import numpy as np
import pytest

from etna.metrics import MAE
from etna.models.nn import PatchTSModel
from etna.models.nn.patchts import PatchTSNet
from etna.transforms import StandardScalerTransform
from tests.test_models.utils import assert_sampling_is_valid

Expand Down Expand Up @@ -61,6 +65,26 @@ def test_patchts_model_run_weekly_overfit_with_scaler_medium_patch(ts_dataset_we
assert mae(ts_test, future) < 1.3


def test_patchts_make_samples(example_df):
rnn_module = MagicMock()
encoder_length = 8
decoder_length = 4

ts_samples = list(
PatchTSNet.make_samples(rnn_module, df=example_df, encoder_length=encoder_length, decoder_length=decoder_length)
)
first_sample = ts_samples[0]
second_sample = ts_samples[1]

assert first_sample["segment"] == "segment_1"
assert first_sample["encoder_real"].shape == (encoder_length, 1)
assert first_sample["decoder_real"].shape == (decoder_length, 1)
assert first_sample["encoder_target"].shape == (encoder_length, 1)
assert first_sample["decoder_target"].shape == (decoder_length, 1)
np.testing.assert_equal(example_df[["target"]].iloc[:encoder_length], first_sample["encoder_real"])
np.testing.assert_equal(example_df[["target"]].iloc[1:encoder_length + 1], second_sample["encoder_real"])


def test_params_to_tune(example_tsds):
ts = example_tsds
model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1))
Expand Down

0 comments on commit 1c21459

Please sign in to comment.