Skip to content

Commit

Permalink
[Feature] Add diff seeds to diff ranks. (open-mmlab#7432)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid authored and SakiRinn committed Mar 17, 2023
1 parent 7ed58db commit 7790bc5
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mmdet/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
reduce_mean)
reduce_mean, sync_random_seed)
from .misc import (center_of_mass, filter_scores_and_topk, flip_tensor,
generate_coordinate, mask2ndarray, multi_apply,
select_single_mlvl, unmap)
Expand All @@ -9,5 +9,5 @@
'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict',
'center_of_mass', 'generate_coordinate', 'select_single_mlvl',
'filter_scores_and_topk'
'filter_scores_and_topk', 'sync_random_seed'
]
34 changes: 34 additions & 0 deletions mmdet/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from collections import OrderedDict

import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import OptimizerHook, get_dist_info
Expand Down Expand Up @@ -151,3 +152,36 @@ def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
if isinstance(py_dict, OrderedDict):
out_dict = OrderedDict(out_dict)
return out_dict


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
2 changes: 2 additions & 0 deletions mmdet/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial

import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
Expand Down Expand Up @@ -197,3 +198,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
8 changes: 6 additions & 2 deletions mmdet/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler

from mmdet.core.utils import sync_random_seed


class DistributedSampler(_DistributedSampler):

Expand All @@ -15,8 +17,10 @@ def __init__(self,
seed=0):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
self.seed = sync_random_seed(seed)

def __iter__(self):
# deterministically shuffle based on epoch
Expand Down
12 changes: 10 additions & 2 deletions mmdet/datasets/samplers/infinite_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from mmcv.runner import get_dist_info
from torch.utils.data.sampler import Sampler

from mmdet.core.utils import sync_random_seed


class InfiniteGroupBatchSampler(Sampler):
"""Similar to `BatchSampler` warping a `GroupSampler. It is designed for
Expand Down Expand Up @@ -48,7 +50,10 @@ def __init__(self,
self.world_size = world_size
self.dataset = dataset
self.batch_size = batch_size
self.seed = seed if seed is not None else 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
self.seed = sync_random_seed(seed)
self.shuffle = shuffle

assert hasattr(self.dataset, 'flag')
Expand Down Expand Up @@ -133,7 +138,10 @@ def __init__(self,
self.world_size = world_size
self.dataset = dataset
self.batch_size = batch_size
self.seed = seed if seed is not None else 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
self.seed = sync_random_seed(seed)
self.shuffle = shuffle
self.size = len(dataset)
self.indices = self._indices_of_rank()
Expand Down
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
Expand Down Expand Up @@ -52,6 +53,10 @@ def parse_args():
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff-seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -169,6 +174,7 @@ def main():

# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
Expand Down

0 comments on commit 7790bc5

Please sign in to comment.