Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate LoggerCollection in favor of trainer.loggers #12147

Merged
merged 12 commits into from
Mar 4, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))


- Deprecated `LoggerCollection` in favor of `trainer.loggers` ([#12147](https://github.com/PyTorchLightning/pytorch-lightning/pull/12147))


- Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978))


Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,19 @@ def version(self) -> Union[int, str]:
class LoggerCollection(LightningLoggerBase):
"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.

.. deprecated:: v1.6
`LoggerCollection` is deprecated in v1.6 in favor of `trainer.loggers` and will be removed in v1.8.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

Args:
logger_iterable: An iterable collection of loggers
"""

def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
super().__init__()
self._logger_iterable = logger_iterable
rank_zero_deprecation(
"`LoggerCollection` is deprecated in v1.6 in favor of `trainer.loggers` and will be removed in v1.8."
akashkw marked this conversation as resolved.
Show resolved Hide resolved
)

def __getitem__(self, index: int) -> LightningLoggerBase:
return list(self._logger_iterable)[index]
Expand Down
21 changes: 20 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
Expand Down Expand Up @@ -662,6 +662,25 @@ def _get_python_cprofile_total_duration(profile):
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)


def test_v1_8_0_logger_collection(tmpdir):
logger1 = CSVLogger(tmpdir)
logger2 = CSVLogger(tmpdir)

trainer1 = Trainer(logger=logger1)
trainer2 = Trainer(logger=[logger1, logger2])

# Should have no deprecation warning
trainer1.logger
trainer1.loggers
trainer2.loggers

with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
trainer2.logger

with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
LoggerCollection([logger1, logger2])


def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
class PrecisionPluginSaveHook(PrecisionPlugin):
def on_save_checkpoint(self, checkpoint):
Expand Down
15 changes: 10 additions & 5 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def test_logger_collection():
mock1 = MagicMock()
mock2 = MagicMock()

logger = LoggerCollection([mock1, mock2])
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([mock1, mock2])

assert logger[0] == mock1
assert logger[1] == mock2
Expand Down Expand Up @@ -62,14 +63,16 @@ def test_logger_collection_unique_names():
logger1 = CustomLogger(name=unique_name)
logger2 = CustomLogger(name=unique_name)

logger = LoggerCollection([logger1, logger2])
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([logger1, logger2])

assert logger.name == unique_name


def test_logger_collection_names_order():
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
logger = LoggerCollection(loggers)
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection(loggers)
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"


Expand All @@ -78,14 +81,16 @@ def test_logger_collection_unique_versions():
logger1 = CustomLogger(version=unique_version)
logger2 = CustomLogger(version=unique_version)

logger = LoggerCollection([logger1, logger2])
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([logger1, logger2])

assert logger.version == unique_version


def test_logger_collection_versions_order():
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
logger = LoggerCollection(loggers)
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection(loggers)
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"


Expand Down
15 changes: 7 additions & 8 deletions tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -450,9 +450,9 @@ def test_pytorch_profiler_nested(tmpdir):
assert events_name == expected, (events_name, torch.__version__, platform.system())


def test_pytorch_profiler_logger_collection(tmpdir):
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer's logger is an
instance of LoggerCollection.
def test_pytorch_profiler_multiple_loggers(tmpdir):
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer is configured with
multiple loggers.

See issue #8157.
"""
Expand All @@ -465,10 +465,9 @@ def look_for_trace(trace_dir):
assert not look_for_trace(tmpdir)

model = BoringModel()
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
assert isinstance(trainer.logger, LoggerCollection)
loggers = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=loggers, limit_train_batches=5, max_epochs=1)
assert len(trainer.loggers) == 2
trainer.fit(model)
assert look_for_trace(tmpdir)

Expand Down
7 changes: 3 additions & 4 deletions tests/trainer/properties/test_log_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -109,8 +109,8 @@ def test_logdir_custom_logger(tmpdir):
assert trainer.log_dir == expected


def test_logdir_logger_collection(tmpdir):
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection."""
def test_logdir_multiple_loggers(tmpdir):
"""Tests that the logdir equals the default_root_dir when trainer has multiple loggers."""
default_root_dir = tmpdir / "default_root_dir"
save_dir = tmpdir / "save_dir"
model = TestModel(default_root_dir)
Expand All @@ -119,7 +119,6 @@ def test_logdir_logger_collection(tmpdir):
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir

trainer.fit(model)
Expand Down
14 changes: 10 additions & 4 deletions tests/trainer/properties/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from tests.loggers.test_base import CustomLogger
Expand Down Expand Up @@ -50,8 +52,10 @@ def test_trainer_loggers_setters():
"""Test the behavior of setters for trainer.logger and trainer.loggers."""
logger1 = CustomLogger()
logger2 = CustomLogger()
logger_collection = LoggerCollection([logger1, logger2])
logger_collection_2 = LoggerCollection([logger2])
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger_collection = LoggerCollection([logger1, logger2])
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger_collection_2 = LoggerCollection([logger2])
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

trainer = Trainer()
assert type(trainer.logger) == TensorBoardLogger
Expand All @@ -63,7 +67,8 @@ def test_trainer_loggers_setters():
assert trainer.loggers == [logger1]

trainer.logger = logger_collection
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
assert trainer.loggers == [logger1, logger2]

# LoggerCollection of size 1 should result in trainer.logger becoming the contained logger.
Expand All @@ -78,7 +83,8 @@ def test_trainer_loggers_setters():
# Test setters for trainer.loggers
trainer.loggers = [logger1, logger2]
assert trainer.loggers == [logger1, logger2]
assert trainer.logger._logger_iterable == logger_collection._logger_iterable
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
assert trainer.logger._logger_iterable == logger_collection._logger_iterable

trainer.loggers = [logger1]
assert trainer.loggers == [logger1]
Expand Down