Skip to content

Commit

Permalink
Support the testing process of MaskTrack R-CNN (#298)
Browse files Browse the repository at this point in the history
* Support the testing process of MaskTrack R-CNN

* refactor track2result and restore_result functions

* fix a typo

* fix typos

* update docstring of MaskTrack R-CNN

* refactor outs2results and results2outs functions

* update based on 1-st comments

* fix bug

* update based on 2nd comments

* fix a typo
  • Loading branch information
GT9505 authored Oct 22, 2021
1 parent 6293ddb commit 2d5dce5
Show file tree
Hide file tree
Showing 25 changed files with 698 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,10 @@
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
# TODO: Support tracker
# tracker=dict(
# type='MaskTrackRCNNTracker',
# score_coefficient=1.0,
# iou_coefficient=2.0,
# label_coefficient=10.0))
tracker=None)
tracker=dict(
type='MaskTrackRCNNTracker',
match_weights=dict(det_score=1.0, iou=2.0, det_label=10.0),
num_frames_retain=10))

# dataset settings
img_norm_cfg = dict(
Expand Down
14 changes: 12 additions & 2 deletions mmtrack/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.distributed as dist
from mmcv.image import tensor2imgs
from mmcv.runner import get_dist_info
from mmdet.core import encode_mask_results


def single_gpu_test(model,
Expand Down Expand Up @@ -44,8 +45,6 @@ def single_gpu_test(model,
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
for k, v in result.items():
results[k].append(v)

batch_size = data['img'][0].size(0)
if show or out_dir:
Expand Down Expand Up @@ -104,6 +103,13 @@ def single_gpu_test(model,

prev_img_meta = img_meta

for key in result:
if 'mask' in key:
result[key] = encode_mask_results(result[key])

for k, v in result.items():
results[k].append(v)

for _ in range(batch_size):
prog_bar.update()

Expand Down Expand Up @@ -141,6 +147,10 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
for key in result:
if 'mask' in key:
result[key] = encode_mask_results(result[key])

for k, v in result.items():
results[k].append(v)

Expand Down
16 changes: 10 additions & 6 deletions mmtrack/core/evaluation/eval_mot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import numpy as np
import pandas as pd
from mmcv.utils import print_log
from mmdet.core import bbox2result
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from motmetrics.lap import linear_sum_assignment
from motmetrics.math_util import quiet_divide

from ..track import track2result
from mmtrack.core.track import outs2results

METRIC_MAPS = {
'idf1': 'IDF1',
Expand Down Expand Up @@ -49,12 +48,17 @@ def acc_single_video(results,
]
for result, gt in zip(results, gts):
if ignore_by_classes:
gt_ignore = bbox2result(gt['bboxes_ignore'], gt['labels_ignore'],
num_classes)
gt_ignore = outs2results(
bboxes=gt['bboxes_ignore'],
labels=gt['labels_ignore'],
num_classes=num_classes)['bbox_results']
else:
gt_ignore = [gt['bboxes_ignore'] for i in range(num_classes)]
gt = track2result(gt['bboxes'], gt['labels'], gt['instance_ids'],
num_classes)
gt = outs2results(
bboxes=gt['bboxes'],
labels=gt['labels'],
ids=gt['instance_ids'],
num_classes=num_classes)['bbox_results']
for i in range(num_classes):
gt_ids, gt_bboxes = gt[i][:, 0].astype(np.int), gt[i][:, 1:]
pred_ids, pred_bboxes = result[i][:, 0].astype(
Expand Down
4 changes: 2 additions & 2 deletions mmtrack/core/track/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .correlation import depthwise_correlation
from .similarity import embed_similarity
from .transforms import imrenormalize, restore_result, track2result
from .transforms import imrenormalize, outs2results, results2outs

__all__ = [
'depthwise_correlation', 'track2result', 'restore_result',
'depthwise_correlation', 'outs2results', 'results2outs',
'embed_similarity', 'imrenormalize'
]
143 changes: 108 additions & 35 deletions mmtrack/core/track/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mmcv
import numpy as np
import torch
from mmdet.core import bbox2result


def imrenormalize(img, img_norm_cfg, new_img_norm_cfg):
Expand Down Expand Up @@ -47,56 +48,128 @@ def _imrenormalize(img, img_norm_cfg, new_img_norm_cfg):
return img


def track2result(bboxes, labels, ids, num_classes):
"""Convert tracking results to a list of numpy arrays.
def outs2results(bboxes=None,
labels=None,
masks=None,
ids=None,
num_classes=None,
**kwargs):
"""Convert tracking/detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor | np.ndarray): shape (n, 5)
labels (torch.Tensor | np.ndarray): shape (n, )
masks (torch.Tensor | np.ndarray): shape (n, h, w)
ids (torch.Tensor | np.ndarray): shape (n, )
num_classes (int): class number, including background class
num_classes (int): class number, not including background class
Returns:
list(ndarray): tracking results of each class.
dict[str : list(ndarray) | list[list[np.ndarray]]]: tracking/detection
results of each class. It may contain keys as belows:
- bbox_results (list[np.ndarray]): Each list denotes bboxes of one
category.
- mask_results (list[list[np.ndarray]]): Each outer list denotes masks
of one category. Each inner list denotes one mask belonging to
the category. Each mask has shape (h, w).
"""
valid_inds = ids > -1
bboxes = bboxes[valid_inds]
labels = labels[valid_inds]
ids = ids[valid_inds]
assert labels is not None
assert num_classes is not None

if bboxes.shape[0] == 0:
return [np.zeros((0, 6), dtype=np.float32) for i in range(num_classes)]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
ids = ids.cpu().numpy()
return [
np.concatenate((ids[labels == i, None], bboxes[labels == i, :]),
axis=1) for i in range(num_classes)
]
results = dict()

if ids is not None:
valid_inds = ids > -1
ids = ids[valid_inds]
labels = labels[valid_inds]

if bboxes is not None:
if ids is not None:
bboxes = bboxes[valid_inds]
if bboxes.shape[0] == 0:
bbox_results = [
np.zeros((0, 6), dtype=np.float32)
for i in range(num_classes)
]
else:
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
ids = ids.cpu().numpy()
bbox_results = [
np.concatenate(
(ids[labels == i, None], bboxes[labels == i, :]),
axis=1) for i in range(num_classes)
]
else:
bbox_results = bbox2result(bboxes, labels, num_classes)
results['bbox_results'] = bbox_results

if masks is not None:
if ids is not None:
masks = masks[valid_inds]
if isinstance(masks, torch.Tensor):
masks = masks.detach().cpu().numpy()
masks_results = [[] for _ in range(num_classes)]
for i in range(bboxes.shape[0]):
masks_results[labels[i]].append(masks[i])
results['mask_results'] = masks_results

return results


def restore_result(result, return_ids=False):
def results2outs(bbox_results=None,
mask_results=None,
mask_shape=None,
**kwargs):
"""Restore the results (list of results of each category) into the results
of the model forward.
Args:
result (list[ndarray]): shape (n, 5) or (n, 6)
return_ids (bool, optional): Whether the input has tracking
result. Default to False.
bbox_results (list[np.ndarray]): Each list denotes bboxes of one
category.
mask_results (list[list[np.ndarray]]): Each outer list denotes masks of
one category. Each inner list denotes one mask belonging to
the category. Each mask has shape (h, w).
mask_shape (tuple[int]): The shape (h, w) of mask.
Returns:
tuple: tracking results of each class.
tuple: tracking results of each class. It may contain keys as belows:
- bboxes (np.ndarray): shape (n, 5)
- labels (np.ndarray): shape (n, )
- masks (np.ndarray): shape (n, h, w)
- ids (np.ndarray): shape (n, )
"""
labels = []
for i, bbox in enumerate(result):
labels.extend([i] * bbox.shape[0])
bboxes = np.concatenate(result, axis=0).astype(np.float32)
labels = np.array(labels, dtype=np.int64)
if return_ids:
ids = bboxes[:, 0].astype(np.int64)
bboxes = bboxes[:, 1:]
return bboxes, labels, ids
else:
return bboxes, labels
outputs = dict()

if bbox_results is not None:
labels = []
for i, bbox in enumerate(bbox_results):
labels.extend([i] * bbox.shape[0])
labels = np.array(labels, dtype=np.int64)
outputs['labels'] = labels

bboxes = np.concatenate(bbox_results, axis=0).astype(np.float32)
if bboxes.shape[1] == 5:
outputs['bboxes'] = bboxes
elif bboxes.shape[1] == 6:
ids = bboxes[:, 0].astype(np.int64)
bboxes = bboxes[:, 1:]
outputs['bboxes'] = bboxes
outputs['ids'] = ids
else:
raise NotImplementedError(
f'Not supported bbox shape: (N, {bboxes.shape[1]})')

if mask_results is not None:
assert mask_shape is not None
mask_height, mask_width = mask_shape
mask_results = mmcv.concat_list(mask_results)
if len(mask_results) == 0:
masks = np.zeros((0, mask_height, mask_width)).astype(bool)
else:
masks = np.stack(mask_results, axis=0)
outputs['masks'] = masks

return outputs
14 changes: 9 additions & 5 deletions mmtrack/datasets/mot_challenge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmdet.core import eval_map
from mmdet.datasets import DATASETS

from mmtrack.core import restore_result
from mmtrack.core import results2outs
from .coco_video_dataset import CocoVideoDataset


Expand Down Expand Up @@ -187,8 +187,11 @@ def format_track_results(self, results, infos, resfile):
frame = info['mot_frame_id']
else:
frame = info['frame_id'] + 1
bboxes, labels, ids = restore_result(res, return_ids=True)
for bbox, label, id in zip(bboxes, labels, ids):

outs_track = results2outs(bbox_results=res)
for bbox, label, id in zip(outs_track['bboxes'],
outs_track['labels'],
outs_track['ids']):
x1, y1, x2, y2, conf = bbox
f.writelines(
f'{frame},{id},{x1:.3f},{y1:.3f},{(x2-x1):.3f},' +
Expand All @@ -202,8 +205,9 @@ def format_bbox_results(self, results, infos, resfile):
frame = info['mot_frame_id']
else:
frame = info['frame_id'] + 1
bboxes, labels = restore_result(res)
for bbox, label in zip(bboxes, labels):

outs_det = results2outs(bbox_results=res)
for bbox, label in zip(outs_det['bboxes'], outs_det['labels']):
x1, y1, x2, y2, conf = bbox
f.writelines(
f'{frame},-1,{x1:.3f},{y1:.3f},{(x2-x1):.3f},' +
Expand Down
7 changes: 4 additions & 3 deletions mmtrack/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile

from mmtrack.core import restore_result
from mmtrack.core import results2outs


@PIPELINES.register_module()
Expand Down Expand Up @@ -98,9 +98,10 @@ class LoadDetections(object):
"""

def __call__(self, results):
detections = results['detections']
outs_det = results2outs(bbox_results=results['detections'])
bboxes = outs_det['bboxes']
labels = outs_det['labels']

bboxes, labels = restore_result(detections)
results['public_bboxes'] = bboxes[:, :4]
if bboxes.shape[1] > 4:
results['public_scores'] = bboxes[:, -1]
Expand Down
Loading

0 comments on commit 2d5dce5

Please sign in to comment.