Skip to content

Commit

Permalink
[Feature] Use --diff-seed to set different torch seed on different …
Browse files Browse the repository at this point in the history
…rank (#781)

* [Feature]: Add diff seeds to diff ranks and set torch seed in worker_init_fn

* [Feature]: Add diff seeds to diff ranks and set torch seed in worker_init_fn

* Update train.py
  • Loading branch information
Yshuo-Li authored Mar 23, 2022
1 parent 7ced34a commit 4ff208e
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 10 deletions.
5 changes: 3 additions & 2 deletions mmedit/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from .restoration_inference import restoration_inference
from .restoration_video_inference import restoration_video_inference
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model
from .train import init_random_seed, set_random_seed, train_model
from .video_interpolation_inference import video_interpolation_inference

__all__ = [
'train_model', 'set_random_seed', 'init_model', 'matting_inference',
'inpainting_inference', 'restoration_inference', 'generation_inference',
'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference',
'restoration_face_inference', 'video_interpolation_inference'
'restoration_face_inference', 'video_interpolation_inference',
'init_random_seed'
]
33 changes: 32 additions & 1 deletion mmedit/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel
from mmcv.runner import HOOKS, IterBasedRunner
from mmcv.runner import HOOKS, IterBasedRunner, get_dist_info
from mmcv.utils import build_from_cfg

from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers
Expand All @@ -17,6 +18,36 @@
from mmedit.utils import get_root_logger


def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed

# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()


def set_random_seed(seed, deterministic=False):
"""Set random seed.
Expand Down
1 change: 1 addition & 0 deletions mmedit/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
18 changes: 18 additions & 0 deletions tests/test_runtime/test_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmedit.apis.train import init_random_seed, set_random_seed


def test_init_random_seed():
init_random_seed(0, device='cpu')
init_random_seed(device='cpu')
# test on gpu
if torch.cuda.is_available():
init_random_seed(0, device='cuda')
init_random_seed(device='cuda')


def test_set_random_seed():
set_random_seed(0, deterministic=False)
set_random_seed(0, deterministic=True)
20 changes: 13 additions & 7 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import mmcv
import torch
import torch.distributed as dist
from mmcv import Config
from mmcv.runner import init_dist

from mmedit import __version__
from mmedit.apis import set_random_seed, train_model
from mmedit.apis import init_random_seed, set_random_seed, train_model
from mmedit.datasets import build_dataset
from mmedit.models import build_model
from mmedit.utils import collect_env, get_root_logger, setup_multi_processes
Expand All @@ -34,6 +35,10 @@ def parse_args():
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -104,11 +109,12 @@ def main():
logger.info('Config:\n{}'.format(cfg.text))

# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}, deterministic: {}'.format(
args.seed, args.deterministic))
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info('Set random seed to {}, deterministic: {}'.format(
seed, args.deterministic))
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed

model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
Expand All @@ -132,7 +138,7 @@ def main():
cfg['exp_name'] = osp.splitext(osp.basename(cfg.work_dir))[0]
meta['exp_name'] = cfg.exp_name
meta['mmedit Version'] = __version__
meta['seed'] = args.seed
meta['seed'] = seed
meta['env_info'] = env_info

# add an attribute for visualization convenience
Expand Down

0 comments on commit 4ff208e

Please sign in to comment.