diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index d4c3c916c7ed0..fe0c5b2d5e46f 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -646,15 +646,14 @@ def _update_metrics( current: Optional[int] = None, total_batches: bool = False, ) -> None: - if not self.is_enabled or self._metric_component is None: - return - if current is not None and not total_batches: total = self.total_train_batches if not self._should_update(current, total): return metrics = self.get_metrics(trainer, pl_module) + if not self.is_enabled or self._metric_component is None: + return if self._metric_component: self._metric_component.update(metrics) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 74abb8ecd850c..228362800fff3 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -282,8 +282,9 @@ def on_train_batch_end( @override def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + metrics = self.get_metrics(trainer, pl_module) if not self.train_progress_bar.disable: - self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + self.train_progress_bar.set_postfix(metrics) if self._leave: self.train_progress_bar.close() diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 9d74871ce84e4..28d34270c3668 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import pickle from collections import defaultdict from unittest import mock -from unittest.mock import DEFAULT, Mock +from unittest.mock import DEFAULT, Mock, patch import pytest from tests_pytorch.helpers.runif import RunIf @@ -26,6 +27,7 @@ from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger +from lightning.pytorch.strategies import DDPStrategy @RunIf(rich=True) @@ -605,3 +607,55 @@ def val_dataloader(self): # This should not raise an AssertionError trainer.fit(model) + + +def test_rich_progress_bar_ddp_deadlock(tmp_path): + """Tests that RichProgressBar doesn't deadlock when using DDP on train epoch end. + + We used to have a bug where metrics were synced only on the rank 0 process. See + https://github.com/Lightning-AI/pytorch-lightning/issues/21264 + for more details. + + """ + RichProgressBar() + + # We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True + class MyModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx)["loss"] + self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True) + return {"loss": loss} + + model = MyModel() + + # We need to mock these logger connector hooks, since these also attempt to sync metrics + # and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end. + def mock_on_epoch_end(self): + pass + + def mock_update_train_epoch_metrics(self): + pass + + with ( + patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end), + patch( + "lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics", + mock_update_train_epoch_metrics, + ), + ): + trainer = Trainer( + default_root_dir=tmp_path, + num_sanity_val_steps=0, + max_epochs=1, + val_check_interval=1, + accelerator="cpu", + devices=2, + strategy=DDPStrategy( + process_group_backend="gloo", # run on CPU + timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail + ), + enable_progress_bar=True, + enable_model_summary=False, + enable_checkpointing=False, + ) + trainer.fit(model) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 0bd29b998c598..617fb5f71c5de 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import math import os import pickle @@ -32,6 +33,7 @@ from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger +from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -859,3 +861,57 @@ def reset(self, total=None): assert 2 in val_bar.total_values, ( f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}" ) + + +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) +def test_tqdm_progress_bar_ddp_deadlock(tmp_path): + """Tests that TQDMProgressBar doesn't deadlock when using DDP on train epoch end. + + We used to have a bug where metrics were synced only on the rank 0 process. See + https://github.com/Lightning-AI/pytorch-lightning/issues/21264 + for more details. + + """ + pbar = TQDMProgressBar() + + # We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True + class MyModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx)["loss"] + self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True) + return {"loss": loss} + + model = MyModel() + + # We need to mock these logger connector hooks, since these also attempt to sync metrics + # and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end. + def mock_on_epoch_end(self): + pass + + def mock_update_train_epoch_metrics(self): + pass + + with ( + patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end), + patch( + "lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics", + mock_update_train_epoch_metrics, + ), + ): + trainer = Trainer( + default_root_dir=tmp_path, + num_sanity_val_steps=0, + max_epochs=1, + val_check_interval=1, + accelerator="cpu", + devices=2, + strategy=DDPStrategy( + process_group_backend="gloo", # run on CPU + timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail + ), + callbacks=[pbar], + enable_progress_bar=True, + enable_model_summary=False, + enable_checkpointing=False, + ) + trainer.fit(model)