Skip to content

Commit

Permalink
add utils.set_determinism for reproducibility (#576)
Browse files Browse the repository at this point in the history
This PR:
1 - adds a set_determinism function to set seed for Python, PyTorch,
CUDA and deterministic settings for cudnn.
2 - if seed is None, then no deterministic settings are used. This may
be important as turning off cuDnn benchmarking to ensure determinism,
can also negatively impact perf.
3 - note that for the None case, we revert / ensure cudnn is set back to
non-deterministic and benchmarking/tuning in case people are toggling.

This lack of determinism negatively impacted work with AWS where we
ended up with variations in loss curves while running fp8 for our joint
blog that appeared to be from fp8 but are instead likely due to not
having determinism in titan.

Testing - I ran multiple small runs with 7B while rotating between three
seeds and saw consistent ending loss points matching to the respective
seeds.

This PR does not set per worker aspects for the dataloader since we do
not shuffle atm...but that could become a future source of randomness
that will need to be set if we shuffle in the future.
  • Loading branch information
lessw2020 authored Oct 1, 2024
1 parent eef8bb2 commit 9ccc161
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
7 changes: 6 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,12 @@ def __init__(self):
default=50,
help="Python garbage control scheduling interval, in steps",
)

self.parser.add_argument(
"--training.seed",
type=int,
default=None,
help="Implement reproducibility by setting a Python, PyTorch and CUDA seed",
)
# checkpointing configs
self.parser.add_argument(
"--checkpoint.enable_checkpoint",
Expand Down
20 changes: 19 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Union
from typing import Optional, Union

import torch
import torch.distributed._functional_collectives as funcol
Expand All @@ -36,6 +36,24 @@ def _warn_overwrite_env(env, val):
os.environ[env] = val


def set_determinism(seed: Optional[int]) -> None:
"""
Set Python, PyTorch, CUDA seeds and cudnn settings for reproducibility
"""
if seed is not None:
# CPU and GPU determinism
torch.manual_seed(seed)
# set deterministic cudnn algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# set Python seed
os.environ["PYTHONHASHSEED"] = str(seed)
else:
# ensure we turn off deterministic cudnn algorithms
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


def set_pg_timeouts(timeout, world_mesh):
"""
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
Expand Down
9 changes: 9 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def main(job_config: JobConfig):
# take control of garbage collection to avoid stragglers
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

# set determinisism, use seed == None to skip deterministic training
utils.set_determinism(job_config.training.seed)
if job_config.training.seed is None:
logger.info("Deterministic training off")
else:
logger.info(
f"Deterministic training on. Using seed: {job_config.training.seed}"
)

# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
Expand Down

0 comments on commit 9ccc161

Please sign in to comment.