Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,15 +646,14 @@ def _update_metrics(
current: Optional[int] = None,
total_batches: bool = False,
) -> None:
if not self.is_enabled or self._metric_component is None:
return

if current is not None and not total_batches:
total = self.total_train_batches
if not self._should_update(current, total):
return

metrics = self.get_metrics(trainer, pl_module)
if not self.is_enabled or self._metric_component is None:
return
if self._metric_component:
self._metric_component.update(metrics)

Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,9 @@ def on_train_batch_end(

@override
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
metrics = self.get_metrics(trainer, pl_module)
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.train_progress_bar.set_postfix(metrics)
if self._leave:
self.train_progress_bar.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import pickle
from collections import defaultdict
from unittest import mock
from unittest.mock import DEFAULT, Mock
from unittest.mock import DEFAULT, Mock, patch

import pytest
from tests_pytorch.helpers.runif import RunIf
Expand All @@ -26,6 +27,7 @@
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers.logger import DummyLogger
from lightning.pytorch.strategies import DDPStrategy


@RunIf(rich=True)
Expand Down Expand Up @@ -605,3 +607,55 @@ def val_dataloader(self):

# This should not raise an AssertionError
trainer.fit(model)


def test_rich_progress_bar_ddp_deadlock(tmp_path):
"""Tests that RichProgressBar doesn't deadlock when using DDP on train epoch end.

We used to have a bug where metrics were synced only on the rank 0 process. See
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
for more details.

"""
RichProgressBar()

# We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)["loss"]
self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True)
return {"loss": loss}

model = MyModel()

# We need to mock these logger connector hooks, since these also attempt to sync metrics
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
def mock_on_epoch_end(self):
pass

def mock_update_train_epoch_metrics(self):
pass

with (
patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end),
patch(
"lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics",
mock_update_train_epoch_metrics,
),
):
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
max_epochs=1,
val_check_interval=1,
accelerator="cpu",
devices=2,
strategy=DDPStrategy(
process_group_backend="gloo", # run on CPU
timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail
),
enable_progress_bar=True,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(model)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import math
import os
import pickle
Expand All @@ -32,6 +33,7 @@
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers.logger import DummyLogger
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -859,3 +861,57 @@ def reset(self, total=None):
assert 2 in val_bar.total_values, (
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
)


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tqdm_progress_bar_ddp_deadlock(tmp_path):
"""Tests that TQDMProgressBar doesn't deadlock when using DDP on train epoch end.

We used to have a bug where metrics were synced only on the rank 0 process. See
https://github.com/Lightning-AI/pytorch-lightning/issues/21264
for more details.

"""
pbar = TQDMProgressBar()

# We need a LightningModule that logs a metric with on_epoch=True, sync_dist=True
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)["loss"]
self.log("loss", loss, on_step=False, on_epoch=True, sync_dist=True)
return {"loss": loss}

model = MyModel()

# We need to mock these logger connector hooks, since these also attempt to sync metrics
# and can "save" otherwise incorrect implementations of TQDMProgressBar.on_train_epoch_end.
def mock_on_epoch_end(self):
pass

def mock_update_train_epoch_metrics(self):
pass

with (
patch("lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.on_epoch_end", mock_on_epoch_end),
patch(
"lightning.pytorch.trainer.connectors.logger_connector._LoggerConnector.update_train_epoch_metrics",
mock_update_train_epoch_metrics,
),
):
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
max_epochs=1,
val_check_interval=1,
accelerator="cpu",
devices=2,
strategy=DDPStrategy(
process_group_backend="gloo", # run on CPU
timeout=datetime.timedelta(seconds=5), # timeout quickly for the test to fail
),
callbacks=[pbar],
enable_progress_bar=True,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(model)