-
Notifications
You must be signed in to change notification settings - Fork 597
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix a bug for running vid & sot * change eval_hook in sot config * use evalhook of mmcv * fix a bug where detector is not initilized when training
- Loading branch information
Showing
4 changed files
with
50 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,72 @@ | ||
import os.path as osp | ||
|
||
from mmdet.core import DistEvalHook as _DistEvalHook | ||
from mmdet.core import EvalHook as _EvalHook | ||
import torch.distributed as dist | ||
from mmcv.runner import DistEvalHook as BaseDistEvalHook | ||
from mmcv.runner import EvalHook as BaseEvalHook | ||
from torch.nn.modules.batchnorm import _BatchNorm | ||
|
||
|
||
class EvalHook(_EvalHook): | ||
"""Please refer to `mmdet.core.evaluation.eval_hooks.py:EvalHook` for | ||
detailed docstring.""" | ||
class EvalHook(BaseEvalHook): | ||
"""Please refer to `mmcv.runner.hooks.evaluation.py:EvalHook` for detailed | ||
docstring.""" | ||
|
||
def after_train_epoch(self, runner): | ||
def _do_evaluate(self, runner): | ||
"""perform evaluation and save ckpt.""" | ||
if not self._should_evaluate(runner): | ||
return | ||
|
||
if hasattr(self.dataloader.dataset, | ||
'load_as_video') and self.dataloader.dataset.load_as_video: | ||
from mmtrack.apis import single_gpu_test | ||
else: | ||
from mmdet.apis import single_gpu_test | ||
results = single_gpu_test(runner.model, self.dataloader, show=False) | ||
self.evaluate(runner, results) | ||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) | ||
key_score = self.evaluate(runner, results) | ||
if self.save_best: | ||
self._save_ckpt(runner, key_score) | ||
|
||
|
||
class DistEvalHook(_DistEvalHook): | ||
"""Please refer to `mmdet.core.evaluation.eval_hooks.py:DistEvalHook` for | ||
class DistEvalHook(BaseDistEvalHook): | ||
"""Please refer to `mmcv.runner.hooks.evaluation.py:DistEvalHook` for | ||
detailed docstring.""" | ||
|
||
def after_train_epoch(self, runner): | ||
def _do_evaluate(self, runner): | ||
"""perform evaluation and save ckpt.""" | ||
# Synchronization of BatchNorm's buffer (running_mean | ||
# and running_var) is not supported in the DDP of pytorch, | ||
# which may cause the inconsistent performance of models in | ||
# different ranks, so we broadcast BatchNorm's buffers | ||
# of rank 0 to other ranks to avoid this. | ||
if self.broadcast_bn_buffer: | ||
model = runner.model | ||
for name, module in model.named_modules(): | ||
if isinstance(module, | ||
_BatchNorm) and module.track_running_stats: | ||
dist.broadcast(module.running_var, 0) | ||
dist.broadcast(module.running_mean, 0) | ||
|
||
if not self._should_evaluate(runner): | ||
return | ||
|
||
tmpdir = self.tmpdir | ||
if tmpdir is None: | ||
tmpdir = osp.join(runner.work_dir, '.eval_hook') | ||
|
||
if hasattr(self.dataloader.dataset, | ||
'load_as_video') and self.dataloader.dataset.load_as_video: | ||
from mmtrack.apis import multi_gpu_test | ||
else: | ||
from mmdet.apis import multi_gpu_test | ||
tmpdir = self.tmpdir | ||
if tmpdir is None: | ||
tmpdir = osp.join(runner.work_dir, '.eval_hook') | ||
results = multi_gpu_test( | ||
runner.model, | ||
self.dataloader, | ||
tmpdir=tmpdir, | ||
gpu_collect=self.gpu_collect) | ||
if runner.rank == 0: | ||
print('\n') | ||
self.evaluate(runner, results) | ||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) | ||
key_score = self.evaluate(runner, results) | ||
|
||
if self.save_best: | ||
self._save_ckpt(runner, key_score) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters