diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 8d6fd64a2cb47..469832b23c49b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -174,6 +174,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed deprecated `pytorch_lightning.core.decorators.parameter_validation` from `decorators` ([#13514](https://github.com/Lightning-AI/lightning/pull/13514)) + + - Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149)) diff --git a/src/pytorch_lightning/core/decorators.py b/src/pytorch_lightning/core/decorators.py deleted file mode 100644 index 33c83b4b10d6d..0000000000000 --- a/src/pytorch_lightning/core/decorators.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn - -rank_zero_deprecation( - "Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, " - "and will be removed in v1.7. It has been replaced by automatic parameters tying with " - "`pytorch_lightning.utilities.params_tying.set_shared_parameters`" -) - -from functools import wraps # noqa: E402 -from typing import Callable # noqa: E402 - - -def parameter_validation(fn: Callable) -> Callable: - """Validates that the module parameter lengths match after moving to the device. It is useful when tying - weights on TPU's. - - Args: - fn: ``model_to_device`` method - - Note: - TPU's require weights to be tied/shared after moving the module to the device. - Failure to do this results in the initialization of new weights which are not tied. - To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook - which is called after the module has been moved to the device. - - See Also: - - `XLA Documentation `_ - """ - - @wraps(fn) - def inner_fn(self, *args, **kwargs): - pre_layer_count = len(list(self.model.parameters())) - module = fn(self, *args, **kwargs) - self.model.on_post_move_to_device() - post_layer_count = len(list(self.model.parameters())) - - if not pre_layer_count == post_layer_count: - rank_zero_warn( - "The model layers do not match after moving to the target device." - " If your model employs weight sharing on TPU," - " please tie your weights using the `on_post_move_to_device` model hook.\n" - f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]" - ) - - return module - - return inner_fn diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-7.py b/tests/tests_pytorch/deprecated_api/test_remove_1-7.py index 2ae305d2c06b7..17cccbfa80a5e 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-7.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-7.py @@ -31,7 +31,6 @@ TorchElasticEnvironment, ) from pytorch_lightning.strategies import SingleDeviceStrategy -from tests_pytorch.deprecated_api import _soft_unimport_module from tests_pytorch.plugins.environments.test_lsf_environment import _make_rankfile @@ -76,15 +75,6 @@ def on_post_move_to_device(self): trainer.fit(model) -def test_v1_7_0_deprecate_parameter_validation(): - - _soft_unimport_module("pytorch_lightning.core.decorators") - with pytest.deprecated_call( - match="Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5" - ): - from pytorch_lightning.core.decorators import parameter_validation # noqa: F401 - - def test_v1_7_0_deprecated_slurm_job_id(): trainer = Trainer() with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."):