Skip to content

Commit

Permalink
[Enhancement] Pass custom_hooks to mmcv (open-mmlab#609)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaotongxiao authored Nov 30, 2021
1 parent 22e8c32 commit 227a0cc
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions mmocr/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
Expand Down Expand Up @@ -108,9 +107,13 @@ def train_detector(model,
optimizer_config = cfg.optimizer_config

# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
Expand Down Expand Up @@ -144,20 +147,6 @@ def train_detector(model,
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
Expand Down

0 comments on commit 227a0cc

Please sign in to comment.