From 6259cb32c6fb14f0056d5dfc76bd8babea48f96b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 30 Mar 2022 20:58:37 +0530 Subject: [PATCH 1/4] Remove deprecated stochastic_weight_averaging flag from Trainer --- .../trainer/connectors/callback_connector.py | 21 ------------ pytorch_lightning/trainer/trainer.py | 11 ------- tests/callbacks/test_stochastic_weight_avg.py | 32 ------------------- tests/deprecated_api/test_remove_1-7.py | 5 --- 4 files changed, 69 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 897bcd03396d1..38c139d7a7f09 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -47,7 +47,6 @@ def on_trainer_init( weights_save_path: Optional[str], enable_model_summary: bool, weights_summary: Optional[str], - stochastic_weight_avg: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, ): @@ -60,13 +59,6 @@ def on_trainer_init( ) self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir - if stochastic_weight_avg: - rank_zero_deprecation( - "Setting `Trainer(stochastic_weight_avg=True)` is deprecated in v1.5 and will be removed in v1.7." - " Please pass `pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging`" - " directly to the Trainer's `callbacks` argument instead." - ) - self.trainer._stochastic_weight_avg = stochastic_weight_avg # init callbacks if isinstance(callbacks, Callback): @@ -77,9 +69,6 @@ def on_trainer_init( # pass through the required args to figure out defaults self._configure_checkpoint_callbacks(checkpoint_callback, enable_checkpointing) - # configure swa callback - self._configure_swa_callbacks() - # configure the timer callback. # responsible to stop the training when max_time is reached. self._configure_timer_callback(max_time) @@ -210,16 +199,6 @@ def _configure_model_summary_callback( self.trainer.callbacks.append(model_summary) self.trainer._weights_summary = weights_summary - def _configure_swa_callbacks(self): - if not self.trainer._stochastic_weight_avg: - return - - from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging - - existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)] - if not existing_swa: - self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks - def _configure_progress_bar( self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True ) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c0ea6f6f38dbd..e39c806cbbec1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -188,7 +188,6 @@ def __init__( amp_level: Optional[str] = None, move_metrics_to_cpu: bool = False, multiple_trainloader_mode: str = "max_size_cycle", - stochastic_weight_avg: bool = False, terminate_on_nan: Optional[bool] = None, ) -> None: r""" @@ -463,15 +462,6 @@ def __init__( and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. Default: ``"max_size_cycle"``. - - stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) - `_. - Default: ``False``. - - .. deprecated:: v1.5 - ``stochastic_weight_avg`` has been deprecated in v1.5 and will be removed in v1.7. - Please pass :class:`~pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging` - directly to the Trainer's ``callbacks`` argument instead. """ super().__init__() Trainer._log_api_event("init") @@ -552,7 +542,6 @@ def __init__( weights_save_path, enable_model_summary, weights_summary, - stochastic_weight_avg, max_time, accumulate_grad_batches, ) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 0abac46732e97..4a79c462f2780 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -195,38 +195,6 @@ def test_swa_raises(): StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1]) -@pytest.mark.parametrize("stochastic_weight_avg", [False, True]) -@pytest.mark.parametrize("use_callbacks", [False, True]) -def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool): - """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer.""" - - class TestModel(BoringModel): - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer - - model = TestModel() - kwargs = { - "default_root_dir": tmpdir, - "callbacks": StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None, - "stochastic_weight_avg": stochastic_weight_avg, - "limit_train_batches": 4, - "limit_val_batches": 4, - "max_epochs": 2, - } - if stochastic_weight_avg: - with pytest.deprecated_call(match=r"stochastic_weight_avg=True\)` is deprecated in v1.5"): - trainer = Trainer(**kwargs) - else: - trainer = Trainer(**kwargs) - trainer.fit(model) - if use_callbacks or stochastic_weight_avg: - assert sum(1 for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)) == 1 - assert trainer.callbacks[0]._swa_lrs == [1e-3 if use_callbacks else 0.1] - else: - assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks) - - def test_swa_deepcopy(tmpdir): """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM.""" diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 548e45683c13f..b7f14cf510bf2 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -131,11 +131,6 @@ def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): _ = Trainer(prepare_data_per_node=False) -def test_v1_7_0_stochastic_weight_avg_trainer_constructor(tmpdir): - with pytest.deprecated_call(match=r"Setting `Trainer\(stochastic_weight_avg=True\)` is deprecated in v1.5"): - _ = Trainer(stochastic_weight_avg=True) - - @pytest.mark.parametrize("terminate_on_nan", [True, False]) def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan): with pytest.deprecated_call( From f146648349a94a93d9f7faf0427a1f48819df6cd Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 30 Mar 2022 21:02:55 +0530 Subject: [PATCH 2/4] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0d288f7efc29..872757a1ab887 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- +- Remove deprecated `stochastic_weight_averaging` flag from `Trainer` ([#12535](https://github.com/PyTorchLightning/pytorch-lightning/pull/12535)) - From 724500bee294b7e739a47b7b75c607b71d08a4d8 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 31 Mar 2022 08:37:16 +0530 Subject: [PATCH 3/4] Update CHANGELOG.md Co-authored-by: ananthsub <2382532+ananthsub@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 507f678188927..a5167e1d4354e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- Remove deprecated `stochastic_weight_averaging` flag from `Trainer` ([#12535](https://github.com/PyTorchLightning/pytorch-lightning/pull/12535)) +- Removed the deprecated `stochastic_weight_averaging` argument from the `Trainer` constructor ([#12535](https://github.com/PyTorchLightning/pytorch-lightning/pull/12535)) - Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514)) From d30f90a0d7f14435eea79780d1f161fcda9395cc Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 31 Mar 2022 14:47:50 +0530 Subject: [PATCH 4/4] Update CHANGELOG.md Co-authored-by: ananthsub <2382532+ananthsub@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5167e1d4354e..3e1ada8db632c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- Removed the deprecated `stochastic_weight_averaging` argument from the `Trainer` constructor ([#12535](https://github.com/PyTorchLightning/pytorch-lightning/pull/12535)) +- Removed the deprecated `stochastic_weight_avg` argument from the `Trainer` constructor ([#12535](https://github.com/PyTorchLightning/pytorch-lightning/pull/12535)) - Removed the deprecated `progress_bar_refresh_rate` argument from the `Trainer` constructor ([#12514](https://github.com/PyTorchLightning/pytorch-lightning/pull/12514))