From bee213bfe811cac51fc3abe333d0837b45db1382 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Wed, 1 Jun 2022 13:26:34 +0530 Subject: [PATCH 1/8] log_rank_zero_only flag --- src/pytorch_lightning/callbacks/early_stopping.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 51bab7f98d941..94fd574feb7aa 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -63,6 +63,7 @@ class EarlyStopping(Callback): divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation. + log_rank_zero_only: When set ``True``, logs the status of the early stopping callback only for rank 0 process. Raises: MisconfigurationException: @@ -100,6 +101,7 @@ def __init__( stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, check_on_train_epoch_end: Optional[bool] = None, + log_rank_zero_only: bool = False, ): super().__init__() self.monitor = monitor @@ -114,6 +116,7 @@ def __init__( self.wait_count = 0 self.stopped_epoch = 0 self._check_on_train_epoch_end = check_on_train_epoch_end + self.log_rank_zero_only = log_rank_zero_only if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -202,7 +205,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: if should_stop: self.stopped_epoch = trainer.current_epoch if reason and self.verbose: - self._log_info(trainer, reason) + self._log_info(trainer, reason, self.log_rank_zero_only) def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]: should_stop = False @@ -255,8 +258,8 @@ def _improvement_message(self, current: Tensor) -> str: return msg @staticmethod - def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None: - if trainer is not None and trainer.world_size > 1: + def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: + if not log_rank_zero_only and trainer is not None and trainer.world_size > 1: log.info(f"[rank: {trainer.global_rank}] {message}") else: log.info(message) From 8eebb4f63bc6c08974fd6570f8ac5cd8eec3d63e Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Mon, 6 Jun 2022 00:00:40 +0530 Subject: [PATCH 2/8] fix log_rank_zero_only flag --- src/pytorch_lightning/callbacks/early_stopping.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 94fd574feb7aa..9627f2d9cd13d 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -259,7 +259,10 @@ def _improvement_message(self, current: Tensor) -> str: @staticmethod def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: - if not log_rank_zero_only and trainer is not None and trainer.world_size > 1: + # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled + if log_rank_zero_only and trainer.global_rank != 0: + return + if trainer is not None and trainer.world_size > 1: log.info(f"[rank: {trainer.global_rank}] {message}") else: log.info(message) From f8713a58928f86b7bc75ac817c38bf54fa9bcd8e Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Mon, 6 Jun 2022 00:00:57 +0530 Subject: [PATCH 3/8] add test for log_rank_zero_only --- .../callbacks/test_early_stopping.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index f9e55e059d226..eaae0a4b97142 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -456,3 +456,31 @@ def test_early_stopping_squeezes(): early_stopping._run_early_stopping_check(trainer) es_mock.assert_called_once_with(torch.tensor(0)) + + +@pytest.mark.parametrize( + "log_rank_zero_only, world_size, global_rank, expected_log", + [ + (False, 1, 0, f"bar"), + (False, 2, 0, f"[rank: 0] bar"), + (False, 2, 1, f"[rank: 1] bar"), + (True, 1, 0, f"bar"), + (True, 2, 0, f"[rank: 0] bar"), + (True, 2, 1, None), + ], +) +def test_early_stopping_log_info(tmpdir, log_rank_zero_only, world_size, global_rank, expected_log): + """checks if log.info() gets called with expected message when used within EarlyStopping""" + + early_stopping = EarlyStopping(monitor="foo") + trainer = Trainer() + trainer.strategy.global_rank = global_rank + trainer.strategy.world_size = world_size + + with mock.patch("pytorch_lightning.callbacks.early_stopping.log.info") as log_mock: + early_stopping._log_info(trainer, "bar", log_rank_zero_only) + + if expected_log: + log_mock.assert_called_once_with(expected_log) + else: + log_mock.assert_not_called() \ No newline at end of file From ae11a63e12ffeed64d10c2cf205be85233a7995e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jun 2022 18:45:30 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/callbacks/early_stopping.py | 2 +- tests/tests_pytorch/callbacks/test_early_stopping.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 9627f2d9cd13d..e9a24ed758c9e 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -259,7 +259,7 @@ def _improvement_message(self, current: Tensor) -> str: @staticmethod def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: - # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled + # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled if log_rank_zero_only and trainer.global_rank != 0: return if trainer is not None and trainer.world_size > 1: diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index eaae0a4b97142..e81dafb506682 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -470,17 +470,17 @@ def test_early_stopping_squeezes(): ], ) def test_early_stopping_log_info(tmpdir, log_rank_zero_only, world_size, global_rank, expected_log): - """checks if log.info() gets called with expected message when used within EarlyStopping""" - + """checks if log.info() gets called with expected message when used within EarlyStopping.""" + early_stopping = EarlyStopping(monitor="foo") trainer = Trainer() trainer.strategy.global_rank = global_rank trainer.strategy.world_size = world_size - + with mock.patch("pytorch_lightning.callbacks.early_stopping.log.info") as log_mock: early_stopping._log_info(trainer, "bar", log_rank_zero_only) - + if expected_log: log_mock.assert_called_once_with(expected_log) else: - log_mock.assert_not_called() \ No newline at end of file + log_mock.assert_not_called() From ef83958742f406aa6f09fd7b32c24e3d80b77c3d Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Mon, 6 Jun 2022 00:17:06 +0530 Subject: [PATCH 5/8] update changelog.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4623e03599c8..2f76472f9d1af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a flag named `log_rank_zero_only` to `EarlyStopping` to disable logging to non-zero rank processes ([#13233](https://github.com/PyTorchLightning/pytorch-lightning/pull/13233)) + + - Added support for reloading the last checkpoint saved by passing `ckpt_path="last"` ([#12816](https://github.com/PyTorchLightning/pytorch-lightning/pull/12816)) From 79afdaeed2ccc8709da41d36165a40f812fa94c2 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Mon, 6 Jun 2022 00:57:52 +0530 Subject: [PATCH 6/8] add test for Trainer=None --- .../callbacks/early_stopping.py | 16 ++++++------ .../callbacks/test_early_stopping.py | 25 ++++++++++++------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index e9a24ed758c9e..f58d2cd39ce4c 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -259,10 +259,12 @@ def _improvement_message(self, current: Tensor) -> str: @staticmethod def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: - # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled - if log_rank_zero_only and trainer.global_rank != 0: - return - if trainer is not None and trainer.world_size > 1: - log.info(f"[rank: {trainer.global_rank}] {message}") - else: - log.info(message) + if trainer: + # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled + if log_rank_zero_only and trainer.global_rank != 0: + return + # if world size is more than one then specify the rank of the processed being logged + if trainer.world_size > 1: + log.info(f"[rank: {trainer.global_rank}] {message}") + return + log.info(message) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index e81dafb506682..b6035b5e646a0 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -458,28 +458,35 @@ def test_early_stopping_squeezes(): es_mock.assert_called_once_with(torch.tensor(0)) +@pytest.mark.parametrize("trainer", [Trainer(), None]) @pytest.mark.parametrize( "log_rank_zero_only, world_size, global_rank, expected_log", [ - (False, 1, 0, f"bar"), - (False, 2, 0, f"[rank: 0] bar"), - (False, 2, 1, f"[rank: 1] bar"), - (True, 1, 0, f"bar"), - (True, 2, 0, f"[rank: 0] bar"), + (False, 1, 0, "bar"), + (False, 2, 0, "[rank: 0] bar"), + (False, 2, 1, "[rank: 1] bar"), + (True, 1, 0, "bar"), + (True, 2, 0, "[rank: 0] bar"), (True, 2, 1, None), ], ) -def test_early_stopping_log_info(tmpdir, log_rank_zero_only, world_size, global_rank, expected_log): +def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only, world_size, global_rank, expected_log): """checks if log.info() gets called with expected message when used within EarlyStopping.""" early_stopping = EarlyStopping(monitor="foo") - trainer = Trainer() - trainer.strategy.global_rank = global_rank - trainer.strategy.world_size = world_size + + # set the global_rank and world_size if trainer is not None + # or else always expect the simple logging message + if trainer: + trainer.strategy.global_rank = global_rank + trainer.strategy.world_size = world_size + else: + expected_log = "bar" with mock.patch("pytorch_lightning.callbacks.early_stopping.log.info") as log_mock: early_stopping._log_info(trainer, "bar", log_rank_zero_only) + # check log.info() was called or not with expected arg if expected_log: log_mock.assert_called_once_with(expected_log) else: From 139c6713561f02755acf99f67381525e68f592c5 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Tue, 7 Jun 2022 08:44:23 +0530 Subject: [PATCH 7/8] _log_info is static --- tests/tests_pytorch/callbacks/test_early_stopping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index b6035b5e646a0..96a6f29be86d9 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -473,8 +473,6 @@ def test_early_stopping_squeezes(): def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only, world_size, global_rank, expected_log): """checks if log.info() gets called with expected message when used within EarlyStopping.""" - early_stopping = EarlyStopping(monitor="foo") - # set the global_rank and world_size if trainer is not None # or else always expect the simple logging message if trainer: @@ -484,7 +482,7 @@ def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only, world_size expected_log = "bar" with mock.patch("pytorch_lightning.callbacks.early_stopping.log.info") as log_mock: - early_stopping._log_info(trainer, "bar", log_rank_zero_only) + EarlyStopping._log_info(trainer, "bar", log_rank_zero_only) # check log.info() was called or not with expected arg if expected_log: From 69e1007f426a4421f19f8eb942c38142f43c7bc2 Mon Sep 17 00:00:00 2001 From: Ekagra Ranjan Date: Tue, 7 Jun 2022 08:48:52 +0530 Subject: [PATCH 8/8] fix typo comment --- src/pytorch_lightning/callbacks/early_stopping.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index f58d2cd39ce4c..2fd730482fcc4 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -263,8 +263,10 @@ def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled if log_rank_zero_only and trainer.global_rank != 0: return - # if world size is more than one then specify the rank of the processed being logged + # if world size is more than one then specify the rank of the process being logged if trainer.world_size > 1: log.info(f"[rank: {trainer.global_rank}] {message}") return + + # if above conditions don't meet and we have to log log.info(message)