From f9336032a556bd7e1286d49fe04295c0069d369f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Thu, 6 May 2021 07:19:54 +0800 Subject: [PATCH] Use MMCV's EvalHook in MMSegmentation (#438) * mmcv eval hook * mmcv evalhook compatible * add warnings * inherit from base class * fix unitest * adapt to mmcv 1.3.1 * fixed unittest * set by_epoch=False * fixed efficient test * update docstring Co-authored-by: Jiarui XU --- mmseg/__init__.py | 2 +- mmseg/core/evaluation/eval_hooks.py | 106 ++++++++++++++-------------- tests/test_eval_hook.py | 4 +- 3 files changed, 57 insertions(+), 55 deletions(-) diff --git a/mmseg/__init__.py b/mmseg/__init__.py index d1f472c044..96a8ca14fe 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -2,7 +2,7 @@ from .version import __version__, version_info -MMCV_MIN = '1.1.4' +MMCV_MIN = '1.3.1' MMCV_MAX = '1.4.0' diff --git a/mmseg/core/evaluation/eval_hooks.py b/mmseg/core/evaluation/eval_hooks.py index 09c6265ece..34c44c7fe3 100644 --- a/mmseg/core/evaluation/eval_hooks.py +++ b/mmseg/core/evaluation/eval_hooks.py @@ -1,37 +1,49 @@ import os.path as osp -from mmcv.runner import Hook -from torch.utils.data import DataLoader +from mmcv.runner import DistEvalHook as _DistEvalHook +from mmcv.runner import EvalHook as _EvalHook -class EvalHook(Hook): - """Evaluation hook. +class EvalHook(_EvalHook): + """Single GPU EvalHook, with efficient test support. - Attributes: - dataloader (DataLoader): A PyTorch dataloader. - interval (int): Evaluation interval (by epochs). Default: 1. + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + Returns: + list: The prediction results. """ - def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs): - if not isinstance(dataloader, DataLoader): - raise TypeError('dataloader must be a pytorch DataLoader, but got ' - f'{type(dataloader)}') - self.dataloader = dataloader - self.interval = interval - self.by_epoch = by_epoch - self.eval_kwargs = eval_kwargs + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.efficient_test = efficient_test def after_train_iter(self, runner): - """After train epoch hook.""" + """After train epoch hook. + + Override default ``single_gpu_test``. + """ if self.by_epoch or not self.every_n_iters(runner, self.interval): return from mmseg.apis import single_gpu_test runner.log_buffer.clear() - results = single_gpu_test(runner.model, self.dataloader, show=False) + results = single_gpu_test( + runner.model, + self.dataloader, + show=False, + efficient_test=self.efficient_test) self.evaluate(runner, results) def after_train_epoch(self, runner): - """After train epoch hook.""" + """After train epoch hook. + + Override default ``single_gpu_test``. + """ if not self.by_epoch or not self.every_n_epochs(runner, self.interval): return from mmseg.apis import single_gpu_test @@ -39,45 +51,31 @@ def after_train_epoch(self, runner): results = single_gpu_test(runner.model, self.dataloader, show=False) self.evaluate(runner, results) - def evaluate(self, runner, results): - """Call evaluate function of dataset.""" - eval_res = self.dataloader.dataset.evaluate( - results, logger=runner.logger, **self.eval_kwargs) - for name, val in eval_res.items(): - runner.log_buffer.output[name] = val - runner.log_buffer.ready = True - -class DistEvalHook(EvalHook): - """Distributed evaluation hook. +class DistEvalHook(_DistEvalHook): + """Distributed EvalHook, with efficient test support. - Attributes: - dataloader (DataLoader): A PyTorch dataloader. - interval (int): Evaluation interval (by epochs). Default: 1. - tmpdir (str | None): Temporary directory to save the results of all - processes. Default: None. - gpu_collect (bool): Whether to use gpu or cpu to collect results. + Args: + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. Default: False. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + Returns: + list: The prediction results. """ - def __init__(self, - dataloader, - interval=1, - gpu_collect=False, - by_epoch=False, - **eval_kwargs): - if not isinstance(dataloader, DataLoader): - raise TypeError( - 'dataloader must be a pytorch DataLoader, but got {}'.format( - type(dataloader))) - self.dataloader = dataloader - self.interval = interval - self.gpu_collect = gpu_collect - self.by_epoch = by_epoch - self.eval_kwargs = eval_kwargs + greater_keys = ['mIoU', 'mAcc', 'aAcc'] + + def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs): + super().__init__(*args, by_epoch=by_epoch, **kwargs) + self.efficient_test = efficient_test def after_train_iter(self, runner): - """After train epoch hook.""" + """After train epoch hook. + + Override default ``multi_gpu_test``. + """ if self.by_epoch or not self.every_n_iters(runner, self.interval): return from mmseg.apis import multi_gpu_test @@ -86,13 +84,17 @@ def after_train_iter(self, runner): runner.model, self.dataloader, tmpdir=osp.join(runner.work_dir, '.eval_hook'), - gpu_collect=self.gpu_collect) + gpu_collect=self.gpu_collect, + efficient_test=self.efficient_test) if runner.rank == 0: print('\n') self.evaluate(runner, results) def after_train_epoch(self, runner): - """After train epoch hook.""" + """After train epoch hook. + + Override default ``multi_gpu_test``. + """ if not self.by_epoch or not self.every_n_epochs(runner, self.interval): return from mmseg.apis import multi_gpu_test diff --git a/tests/test_eval_hook.py b/tests/test_eval_hook.py index a6a1352ea5..c83623de0c 100644 --- a/tests/test_eval_hook.py +++ b/tests/test_eval_hook.py @@ -63,7 +63,7 @@ def test_iter_eval_hook(): # test EvalHook with tempfile.TemporaryDirectory() as tmpdir: - eval_hook = EvalHook(data_loader) + eval_hook = EvalHook(data_loader, by_epoch=False) runner = mmcv.runner.IterBasedRunner( model=model, optimizer=optimizer, @@ -143,7 +143,7 @@ def test_dist_eval_hook(): # test DistEvalHook with tempfile.TemporaryDirectory() as tmpdir: - eval_hook = DistEvalHook(data_loader) + eval_hook = DistEvalHook(data_loader, by_epoch=False) runner = mmcv.runner.IterBasedRunner( model=model, optimizer=optimizer,