diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a05430bbeb06..ae5d90eae5341 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102)) +- Deprecated `LoggerCollection` in favor of `trainer.loggers` ([#12147](https://github.com/PyTorchLightning/pytorch-lightning/pull/12147)) + + - Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978)) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index f0a8ba13dbdcd..dc50d48b23670 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -221,6 +221,10 @@ def version(self) -> Union[int, str]: class LoggerCollection(LightningLoggerBase): """The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. + .. deprecated:: v1.6 + `LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. + Directly pass a list of loggers to the Trainer and access the list via the `trainer.loggers` attribute. + Args: logger_iterable: An iterable collection of loggers """ @@ -228,6 +232,10 @@ class LoggerCollection(LightningLoggerBase): def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): super().__init__() self._logger_iterable = logger_iterable + rank_zero_deprecation( + "`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers" + " to the Trainer and access the list via the `trainer.loggers` attribute." + ) def __getitem__(self, index: int) -> LightningLoggerBase: return list(self._logger_iterable)[index] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4c01e58e9e3d5..881b62271ec73 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2623,7 +2623,9 @@ def logger(self) -> Optional[LightningLoggerBase]: " This behavior will change in v1.8 when LoggerCollection is removed, and" " trainer.logger will return the first logger in trainer.loggers" ) - return LoggerCollection(self.loggers) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return LoggerCollection(self.loggers) @logger.setter def logger(self, logger: Optional[LightningLoggerBase]) -> None: diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 70cf8865bed6c..3c8923a96c3ad 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -21,7 +21,7 @@ from torch import optim from pytorch_lightning import Callback, Trainer -from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase +from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin @@ -662,6 +662,23 @@ def _get_python_cprofile_total_duration(profile): np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) +def test_v1_8_0_logger_collection(tmpdir): + logger1 = CSVLogger(tmpdir) + logger2 = CSVLogger(tmpdir) + + trainer1 = Trainer(logger=logger1) + trainer2 = Trainer(logger=[logger1, logger2]) + + # Should have no deprecation warning + trainer1.logger + trainer1.loggers + trainer2.loggers + trainer2.logger + + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + LoggerCollection([logger1, logger2]) + + def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir): class PrecisionPluginSaveHook(PrecisionPlugin): def on_save_checkpoint(self, checkpoint): diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index cd7eec14eec77..ff93f9933d74b 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -34,7 +34,8 @@ def test_logger_collection(): mock1 = MagicMock() mock2 = MagicMock() - logger = LoggerCollection([mock1, mock2]) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger = LoggerCollection([mock1, mock2]) assert logger[0] == mock1 assert logger[1] == mock2 @@ -62,14 +63,16 @@ def test_logger_collection_unique_names(): logger1 = CustomLogger(name=unique_name) logger2 = CustomLogger(name=unique_name) - logger = LoggerCollection([logger1, logger2]) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger = LoggerCollection([logger1, logger2]) assert logger.name == unique_name def test_logger_collection_names_order(): loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")] - logger = LoggerCollection(loggers) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger = LoggerCollection(loggers) assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}" @@ -78,14 +81,16 @@ def test_logger_collection_unique_versions(): logger1 = CustomLogger(version=unique_version) logger2 = CustomLogger(version=unique_version) - logger = LoggerCollection([logger1, logger2]) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger = LoggerCollection([logger1, logger2]) assert logger.version == unique_version def test_logger_collection_versions_order(): loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")] - logger = LoggerCollection(loggers) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger = LoggerCollection(loggers) assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}" diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 161c9e6e35670..2971ba7bc2c19 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -24,7 +24,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging -from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -450,9 +450,9 @@ def test_pytorch_profiler_nested(tmpdir): assert events_name == expected, (events_name, torch.__version__, platform.system()) -def test_pytorch_profiler_logger_collection(tmpdir): - """Tests whether the PyTorch profiler is able to write its trace locally when the Trainer's logger is an - instance of LoggerCollection. +def test_pytorch_profiler_multiple_loggers(tmpdir): + """Tests whether the PyTorch profiler is able to write its trace locally when the Trainer is configured with + multiple loggers. See issue #8157. """ @@ -465,10 +465,9 @@ def look_for_trace(trace_dir): assert not look_for_trace(tmpdir) model = BoringModel() - # Wrap the logger in a list so it becomes a LoggerCollection - logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)] - trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1) - assert isinstance(trainer.logger, LoggerCollection) + loggers = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)] + trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=loggers, limit_train_batches=5, max_epochs=1) + assert len(trainer.loggers) == 2 trainer.fit(model) assert look_for_trace(tmpdir) diff --git a/tests/trainer/properties/test_log_dir.py b/tests/trainer/properties/test_log_dir.py index 6777ec8183737..db2f862f82e7b 100644 --- a/tests/trainer/properties/test_log_dir.py +++ b/tests/trainer/properties/test_log_dir.py @@ -15,7 +15,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger from tests.helpers.boring_model import BoringModel @@ -109,8 +109,8 @@ def test_logdir_custom_logger(tmpdir): assert trainer.log_dir == expected -def test_logdir_logger_collection(tmpdir): - """Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection.""" +def test_logdir_multiple_loggers(tmpdir): + """Tests that the logdir equals the default_root_dir when trainer has multiple loggers.""" default_root_dir = tmpdir / "default_root_dir" save_dir = tmpdir / "save_dir" model = TestModel(default_root_dir) @@ -119,7 +119,6 @@ def test_logdir_logger_collection(tmpdir): max_steps=2, logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)], ) - assert isinstance(trainer.logger, LoggerCollection) assert trainer.log_dir == default_root_dir trainer.fit(model) diff --git a/tests/trainer/properties/test_loggers.py b/tests/trainer/properties/test_loggers.py index d3db78986f361..b48f3f2768175 100644 --- a/tests/trainer/properties/test_loggers.py +++ b/tests/trainer/properties/test_loggers.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from pytorch_lightning import Trainer from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from tests.loggers.test_base import CustomLogger @@ -50,8 +52,10 @@ def test_trainer_loggers_setters(): """Test the behavior of setters for trainer.logger and trainer.loggers.""" logger1 = CustomLogger() logger2 = CustomLogger() - logger_collection = LoggerCollection([logger1, logger2]) - logger_collection_2 = LoggerCollection([logger2]) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger_collection = LoggerCollection([logger1, logger2]) + with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"): + logger_collection_2 = LoggerCollection([logger2]) trainer = Trainer() assert type(trainer.logger) == TensorBoardLogger