Skip to content

Commit

Permalink
applies mutr bugfix #15
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Doll committed Jan 13, 2023
1 parent 895277b commit 99e39a5
Showing 1 changed file with 52 additions and 7 deletions.
59 changes: 52 additions & 7 deletions plugin/track/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import warnings
from nuscenes.utils.data_classes import Box as NuScenesBox
from os import path as osp
import copy
from typing import List, Tuple, Union

from mmdet.datasets import DATASETS
from mmdet3d.core import show_result
Expand All @@ -18,6 +20,39 @@
from nuscenes import NuScenes


class NuScenesTrackingBox(NuScenesBox):
def __init__(self,
center: List[float],
size: List[float],
orientation: Quaternion,
label: int = np.nan,
score: float = np.nan,
velocity: Tuple = (np.nan, np.nan, np.nan),
name: str = None,
token: str = None,
):
"""
:param center: Center of box given as x, y, z.
:param size: Size of box in width, length, height.
:param orientation: Box orientation.
:param label: Integer label, optional.
:param score: Classification score, optional.
:param velocity: Box velocity in x, y, z direction.
:param name: Box name, optional. Can be used e.g. for denote category name.
:param token: Unique string identifier from DB.
"""
super(NuScenesTrackingBox, self).__init__(center, size, orientation, label,
score, velocity, name, token)

def rotate(self, quaternion: Quaternion) -> None:
self.center = np.dot(quaternion.rotation_matrix, self.center)
self.orientation = quaternion * self.orientation
self.velocity = np.dot(quaternion.rotation_matrix, self.velocity)

def copy(self) -> 'NuScenesTrackingBox':
return copy.deepcopy(self)


@DATASETS.register_module()
class NuScenesTrackDataset(Dataset):
r"""NuScenes Dataset.
Expand Down Expand Up @@ -615,16 +650,16 @@ def _format_bbox(self, results, jsonfile_prefix=None):
center_[2] = center_[2] + (box.wlh.tolist()[2] / 2.0)
nusc_anno = dict(
sample_token=sample_token,
translation=center_,
translation=box.center.tolist(),
# translation=box.center.tolist(),
size=box.wlh.tolist(),
rotation=box.orientation.elements.tolist(),
velocity=box.velocity[:2].tolist(),
# detection_name=name,
tracking_name=name,
tracking_score=det['track_scores'][i].item(),
tracking_id=str(det['track_ids'][i].item()),
# attribute_name=attr)
attribute_name=attr,
tracking_score=box.score,
tracking_id=box.token,
)
# print(nusc_anno)
annos.append(nusc_anno)
Expand Down Expand Up @@ -847,14 +882,23 @@ def output_to_nusc_box(detection):
- boxes_3d (:obj:`BaseInstance3DBoxes`): Detection bbox.
- scores_3d (torch.Tensor): Detection scores.
- labels_3d (torch.Tensor): Predicted box labels.
- tracking (bool): if convert for tracking evaluation
Returns:
list[:obj:`NuScenesBox`]: List of standard NuScenesBoxes.
list[:obj:`NuScenesBox`]: List of NuScenesTrackingBoxes.
"""
box3d = detection['boxes_3d']
scores = detection['scores_3d'].numpy()
# overwrite the scores with the tracking scores
if 'track_scores' in detection.keys() and detection['track_scores'] is not None:
scores = detection['track_scores'].numpy()
labels = detection['labels_3d'].numpy()

if 'track_ids' in detection.keys() and detection['track_ids'] is not None:
track_ids = detection['track_ids']
else:
track_ids = [None for _ in range(len(box3d))]

box_gravity_center = box3d.gravity_center.numpy()
box_dims = box3d.dims.numpy()
box_yaw = box3d.yaw.numpy()
Expand All @@ -870,13 +914,14 @@ def output_to_nusc_box(detection):
# velo_ori = box3d[i, 6]
# velocity = (
# velo_val * np.cos(velo_ori), velo_val * np.sin(velo_ori), 0.0)
box = NuScenesBox(
box = NuScenesTrackingBox(
box_gravity_center[i],
box_dims[i],
quat,
label=labels[i],
score=scores[i],
velocity=velocity)
velocity=velocity,
token=str(track_ids[i]))
box_list.append(box)
return box_list

Expand Down

0 comments on commit 99e39a5

Please sign in to comment.