Skip to content

Commit 76f4f8f

Browse files
authored
Remove train test cfg in bmn and bsn (open-mmlab#676)
* remove train_cfg and test_cfg in bmn and bsn and modify average_clips judge * skip checking
1 parent 2278fee commit 76f4f8f

File tree

4 files changed

+21
-25
lines changed

4 files changed

+21
-25
lines changed

configs/_base_/models/bmn_400x100.py

-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,3 @@
1010
soft_nms_low_threshold=0.5,
1111
soft_nms_high_threshold=0.9,
1212
post_process_top_k=100)
13-
# model training and testing settings
14-
train_cfg = None
15-
test_cfg = dict(average_clips='score')

configs/_base_/models/bsn_pem.py

-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,3 @@
1111
soft_nms_low_threshold=0.65,
1212
soft_nms_high_threshold=0.9,
1313
post_process_top_k=100)
14-
# model training and testing settings
15-
train_cfg = None
16-
test_cfg = dict(average_clips='score')

configs/_base_/models/bsn_tem.py

-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,3 @@
66
tem_feat_dim=400,
77
tem_hidden_dim=512,
88
tem_match_threshold=0.5)
9-
# model training and testing settings
10-
train_cfg = None
11-
test_cfg = dict(average_clips='score')

tools/test.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -136,30 +136,35 @@ def main():
136136

137137
dataset_type = cfg.data.test.type
138138
if output_config.get('out', None):
139-
out = output_config['out']
140-
# make sure the dirname of the output path exists
141-
mmcv.mkdir_or_exist(osp.dirname(out))
142-
_, suffix = osp.splitext(out)
143-
if dataset_type == 'AVADataset':
144-
assert suffix[1:] == 'csv', ('For AVADataset, the format of the '
145-
'output file should be csv')
139+
if 'output_format' in output_config:
140+
# ugly workround to make recognition and localization the same
141+
warnings.warn(
142+
'Skip checking `output_format` in localization task.')
146143
else:
147-
assert suffix[1:] in file_handlers, (
148-
'The format of the output '
149-
'file should be json, pickle or yaml')
144+
out = output_config['out']
145+
# make sure the dirname of the output path exists
146+
mmcv.mkdir_or_exist(osp.dirname(out))
147+
_, suffix = osp.splitext(out)
148+
if dataset_type == 'AVADataset':
149+
assert suffix[1:] == 'csv', ('For AVADataset, the format of '
150+
'the output file should be csv')
151+
else:
152+
assert suffix[1:] in file_handlers, (
153+
'The format of the output '
154+
'file should be json, pickle or yaml')
150155

151156
# set cudnn benchmark
152157
if cfg.get('cudnn_benchmark', False):
153158
torch.backends.cudnn.benchmark = True
154159
cfg.data.test.test_mode = True
155160

156-
if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
157-
cfg.model.setdefault('test_cfg',
158-
dict(average_clips=args.average_clips))
159-
else:
161+
if args.average_clips is not None:
160162
# You can set average_clips during testing, it will override the
161-
# original settting
162-
if args.average_clips is not None:
163+
# original setting
164+
if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
165+
cfg.model.setdefault('test_cfg',
166+
dict(average_clips=args.average_clips))
167+
else:
163168
if cfg.model.get('test_cfg') is not None:
164169
cfg.model.test_cfg.average_clips = args.average_clips
165170
else:

0 commit comments

Comments
 (0)