Skip to content

Commit

Permalink
update badcase or pred show logic
Browse files Browse the repository at this point in the history
  • Loading branch information
fl9987 committed Aug 3, 2023
1 parent d8ec3db commit de2a45e
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions tools/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import warnings
import os.path as osp

import mmengine
Expand Down Expand Up @@ -80,14 +79,15 @@ def merge_args(cfg, args):
osp.splitext(osp.basename(args.config))[0])

# -------------------- visualization --------------------
if args.show or (args.show_dir is not None):
if (args.show and not args.badcase) or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
'PoseVisualizationHook is not set in the ' \
'`default_hooks` field of config. Please set ' \
'`visualization=dict(type="PoseVisualizationHook")`'

cfg.default_hooks.visualization.enable = True
cfg.default_hooks.visualization.show = args.show
cfg.default_hooks.visualization.show = False \
if args.badcase else args.show
if args.show:
cfg.default_hooks.visualization.wait_time = args.wait_time
cfg.default_hooks.visualization.out_dir = args.show_dir
Expand All @@ -99,21 +99,18 @@ def merge_args(cfg, args):
'BadcaseAnalyzeHook is not set in the ' \
'`default_hooks` field of config. Please set ' \
'`badcase=dict(type="BadcaseAnalyzeHook")`'

cfg.default_hooks.badcase.enable = True
badcase_show = cfg.default_hooks.badcase.get('show', 'False')
if badcase_show:
cfg.default_hooks.badcase.show = args.show
if args.show:
cfg.default_hooks.badcase.wait_time = args.wait_time
if args.show:
warnings.warn("Enabling both pred and badcase"
"visualiztion can be confusing")
cfg.default_hooks.badcase.interval = args.interval

metric_type = cfg.default_hooks.badcase.get('metric_type', 'loss')
if metric_type not in ['loss', 'accuracy']:
raise ValueError("Only support badcase metric type"
raise ValueError('Only support badcase metric type'
"in ['loss', 'accuracy']")

if metric_type == 'loss':
if not cfg.default_hooks.badcase.get('metric'):
cfg.default_hooks.badcase.metric = cfg.model.head.loss
Expand Down

0 comments on commit de2a45e

Please sign in to comment.