From 1c21459bb92098464e2e9b8accd9893fc920f7b4 Mon Sep 17 00:00:00 2001 From: malodetz Date: Sun, 9 Jul 2023 21:16:26 +0300 Subject: [PATCH] Add make_samples tests --- tests/test_models/nn/test_patchts.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_models/nn/test_patchts.py b/tests/test_models/nn/test_patchts.py index a34671f41..1c7086c43 100644 --- a/tests/test_models/nn/test_patchts.py +++ b/tests/test_models/nn/test_patchts.py @@ -1,7 +1,11 @@ +from unittest.mock import MagicMock + +import numpy as np import pytest from etna.metrics import MAE from etna.models.nn import PatchTSModel +from etna.models.nn.patchts import PatchTSNet from etna.transforms import StandardScalerTransform from tests.test_models.utils import assert_sampling_is_valid @@ -61,6 +65,26 @@ def test_patchts_model_run_weekly_overfit_with_scaler_medium_patch(ts_dataset_we assert mae(ts_test, future) < 1.3 +def test_patchts_make_samples(example_df): + rnn_module = MagicMock() + encoder_length = 8 + decoder_length = 4 + + ts_samples = list( + PatchTSNet.make_samples(rnn_module, df=example_df, encoder_length=encoder_length, decoder_length=decoder_length) + ) + first_sample = ts_samples[0] + second_sample = ts_samples[1] + + assert first_sample["segment"] == "segment_1" + assert first_sample["encoder_real"].shape == (encoder_length, 1) + assert first_sample["decoder_real"].shape == (decoder_length, 1) + assert first_sample["encoder_target"].shape == (encoder_length, 1) + assert first_sample["decoder_target"].shape == (decoder_length, 1) + np.testing.assert_equal(example_df[["target"]].iloc[:encoder_length], first_sample["encoder_real"]) + np.testing.assert_equal(example_df[["target"]].iloc[1:encoder_length + 1], second_sample["encoder_real"]) + + def test_params_to_tune(example_tsds): ts = example_tsds model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1))