Skip to content

Commit

Permalink
add online visualization function for semantic segmentation results
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Apr 7, 2021
1 parent 9cb75e7 commit 9970aa4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
27 changes: 27 additions & 0 deletions mmdet3d/core/visualizer/open3d_vis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import cv2
import numpy as np
import torch
Expand Down Expand Up @@ -44,6 +45,9 @@ def _draw_points(points,
elif mode == 'xyzrgb':
pcd.points = o3d.utility.Vector3dVector(points[:, :3])
points_colors = points[:, 3:6]
# normalize to [0, 1] for open3d drawing
if not ((points_colors >= 0.0) & (points_colors <= 1.0)).all():
points_colors /= 255.0
else:
raise NotImplementedError

Expand Down Expand Up @@ -462,6 +466,7 @@ def __init__(self,
self.rot_axis = rot_axis
self.center_mode = center_mode
self.mode = mode
self.seg_num = 0

# draw points
if points is not None:
Expand Down Expand Up @@ -494,6 +499,28 @@ def add_bboxes(self, bbox3d, bbox_color=None, points_in_box_color=None):
bbox_color, points_in_box_color, self.rot_axis,
self.center_mode, self.mode)

def add_seg_mask(self, seg_mask_colors):
"""Add segmentation mask to visualizer via per-point colorization.
Args:
seg_mask_colors (numpy.array, shape=[N, 6]):
The segmentation mask whose first 3 dims are point coordinates
and last 3 dims are converted colors.
"""
# we can't draw the colors on existing points
# in case gt and pred mask would overlap
# instead we set a large offset along x-axis for each seg mask
self.seg_num += 1
offset = (np.array(self.pcd.points).max(0) -
np.array(self.pcd.points).min(0))[0] * 1.2 * self.seg_num
mesh_frame = geometry.TriangleMesh.create_coordinate_frame(
size=1, origin=[offset, 0, 0]) # create coordinate frame for seg
self.o3d_visualizer.add_geometry(mesh_frame)
seg_points = copy.deepcopy(seg_mask_colors)
seg_points[:, 0] += offset
_draw_points(
seg_points, self.o3d_visualizer, self.points_size, mode='xyzrgb')

def show(self, save_path=None):
"""Visualize the points cloud.
Expand Down
38 changes: 22 additions & 16 deletions mmdet3d/core/visualizer/show_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,10 @@ def show_seg_result(points,
unannotated points. Defaults to None.
show (bool, optional): Visualize the results online. Defaults to False.
"""
'''
# TODO: not sure how to draw colors online, maybe we need two frames?
from .open3d_vis import Visualizer
if show:
vis = Visualizer(points)
if pred_bboxes is not None:
vis.add_bboxes(bbox3d=pred_bboxes)
if gt_bboxes is not None:
vis.add_bboxes(bbox3d=gt_bboxes, bbox_color=(0, 0, 1))
vis.show()
'''
# we need 3D coordinates to visualize segmentation mask
if gt_seg is not None or pred_seg is not None:
assert points is not None, \
'3D coordinates are required for segmentation visualization'

# filter out ignored points
if gt_seg is not None and ignore_index is not None:
Expand All @@ -156,8 +148,23 @@ def show_seg_result(points,

if gt_seg is not None:
gt_seg_color = palette[gt_seg]
gt_seg_color = np.concatenate([points[:, :3], gt_seg_color], axis=1)
if pred_seg is not None:
pred_seg_color = palette[pred_seg]
pred_seg_color = np.concatenate([points[:, :3], pred_seg_color],
axis=1)

# online visualization of segmentation mask
# we show three masks in a row, scene_points, gt_mask, pred_mask
if show:
from .open3d_vis import Visualizer
mode = 'xyzrgb' if points.shape[1] == 6 else 'xyz'
vis = Visualizer(points, mode=mode)
if gt_seg is not None:
vis.add_seg_mask(gt_seg_color)
if pred_seg is not None:
vis.add_seg_mask(pred_seg_color)
vis.show()

result_path = osp.join(out_dir, filename)
mmcv.mkdir_or_exist(result_path)
Expand All @@ -166,9 +173,8 @@ def show_seg_result(points,
_write_obj(points, osp.join(result_path, f'{filename}_points.obj'))

if gt_seg is not None:
gt_seg = np.concatenate([points[:, :3], gt_seg_color], axis=1)
_write_obj(gt_seg, osp.join(result_path, f'{filename}_gt.obj'))
_write_obj(gt_seg_color, osp.join(result_path, f'{filename}_gt.obj'))

if pred_seg is not None:
pred_seg = np.concatenate([points[:, :3], pred_seg_color], axis=1)
_write_obj(pred_seg, osp.join(result_path, f'{filename}_pred.obj'))
_write_obj(pred_seg_color, osp.join(result_path,
f'{filename}_pred.obj'))

0 comments on commit 9970aa4

Please sign in to comment.