Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support DanceTrack dataset for MOT #543

Merged
merged 10 commits into from
May 12, 2022
74 changes: 74 additions & 0 deletions configs/_base_/datasets/dancetrack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# dataset settings
dataset_type = 'DanceTrackDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadMultiImagesFromFile', to_float32=True),
dict(type='SeqLoadAnnotations', with_bbox=True, with_track=True),
dict(
type='SeqResize',
img_scale=(1088, 1088),
share_params=True,
ratio_range=(0.8, 1.2),
keep_ratio=True,
bbox_clip_border=False),
dict(type='SeqPhotoMetricDistortion', share_params=True),
dict(
type='SeqRandomCrop',
share_params=False,
crop_size=(1088, 1088),
bbox_clip_border=False),
dict(type='SeqRandomFlip', share_params=True, flip_ratio=0.5),
dict(type='SeqNormalize', **img_norm_cfg),
dict(type='SeqPad', size_divisor=32),
dict(type='MatchInstances', skip_nomatch=True),
dict(
type='VideoCollect',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'gt_match_indices',
'gt_instance_ids'
]),
dict(type='SeqDefaultFormatBundle', ref_prefix='ref')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1088, 1088),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data_root = 'data/dancetrack/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
visibility_thr=-1,
ann_file=data_root + 'annotations/train_cocoformat.json',
img_prefix=data_root + 'train',
ref_img_sampler=dict(
num_ref_imgs=1,
frame_range=10,
filter_key_img=True,
method='uniform'),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/val_cocoformat.json',
img_prefix=data_root + 'val',
ref_img_sampler=None,
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/val_cocoformat.json',
img_prefix=data_root + 'val',
ref_img_sampler=None,
pipeline=test_pipeline))
130 changes: 130 additions & 0 deletions configs/mot/qdtrack/qdtrack_faster-rcnn_r50_fpn_4e_dancetrack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
_base_ = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'../../_base_/models/faster_rcnn_r50_fpn.py',
'../../_base_/datasets/dancetrack.py', '../../_base_/default_runtime.py'
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

