Skip to content

Commit

Permalink
Remove rank 0 restrictions from logger (#8608)
Browse files Browse the repository at this point in the history
  • Loading branch information
edward-io authored Aug 6, 2021
1 parent 4928dc5 commit 8473cf4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))

- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
step = self.trainer.global_step

# log actual metrics
if self.trainer.is_global_zero:
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()

self._logged_metrics.update(scalar_metrics)

Expand Down
56 changes: 43 additions & 13 deletions tests/trainer/logging_/test_distributed_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import os
from typing import Any, Dict, Optional, Union
from unittest import mock
from unittest.mock import Mock

import pytorch_lightning as pl
Expand All @@ -23,23 +22,50 @@
from tests.helpers.runif import RunIf


class AllRankLogger(LightningLoggerBase):
"""
Logger to test all-rank logging (i.e. not just rank 0).
Logs are saved to local variable `logs`.
"""

def __init__(self):
super().__init__()
self.logs = {}
self.exp = object()

def experiment(self) -> Any:
return self.exp

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
self.logs.update(metrics)

def version(self) -> Union[int, str]:
return 1

def name(self) -> str:
return "AllRank"

def log_hyperparams(self, *args, **kwargs) -> None:
pass


class TestModel(BoringModel):
def on_pretrain_routine_end(self) -> None:
with mock.patch("pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics") as m:
self.trainer.logger_connector.log_metrics({"a": 2})
logged_times = m.call_count
expected = int(self.trainer.is_global_zero)
msg = f"actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}"
assert logged_times == expected, msg
log_name = "rank-{rank}"

def on_train_start(self):
self.log(self.log_name.format(rank=self.local_rank), 0)

def on_train_end(self):
assert self.log_name.format(rank=self.local_rank) in self.logger.logs, "Expected rank to be logged"


@RunIf(skip_windows=True)
def test_global_zero_only_logging_ddp_cpu(tmpdir):
def test_all_rank_logging_ddp_cpu(tmpdir):
"""
Makes sure logging only happens from root zero
Check that all ranks can be logged from
"""
model = TestModel()
model.training_epoch_end = None
all_rank_logger = AllRankLogger()
trainer = Trainer(
accelerator="ddp_cpu",
num_processes=2,
Expand All @@ -48,16 +74,19 @@ def test_global_zero_only_logging_ddp_cpu(tmpdir):
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
logger=all_rank_logger,
log_every_n_steps=1,
)
trainer.fit(model)


@RunIf(min_gpus=2)
def test_global_zero_only_logging_ddp_spawn(tmpdir):
def test_all_rank_logging_ddp_spawn(tmpdir):
"""
Makes sure logging only happens from root zero
Check that all ranks can be logged from
"""
model = TestModel()
all_rank_logger = AllRankLogger()
model.training_epoch_end = None
trainer = Trainer(
accelerator="ddp_spawn",
Expand All @@ -66,6 +95,7 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir):
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
logger=all_rank_logger,
weights_summary=None,
)
trainer.fit(model)
Expand Down

0 comments on commit 8473cf4

Please sign in to comment.