|
4 | 4 | import numpy as np
|
5 | 5 | import torch
|
6 | 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
7 |
| -from mmcv.runner import build_optimizer, build_runner |
| 7 | +from mmcv.runner import HOOKS, build_optimizer, build_runner |
| 8 | +from mmcv.utils import build_from_cfg |
8 | 9 |
|
9 | 10 | from mmseg.core import DistEvalHook, EvalHook
|
10 | 11 | from mmseg.datasets import build_dataloader, build_dataset
|
@@ -109,6 +110,20 @@ def train_segmentor(model,
|
109 | 110 | eval_hook = DistEvalHook if distributed else EvalHook
|
110 | 111 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
111 | 112 |
|
| 113 | + # user-defined hooks |
| 114 | + if cfg.get('custom_hooks', None): |
| 115 | + custom_hooks = cfg.custom_hooks |
| 116 | + assert isinstance(custom_hooks, list), \ |
| 117 | + f'custom_hooks expect list type, but got {type(custom_hooks)}' |
| 118 | + for hook_cfg in cfg.custom_hooks: |
| 119 | + assert isinstance(hook_cfg, dict), \ |
| 120 | + 'Each item in custom_hooks expects dict type, but got ' \ |
| 121 | + f'{type(hook_cfg)}' |
| 122 | + hook_cfg = hook_cfg.copy() |
| 123 | + priority = hook_cfg.pop('priority', 'NORMAL') |
| 124 | + hook = build_from_cfg(hook_cfg, HOOKS) |
| 125 | + runner.register_hook(hook, priority=priority) |
| 126 | + |
112 | 127 | if cfg.resume_from:
|
113 | 128 | runner.resume(cfg.resume_from)
|
114 | 129 | elif cfg.load_from:
|
|
0 commit comments