model = dict(
type='QDTrack',
detector=dict(
backbone=dict(
norm_cfg=dict(requires_grad=False),
style='caffe',
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')),
rpn_head=dict(bbox_coder=dict(clip_border=False)),
roi_head=dict(
bbox_head=dict(
loss_bbox=dict(type='L1Loss', loss_weight=1.0),
bbox_coder=dict(clip_border=False),
num_classes=1)),
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco-person/faster_rcnn_r50_fpn_1x_coco-person_20201216_175929-d022e227.pth' # noqa: E501
)),
track_head=dict(
type='QuasiDenseTrackHead',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After using base config of ./qdtrack_faster-rcnn_r50_fpn_4e_crowdhuman_mot17-private-half.py, some duplicated keys can be removedd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that it was a fault to set ./qdtrack_faster-rcnn_r50_fpn_4e_crowdhuman_mot17-private-half.py as a base because it is based on '../../_base_/datasets/mot_challenge.py' while the dancetrack qdtrack config shoule be based on '../../_base_/datasets/dancetrack.py'. I think the right way should be creating a new config based on:

_base_ = [
    '../../_base_/models/faster_rcnn_r50_fpn.py',
    '../../_base_/datasets/dancetrack.py', '../../_base_/default_runtime.py'
]

Given the rule that different base config files should have no key conflict, the dancetrack qdtrack config file should inherit no config file using '../../_base_/datasets/mot_challenge.py'. To be precise, if it inherits '../../_base_/datasets/dancetrack.py' and ./qdtrack_faster-rcnn_r50_fpn_4e_mot17-private-half.py' at the same time, it raises error

KeyError: "Duplicate key is not allowed among bases. Duplicate keys: {'data_root', 'train_pipeline', 'img_norm_cfg', 'test_pipeline', 'data', 'dataset_type'}"

roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
embed_head=dict(
type='QuasiDenseEmbedHead',
num_convs=4,
num_fcs=1,
embed_channels=256,
norm_cfg=dict(type='GN', num_groups=32),
loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),
loss_track_aux=dict(
type='L2Loss',
neg_pos_ub=3,
pos_margin=0,
neg_margin=0.1,
hard_mining=True,
loss_weight=1.0)),
loss_bbox=dict(type='L1Loss', loss_weight=1.0),
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='CombinedSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=3,
add_gt_as_proposals=True,
pos_sampler=dict(type='InstanceBalancedPosSampler'),
neg_sampler=dict(type='RandomSampler')))),
tracker=dict(
type='QuasiDenseEmbedTracker',
init_score_thr=0.9,
obj_score_thr=0.5,
match_score_thr=0.5,
memo_tracklet_frames=30,
memo_backdrop_frames=1,
memo_momentum=0.8,
nms_conf_thr=0.5,
nms_backdrop_iou_thr=0.3,
nms_class_iou_thr=0.7,
with_cats=True,
match_metric='bisoftmax'))
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadMultiImagesFromFile', to_float32=True),
dict(type='SeqLoadAnnotations', with_bbox=True, with_track=True),
dict(
type='SeqResize',
img_scale=(1088, 1088),
share_params=True,
ratio_range=(0.8, 1.2),
keep_ratio=True,
bbox_clip_border=False),
dict(type='SeqPhotoMetricDistortion', share_params=True),
dict(
type='SeqRandomCrop',
share_params=False,
crop_size=(1088, 1088),
bbox_clip_border=False),
dict(type='SeqRandomFlip', share_params=True, flip_ratio=0.5),
dict(type='SeqNormalize', **img_norm_cfg),
dict(type='SeqPad', size_divisor=32),
dict(type='MatchInstances', skip_nomatch=True),
dict(
type='VideoCollect',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'gt_match_indices',
'gt_instance_ids'
]),
dict(type='SeqDefaultFormatBundle', ref_prefix='ref')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1088, 1088),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer && learning policy
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(policy='step', step=[3])
# runtime settings
total_epochs = 4
evaluation = dict(metric=['bbox', 'track'], interval=1)
17 changes: 16 additions & 1 deletion docs/en/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This page provides the instructions for dataset preparation on existing benchmar
- [CrowdHuman](https://www.crowdhuman.org/)
- [LVIS](https://www.lvisdataset.org/)
- [TAO](https://taodataset.org/)
- [DanceTrack](https://dancetrack.github.io)
- Single Object Tracking
- [LaSOT](http://vision.cs.stonybrook.edu/~lasot/)
- [UAV123](https://cemse.kaust.edu.sa/ivul/uav123/)
Expand All @@ -31,7 +32,7 @@ Please download the datasets from the official websites. It is recommended to sy

#### 1.2 Multiple Object Tracking

- For the training and testing of multi object tracking task, one of the MOT Challenge datasets (e.g. MOT17) and TAO are needed, CrowdHuman and LVIS can be served as comlementary dataset.
- For the training and testing of multi object tracking task, one of the MOT Challenge datasets (e.g. MOT17, TAO and DanceTrack) are needed, CrowdHuman and LVIS can be served as comlementary dataset.

- The `annotations` under `tao` contains the official annotations from [here](https://github.com/TAO-Dataset/annotations).

Expand Down Expand Up @@ -98,6 +99,11 @@ mmtracking
| | ├── train
| | ├── test
│ │
| ├── DanceTrack
| | ├── train
| | ├── val
| | ├── test
| |
│ ├── crowdhuman
│ │ ├── annotation_train.odgt
│ │ ├── annotation_val.odgt
Expand Down Expand Up @@ -230,6 +236,9 @@ python ./tools/convert_datasets/ilsvrc/imagenet2coco_vid.py -i ./data/ILSVRC -o
python ./tools/convert_datasets/mot/mot2coco.py -i ./data/MOT17/ -o ./data/MOT17/annotations --split-train --convert-det
python ./tools/convert_datasets/mot/mot2reid.py -i ./data/MOT17/ -o ./data/MOT17/reid --val-split 0.2 --vis-threshold 0.3

# DanceTrack
python ./tools/convert_datasets/dancetrack/dancetrack2coco.py -i ./data/DanceTrack ./data/DanceTrack/annotations

# CrowdHuman
python ./tools/convert_datasets/mot/crowdhuman2coco.py -i ./data/crowdhuman -o ./data/crowdhuman/annotations

Expand Down Expand Up @@ -320,6 +329,12 @@ mmtracking
│ │ │ ├── imgs
│ │ │ ├── meta
│ │
│ ├── DanceTrack
│ │ ├── train
│ │ ├── val
│ │ ├── test
│ │ ├── annotations
│ │
│ ├── crowdhuman
│ │ ├── annotation_train.odgt
│ │ ├── annotation_val.odgt
Expand Down
17 changes: 16 additions & 1 deletion docs/zh_cn/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [CrowdHuman](https://www.crowdhuman.org/)
- [LVIS](https://www.lvisdataset.org/)
- [TAO](https://taodataset.org/)
- [DanceTrack](https://dancetrack.github.io)
- 单目标跟踪
- [LaSOT](http://vision.cs.stonybrook.edu/~lasot/)
- [UAV123](https://cemse.kaust.edu.sa/ivul/uav123/)
Expand All @@ -31,7 +32,7 @@

#### 1.2 多目标跟踪

- 对于多目标跟踪任务的训练和测试,需要 MOT Challenge 中的任意一个数据集(比如 MOT17)和 TAO , CrowdHuman 和 LVIS 可以作为补充数据。
- 对于多目标跟踪任务的训练和测试,需要 MOT Challenge 中的任意一个数据集(比如 MOT17, TAO和DanceTrack), CrowdHuman 和 LVIS 可以作为补充数据。

- `tao` 文件夹下包含官方标注的 `annotations` 可以从[这里](https://github.com/TAO-Dataset/annotations)获取。

Expand Down Expand Up @@ -98,6 +99,11 @@ mmtracking
| | ├── train
| | ├── test
│ │
| ├── DanceTrack
| | ├── train
| | ├── val
| | ├── test
| |
│ ├── crowdhuman
│ │ ├── annotation_train.odgt
│ │ ├── annotation_val.odgt
Expand Down Expand Up @@ -231,6 +237,9 @@ python ./tools/convert_datasets/ilsvrc/imagenet2coco_vid.py -i ./data/ILSVRC -o
python ./tools/convert_datasets/mot/mot2coco.py -i ./data/MOT17/ -o ./data/MOT17/annotations --split-train --convert-det
python ./tools/convert_datasets/mot/mot2reid.py -i ./data/MOT17/ -o ./data/MOT17/reid --val-split 0.2 --vis-threshold 0.3

# DanceTrack
python ./tools/convert_datasets/dancetrack/dancetrack2coco.py -i ./data/DanceTrack ./data/DanceTrack/annotations

# CrowdHuman
python ./tools/convert_datasets/mot/crowdhuman2coco.py -i ./data/crowdhuman -o ./data/crowdhuman/annotations

Expand Down Expand Up @@ -321,6 +330,12 @@ mmtracking
│ │ │ ├── imgs
│ │ │ ├── meta
│ │
│ ├── DanceTrack
│ │ ├── train
│ │ ├── val
│ │ ├── test
│ │ ├── annotations
│ │
│ ├── crowdhuman
│ │ ├── annotation_train.odgt
│ │ ├── annotation_val.odgt
Expand Down
3 changes: 2 additions & 1 deletion mmtrack/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base_sot_dataset import BaseSOTDataset
from .builder import build_dataloader
from .coco_video_dataset import CocoVideoDataset
from .dancetrack_dataset import DanceTrackDataset
from .dataset_wrappers import RandomSampleConcatDataset
from .got10k_dataset import GOT10kDataset
from .imagenet_vid_dataset import ImagenetVIDDataset
Expand All @@ -30,5 +31,5 @@
'UAV123Dataset', 'TrackingNetDataset', 'OTB100Dataset',
'YouTubeVISDataset', 'GOT10kDataset', 'VOTDataset', 'BaseSOTDataset',
'SOTCocoDataset', 'SOTImageNetVIDDataset', 'RandomSampleConcatDataset',
'TaoDataset'
'TaoDataset', 'DanceTrackDataset'
]
Loading