From 172f3ed6699a119eb4cf6d62377746196bafcba3 Mon Sep 17 00:00:00 2001 From: pennfranc <42946363+pennfranc@users.noreply.github.com> Date: Mon, 22 Jun 2020 11:58:22 +0200 Subject: [PATCH] Fix/predict single value (#108) * fix(TorchForecastingModel): solved bug at TorchForecastingModel.predict(n) with n = 1 * feature(testing): added tests for length 1 predictions for RNN and TCN, set torch random seed for TCN test Co-authored-by: pennfranc Co-authored-by: TheMP --- darts/models/torch_forecasting_model.py | 3 +-- darts/tests/test_RNN.py | 4 ++++ darts/tests/test_TCN.py | 5 +++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/darts/models/torch_forecasting_model.py b/darts/models/torch_forecasting_model.py index e4c2644496..5059a23cdd 100644 --- a/darts/models/torch_forecasting_model.py +++ b/darts/models/torch_forecasting_model.py @@ -320,8 +320,7 @@ def predict(self, n: int) -> TimeSeries: pred_in[:, -1, :] = out[:, self.first_prediction_index] test_out.append(out.cpu().detach().numpy()[0, self.first_prediction_index]) test_out = np.stack(test_out) - - return self._build_forecast_series(test_out.squeeze()) + return self._build_forecast_series(test_out.squeeze(1)) @property def first_prediction_index(self) -> int: diff --git a/darts/tests/test_RNN.py b/darts/tests/test_RNN.py index 45b82af435..79da20f9b3 100644 --- a/darts/tests/test_RNN.py +++ b/darts/tests/test_RNN.py @@ -45,4 +45,8 @@ def test_fit(self): pred3 = model3.predict(n=6) self.assertNotEqual(sum(pred1.values() - pred3.values()), 0.) + # test short predict + pred4 = model3.predict(n=1) + self.assertEqual(len(pred4), 1) + shutil.rmtree('.darts') diff --git a/darts/tests/test_TCN.py b/darts/tests/test_TCN.py index 05e4a209b8..d274b16db2 100644 --- a/darts/tests/test_TCN.py +++ b/darts/tests/test_TCN.py @@ -33,7 +33,12 @@ def test_fit(self): pred2 = model2.predict(n=2).values()[0] self.assertTrue(abs(pred2 - 10) < abs(pred - 10)) + # test short predict + pred3 = model2.predict(n=1) + self.assertEqual(len(pred3), 1) + def test_coverage(self): + torch.manual_seed(0) input_lengths = range(20, 50) kernel_sizes = range(2, 5) dilation_bases = range(2, 5)