Skip to content

Commit

Permalink
[BUG] Check environ before selecting a seed to prevent warning message (
Browse files Browse the repository at this point in the history
#4743)

* Check environment var independently to selecting a seed to prevent unnecessary warning message

* Add if statement to check if PL_GLOBAL_SEED has been set

* Added seed test to ensure that the seed stays the same, in case

* if

* Delete global seed after test has finished

* Fix code, add tests

* Ensure seed does not exist before tests start

* Refactor test based on review, add log call

* Ensure we clear the os environ in patched dict

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
(cherry picked from commit 635df27)
  • Loading branch information
SeanNaren authored and Borda committed Jan 25, 2021
1 parent fc58f66 commit e296f36
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
17 changes: 7 additions & 10 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn


def seed_everything(seed: Optional[int] = None) -> int:
Expand All @@ -41,18 +41,17 @@ def seed_everything(seed: Optional[int] = None) -> int:

try:
if seed is None:
seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value))
seed = os.environ.get("PL_GLOBAL_SEED")
seed = int(seed)
except (TypeError, ValueError):
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"No correct seed found, seed set to {seed}")

if (seed > max_seed_value) or (seed < min_seed_value):
log.warning(
f"{seed} is not in bounds, \
numpy accepts from {min_seed_value} to {max_seed_value}"
)
if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = _select_seed_randomly(min_seed_value, max_seed_value)

log.info(f"Global seed set to {seed}")
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
Expand All @@ -62,6 +61,4 @@ def seed_everything(seed: Optional[int] = None) -> int:


def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
seed = random.randint(min_seed_value, max_seed_value)
log.warning(f"No correct seed found, seed set to {seed}")
return seed
return random.randint(min_seed_value, max_seed_value)
55 changes: 55 additions & 0 deletions tests/utilities/test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os

from unittest import mock
import pytest

import pytorch_lightning.utilities.seed as seed_utils


@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""
Ensure that after the initial seed everything,
the seed stays the same for the same run.
"""
with pytest.warns(UserWarning, match="No correct seed found"):
seed_utils.seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")

with pytest.warns(None) as record:
seed_utils.seed_everything()
assert not record # does not warn
seed = os.environ.get("PL_GLOBAL_SEED")

assert initial_seed == seed


@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
def test_correct_seed_with_environment_variable():
"""
Ensure that the PL_GLOBAL_SEED environment is read
"""
assert seed_utils.seed_everything() == 2020


@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
def test_invalid_seed():
"""
Ensure that we still fix the seed even if an invalid seed is given
"""
with pytest.warns(UserWarning, match="No correct seed found"):
seed = seed_utils.seed_everything()
assert seed == 123


@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
@pytest.mark.parametrize("seed", (10e9, -10e9))
def test_out_of_bounds_seed(seed):
"""
Ensure that we still fix the seed even if an out-of-bounds seed is given
"""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = seed_utils.seed_everything(seed)
assert actual == 123

0 comments on commit e296f36

Please sign in to comment.