diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d086a82e0ba5..3e1ada8db632c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed the deprecated `stochastic_weight_avg` argument from the `Trainer` constructor ([#12535](https://github.com/PyTorchLightning/pytorch-lightning/pull/12535)) + + - Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514)) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 08b180b174980..50f0839663c40 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -46,7 +46,6 @@ def on_trainer_init( weights_save_path: Optional[str], enable_model_summary: bool, weights_summary: Optional[str], - stochastic_weight_avg: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, ): @@ -59,13 +58,6 @@ def on_trainer_init( ) self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir - if stochastic_weight_avg: - rank_zero_deprecation( - "Setting `Trainer(stochastic_weight_avg=True)` is deprecated in v1.5 and will be removed in v1.7." - " Please pass `pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`" - " directly to the Trainer's `callbacks` argument instead." - ) - self.trainer._stochastic_weight_avg = stochastic_weight_avg # init callbacks if isinstance(callbacks, Callback): @@ -76,9 +68,6 @@ def on_trainer_init( # pass through the required args to figure out defaults self._configure_checkpoint_callbacks(checkpoint_callback, enable_checkpointing) - # configure swa callback - self._configure_swa_callbacks() - # configure the timer callback. # responsible to stop the training when max_time is reached. self._configure_timer_callback(max_time) @@ -201,16 +190,6 @@ def _configure_model_summary_callback( self.trainer.callbacks.append(model_summary) self.trainer._weights_summary = weights_summary - def _configure_swa_callbacks(self): - if not self.trainer._stochastic_weight_avg: - return - - from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging - - existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)] - if not existing_swa: - self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks - def _configure_progress_bar(self, process_position: int = 0, enable_progress_bar: bool = True) -> None: progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cfaa6ef59879c..309e481e1045b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -187,7 +187,6 @@ def __init__( amp_level: Optional[str] = None, move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", - stochastic_weight_avg: bool = False, terminate_on_nan: Optional[bool] = None, ) -> None: r""" @@ -452,15 +451,6 @@ def __init__( and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. Default: ``"max_size_cycle"``. - - stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) - `_. - Default: ``False``. - - .. deprecated:: v1.5 - ``stochastic_weight_avg`` has been deprecated in v1.5 and will be removed in v1.7. - Please pass :class:`~pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging` - directly to the Trainer's ``callbacks`` argument instead. """ super().__init__() Trainer._log_api_event("init") @@ -540,7 +530,6 @@ def __init__( weights_save_path, enable_model_summary, weights_summary, - stochastic_weight_avg, max_time, accumulate_grad_batches, ) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 0abac46732e97..4a79c462f2780 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -195,38 +195,6 @@ def test_swa_raises(): StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1]) -@pytest.mark.parametrize("stochastic_weight_avg", [False, True]) -@pytest.mark.parametrize("use_callbacks", [False, True]) -def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool): - """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer.""" - - class TestModel(BoringModel): - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer - - model = TestModel() - kwargs = { - "default_root_dir": tmpdir, - "callbacks": StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None, - "stochastic_weight_avg": stochastic_weight_avg, - "limit_train_batches": 4, - "limit_val_batches": 4, - "max_epochs": 2, - } - if stochastic_weight_avg: - with pytest.deprecated_call(match=r"stochastic_weight_avg=True\)` is deprecated in v1.5"): - trainer = Trainer(**kwargs) - else: - trainer = Trainer(**kwargs) - trainer.fit(model) - if use_callbacks or stochastic_weight_avg: - assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1 - assert trainer.callbacks[0]._swa_lrs == [1e-3 if use_callbacks else 0.1] - else: - assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks) - - def test_swa_deepcopy(tmpdir): """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM.""" diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index e2bbcd70e53c2..ecd890e6b6291 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -130,11 +130,6 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): _ = Trainer(prepare_data_per_node=False) -def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir): - with pytest.deprecated_call(match=r"Setting `Trainer\(stochastic_weight_avg=True\)` is deprecated in v1.5"): - _ = Trainer(stochastic_weight_avg=True) - - @pytest.mark.parametrize("terminate_on_nan", [True, False]) def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan): with pytest.deprecated_call(