Skip to content

Commit

Permalink
Make numpy an optional dependency in utilities\seed.py (#20055)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
01AbhiSingh and awaelchli authored Jul 12, 2024
1 parent 9987d99 commit d5ae9ec
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055))

-

Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import platform
import sys

from lightning_utilities.core.imports import compare_version
from lightning_utilities.core.imports import RequirementCache, compare_version

_NUMPY_AVAILABLE = RequirementCache("numpy")


_IS_WINDOWS = platform.system() == "Windows"

Expand Down
44 changes: 29 additions & 15 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import random
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch

from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE
from lightning.fabric.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min

max_seed_value = 4294967295 # 2^32 - 1 (uint32)
min_seed_value = 0


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
Expand Down Expand Up @@ -54,7 +56,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
if _NUMPY_AVAILABLE:
np.random.seed(seed)
torch.manual_seed(seed)

os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
Expand Down Expand Up @@ -91,24 +94,34 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:
log.debug(
f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}"
)
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
torch_ss, stdlib_ss = ss.spawn(2)
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)
seed_sequence = _generate_seed_sequence(base_seed, worker_id, global_rank, count=4)
torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed
random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds
if _NUMPY_AVAILABLE:
np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only


def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]:
"""Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
algorithm."""
# Combine base seed, worker id and rank into a unique 64-bit number
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
seeds = []
for _ in range(count):
# x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
seeds.append(combined_seed)
return seeds


def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
states = {
"torch": torch.get_rng_state(),
"numpy": np.random.get_state(),
"python": python_get_rng_state(),
}
if _NUMPY_AVAILABLE:
states["numpy"] = np.random.get_state()
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else []
return states
Expand All @@ -121,6 +134,7 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
np.random.set_state(rng_state_dict["numpy"])
if _NUMPY_AVAILABLE and "numpy" in rng_state_dict:
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Triggering KeyboardInterrupt (Ctrl+C) during `.fit()`, `.evaluate()`, `.test()` or `.predict()` now terminates all processes launched by the Trainer and exits the program ([#19976](https://github.com/Lightning-AI/pytorch-lightning/pull/19976))

-
- Changed the implementation of how seeds are chosen for dataloader workers when using `seed_everything(..., workers=True)` ([#20055](https://github.com/Lightning-AI/pytorch-lightning/pull/20055))


### Deprecated

Expand Down
32 changes: 31 additions & 1 deletion tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import os
import random
from unittest import mock
from unittest.mock import Mock

import lightning.fabric.utilities
import numpy
import pytest
import torch
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning.fabric.utilities.seed import (
_collect_rng_states,
_set_rng_states,
pl_worker_init_function,
seed_everything,
)


@mock.patch.dict(os.environ, clear=True)
Expand Down Expand Up @@ -95,3 +102,26 @@ def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
states = _collect_rng_states()
assert states["torch.cuda"] == []


@pytest.mark.parametrize(("num_workers", "num_ranks"), [(64, 64)])
@pytest.mark.parametrize("base_seed", [100, 1024, 2**32 - 1])
def test_pl_worker_init_function(base_seed, num_workers, num_ranks):
"""Test that Lightning's `worker_init_fn` sets unique seeds per worker/rank derived from the base seed."""
torch_rands = set()
stdlib_rands = set()
numpy_rands = set()

for worker_id in range(num_workers):
for rank in range(num_ranks):
seed_everything(base_seed)
pl_worker_init_function(worker_id, rank)
torch_rands.add(tuple(torch.randint(0, 1_000_000, (100,)).tolist()))
stdlib_rands.add(tuple(random.randint(0, 1_000_000) for _ in range(100)))
numpy_rands.add(tuple(numpy.random.randint(0, 1_000_000, (100,)).tolist()))

# Assert there are no duplicates (no collisions)
assert len(torch_rands) == num_ranks * num_workers
assert len(stdlib_rands) == num_ranks * num_workers
assert len(numpy_rands) == num_ranks * num_workers
assert len(torch_rands | stdlib_rands | numpy_rands) == 3 * num_workers * num_ranks

0 comments on commit d5ae9ec

Please sign in to comment.