Skip to content

Commit

Permalink
[Feature] Add diff seeds to diff ranks and set torch seed in worker_i…
Browse files Browse the repository at this point in the history
…nit_fn (open-mmlab#113)

* add init_random_seed

* Set diff seed to diff workers
  • Loading branch information
pppppM committed Mar 27, 2022
1 parent 0cbe900 commit 2d42fbf
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 21 deletions.
4 changes: 3 additions & 1 deletion mmrazor/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
from .mmcls import * # noqa: F401,F403
from .mmdet import * # noqa: F401,F403
from .mmseg import * # noqa: F401,F403
from .utils import set_random_seed # noqa: F401
from .utils import init_random_seed, set_random_seed # noqa: F401

__all__ = ['init_random_seed', 'set_random_seed']
33 changes: 33 additions & 0 deletions mmrazor/apis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,39 @@

import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info


def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed

# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()


def set_random_seed(seed: int, deterministic: bool = False) -> None:
Expand Down
21 changes: 14 additions & 7 deletions tools/mmcls/train_mmcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import mmcv
import torch
import torch.distributed as dist
from mmcls import __version__
from mmcls.datasets import build_dataset
from mmcls.utils import collect_env, get_root_logger
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist

# Differences from mmclassification
from mmrazor.apis import set_random_seed, train_mmcls_model
from mmrazor.apis import init_random_seed, set_random_seed, train_mmcls_model
from mmrazor.models import build_algorithm
from mmrazor.utils import setup_multi_processes

Expand Down Expand Up @@ -54,6 +55,10 @@ def parse_args():
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -154,12 +159,14 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

# Difference from mmclassification
# replace `model` to `algorithm`
Expand Down
20 changes: 13 additions & 7 deletions tools/mmdet/train_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
from mmdet import __version__
from mmdet.datasets import build_dataset
from mmdet.utils import collect_env, get_root_logger

from mmrazor.apis import set_random_seed, train_mmdet_model
from mmrazor.apis import init_random_seed, set_random_seed, train_mmdet_model
from mmrazor.models import build_algorithm
from mmrazor.utils import setup_multi_processes

Expand Down Expand Up @@ -61,6 +62,10 @@ def parse_args():
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -166,12 +171,13 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)

algorithm = build_algorithm(cfg.algorithm)
Expand Down
18 changes: 12 additions & 6 deletions tools/mmseg/train_mmseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import mmcv
import torch
import torch.distributed as dist
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash
Expand All @@ -24,7 +25,7 @@
from mmseg.utils import collect_env, get_root_logger

# Differences from mmdetection
from mmrazor.apis import set_random_seed, train_mmseg_model
from mmrazor.apis import init_random_seed, set_random_seed, train_mmseg_model
from mmrazor.models.builder import build_algorithm
from mmrazor.utils import setup_multi_processes

Expand Down Expand Up @@ -64,6 +65,10 @@ def parse_args():
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff_seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -168,11 +173,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, deterministic: '
f'{args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {args.seed}, deterministic: '
f'{args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = args.seed
meta['exp_name'] = osp.basename(args.config)

Expand Down

0 comments on commit 2d42fbf

Please sign in to comment.