Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid printing the seed info message multiple times #20108

Merged
merged 4 commits into from
Jul 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))

-
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))


### Changed

4 changes: 2 additions & 2 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
@@ -909,7 +909,7 @@ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> No
logger.log_metrics(metrics=metrics, step=step)

@staticmethod
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None, verbose: bool = True) -> int:
r"""Helper function to seed everything without explicitly importing Lightning.

See :func:`~lightning.fabric.utilities.seed.seed_everything` for more details.
@@ -919,7 +919,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
# Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new
# release, we can afford to do it.
workers = True
return seed_everything(seed=seed, workers=workers)
return seed_everything(seed=seed, workers=workers, verbose=verbose)

def _wrap_and_launch(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any:
self._launched = True
9 changes: 6 additions & 3 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
min_seed_value = 0


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: bool = True) -> int:
r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module.
In addition, sets the following environment variables:

@@ -32,6 +32,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:`~lightning.fabric.utilities.seed.pl_worker_init_function`.
verbose: Whether to print a message on each rank with the seed being set.

"""
if seed is None:
@@ -52,7 +53,9 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = 0

log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
if verbose:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))

os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
if _NUMPY_AVAILABLE:
@@ -76,7 +79,7 @@ def reset_seed() -> None:
if seed is None:
return
workers = os.environ.get("PL_SEED_WORKERS", "0")
seed_everything(int(seed), workers=bool(int(workers)))
seed_everything(int(seed), workers=bool(int(workers)), verbose=False)


def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
5 changes: 3 additions & 2 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -13,7 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703))

-
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))


### Changed

@@ -41,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Avoid LightningCLI saving hyperparameters with `class_path` and `init_args` since this would be a breaking change ([#20068](https://github.com/Lightning-AI/pytorch-lightning/pull/20068))

-
- Fixed an issue that would cause too many printouts of the seed info when using `seed_everything()` ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))



28 changes: 18 additions & 10 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
@@ -3,34 +3,34 @@
from unittest import mock
from unittest.mock import Mock

import lightning.fabric.utilities
import numpy
import pytest
import torch
from lightning.fabric.utilities.seed import (
_collect_rng_states,
_set_rng_states,
pl_worker_init_function,
reset_seed,
seed_everything,
)


@mock.patch.dict(os.environ, clear=True)
def test_default_seed():
"""Test that the default seed is 0 when no seed provided and no environment variable set."""
assert lightning.fabric.utilities.seed.seed_everything() == 0
assert seed_everything() == 0
assert os.environ["PL_GLOBAL_SEED"] == "0"


@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
with pytest.warns(UserWarning, match="No seed found"):
lightning.fabric.utilities.seed.seed_everything()
seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")

with pytest.warns(None) as record:
lightning.fabric.utilities.seed.seed_everything()
seed_everything()
assert not record # does not warn
seed = os.environ.get("PL_GLOBAL_SEED")

@@ -40,14 +40,14 @@ def test_seed_stays_same_with_multiple_seed_everything_calls():
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
def test_correct_seed_with_environment_variable():
"""Ensure that the PL_GLOBAL_SEED environment is read."""
assert lightning.fabric.utilities.seed.seed_everything() == 2020
assert seed_everything() == 2020


@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = lightning.fabric.utilities.seed.seed_everything()
seed = seed_everything()
assert seed == 0


@@ -56,15 +56,15 @@ def test_invalid_seed():
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = lightning.fabric.utilities.seed.seed_everything(seed)
actual = seed_everything(seed)
assert actual == 0


def test_reset_seed_no_op():
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
assert "PL_GLOBAL_SEED" not in os.environ
seed_before = torch.initial_seed()
lightning.fabric.utilities.seed.reset_seed()
reset_seed()
assert torch.initial_seed() == seed_before
assert "PL_GLOBAL_SEED" not in os.environ

@@ -75,18 +75,26 @@ def test_reset_seed_everything(workers):
assert "PL_GLOBAL_SEED" not in os.environ
assert "PL_SEED_WORKERS" not in os.environ

lightning.fabric.utilities.seed.seed_everything(123, workers)
seed_everything(123, workers)
before = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))

lightning.fabric.utilities.seed.reset_seed()
reset_seed()
after = torch.rand(1)
assert os.environ["PL_GLOBAL_SEED"] == "123"
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
assert torch.allclose(before, after)


def test_reset_seed_non_verbose(caplog):
seed_everything(123)
assert len(caplog.records) == 1
caplog.clear()
reset_seed() # should call `seed_everything(..., verbose=False)`
assert len(caplog.records) == 0


def test_backward_compatibility_rng_states_dict():
"""Test that an older rng_states_dict without the "torch.cuda" key does not crash."""
states = _collect_rng_states()