diff --git a/configs/sot/siamese_rpn/README.md b/configs/sot/siamese_rpn/README.md index bbf32acc6..0c3b2a7bd 100644 --- a/configs/sot/siamese_rpn/README.md +++ b/configs/sot/siamese_rpn/README.md @@ -14,12 +14,26 @@ } ``` -## Results and models on LaSOT dataset +## Results and models -We observe around 1.0 points fluctuations in Success and 1.5 points fluctuations in Norm percision. We provide the best model. +### LaSOT -Note that all of checkpoints from 11-th to 20-th epoch need to be evaluated in order to achieve the best results. +Note that the checkpoints from 10-th to 20-th epoch will be evaluated during training. You can find the best checkpoint from the log file. + +We observe around 1.0 points fluctuations in Success and 1.5 points fluctuations in Norm percision. We provide the best model with its configuration and training log. | Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | Success | Norm precision | Config | Download | | :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :----: | :------: | :--------: | | R-50 | - | 20e | 7.54 | 50.0 | 49.9 | 57.9 | [config](siamese_rpn_r50_1x_lasot.py) | [model](https://download.openmmlab.com/mmtracking/sot/siamese_rpn/siamese_rpn_r50_1x_lasot/siamese_rpn_r50_1x_lasot_20201218_051019-3c522eff.pth) | [log](https://download.openmmlab.com/mmtracking/sot/siamese_rpn/siamese_rpn_r50_1x_lasot/siamese_rpn_r50_1x_lasot_20201218_051019.log.json) | + +### UAV123 + +The checkpoints from 10-th to 20-th epoch will be evaluated during training. + +After training, you need to pick up the best checkpoint from the log file, then use the best checkpoint to search the hyperparameters on UAV123 following [here](https://github.com/open-mmlab/mmtracking/blob/master/docs/useful_tools_scripts.md#siameserpn-test-time-parameter-search) to achieve the best results. + +We observe around xxx points fluctuations in Success and xxx points fluctuations in Norm percision. We provide the best model with its configuration and training log. + +| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | Success | Norm precision | Config | Download | +| :-------------: | :-----: | :-----: | :------: | :------------: | :----: | :----: | :------: | :--------: | +| R-50 | - | 20e | 7.54 | - | 61.8 | 77.3 | [config](siamese_rpn_r50_1x_uav123.py) | [model](https://download.openmmlab.com/mmtracking/sot/siamese_rpn/siamese_rpn_r50_1x_lasot/siamese_rpn_r50_1x_lasot_20201218_051019-3c522eff.pth) | [log](https://download.openmmlab.com/mmtracking/sot/siamese_rpn/siamese_rpn_r50_1x_lasot/siamese_rpn_r50_1x_lasot_20201218_051019.log.json) | diff --git a/configs/sot/siamese_rpn/siamese_rpn_r50_1x_uav123.py b/configs/sot/siamese_rpn/siamese_rpn_r50_1x_uav123.py new file mode 100644 index 000000000..a010bb616 --- /dev/null +++ b/configs/sot/siamese_rpn/siamese_rpn_r50_1x_uav123.py @@ -0,0 +1,17 @@ +_base_ = ['./siamese_rpn_r50_1x_lasot.py'] + +# model settings +model = dict( + test_cfg=dict(rpn=dict(penalty_k=0.01, window_influence=0.02, lr=0.46))) + +data_root = 'data/' +# dataset settings +data = dict( + val=dict( + type='UAV123Dataset', + ann_file=data_root + 'UAV123/annotations/uav123.json', + img_prefix=data_root + 'UAV123/data_seq/UAV123'), + test=dict( + type='UAV123Dataset', + ann_file=data_root + 'UAV123/annotations/uav123.json', + img_prefix=data_root + 'UAV123/data_seq/UAV123')) diff --git a/docs/dataset.md b/docs/dataset.md index 232f27dba..a93c4076a 100644 --- a/docs/dataset.md +++ b/docs/dataset.md @@ -8,6 +8,7 @@ This page provides the instructions for dataset preparation on existing benchmar - [MOT Challenge](https://motchallenge.net/) - Single Object Tracking - [LaSOT](http://vision.cs.stonybrook.edu/~lasot/) + - [UAV123](https://cemse.kaust.edu.sa/ivul/uav123/) ### 1. Download Datasets @@ -21,7 +22,7 @@ Notes: - For the training and testing of multi object tracking task, only one of the MOT Challenge dataset (e.g. MOT17) is needed. -- For the training and testing of single object tracking task, the MSCOCO, ILSVRC and LaSOT datasets are needed. +- For the training and testing of single object tracking task, the MSCOCO, ILSVRC, LaSOT and UAV123 datasets are needed. ``` mmtracking @@ -62,6 +63,14 @@ mmtracking | ├── MOT15/MOT16/MOT17/MOT20 | | ├── train | | ├── test +│ │ +│ ├── UAV123 +│ │ ├── data_seq +│ │ │ ├── UAV123 +│ │ │ │ ├── bike1 +│ │ │ │ ├── boat1 +│ │ ├── anno +│ │ │ ├── UAV123 ``` ### 2. Convert Annotations @@ -83,6 +92,9 @@ python ./tools/convert_datasets/lasot2coco.py -i ./data/lasot/LaSOTTesting -o ./ # The processing of other MOT Challenge dataset is the same as MOT17 python ./tools/convert_datasets/mot2coco.py -i ./data/MOT17/ -o ./data/MOT17/annotations --split-train --convert-det python ./tools/convert_datasets/mot2reid.py -i ./data/MOT17/ -o ./data/MOT17/reid --val-split 0.2 --vis-threshold 0.3 + +# UAV123 +python ./tools/convert_datasets/uav2coco.py -i ./data/UAV123/ -o ./data/UAV123/annotations ``` The folder structure will be as following after your run these scripts: @@ -132,6 +144,15 @@ mmtracking | | ├── reid │ │ │ ├── imgs │ │ │ ├── meta +│ │ +│ ├── UAV123 +│ │ ├── data_seq +│ │ │ ├── UAV123 +│ │ │ │ ├── bike1 +│ │ │ │ ├── boat1 +│ │ ├── anno (the offical annotation files) +│ │ │ ├── UAV123 +│ │ ├── annotations (the converted annotation file) ``` #### The folder of annotations in ILSVRC @@ -200,3 +221,9 @@ MOT17-02-FRCNN_000009/000081.jpg 3 For validation, The annotation list `val_20.txt` remains the same as format above. Images in `reid/imgs` are cropped from raw images in `MOT17/train` by the corresponding `gt.txt`. The value of ground-truth labels should fall in range `[0, num_classes - 1]`. + +#### The folder of annotations in UAV123 + +There are only 1 json files in `data/UAV123/annotations`: + +`uav123.json`: Json file containing the annotations information of the UAV123 dataset. diff --git a/docs/useful_tools_scripts.md b/docs/useful_tools_scripts.md index ec0ca2a07..4edcc4262 100644 --- a/docs/useful_tools_scripts.md +++ b/docs/useful_tools_scripts.md @@ -111,6 +111,18 @@ python tools/publish_model.py work_dirs/dff_faster_rcnn_r101_dc5_1x_imagenetvid/ The final output filename will be `dff_faster_rcnn_r101_dc5_1x_imagenetvid_20201230-{hash id}.pth`. +## SiameseRPN++ Test-time Parameter Search + +`tools/sot_siamrpn_param_search.py` can search the test-time tracking parameters in SiameseRPN++: `penalty_k`, `lr` and `window_influence`. You need to pass the searching range of each parameter into the argparser. + +Example on UAV123 dataset: + +```shell +./tools/dist_sot_siamrpn_param_search.sh [${CONFIG_FILE}] [$GPUS] \ +[--checkpoint ${CHECKPOINT}] [--log ${LOG_FILENAME}] [--eval ${EVAL}] \ +[--penalty-k-range 0.05,0.5,0.05] [--lr-range 0.3,0.45,0.02] [--win-infu-range 0.46,0.55,0.02] +``` + ## Miscellaneous ### Print the entire config diff --git a/docs_zh-CN/dataset.md b/docs_zh-CN/dataset.md index 00b47dfee..50574f012 100644 --- a/docs_zh-CN/dataset.md +++ b/docs_zh-CN/dataset.md @@ -8,6 +8,7 @@ - [MOT Challenge](https://motchallenge.net/) - 单目标跟踪 - [LaSOT](http://vision.cs.stonybrook.edu/~lasot/) + - [UAV123](https://cemse.kaust.edu.sa/ivul/uav123/) ### 1. 下载数据集 @@ -21,7 +22,7 @@ - 对于多目标跟踪任务的训练和测试,只需要 MOT Challenge 中的任意一个数据集(比如 MOT17)。 -- 对于单目标跟踪任务的训练和测试,需要 MSCOCO,ILSVRC 和 LaSOT 数据集。 +- 对于单目标跟踪任务的训练和测试,需要 MSCOCO,ILSVRC, LaSOT 和 UAV123 数据集。 ``` mmtracking @@ -62,6 +63,14 @@ mmtracking | ├── MOT15/MOT16/MOT17/MOT20 | | ├── train | | ├── test +│ │ +│ ├── UAV123 +│ │ ├── data_seq +│ │ │ ├── UAV123 +│ │ │ │ ├── bike1 +│ │ │ │ ├── boat1 +│ │ ├── anno +│ │ │ ├── UAV123 ``` ### 2. 转换标注格式 @@ -84,6 +93,9 @@ python ./tools/convert_datasets/lasot2coco.py -i ./data/lasot/LaSOTTesting -o ./ # MOT Challenge中其余数据集的处理与MOT17相同 python ./tools/convert_datasets/mot2coco.py -i ./data/MOT17/ -o ./data/MOT17/annotations --split-train --convert-det python ./tools/convert_datasets/mot2reid.py -i ./data/MOT17/ -o ./data/MOT17/reid --val-split 0.2 --vis-threshold 0.3 + +# UAV123 +python ./tools/convert_datasets/uav2coco.py -i ./data/UAV123/ -o ./data/UAV123/annotations ``` 完成以上格式转换后,文件目录结构如下: @@ -133,6 +145,15 @@ mmtracking | | ├── reid │ │ │ ├── imgs │ │ │ ├── meta +│ │ +│ ├── UAV123 +│ │ ├── data_seq +│ │ │ ├── UAV123 +│ │ │ │ ├── bike1 +│ │ │ │ ├── boat1 +│ │ ├── anno (the offical annotation files) +│ │ │ ├── UAV123 +│ │ ├── annotations (the converted annotation file) ``` #### ILSVRC的标注文件夹 @@ -202,3 +223,9 @@ MOT17-02-FRCNN_000009/000081.jpg 3 验证集标注 `val_20.txt` 的结构和上面类似。 `reid/imgs` 中的图片是从 `MOT17/train` 中原始图片根据对应的 `gt.txt` 裁剪得到。真实类别标签值在 `[0, num_classes - 1]` 范围内。 + +#### UAV123的标注文件夹 + +在 `data/UAV123/annotations` 中只有一个 json 文件: + +`uav123.json`: 包含 UAV123 数据集标注信息的 json 文件。 diff --git a/docs_zh-CN/useful_tools_scripts.md b/docs_zh-CN/useful_tools_scripts.md index a9bc91114..4a616b046 100644 --- a/docs_zh-CN/useful_tools_scripts.md +++ b/docs_zh-CN/useful_tools_scripts.md @@ -111,6 +111,18 @@ python tools/publish_model.py work_dirs/dff_faster_rcnn_r101_dc5_1x_imagenetvid/ 最后输出的文件名为 `dff_faster_rcnn_r101_dc5_1x_imagenetvid_20201230-{hash id}.pth`。 +## SiameseRPN++ 测试参数搜索 + +`tools/sot_siamrpn_param_search.py` 用来搜索 SiameseRPN++ 测试时的跟踪相关参数: `penalty_k`, `lr` 和 `window_influence`。你需要在参数解析器中传入前面每个参数的搜索范围。 + +在 UAV123 上的超参搜索范例: + +```shell +./tools/dist_sot_siamrpn_param_search.sh [${CONFIG_FILE}] [$GPUS] \ +[--checkpoint ${CHECKPOINT}] [--log ${LOG_FILENAME}] [--eval ${EVAL}] \ +[--penalty-k-range 0.05,0.5,0.05] [--lr-range 0.3,0.45,0.02] [--win-infu-range 0.46,0.55,0.02] +``` + ## 其它有用的工具脚本 ### 输出完整的配置 diff --git a/mmtrack/datasets/__init__.py b/mmtrack/datasets/__init__.py index c6bd3005a..f9ed16f7a 100644 --- a/mmtrack/datasets/__init__.py +++ b/mmtrack/datasets/__init__.py @@ -9,10 +9,13 @@ from .parsers import CocoVID from .pipelines import PIPELINES from .reid_dataset import ReIDDataset +from .sot_test_dataset import SOTTestDataset from .sot_train_dataset import SOTTrainDataset +from .uav123_dataset import UAV123Dataset __all__ = [ 'DATASETS', 'PIPELINES', 'build_dataloader', 'build_dataset', 'CocoVID', 'CocoVideoDataset', 'ImagenetVIDDataset', 'MOTChallengeDataset', - 'LaSOTDataset', 'SOTTrainDataset', 'ReIDDataset' + 'ReIDDataset', 'SOTTrainDataset', 'SOTTestDataset', 'LaSOTDataset', + 'UAV123Dataset' ] diff --git a/mmtrack/datasets/lasot_dataset.py b/mmtrack/datasets/lasot_dataset.py index f59861f2c..6547d87ce 100644 --- a/mmtrack/datasets/lasot_dataset.py +++ b/mmtrack/datasets/lasot_dataset.py @@ -1,24 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from mmcv.utils import print_log from mmdet.datasets import DATASETS -from mmtrack.core.evaluation import eval_sot_ope -from .coco_video_dataset import CocoVideoDataset +from .sot_test_dataset import SOTTestDataset @DATASETS.register_module() -class LaSOTDataset(CocoVideoDataset): +class LaSOTDataset(SOTTestDataset): """LaSOT dataset for the testing of single object tracking. The dataset doesn't support training mode. """ - CLASSES = (0, ) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def _parse_ann_info(self, img_info, ann_info): """Parse bbox annotations. @@ -39,60 +32,3 @@ def _parse_ann_info(self, img_info, ann_info): ignore = ann_info[0]['full_occlusion'] or ann_info[0]['out_of_view'] ann = dict(bboxes=gt_bboxes, labels=gt_labels, ignore=ignore) return ann - - def evaluate(self, results, metric=['track'], logger=None): - """Evaluation in OPE protocol. - - Args: - results (dict): Testing results of the dataset. - metric (str | list[str]): Metrics to be evaluated. Options are - 'track'. - logger (logging.Logger | str | None): Logger used for printing - related information during evaluation. Default: None. - - Returns: - dict[str, float]: OPE style evaluation metric (i.e. success, - norm precision and precision). - """ - if isinstance(metric, list): - metrics = metric - elif isinstance(metric, str): - metrics = [metric] - else: - raise TypeError('metric must be a list or a str.') - allowed_metrics = ['track'] - for metric in metrics: - if metric not in allowed_metrics: - raise KeyError(f'metric {metric} is not supported.') - - eval_results = dict() - if 'track' in metrics: - assert len(self.data_infos) == len(results['track_results']) - print_log('Evaluate OPE Benchmark...', logger=logger) - inds = [ - i for i, _ in enumerate(self.data_infos) if _['frame_id'] == 0 - ] - num_vids = len(inds) - inds.append(len(self.data_infos)) - - track_bboxes = [ - list( - map(lambda x: x[:4], - results['track_results'][inds[i]:inds[i + 1]])) - for i in range(num_vids) - ] - - ann_infos = [self.get_ann_info(_) for _ in self.data_infos] - ann_infos = [ - ann_infos[inds[i]:inds[i + 1]] for i in range(num_vids) - ] - track_eval_results = eval_sot_ope( - results=track_bboxes, annotations=ann_infos) - eval_results.update(track_eval_results) - - for k, v in eval_results.items(): - if isinstance(v, float): - eval_results[k] = float(f'{(v):.3f}') - print_log(eval_results, logger=logger) - - return eval_results diff --git a/mmtrack/datasets/sot_test_dataset.py b/mmtrack/datasets/sot_test_dataset.py new file mode 100644 index 000000000..684c45318 --- /dev/null +++ b/mmtrack/datasets/sot_test_dataset.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from mmcv.utils import print_log +from mmdet.datasets import DATASETS + +from mmtrack.core.evaluation import eval_sot_ope +from .coco_video_dataset import CocoVideoDataset + + +@DATASETS.register_module() +class SOTTestDataset(CocoVideoDataset): + """Dataset for the testing of single object tracking. + + The dataset doesn't support training mode. + """ + + CLASSES = (0, ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _parse_ann_info(self, img_info, ann_info): + """Parse bbox annotations. + + Args: + img_info (dict): image information. + ann_info (list[dict]): Annotation information of an image. Each + image only has one bbox annotation. + + Returns: + dict: A dict containing the following keys: bboxes, labels. + labels are not useful in SOT. + """ + gt_bboxes = np.array(ann_info[0]['bbox'], dtype=np.float32) + # convert [x1, y1, w, h] to [x1, y1, x2, y2] + gt_bboxes[2] += gt_bboxes[0] + gt_bboxes[3] += gt_bboxes[1] + gt_labels = np.array(self.cat2label[ann_info[0]['category_id']]) + if 'ignore' in ann_info[0]: + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + ignore=ann_info[0]['ignore']) + else: + ann = dict(bboxes=gt_bboxes, labels=gt_labels) + return ann + + def evaluate(self, results, metric=['track'], logger=None): + """Evaluation in OPE protocol. + + Args: + results (dict): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. Options are + 'track'. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str, float]: OPE style evaluation metric (i.e. success, + norm precision and precision). + """ + if isinstance(metric, list): + metrics = metric + elif isinstance(metric, str): + metrics = [metric] + else: + raise TypeError('metric must be a list or a str.') + allowed_metrics = ['track'] + for metric in metrics: + if metric not in allowed_metrics: + raise KeyError(f'metric {metric} is not supported.') + + eval_results = dict() + if 'track' in metrics: + assert len(self.data_infos) == len(results['track_results']) + print_log('Evaluate OPE Benchmark...', logger=logger) + inds = [ + i for i, _ in enumerate(self.data_infos) if _['frame_id'] == 0 + ] + num_vids = len(inds) + inds.append(len(self.data_infos)) + + track_bboxes = [ + list( + map(lambda x: x[:4], + results['track_results'][inds[i]:inds[i + 1]])) + for i in range(num_vids) + ] + + ann_infos = [self.get_ann_info(_) for _ in self.data_infos] + ann_infos = [ + ann_infos[inds[i]:inds[i + 1]] for i in range(num_vids) + ] + track_eval_results = eval_sot_ope( + results=track_bboxes, annotations=ann_infos) + eval_results.update(track_eval_results) + + for k, v in eval_results.items(): + if isinstance(v, float): + eval_results[k] = float(f'{(v):.3f}') + print_log(eval_results, logger=logger) + + return eval_results diff --git a/mmtrack/datasets/uav123_dataset.py b/mmtrack/datasets/uav123_dataset.py new file mode 100644 index 000000000..8e41cac4a --- /dev/null +++ b/mmtrack/datasets/uav123_dataset.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.datasets import DATASETS + +from .sot_test_dataset import SOTTestDataset + + +@DATASETS.register_module() +class UAV123Dataset(SOTTestDataset): + """UAV123 dataset for the testing of single object tracking. + + The dataset doesn't support training mode. + """ + pass diff --git a/tests/test_data/test_datasets/test_lasot_dataset.py b/tests/test_data/test_datasets/test_sot_test_dataset.py similarity index 89% rename from tests/test_data/test_datasets/test_lasot_dataset.py rename to tests/test_data/test_datasets/test_sot_test_dataset.py index fb81ee3b1..28284373b 100644 --- a/tests/test_data/test_datasets/test_lasot_dataset.py +++ b/tests/test_data/test_datasets/test_sot_test_dataset.py @@ -11,8 +11,8 @@ LASOT_ANN_PATH = f'{PREFIX}/demo_sot_data/lasot' -@pytest.mark.parametrize('dataset', ['LaSOTDataset']) -def test_lasot_dataset_parse_ann_info(dataset): +@pytest.mark.parametrize('dataset', ['SOTTestDataset', 'LaSOTDataset']) +def test_parse_ann_info(dataset): dataset_class = DATASETS.get(dataset) dataset = dataset_class( @@ -29,8 +29,8 @@ def test_lasot_dataset_parse_ann_info(dataset): assert ann['labels'] == 0 -def test_lasot_evaluation(): - dataset_class = DATASETS.get('LaSOTDataset') +def test_sot_ope_evaluation(): + dataset_class = DATASETS.get('SOTTestDataset') dataset = dataset_class( ann_file=osp.join(LASOT_ANN_PATH, 'lasot_test_dummy.json'), pipeline=[]) diff --git a/tools/convert_datasets/imagenet2coco_det.py b/tools/convert_datasets/imagenet2coco_det.py index ab5e193d6..25dcdc87d 100644 --- a/tools/convert_datasets/imagenet2coco_det.py +++ b/tools/convert_datasets/imagenet2coco_det.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import glob +import os import os.path as osp import xml.etree.ElementTree as ET from collections import defaultdict @@ -167,7 +168,8 @@ def convert_det(DET, ann_dir, save_dir): is_vid_train_frame, records, DET, obj_num_classes) - + if not osp.isdir(save_dir): + os.makedirs(save_dir) mmcv.dump(DET, osp.join(save_dir, 'imagenet_det_30plus1cls.json')) print('-----ImageNet DET------') print(f'total {records["img_id"] - 1} images') diff --git a/tools/convert_datasets/imagenet2coco_vid.py b/tools/convert_datasets/imagenet2coco_vid.py index 85e82fe2c..b438d580b 100644 --- a/tools/convert_datasets/imagenet2coco_vid.py +++ b/tools/convert_datasets/imagenet2coco_vid.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import os import os.path as osp import xml.etree.ElementTree as ET from collections import defaultdict @@ -170,6 +171,8 @@ def convert_vid(VID, ann_dir, save_dir, mode='train'): records['ann_id'] += 1 records['img_id'] += 1 records['vid_id'] += 1 + if not osp.isdir(save_dir): + os.makedirs(save_dir) mmcv.dump(VID, osp.join(save_dir, f'imagenet_vid_{mode}.json')) print(f'-----ImageNet VID {mode}------') print(f'{records["vid_id"]- 1} videos') diff --git a/tools/convert_datasets/lasot2coco.py b/tools/convert_datasets/lasot2coco.py index 41cd7b61f..bba4e3292 100644 --- a/tools/convert_datasets/lasot2coco.py +++ b/tools/convert_datasets/lasot2coco.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import os import os.path as osp from collections import defaultdict @@ -83,6 +84,8 @@ def convert_lasot_test(lasot_test, ann_dir, save_dir): records['global_instance_id'] += 1 records['vid_id'] += 1 + if not osp.isdir(save_dir): + os.makedirs(save_dir) mmcv.dump(lasot_test, osp.join(save_dir, 'lasot_test.json')) print('-----LaSOT Test Dataset------') print(f'{records["vid_id"]- 1} videos') diff --git a/tools/convert_datasets/mot2coco.py b/tools/convert_datasets/mot2coco.py index 04958629b..0c2ae95a8 100644 --- a/tools/convert_datasets/mot2coco.py +++ b/tools/convert_datasets/mot2coco.py @@ -108,7 +108,7 @@ def parse_dets(dets): def main(): args = parse_args() - if not osp.exists(args.output): + if not osp.isdir(args.output): os.makedirs(args.output) sets = ['train', 'test'] diff --git a/tools/convert_datasets/mot2reid.py b/tools/convert_datasets/mot2reid.py index d4f766068..d607478a2 100644 --- a/tools/convert_datasets/mot2reid.py +++ b/tools/convert_datasets/mot2reid.py @@ -72,7 +72,7 @@ def parse_args(): def main(): args = parse_args() - if not osp.exists(args.output): + if not osp.isdir(args.output): os.makedirs(args.output) elif os.listdir(args.output): raise OSError(f'Directory must empty: \'{args.output}\'') diff --git a/tools/convert_datasets/uav123_info.txt b/tools/convert_datasets/uav123_info.txt new file mode 100644 index 000000000..d28b33b7b --- /dev/null +++ b/tools/convert_datasets/uav123_info.txt @@ -0,0 +1,124 @@ +The format of each line in the txt is anno_name,anno_path,video_path,start_frame,end_frame +bike1,anno/UAV123/bike1.txt,data_seq/UAV123/bike1,1,3085 +bike2,anno/UAV123/bike2.txt,data_seq/UAV123/bike2,1,553 +bike3,anno/UAV123/bike3.txt,data_seq/UAV123/bike3,1,433 +bird1_1,anno/UAV123/bird1_1.txt,data_seq/UAV123/bird1,1,253 +bird1_2,anno/UAV123/bird1_2.txt,data_seq/UAV123/bird1,775,1477 +bird1_3,anno/UAV123/bird1_3.txt,data_seq/UAV123/bird1,1573,2437 +boat1,anno/UAV123/boat1.txt,data_seq/UAV123/boat1,1,901 +boat2,anno/UAV123/boat2.txt,data_seq/UAV123/boat2,1,799 +boat3,anno/UAV123/boat3.txt,data_seq/UAV123/boat3,1,901 +boat4,anno/UAV123/boat4.txt,data_seq/UAV123/boat4,1,553 +boat5,anno/UAV123/boat5.txt,data_seq/UAV123/boat5,1,505 +boat6,anno/UAV123/boat6.txt,data_seq/UAV123/boat6,1,805 +boat7,anno/UAV123/boat7.txt,data_seq/UAV123/boat7,1,535 +boat8,anno/UAV123/boat8.txt,data_seq/UAV123/boat8,1,685 +boat9,anno/UAV123/boat9.txt,data_seq/UAV123/boat9,1,1399 +building1,anno/UAV123/building1.txt,data_seq/UAV123/building1,1,469 +building2,anno/UAV123/building2.txt,data_seq/UAV123/building2,1,577 +building3,anno/UAV123/building3.txt,data_seq/UAV123/building3,1,829 +building4,anno/UAV123/building4.txt,data_seq/UAV123/building4,1,787 +building5,anno/UAV123/building5.txt,data_seq/UAV123/building5,1,481 +car10,anno/UAV123/car10.txt,data_seq/UAV123/car10,1,1405 +car11,anno/UAV123/car11.txt,data_seq/UAV123/car11,1,337 +car12,anno/UAV123/car12.txt,data_seq/UAV123/car12,1,499 +car13,anno/UAV123/car13.txt,data_seq/UAV123/car13,1,415 +car14,anno/UAV123/car14.txt,data_seq/UAV123/car14,1,1327 +car15,anno/UAV123/car15.txt,data_seq/UAV123/car15,1,469 +car16_1,anno/UAV123/car16_1.txt,data_seq/UAV123/car16,1,415 +car16_2,anno/UAV123/car16_2.txt,data_seq/UAV123/car16,415,1993 +car17,anno/UAV123/car17.txt,data_seq/UAV123/car17,1,1057 +car18,anno/UAV123/car18.txt,data_seq/UAV123/car18,1,1207 +car1_1,anno/UAV123/car1_1.txt,data_seq/UAV123/car1,1,751 +car1_2,anno/UAV123/car1_2.txt,data_seq/UAV123/car1,751,1627 +car1_3,anno/UAV123/car1_3.txt,data_seq/UAV123/car1,1627,2629 +car1_s,anno/UAV123/car1_s.txt,data_seq/UAV123/car1_s,1,1475 +car2,anno/UAV123/car2.txt,data_seq/UAV123/car2,1,1321 +car2_s,anno/UAV123/car2_s.txt,data_seq/UAV123/car2_s,1,320 +car3,anno/UAV123/car3.txt,data_seq/UAV123/car3,1,1717 +car3_s,anno/UAV123/car3_s.txt,data_seq/UAV123/car3_s,1,1300 +car4,anno/UAV123/car4.txt,data_seq/UAV123/car4,1,1345 +car4_s,anno/UAV123/car4_s.txt,data_seq/UAV123/car4_s,1,830 +car5,anno/UAV123/car5.txt,data_seq/UAV123/car5,1,745 +car6_1,anno/UAV123/car6_1.txt,data_seq/UAV123/car6,1,487 +car6_2,anno/UAV123/car6_2.txt,data_seq/UAV123/car6,487,1807 +car6_3,anno/UAV123/car6_3.txt,data_seq/UAV123/car6,1807,2953 +car6_4,anno/UAV123/car6_4.txt,data_seq/UAV123/car6,2953,3925 +car6_5,anno/UAV123/car6_5.txt,data_seq/UAV123/car6,3925,4861 +car7,anno/UAV123/car7.txt,data_seq/UAV123/car7,1,1033 +car8_1,anno/UAV123/car8_1.txt,data_seq/UAV123/car8,1,1357 +car8_2,anno/UAV123/car8_2.txt,data_seq/UAV123/car8,1357,2575 +car9,anno/UAV123/car9.txt,data_seq/UAV123/car9,1,1879 +group1_1,anno/UAV123/group1_1.txt,data_seq/UAV123/group1,1,1333 +group1_2,anno/UAV123/group1_2.txt,data_seq/UAV123/group1,1333,2515 +group1_3,anno/UAV123/group1_3.txt,data_seq/UAV123/group1,2515,3925 +group1_4,anno/UAV123/group1_4.txt,data_seq/UAV123/group1,3925,4873 +group2_1,anno/UAV123/group2_1.txt,data_seq/UAV123/group2,1,907 +group2_2,anno/UAV123/group2_2.txt,data_seq/UAV123/group2,907,1771 +group2_3,anno/UAV123/group2_3.txt,data_seq/UAV123/group2,1771,2683 +group3_1,anno/UAV123/group3_1.txt,data_seq/UAV123/group3,1,1567 +group3_2,anno/UAV123/group3_2.txt,data_seq/UAV123/group3,1567,2827 +group3_3,anno/UAV123/group3_3.txt,data_seq/UAV123/group3,2827,4369 +group3_4,anno/UAV123/group3_4.txt,data_seq/UAV123/group3,4369,5527 +person1,anno/UAV123/person1.txt,data_seq/UAV123/person1,1,799 +person10,anno/UAV123/person10.txt,data_seq/UAV123/person10,1,1021 +person11,anno/UAV123/person11.txt,data_seq/UAV123/person11,1,721 +person12_1,anno/UAV123/person12_1.txt,data_seq/UAV123/person12,1,601 +person12_2,anno/UAV123/person12_2.txt,data_seq/UAV123/person12,601,1621 +person13,anno/UAV123/person13.txt,data_seq/UAV123/person13,1,883 +person14_1,anno/UAV123/person14_1.txt,data_seq/UAV123/person14,1,847 +person14_2,anno/UAV123/person14_2.txt,data_seq/UAV123/person14,847,1813 +person14_3,anno/UAV123/person14_3.txt,data_seq/UAV123/person14,1813,2923 +person15,anno/UAV123/person15.txt,data_seq/UAV123/person15,1,1339 +person16,anno/UAV123/person16.txt,data_seq/UAV123/person16,1,1147 +person17_1,anno/UAV123/person17_1.txt,data_seq/UAV123/person17,1,1501 +person17_2,anno/UAV123/person17_2.txt,data_seq/UAV123/person17,1501,2347 +person18,anno/UAV123/person18.txt,data_seq/UAV123/person18,1,1393 +person19_1,anno/UAV123/person19_1.txt,data_seq/UAV123/person19,1,1243 +person19_2,anno/UAV123/person19_2.txt,data_seq/UAV123/person19,1243,2791 +person19_3,anno/UAV123/person19_3.txt,data_seq/UAV123/person19,2791,4357 +person1_s,anno/UAV123/person1_s.txt,data_seq/UAV123/person1_s,1,1600 +person20,anno/UAV123/person20.txt,data_seq/UAV123/person20,1,1783 +person21,anno/UAV123/person21.txt,data_seq/UAV123/person21,1,487 +person22,anno/UAV123/person22.txt,data_seq/UAV123/person22,1,199 +person23,anno/UAV123/person23.txt,data_seq/UAV123/person23,1,397 +person2_1,anno/UAV123/person2_1.txt,data_seq/UAV123/person2,1,1189 +person2_2,anno/UAV123/person2_2.txt,data_seq/UAV123/person2,1189,2623 +person2_s,anno/UAV123/person2_s.txt,data_seq/UAV123/person2_s,1,250 +person3,anno/UAV123/person3.txt,data_seq/UAV123/person3,1,643 +person3_s,anno/UAV123/person3_s.txt,data_seq/UAV123/person3_s,1,505 +person4_1,anno/UAV123/person4_1.txt,data_seq/UAV123/person4,1,1501 +person4_2,anno/UAV123/person4_2.txt,data_seq/UAV123/person4,1501,2743 +person5_1,anno/UAV123/person5_1.txt,data_seq/UAV123/person5,1,877 +person5_2,anno/UAV123/person5_2.txt,data_seq/UAV123/person5,877,2101 +person6,anno/UAV123/person6.txt,data_seq/UAV123/person6,1,901 +person7_1,anno/UAV123/person7_1.txt,data_seq/UAV123/person7,1,1249 +person7_2,anno/UAV123/person7_2.txt,data_seq/UAV123/person7,1249,2065 +person8_1,anno/UAV123/person8_1.txt,data_seq/UAV123/person8,1,1075 +person8_2,anno/UAV123/person8_2.txt,data_seq/UAV123/person8,1075,1525 +person9,anno/UAV123/person9.txt,data_seq/UAV123/person9,1,661 +truck1,anno/UAV123/truck1.txt,data_seq/UAV123/truck1,1,463 +truck2,anno/UAV123/truck2.txt,data_seq/UAV123/truck2,1,385 +truck3,anno/UAV123/truck3.txt,data_seq/UAV123/truck3,1,535 +truck4_1,anno/UAV123/truck4_1.txt,data_seq/UAV123/truck4,1,577 +truck4_2,anno/UAV123/truck4_2.txt,data_seq/UAV123/truck4,577,1261 +uav1_1,anno/UAV123/uav1_1.txt,data_seq/UAV123/uav1,1,1555 +uav1_2,anno/UAV123/uav1_2.txt,data_seq/UAV123/uav1,1555,2377 +uav1_3,anno/UAV123/uav1_3.txt,data_seq/UAV123/uav1,2473,3469 +uav2,anno/UAV123/uav2.txt,data_seq/UAV123/uav2,1,133 +uav3,anno/UAV123/uav3.txt,data_seq/UAV123/uav3,1,265 +uav4,anno/UAV123/uav4.txt,data_seq/UAV123/uav4,1,157 +uav5,anno/UAV123/uav5.txt,data_seq/UAV123/uav5,1,139 +uav6,anno/UAV123/uav6.txt,data_seq/UAV123/uav6,1,109 +uav7,anno/UAV123/uav7.txt,data_seq/UAV123/uav7,1,373 +uav8,anno/UAV123/uav8.txt,data_seq/UAV123/uav8,1,301 +wakeboard1,anno/UAV123/wakeboard1.txt,data_seq/UAV123/wakeboard1,1,421 +wakeboard10,anno/UAV123/wakeboard10.txt,data_seq/UAV123/wakeboard10,1,469 +wakeboard2,anno/UAV123/wakeboard2.txt,data_seq/UAV123/wakeboard2,1,733 +wakeboard3,anno/UAV123/wakeboard3.txt,data_seq/UAV123/wakeboard3,1,823 +wakeboard4,anno/UAV123/wakeboard4.txt,data_seq/UAV123/wakeboard4,1,697 +wakeboard5,anno/UAV123/wakeboard5.txt,data_seq/UAV123/wakeboard5,1,1675 +wakeboard6,anno/UAV123/wakeboard6.txt,data_seq/UAV123/wakeboard6,1,1165 +wakeboard7,anno/UAV123/wakeboard7.txt,data_seq/UAV123/wakeboard7,1,199 +wakeboard8,anno/UAV123/wakeboard8.txt,data_seq/UAV123/wakeboard8,1,1543 +wakeboard9,anno/UAV123/wakeboard9.txt,data_seq/UAV123/wakeboard9,1,355 diff --git a/tools/convert_datasets/uav2coco.py b/tools/convert_datasets/uav2coco.py new file mode 100644 index 000000000..3d04dae39 --- /dev/null +++ b/tools/convert_datasets/uav2coco.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +from collections import defaultdict + +import mmcv +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser( + description='UAV123 dataset to COCO Video format') + parser.add_argument( + '-i', + '--input', + help='root directory of UAV123 dataset', + ) + parser.add_argument( + '-o', + '--output', + help='directory to save coco formatted label file', + ) + return parser.parse_args() + + +def convert_uav123(uav123, ann_dir, save_dir): + """Convert trackingnet dataset to COCO style. + + Args: + uav123 (dict): The converted COCO style annotations. + ann_dir (str): The path of trackingnet test dataset + save_dir (str): The path to save `uav123`. + """ + # The format of each line in "uav_info123.txt" is + # "anno_name,anno_path,video_path,start_frame,end_frame" + info_path = osp.join(os.path.dirname(__file__), 'uav123_info.txt') + uav_info = mmcv.list_from_file(info_path)[1:] + + records = dict(vid_id=1, img_id=1, ann_id=1, global_instance_id=1) + uav123['categories'] = [dict(id=0, name=0)] + + for info in tqdm(uav_info): + anno_name, anno_path, video_path, start_frame, end_frame = info.split( + ',') + start_frame = int(start_frame) + end_frame = int(end_frame) + + # video_name is not the same as anno_name since one video may have + # several fragments. + # Example: video_name: "bird" anno_name: "bird_1" + video_name = video_path.split('/')[-1] + video = dict(id=records['vid_id'], name=video_name) + uav123['videos'].append(video) + + gt_bboxes = mmcv.list_from_file(osp.join(ann_dir, anno_path)) + assert len(gt_bboxes) == end_frame - start_frame + 1 + + img = mmcv.imread( + osp.join(ann_dir, video_path, '%06d.jpg' % (start_frame))) + height, width, _ = img.shape + for frame_id, src_frame_id in enumerate( + range(start_frame, end_frame + 1)): + file_name = osp.join(video_name, '%06d.jpg' % (src_frame_id)) + image = dict( + file_name=file_name, + height=height, + width=width, + id=records['img_id'], + frame_id=frame_id, + video_id=records['vid_id']) + uav123['images'].append(image) + + if 'NaN' in gt_bboxes[frame_id]: + x1 = y1 = w = h = 0 + else: + x1, y1, w, h = gt_bboxes[frame_id].split(',') + ann = dict( + id=records['ann_id'], + image_id=records['img_id'], + instance_id=records['global_instance_id'], + category_id=0, + bbox=[int(x1), int(y1), int(w), + int(h)], + area=int(w) * int(h)) + uav123['annotations'].append(ann) + + records['ann_id'] += 1 + records['img_id'] += 1 + + records['global_instance_id'] += 1 + records['vid_id'] += 1 + + if not osp.isdir(save_dir): + os.makedirs(save_dir) + mmcv.dump(uav123, osp.join(save_dir, 'uav123.json')) + print('-----UAV123 Dataset------') + print(f'{records["vid_id"]- 1} videos') + print(f'{records["global_instance_id"]- 1} instances') + print(f'{records["img_id"]- 1} images') + print(f'{records["ann_id"] - 1} objects') + print('-----------------------------') + + +def main(): + args = parse_args() + uav123 = defaultdict(list) + convert_uav123(uav123, args.input, args.output) + + +if __name__ == '__main__': + main() diff --git a/tools/dist_sot_siamrpn_search.sh b/tools/dist_sot_siamrpn_search.sh new file mode 100644 index 000000000..af3b7195b --- /dev/null +++ b/tools/dist_sot_siamrpn_search.sh @@ -0,0 +1,8 @@ +CONFIG=$1 +GPUS=$2 +PORT=${PORT:-29500} + + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/sot_siamrpn_param_search.py $CONFIG --launcher pytorch ${@:3} diff --git a/tools/slurm_sot_siamrpn_search.sh b/tools/slurm_sot_siamrpn_search.sh new file mode 100755 index 000000000..31a2c06c4 --- /dev/null +++ b/tools/slurm_sot_siamrpn_search.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=$4 +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-2} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/sot_siamrpn_param_search.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/tools/sot_siamrpn_param_search.py b/tools/sot_siamrpn_param_search.py new file mode 100644 index 000000000..1c08754c7 --- /dev/null +++ b/tools/sot_siamrpn_param_search.py @@ -0,0 +1,250 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os + +import numpy as np +import torch +from mmcv import Config, DictAction, get_logger, print_log +from mmcv.cnn import fuse_conv_bn +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) +from mmdet.datasets import build_dataset + + +def parse_range(range_str): + range_list = range_str.split(',') + assert len(range_list) == 3 and float(range_list[1]) >= float( + range_list[0]) + param = map(float, range_list) + return np.round(np.arange(*param), decimals=2) + + +def parse_args(): + parser = argparse.ArgumentParser(description='mmtrack test model') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file') + parser.add_argument( + '--penalty-k-range', + type=parse_range, + help="the range of hyper-parameter 'penalty_k' in SiamRPN++; the format \ + is 'start,stop,step'") + parser.add_argument( + '--lr-range', + type=parse_range, + help="the range of hyper-parameter 'lr' in SiamRPN++; the format is \ + 'start,stop,step'") + parser.add_argument( + '--win-influ-range', + type=parse_range, + help="the range of hyper-parameter 'window_influence' in SiamRPN++; the \ + format is 'start,stop,step'") + parser.add_argument( + '--fuse-conv-bn', + action='store_true', + help='Whether to fuse conv and bn, this will slightly increase' + 'the inference speed') + parser.add_argument('--log', help='log file', default=None) + parser.add_argument('--eval', type=str, nargs='+', help='eval types') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-score-thr', + type=float, + default=0.3, + help='score threshold (default: 0.3)') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--gpu-collect', + action='store_true', + help='whether to use gpu to collect results.') + parser.add_argument( + '--tmpdir', + help='tmp directory used for collecting results from multiple ' + 'workers, available when gpu-collect is not specified') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + assert args.eval or args.show \ + or args.show_dir, \ + ('Please specify at least one operation (eval/show the ' + 'results) with the argument "--eval"' + ', "--show" or "--show-dir"') + + cfg = Config.fromfile(args.config) + + if cfg.get('USE_MMDET', False): + from mmdet.apis import multi_gpu_test, single_gpu_test + from mmdet.datasets import build_dataloader + from mmdet.models import build_detector as build_model + if 'detector' in cfg.model: + cfg.model = cfg.model.detector + elif cfg.get('USE_MMCLS', False): + from mmtrack.apis import multi_gpu_test, single_gpu_test + from mmtrack.datasets import build_dataloader + from mmtrack.models import build_reid as build_model + if 'reid' in cfg.model: + cfg.model = cfg.model.reid + else: + from mmtrack.apis import multi_gpu_test, single_gpu_test + from mmtrack.datasets import build_dataloader + from mmtrack.models import build_model + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.data.test.test_mode = True + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # build the dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + logger = get_logger('SOTParamsSearcher', log_file=args.log) + + # build the model and load checkpoint + if cfg.get('test_cfg', False): + model = build_model( + cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) + else: + model = build_model(cfg.model) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + if args.checkpoint is not None: + checkpoint = load_checkpoint( + model, args.checkpoint, map_location='cpu') + if 'CLASSES' in checkpoint['meta']: + model.CLASSES = checkpoint['meta']['CLASSES'] + if not hasattr(model, 'CLASSES'): + model.CLASSES = dataset.CLASSES + + if args.fuse_conv_bn: + model = fuse_conv_bn(model) + + if not distributed: + model = MMDataParallel(model, device_ids=[0]) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + + if 'meta' in checkpoint and 'hook_msgs' in checkpoint[ + 'meta'] and 'best_score' in checkpoint['meta']['hook_msgs']: + best_score = checkpoint['meta']['hook_msgs']['best_score'] + else: + best_score = 0 + + best_result = dict(success=best_score, norm_precision=0., precision=0.) + best_params = dict( + penalty_k=cfg.model.test_cfg.rpn.penalty_k, + lr=cfg.model.test_cfg.rpn.lr, + win_influ=cfg.model.test_cfg.rpn.window_influence) + print_log(f'init best score as: {best_score}', logger) + print_log(f'init best params as: {best_params}', logger) + + num_cases = len(args.penalty_k_range) * len(args.lr_range) * len( + args.win_influ_range) + case_count = 0 + + for penalty_k in args.penalty_k_range: + for lr in args.lr_range: + for win_influ in args.win_influ_range: + case_count += 1 + cfg.model.test_cfg.rpn.penalty_k = penalty_k + cfg.model.test_cfg.rpn.lr = lr + cfg.model.test_cfg.rpn.window_influence = win_influ + print_log(f'-----------[{case_count}/{num_cases}]-----------', + logger) + print_log( + f'penalty_k={penalty_k} lr={lr} win_influence={win_influ}', + logger) + + if not distributed: + outputs = single_gpu_test( + model, + data_loader, + args.show, + args.show_dir, + show_score_thr=args.show_score_thr) + else: + outputs = multi_gpu_test(model, data_loader, args.tmpdir, + args.gpu_collect) + + rank, _ = get_dist_info() + if rank == 0: + kwargs = args.eval_options if args.eval_options else {} + if args.eval: + eval_kwargs = cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + eval_hook_args = [ + 'interval', 'tmpdir', 'start', 'gpu_collect', + 'save_best', 'rule', 'by_epoch' + ] + for key in eval_hook_args: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + eval_results = dataset.evaluate(outputs, **eval_kwargs) + # print(eval_results) + print_log(f'evaluation results: {eval_results}', + logger) + print_log('------------------------------------------', + logger) + if eval_results['success'] > best_result['success']: + best_result = eval_results + best_params['penalty_k'] = penalty_k, + best_params['lr'] = lr, + best_params['win_influ'] = win_influ + + print_log( + f'The current best evaluation results: \ + {best_result}', logger) + print_log(f'The current best params: {best_params}', + logger) + + print_log( + f'After parameter searching, the best evaluation results: \ + {best_result}', logger) + print_log(f'After parameter searching, the best params: {best_params}', + logger) + + +if __name__ == '__main__': + main()