From ab2df1ab44701668dbd5673a9b3875cb0711e49d Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Wed, 4 Jan 2023 16:04:55 +0100 Subject: [PATCH] fix deep copying of TFM trainer parameters (#1459) * fix deep copying of TFM trainer parameters * fix failing tests --- darts/models/forecasting/torch_forecasting_model.py | 9 +++++++-- darts/tests/models/forecasting/test_ptl_trainer.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index a7a26dbc8f..d07ffec20e 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -475,11 +475,16 @@ def _init_trainer( trainer_params: dict, max_epochs: Optional[int] = None ) -> pl.Trainer: """Initializes a PyTorch-Lightning trainer for training or prediction from `trainer_params`.""" - trainer_params_copy = {param: val for param, val in trainer_params.items()} + trainer_params_copy = {key: val for key, val in trainer_params.items()} if max_epochs is not None: trainer_params_copy["max_epochs"] = max_epochs - return pl.Trainer(**trainer_params_copy) + # prevent lightning from adding callbacks to the callbacks list in `self.trainer_params` + callbacks = trainer_params_copy.pop("callbacks", None) + return pl.Trainer( + callbacks=[cb for cb in callbacks] if callbacks is not None else callbacks, + **trainer_params_copy, + ) @abstractmethod def _create_model(self, train_sample: Tuple[Tensor]) -> torch.nn.Module: diff --git a/darts/tests/models/forecasting/test_ptl_trainer.py b/darts/tests/models/forecasting/test_ptl_trainer.py index 8f357ee3c2..b714ae3207 100644 --- a/darts/tests/models/forecasting/test_ptl_trainer.py +++ b/darts/tests/models/forecasting/test_ptl_trainer.py @@ -151,7 +151,9 @@ def on_train_epoch_end(self, *args, **kwargs): # check if callbacks were added self.assertEqual(len(model.trainer_params["callbacks"]), 2) - model.fit(self.series, epochs=2) + model.fit(self.series, epochs=2, verbose=True) + # check that lightning did not mutate callbacks (verbosity adds a progress bar callback) + self.assertEqual(len(model.trainer_params["callbacks"]), 2) self.assertEqual(my_counter_0.counter, model.epochs_trained) self.assertEqual(my_counter_2.counter, model.epochs_trained + 2)