Skip to content
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))


- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))


Expand Down
10 changes: 7 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,14 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional

# update docs too once this is resolved
trainer_fn = self.trainer.state.fn
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
if (
isinstance(sampler, DistributedSampler)
and sampler.num_replicas > 1
and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING)
):
rank_zero_warn(
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`, it is"
" recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated"
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
" some samples to make sure all devices have same batch size in case of uneven inputs.",
category=PossibleUserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,19 +526,20 @@ def test_invalid_hook_passed_in_datahook_selector():
dh_selector.get_instance("setup")


def test_eval_distributed_sampler_warning(tmpdir):
@pytest.mark.parametrize("devices, warn_context", [(1, no_warning_call), (2, pytest.warns)])
def test_eval_distributed_sampler_warning(devices, warn_context):
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""

model = BoringModel()
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
trainer = Trainer(strategy="ddp", devices=devices, accelerator="cpu")
trainer._data_connector.attach_data(model)

trainer.state.fn = TrainerFn.VALIDATING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_val_dataloader(model)

trainer.state.fn = TrainerFn.TESTING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_test_dataloader(model)


Expand Down