Skip to content

Commit

Permalink
Merge pull request #45 from Media-Smart/fix_seed_bug
Browse files Browse the repository at this point in the history
fix seed bug
  • Loading branch information
hxcai authored Apr 7, 2021
2 parents 867df87 + dfc2884 commit 0c97e12
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 21 deletions.
4 changes: 3 additions & 1 deletion vedacore/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from .lr_scheduler import FixedLrSchedulerHook
from .optimizer import OptimizerHook
from .snapshot import SnapshotHook
from .sampler_seed import DistSamplerSeedHook
from .worker_init import WorkerInitHook

__all__ = [
'BaseHook', 'EvalHook', 'HookPool', 'LoggerHook', 'FixedLrSchedulerHook',
'OptimizerHook', 'SnapshotHook'
'OptimizerHook', 'SnapshotHook', 'DistSamplerSeedHook', 'WorkerInitHook'
]
24 changes: 24 additions & 0 deletions vedacore/hooks/sampler_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Open-MMLab. All rights reserved.
from vedacore.misc import registry
from .base_hook import BaseHook


@registry.register_module('hook')
class DistSamplerSeedHook(BaseHook):
"""Data-loading sampler for distributed training.
When distributed training, it is only useful in conjunction with
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
purpose with :obj:`IterLoader`.
"""

def before_train_epoch(self, looper):
if hasattr(looper.train_dataloader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
looper.train_dataloader.sampler.set_epoch(looper.epoch)
elif hasattr(looper.train_dataloader.batch_sampler.sampler, 'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes.
looper.train_dataloader.batch_sampler.sampler.set_epoch(looper.epoch)

@property
def modes(self):
return ['train']
17 changes: 17 additions & 0 deletions vedacore/hooks/worker_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from vedacore.misc import registry
from .base_hook import BaseHook


@registry.register_module('hook')
class WorkerInitHook(BaseHook):
"""Worker init for training.
"""

def before_train_epoch(self, looper):
worker_init_fn = looper.train_dataloader.worker_init_fn
if worker_init_fn is not None and hasattr(worker_init_fn, 'set_epoch'):
worker_init_fn.set_epoch(looper.epoch)

@property
def modes(self):
return ['train']
6 changes: 1 addition & 5 deletions vedacore/misc/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_logger(name, log_file=None, log_level=logging.INFO):
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
logger.propagate = False
if name in logger_initialized:
return logger
# handle hierarchical names
Expand All @@ -36,11 +37,6 @@ def get_logger(name, log_file=None, log_level=logging.INFO):
if name.startswith(logger_name):
return logger

if logger.parent is not None:
logger.parent.handlers.clear()
else:
logger.handlers.clear()

stream_handler = logging.StreamHandler()
handlers = [stream_handler]

Expand Down
7 changes: 7 additions & 0 deletions vedadet/assembler/trainval.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def trainval(cfg, distributed, logger):

looper = EpochBasedLooper(cfg.modes, dataloaders, engines, hook_pool,
logger, cfg.workdir)

if isinstance(looper, EpochBasedLooper):
looper.hook_pool.register_hook(dict(typename='WorkerInitHook'))
if distributed:
looper.hook_pool.register_hook(
dict(typename='DistSamplerSeedHook'))

if 'weights' in cfg:
looper.load_weights(**cfg.weights)
if 'train' in cfg.modes:
Expand Down
32 changes: 21 additions & 11 deletions vedadet/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,19 @@ def build_dataloader(dataset,
# that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(dataset, samples_per_gpu,
world_size, rank)
world_size, rank, seed=seed)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False)
dataset, world_size, rank, shuffle=False, seed=seed)
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu

init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
init_fn = WorkerInit(num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

data_loader = DataLoader(
dataset,
Expand All @@ -121,9 +120,20 @@ def build_dataloader(dataset,
return data_loader


def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
class WorkerInit:
def __init__(self, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
self.num_workers = num_workers
self.rank = rank
self.seed = seed if seed is not None else 0
self.epoch = 0

def set_epoch(self, epoch):
self.epoch = epoch

def __call__(self, worker_id):
worker_seed = (self.num_workers * self.rank + worker_id + self.seed +
self.epoch)
np.random.seed(worker_seed)
random.seed(worker_seed)
6 changes: 4 additions & 2 deletions vedadet/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

class DistributedSampler(_DistributedSampler):

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True,
seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
self.seed = seed if seed is not None else 0

def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
Expand Down
6 changes: 4 additions & 2 deletions vedadet/datasets/samplers/group_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None):
rank=None,
seed=0):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
Expand All @@ -84,6 +85,7 @@ def __init__(self,
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.seed = seed if seed is not None else 0

assert hasattr(self.dataset, 'flag')
self.flag = self.dataset.flag
Expand All @@ -99,7 +101,7 @@ def __init__(self,
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.epoch + self.seed)

indices = []
for i, size in enumerate(self.group_sizes):
Expand Down

0 comments on commit 0c97e12

Please sign in to comment.