Skip to content

Commit

Permalink
fix deep copying of TFM trainer parameters (#1459)
Browse files Browse the repository at this point in the history
* fix deep copying of TFM trainer parameters

* fix failing tests
  • Loading branch information
dennisbader authored Jan 4, 2023
1 parent fab7ddf commit ab2df1a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 7 additions & 2 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,16 @@ def _init_trainer(
trainer_params: dict, max_epochs: Optional[int] = None
) -> pl.Trainer:
"""Initializes a PyTorch-Lightning trainer for training or prediction from `trainer_params`."""
trainer_params_copy = {param: val for param, val in trainer_params.items()}
trainer_params_copy = {key: val for key, val in trainer_params.items()}
if max_epochs is not None:
trainer_params_copy["max_epochs"] = max_epochs

return pl.Trainer(**trainer_params_copy)
# prevent lightning from adding callbacks to the callbacks list in `self.trainer_params`
callbacks = trainer_params_copy.pop("callbacks", None)
return pl.Trainer(
callbacks=[cb for cb in callbacks] if callbacks is not None else callbacks,
**trainer_params_copy,
)

@abstractmethod
def _create_model(self, train_sample: Tuple[Tensor]) -> torch.nn.Module:
Expand Down
4 changes: 3 additions & 1 deletion darts/tests/models/forecasting/test_ptl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def on_train_epoch_end(self, *args, **kwargs):

# check if callbacks were added
self.assertEqual(len(model.trainer_params["callbacks"]), 2)
model.fit(self.series, epochs=2)
model.fit(self.series, epochs=2, verbose=True)
# check that lightning did not mutate callbacks (verbosity adds a progress bar callback)
self.assertEqual(len(model.trainer_params["callbacks"]), 2)

self.assertEqual(my_counter_0.counter, model.epochs_trained)
self.assertEqual(my_counter_2.counter, model.epochs_trained + 2)
Expand Down

0 comments on commit ab2df1a

Please sign in to comment.