-
Notifications
You must be signed in to change notification settings - Fork 7
/
demo.py
102 lines (88 loc) · 4.17 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
""" Demo of using VoteNet 3D object detector to detect objects from a point cloud.
"""
import os
import sys
import numpy as np
import argparse
import importlib
import time
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='sunrgbd', help='Dataset: sunrgbd or scannet [default: sunrgbd]')
parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]')
FLAGS = parser.parse_args()
import torch
import torch.nn as nn
import torch.optim as optim
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'models'))
from pc_util import random_sampling, read_ply
from ap_helper import parse_predictions
def preprocess_point_cloud(point_cloud):
''' Prepare the numpy point cloud (N,3) for forward pass '''
point_cloud = point_cloud[:,0:3] # do not use color for now
floor_height = np.percentile(point_cloud[:,2],0.99)
height = point_cloud[:,2] - floor_height
point_cloud = np.concatenate([point_cloud, np.expand_dims(height, 1)],1) # (N,4) or (N,7)
point_cloud = random_sampling(point_cloud, FLAGS.num_point)
pc = np.expand_dims(point_cloud.astype(np.float32), 0) # (1,40000,4)
return pc
if __name__=='__main__':
# Set file paths and dataset config
demo_dir = os.path.join(BASE_DIR, 'demo_files')
if FLAGS.dataset == 'sunrgbd':
sys.path.append(os.path.join(ROOT_DIR, 'sunrgbd'))
from sunrgbd_detection_dataset import DC # dataset config
checkpoint_path = os.path.join(demo_dir, 'pretrained_votenet_on_sunrgbd.tar')
pc_path = os.path.join(demo_dir, 'input_pc_sunrgbd.ply')
elif FLAGS.dataset == 'scannet':
sys.path.append(os.path.join(ROOT_DIR, 'scannet'))
from scannet_detection_dataset import DC # dataset config
checkpoint_path = os.path.join(demo_dir, 'pretrained_votenet_on_scannet.tar')
pc_path = os.path.join(demo_dir, 'input_pc_scannet.ply')
else:
print('Unkown dataset %s. Exiting.'%(DATASET))
exit(-1)
eval_config_dict = {'remove_empty_box': True, 'use_3d_nms': True, 'nms_iou': 0.25,
'use_old_type_nms': False, 'cls_nms': False, 'per_class_proposal': False,
'conf_thresh': 0.5, 'dataset_config': DC}
# Init the model and optimzier
MODEL = importlib.import_module('votenet') # import network module
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = MODEL.VoteNet(num_proposal=256, input_feature_dim=1, vote_factor=1,
sampling='seed_fps', num_class=DC.num_class,
num_heading_bin=DC.num_heading_bin,
num_size_cluster=DC.num_size_cluster,
mean_size_arr=DC.mean_size_arr).to(device)
print('Constructed model.')
# Load checkpoint
optimizer = optim.Adam(net.parameters(), lr=0.001)
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print("Loaded checkpoint %s (epoch: %d)"%(checkpoint_path, epoch))
# Load and preprocess input point cloud
net.eval() # set model to eval mode (for bn and dp)
point_cloud = read_ply(pc_path)
pc = preprocess_point_cloud(point_cloud)
print('Loaded point cloud data: %s'%(pc_path))
# Model inference
inputs = {'point_clouds': torch.from_numpy(pc).to(device)}
tic = time.time()
with torch.no_grad():
end_points = net(inputs)
toc = time.time()
print('Inference time: %f'%(toc-tic))
end_points['point_clouds'] = inputs['point_clouds']
pred_map_cls = parse_predictions(end_points, eval_config_dict)
print('Finished detection. %d object detected.'%(len(pred_map_cls[0])))
dump_dir = os.path.join(demo_dir, '%s_results'%(FLAGS.dataset))
if not os.path.exists(dump_dir): os.mkdir(dump_dir)
MODEL.dump_results(end_points, dump_dir, DC, True)
print('Dumped detection results to folder %s'%(dump_dir))