Skip to content

Commit 7aba32a

Browse files
authored
Merge a7f1ded into 0c31afe
2 parents 0c31afe + a7f1ded commit 7aba32a

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

mmseg/apis/train.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import numpy as np
55
import torch
66
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7-
from mmcv.runner import build_optimizer, build_runner
7+
from mmcv.runner import HOOKS, build_optimizer, build_runner
8+
from mmcv.utils import build_from_cfg
89

910
from mmseg.core import DistEvalHook, EvalHook
1011
from mmseg.datasets import build_dataloader, build_dataset
@@ -109,6 +110,20 @@ def train_segmentor(model,
109110
eval_hook = DistEvalHook if distributed else EvalHook
110111
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
111112

113+
# user-defined hooks
114+
if cfg.get('custom_hooks', None):
115+
custom_hooks = cfg.custom_hooks
116+
assert isinstance(custom_hooks, list), \
117+
f'custom_hooks expect list type, but got {type(custom_hooks)}'
118+
for hook_cfg in cfg.custom_hooks:
119+
assert isinstance(hook_cfg, dict), \
120+
'Each item in custom_hooks expects dict type, but got ' \
121+
f'{type(hook_cfg)}'
122+
hook_cfg = hook_cfg.copy()
123+
priority = hook_cfg.pop('priority', 'NORMAL')
124+
hook = build_from_cfg(hook_cfg, HOOKS)
125+
runner.register_hook(hook, priority=priority)
126+
112127
if cfg.resume_from:
113128
runner.resume(cfg.resume_from)
114129
elif cfg.load_from:

0 commit comments

Comments
 (0)