Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix vis hook bug and ut #1839

Merged
merged 1 commit into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mmdet3d/datasets/det3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def parse_data_info(self, info: dict) -> dict:
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])

info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path']
if 'lidar_sweeps' in info:
for sweep in info['lidar_sweeps']:
Expand Down
5 changes: 3 additions & 2 deletions mmdet3d/datasets/transforms/formating.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def __init__(
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'pcd_trans', 'sample_idx', 'pcd_scale_factor',
'pcd_rotation', 'pcd_rotation_angle', 'lidar_path',
'num_pts_feats', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')):
self.keys = keys
Expand Down
47 changes: 35 additions & 12 deletions mmdet3d/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Sequence

import mmcv
import numpy as np
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.runner import Runner
Expand Down Expand Up @@ -95,15 +96,27 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
# is visualized for each evaluation.
total_curr_iter = runner.iter + batch_idx

data_input = dict()

# Visualize only the first data
img_path = outputs[0].img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
if 'img_path' in outputs[0]:
img_path = outputs[0].img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input['img'] = img

if 'lidar_path' in outputs[0]:
lidar_path = outputs[0].lidar_path
num_pts_feats = outputs[0].num_pts_feats
pts_bytes = self.file_client.get(lidar_path)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, num_pts_feats)
data_input['points'] = points

if total_curr_iter % self.interval == 0:
self._visualizer.add_datasample(
osp.basename(img_path) if self.show else 'val_img',
img,
'val sample',
data_input,
data_sample=outputs[0],
show=self.show,
wait_time=self.wait_time,
Expand Down Expand Up @@ -135,18 +148,28 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
for data_sample in outputs:
self._test_index += 1

img_path = data_sample.img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input = dict()
if 'img_path' in data_sample:
img_path = data_sample.img_path
img_bytes = self.file_client.get(img_path)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
data_input['img'] = img

if 'lidar_path' in data_sample:
lidar_path = data_sample.lidar_path
num_pts_feats = data_sample.num_pts_feats
pts_bytes = self.file_client.get(lidar_path)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, num_pts_feats)
data_input['points'] = points

out_file = None
if self.test_out_dir is not None:
out_file = osp.basename(img_path)
out_file = osp.join(self.test_out_dir, out_file)
out_file = self.test_out_dir

self._visualizer.add_datasample(
osp.basename(img_path) if self.show else 'test_img',
img,
'test sample',
data_input,
data_sample=data_sample,
show=self.show,
wait_time=self.wait_time,
Expand Down
16 changes: 8 additions & 8 deletions mmdet3d/visualization/local_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,14 +524,14 @@ def add_datasample(self,
data_input, data_sample.gt_instances_3d,
data_sample.metainfo, vis_task, palette)
if 'gt_instances' in data_sample:
assert 'img' in data_input
if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb
gt_img_data = self._draw_instances(img,
data_sample.gt_instances,
classes, palette)
if 'gt_pts_seg' in data_sample:
if len(data_sample.gt_instances) > 0:
assert 'img' in data_input
if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb
gt_img_data = self._draw_instances(
img, data_sample.gt_instances, classes, palette)
if 'gt_pts_seg' in data_sample and vis_task == 'seg':
assert classes is not None, 'class information is ' \
'not provided when ' \
'visualizing panoptic ' \
Expand Down
61 changes: 39 additions & 22 deletions tests/test_engine/test_hooks/test_visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,59 @@
from unittest import TestCase
from unittest.mock import Mock

import numpy as np
import torch
from mmengine.structures import InstanceData

from mmdet3d.engine.hooks import Det3DVisualizationHook
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
from mmdet3d.visualization import Det3DLocalVisualizer


def _rand_bboxes(num_boxes, h, w):
cx, cy, bw, bh = torch.rand(num_boxes, 4).T

tl_x = ((cx * w) - (w * bw / 2)).clip(0, w)
tl_y = ((cy * h) - (h * bh / 2)).clip(0, h)
br_x = ((cx * w) + (w * bw / 2)).clip(0, w)
br_y = ((cy * h) + (h * bh / 2)).clip(0, h)

bboxes = torch.vstack([tl_x, tl_y, br_x, br_y]).T
return bboxes


class TestVisualizationHook(TestCase):

def setUp(self) -> None:
Det3DLocalVisualizer.get_instance('visualizer')

pred_instances = InstanceData()
pred_instances.bboxes = _rand_bboxes(5, 10, 12)
pred_instances.labels = torch.randint(0, 2, (5, ))
pred_instances.scores = torch.rand((5, ))
pred_det_data_sample = Det3DDataSample()
pred_det_data_sample.set_metainfo({
pred_instances_3d = InstanceData()
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 1.2000, 0.4800, 1.8900, -1.5808]]))
pred_instances_3d.labels_3d = torch.tensor([0])
pred_instances_3d.scores_3d = torch.tensor([0.8])

pred_det3d_data_sample = Det3DDataSample()
pred_det3d_data_sample.set_metainfo({
'num_pts_feats':
4,
'lidar2img':
np.array([[
6.02943734e+02, -7.07913286e+02, -1.22748427e+01,
-1.70942724e+02
],
[
1.76777261e+02, 8.80879902e+00, -7.07936120e+02,
-1.02568636e+02
],
[
9.99984860e-01, -1.52826717e-03, -5.29071223e-03,
-3.27567990e-01
],
[
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
1.00000000e+00
]]),
'img_path':
osp.join(osp.dirname(__file__), '../../data/color.jpg')
osp.join(
osp.dirname(__file__),
'../../data/kitti/training/image_2/000000.png'),
'lidar_path':
osp.join(
osp.dirname(__file__),
'../../data/kitti/training/velodyne_reduced/000000.bin')
})
pred_det_data_sample.pred_instances = pred_instances
self.outputs = [pred_det_data_sample] * 2
pred_det3d_data_sample.pred_instances_3d = pred_instances_3d
self.outputs = [pred_det3d_data_sample] * 2

def test_after_val_iter(self):
runner = Mock()
Expand Down