forked from TRI-ML/packnet-sfm
-
Notifications
You must be signed in to change notification settings - Fork 3
/
packnet_sfm_node
executable file
·165 lines (124 loc) · 5.57 KB
/
packnet_sfm_node
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#!/usr/bin/env python3
import rospy
import rospkg
from sensor_msgs.msg import Image
import numpy as np
import os
import torch
import time
import cv2
from cv_bridge import CvBridge, CvBridgeError
import packnet_sfm
from packnet_sfm.models.model_wrapper import ModelWrapper
from packnet_sfm.datasets.augmentations import resize_image, to_tensor
from packnet_sfm.utils.horovod import hvd_init, rank, world_size, print0
from packnet_sfm.utils.image import load_image, interpolate_image
from packnet_sfm.utils.config import parse_test_file
from packnet_sfm.utils.load import set_debug
from packnet_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth
from packnet_sfm.utils.logging import pcolor
from packnet_sfm.utils.types import is_seq, is_tensor
MODEL_NAME = "PackNet01_MR_velsup_CStoK.ckpt"
class DepthInference:
def __init__(self):
self.bridge = CvBridge()
self.model_wrapper = None
self.network_input_shape = None
self.original_input_shape = None
self.rgb_img_msg = None
self.rgb_counter = 0
self.depth_img_msg = None
self.set_model_wrapper()
## Communication
# queue_size=None to process only the last message
self.pub_rgb_image = rospy.Publisher('/camera/color/image_raw', Image, queue_size=None)
self.pub_depth_image = rospy.Publisher('/camera/depth/image_raw', Image, queue_size=None)
rospy.Subscriber("/video/image_raw", Image, self.cb_image, queue_size=1)
rospy.spin()
def set_model_wrapper(self):
rospack = rospkg.RosPack()
package_install_path = rospack.get_path('packnet_sfm_ros')
package_install_path = package_install_path.split("install")[0]
models_path = 'src/packnet_sfm_ros/trained_models/'
models_name = MODEL_NAME
config, state_dict = parse_test_file(package_install_path + models_path + models_name)
self.set_network_input_shape(config)
# Initialize model wrapper from checkpoint arguments
self.model_wrapper = ModelWrapper(config, load_datasets=False)
# Restore monodepth_model state
self.model_wrapper.load_state_dict(state_dict)
if torch.cuda.is_available():
self.model_wrapper = self.model_wrapper.to('cuda:{}'.format(rank()), dtype=None)
# Set to eval mode
self.model_wrapper.eval()
def set_network_input_shape(self, config):
self.network_input_shape = config.datasets.augmentation.image_shape
def process(self, rgb_img_msg):
try:
rgb_image = self.bridge.imgmsg_to_cv2(rgb_img_msg, "bgr8")
except CvBridgeError as e:
print(e)
# shrink the image to fit NN input
rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
rgb_image = cv2.resize(rgb_image, (self.network_input_shape[1], self.network_input_shape[0]), interpolation=cv2.INTER_LANCZOS4)
rgb_image = to_tensor(rgb_image).unsqueeze(0)
if torch.cuda.is_available():
rgb_image = rgb_image.to('cuda:{}'.format(rank()), dtype=None)
# Depth inference (returns predicted inverse depth)
pred_inv_depth = self.model_wrapper.depth(rgb_image)
# resize from PIL image and cv2 has different convention about the image shape
pred_inv_depth_resized = interpolate_image(pred_inv_depth, (self.original_input_shape[0], self.original_input_shape[1]), mode='bicubic')
# convert inverse depth to depth image
depth_img = self.write_depth(self.inv2depth(pred_inv_depth_resized))
depth_img_msg = self.bridge.cv2_to_imgmsg(depth_img, encoding="mono16")
# define the header
rgb_img_msg.header.stamp = rospy.Time.now()
depth_img_msg.header.stamp = rospy.Time.now()
rgb_img_msg.header.frame_id = "left_image"
depth_img_msg.header.frame_id = "left_image"
depth_img_msg.header.seq = rgb_img_msg.header.seq
# publish the image and depth_image
self.pub_rgb_image.publish(rgb_img_msg)
self.pub_depth_image.publish(depth_img_msg)
def inv2depth(self, inv_depth):
"""
Invert an inverse depth map to produce a depth map
Parameters
----------
inv_depth : torch.Tensor or list of torch.Tensor [B,1,H,W]
Inverse depth map
Returns
-------
depth : torch.Tensor or list of torch.Tensor [B,1,H,W]
Depth map
"""
if is_seq(inv_depth):
return [inv2depth(item) for item in inv_depth]
else:
return 1. / inv_depth
def write_depth(self, depth):
"""
Write a depth map to file, and optionally its corresponding intrinsics.
This code is modified to export compatible-format depth image to openVSLAM
Parameters
----------
depth : np.array [H,W]
Depth map
"""
# If depth is a tensor
if is_tensor(depth):
depth = depth.detach().squeeze().cpu()
depth = np.clip(depth, 0, 100)
# make depth image to 16 bit format following TUM RGBD dataset format
# it is also ROS standard(?)
depth = np.uint16(depth * 256)
return depth
def cb_image(self, data):
data.header.seq = self.rgb_counter
self.original_input_shape = (data.height, data.width)
self.rgb_counter += 1
rospy.loginfo("cb_image: {}".format(self.rgb_counter))
self.process(data)
if __name__ == "__main__":
rospy.init_node('packnet_sfm_node')
depth_inference_node = DepthInference()