-
Notifications
You must be signed in to change notification settings - Fork 5
/
stream_segmentation.py
executable file
·282 lines (251 loc) · 12.5 KB
/
stream_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#!/usr/bin/env python
#
# Stream object segmentation results using Mask RCNN
#
# Note that because the 3D projection contains only
# part of the object, the 3D bounding box will not be
# accurate. You might want to rely on the 3D pose only.
import time
import sys
import argparse
import numpy as np
import cv2
import rospy
import struct
import torch
import torchvision
from sensor_msgs import point_cloud2
from sensor_msgs.msg import Image, CameraInfo, PointCloud2, PointField
from visualization_msgs.msg import Marker, MarkerArray
from vision_msgs.msg import BoundingBox3D, BoundingBox3DArray
from geometry_msgs.msg import Point, Quaternion, Vector3
from std_msgs.msg import Header
from std_msgs.msg import ColorRGBA
import rbd_spot
from rbd_spot_perception.utils.vision.detector import (COCO_CLASS_NAMES,
maskrcnn_filter_by_score,
maskrcnn_draw_result,
bbox3d_from_points)
def get_intrinsics(P):
return dict(fx=P[0],
fy=P[5],
cx=P[2],
cy=P[6])
def make_bbox_msg(center, sizes):
if len(center) == 7:
x, y, z, qx, qy, qz, qw = center
q = Quaternion(x=qx, y=qy, z=qz, w=qw)
else:
x, y, z = center
q = Quaternion(x=0, y=0, z=0, w=1)
s1, s2, s3 = sizes
msg = BoundingBox3D()
msg.center.position = Point(x=x, y=y, z=z)
msg.center.orientation = q
msg.size = Vector3(x=s1, y=s2, z=s3)
return msg
def make_bbox_marker_msg(center, sizes, marker_id, header):
if len(center) == 7:
x, y, z, qx, qy, qz, qw = center
q = Quaternion(x=qx, y=qy, z=qz, w=qw)
else:
x, y, z = center
q = Quaternion(x=0, y=0, z=0, w=1)
s1, s2, s3 = sizes
marker = Marker()
marker.header = header
marker.id = marker_id
marker.type = Marker.CUBE
marker.pose.position = Point(x=x, y=y, z=z)
marker.pose.orientation = q
# The actual bounding box seems tooo large - for now just draw the center;
marker.scale = Vector3(x=s1, y=s2, z=s3) #0.2, y=0.2, z=0.2)
marker.action = Marker.MODIFY
marker.color = ColorRGBA(r=0.0, g=1.0, b=0.0, a=0.3)
return marker
class SegmentationPublisher:
def __init__(self, camera, mask_threshold=0.7):
self._camera = camera
self._mask_threshold = mask_threshold
# Publishes the image with segmentation drawn
self._segimg_pub = rospy.Publisher(f"/spot/segmentation/{camera}/result", Image, queue_size=10)
# Publishes the point cloud of the back-projected segmentations
self._segpcl_pub = rospy.Publisher(f"/spot/segmentation/{camera}/result_points", PointCloud2, queue_size=10)
# Publishes bounding boxes of detected objects with reasonable filtering done.
self._segbox_pub = rospy.Publisher(f"/spot/segmentation/{camera}/result_boxes", BoundingBox3DArray, queue_size=10)
self._segbox_markers_pub = rospy.Publisher(f"/spot/segmentation/{camera}/result_boxes_viz", MarkerArray, queue_size=10)
def publish_result(self, pred, visual_img, depth_img, caminfo):
"""
Args:
pred (Tensor): output of MaskRCNN model
visual_img (np.ndarray): Image from the visual source (not rotated)
depth_img (np.ndarray): Image from the corresponding depth source
"""
# because the prediction is based on an upright image, we need to make sure
# the drawn result is on an upright image
if self._camera == "front":
visual_img_upright = torch.tensor(cv2.rotate(visual_img, cv2.ROTATE_90_CLOCKWISE)).permute(2, 0, 1)
result_img = maskrcnn_draw_result(pred, visual_img_upright)
else:
result_img = maskrcnn_draw_result(pred, torch.tensor(visual_img).permute(2, 0, 1))
result_img_msg = rbd_spot.image.imgmsg_from_imgarray(result_img.permute(1, 2, 0).numpy())
result_img_msg.header.stamp = caminfo.header.stamp
result_img_msg.header.frame_id = caminfo.header.frame_id
self._segimg_pub.publish(result_img_msg)
rospy.loginfo("Published segmentation result (image)")
# For each mask, obtain a set of points.
masks = pred['masks'].squeeze()
masks = masks.reshape(-1, masks.shape[-2], masks.shape[-1]) # make sure shape is (N, H, W) where N is number of masks
masks = torch.greater(masks, self._mask_threshold)
# We need to roate the masks cw by 90 deg if camera is front
if self._camera == "front":
masks = torch.rot90(masks, 1, (1,2))
points = []
markers = []
boxes = []
for i, mask in enumerate(masks):
mask_coords = mask.nonzero().cpu().numpy() # works with boolean tensor too
mask_coords_T = mask_coords.T
mask_visual = visual_img[mask_coords_T[0], mask_coords_T[1], :].reshape(-1, 3) # colors on the mask
mask_depth = depth_img[mask_coords_T[0], mask_coords_T[1]] # depth on the mask
v, u = mask_coords_T[0], mask_coords_T[1]
I = get_intrinsics(caminfo.P)
z = mask_depth / 1000.0
x = (u - I['cx']) * z / I['fx']
y = (v - I['cy']) * z / I['fy']
# filter out points too close to the gripper (most likely noise)
keep_indices = np.argwhere(z > 0.06).flatten()
z = z[keep_indices]
if len(z) == 0:
continue # we won't have points for this mask
x = x[keep_indices]
y = y[keep_indices]
rgb = [struct.unpack('I', struct.pack('BBBB',
mask_visual[i][0],
mask_visual[i][1],
mask_visual[i][2], 255))[0]
for i in keep_indices]
# The points for a single detection mask
mask_points = [[x[i], y[i], z[i], rgb[i]]
for i in range(len(x))]
points.extend(mask_points)
try:
box_center, box_sizes = bbox3d_from_points([x, y, z], axis_aligned=True, no_rotation=True)
boxes.append(make_bbox_msg(box_center, box_sizes))
markers.append(make_bbox_marker_msg(box_center, box_sizes, 1000 + i, result_img_msg.header))
except Exception as ex:
rospy.logerr(f"Error: {ex}")
fields = [PointField('x', 0, PointField.FLOAT32, 1),
PointField('y', 4, PointField.FLOAT32, 1),
PointField('z', 8, PointField.FLOAT32, 1),
PointField('rgb', 12, PointField.UINT32, 1)]
header = Header()
header.stamp = caminfo.header.stamp
header.frame_id = caminfo.header.frame_id
pc2 = point_cloud2.create_cloud(header, fields, points)
# static transform is already published by ros_publish_image_result, so no need here.
self._segpcl_pub.publish(pc2)
rospy.loginfo("Published segmentation result (points)")
# publish bounding boxes and markers
bboxes_array = BoundingBox3DArray(header=header,
boxes=boxes)
self._segbox_pub.publish(bboxes_array)
rospy.loginfo("Published segmentation result (bboxes)")
markers_array = MarkerArray(markers=markers)
self._segbox_markers_pub.publish(markers_array)
rospy.loginfo("Published segmentation result (markers)")
def main():
parser = argparse.ArgumentParser(description="stream segmentation with Mask RCNN")
parser.add_argument("--camera", type=str, help="camera set to stream images from.",
choices=['front', 'left', 'right', 'back', 'hand'],
default='hand')
parser.add_argument("-q", "--quality", type=int,
help="image quality [0-100]", default=75)
formats = ["UNKNOWN", "JPEG", "RAW", "RLE"]
parser.add_argument("-f", "--format", type=str, default="RAW",
help="format", choices=formats)
parser.add_argument("-t", "--timeout", type=float, help="time to keep streaming")
parser.add_argument("-p", "--pub", action="store_true", help="publish as ROS messages")
parser.add_argument("-r", "--rate", type=float,
help="maximum number of detections per second", default=3.0)
args, _ = parser.parse_known_args()
conn = rbd_spot.SpotSDKConn(sdk_name="StreamSegmentationClient")
image_client = rbd_spot.image.create_client(conn)
# Make image requests for specified camera
if args.camera == "hand":
sources = ["hand_color_image", "hand_depth_in_hand_color_frame"]
elif args.camera == "front":
sources = ["frontleft_fisheye_image",
"frontleft_depth_in_visual_frame",
"frontright_fisheye_image",
"frontright_depth_in_visual_frame"]
else:
sources = [f"{args.camera}_fisheye_image",
f"{args.camera}_depth_in_visual_frame"]
# create ros publishers; we publish: (1) raw image (2) depth (3) image with
# segmentation drawn (4) segmentation point cloud
if args.pub:
# create ros publishers; we publish: (1) raw image (2) depth (3) camera info
# (4) image with segmentation result drawn (5) segmentation point cloud
# The first 3 are done through rbd_spot.image, while the last two are
# handled by SegmentationPublisher.
rospy.init_node(f"stream_segmentation_{args.camera}")
image_publishers = rbd_spot.image.ros_create_publishers(sources, name_space="segmentation")
seg_publisher = SegmentationPublisher(args.camera)
rate = rospy.Rate(args.rate)
print(f"Will stream images from {sources}")
image_requests = rbd_spot.image.build_image_requests(
sources, quality=args.quality, fmt=args.format)
print("Loading model...")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
model.to(device)
# Stream the image through specified sources
_start_time = time.time()
while True:
try:
result, time_taken = rbd_spot.image.getImage(image_client, image_requests)
print("GetImage took: {:.3f}".format(time_taken))
if args.pub:
rbd_spot.image.ros_publish_image_result(conn, result, image_publishers)
# Get visual, depth and camera info
if args.camera == "front":
# contains each element is (Image, CameraInfo)
visual_depths = [(result[0], result[1]),
(result[2], result[3])]
else:
visual_depths = [(result[0], result[1])]
# run through model
for visual_response, depth_response in visual_depths:
visual_msg, caminfo = rbd_spot.image.imgmsg_from_response(visual_response, conn)
depth_msg, caminfo = rbd_spot.image.imgmsg_from_response(depth_response, conn)
image = rbd_spot.image.imgarray_from_imgmsg(visual_msg)
if args.camera != "hand":
# grayscale image; make it 3 channels
image = cv2.merge([image, image, image])
depth_image = rbd_spot.image.imgarray_from_imgmsg(depth_msg)
image_input = torch.tensor(image)
if args.camera == "front":
# we need to rotate the images by 90 degrees ccw to make it upright
image_input = torch.tensor(cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE))
image_input = torch.div(image_input, 255)
if device.type == 'cuda':
image_input = image_input.cuda(device)
pred = model([image_input.permute(2, 0, 1)])[0]
pred = maskrcnn_filter_by_score(pred, 0.7)
# Print out a summary
print("detected objects: {}".format(list(sorted(COCO_CLASS_NAMES[l] for l in pred['labels']))))
if len(pred['labels']) > 0:
if args.pub:
seg_publisher.publish_result(pred, image, depth_image, caminfo)
if args.pub:
rate.sleep()
_used_time = time.time() - _start_time
if args.timeout and _used_time > args.timeout:
break
finally:
if args.pub and rospy.is_shutdown():
sys.exit(1)
if __name__ == "__main__":
main()