From de2a45eecee296a1e6c114f8c90c9ecc0babfd78 Mon Sep 17 00:00:00 2001 From: fl9987 Date: Thu, 3 Aug 2023 15:53:41 +0800 Subject: [PATCH] update badcase or pred show logic --- tools/test.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tools/test.py b/tools/test.py index 80e7e409ca..f161c036d1 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import os -import warnings import os.path as osp import mmengine @@ -80,14 +79,15 @@ def merge_args(cfg, args): osp.splitext(osp.basename(args.config))[0]) # -------------------- visualization -------------------- - if args.show or (args.show_dir is not None): + if (args.show and not args.badcase) or (args.show_dir is not None): assert 'visualization' in cfg.default_hooks, \ 'PoseVisualizationHook is not set in the ' \ '`default_hooks` field of config. Please set ' \ '`visualization=dict(type="PoseVisualizationHook")`' cfg.default_hooks.visualization.enable = True - cfg.default_hooks.visualization.show = args.show + cfg.default_hooks.visualization.show = False \ + if args.badcase else args.show if args.show: cfg.default_hooks.visualization.wait_time = args.wait_time cfg.default_hooks.visualization.out_dir = args.show_dir @@ -99,21 +99,18 @@ def merge_args(cfg, args): 'BadcaseAnalyzeHook is not set in the ' \ '`default_hooks` field of config. Please set ' \ '`badcase=dict(type="BadcaseAnalyzeHook")`' - + cfg.default_hooks.badcase.enable = True - badcase_show = cfg.default_hooks.badcase.get('show', 'False') - if badcase_show: + cfg.default_hooks.badcase.show = args.show + if args.show: cfg.default_hooks.badcase.wait_time = args.wait_time - if args.show: - warnings.warn("Enabling both pred and badcase" - "visualiztion can be confusing") cfg.default_hooks.badcase.interval = args.interval metric_type = cfg.default_hooks.badcase.get('metric_type', 'loss') if metric_type not in ['loss', 'accuracy']: - raise ValueError("Only support badcase metric type" + raise ValueError('Only support badcase metric type' "in ['loss', 'accuracy']") - + if metric_type == 'loss': if not cfg.default_hooks.badcase.get('metric'): cfg.default_hooks.badcase.metric = cfg.model.head.loss