-
Notifications
You must be signed in to change notification settings - Fork 3
/
pillar_feature_net.py
64 lines (53 loc) · 3.12 KB
/
pillar_feature_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class PillarFeatureNet(nn.Module):
def __init__(self, voxel_size, point_cloud_range, in_channel, out_channel):
super().__init__()
# Pillar Feature Dimension (C in paper)
self.out_channel = out_channel
self.vx, self.vy = voxel_size[0], voxel_size[1]
self.x_offset = voxel_size[0] / 2 + point_cloud_range[0]
self.y_offset = voxel_size[1] / 2 + point_cloud_range[1]
self.x_l = int((point_cloud_range[3] - point_cloud_range[0]) / voxel_size[0])
self.y_l = int((point_cloud_range[4] - point_cloud_range[1]) / voxel_size[1])
# 1x1 conv = linear layer
self.conv = nn.Conv1d(in_channel, out_channel, 1, bias=False)
self.bn = nn.BatchNorm1d(out_channel, eps=1e-3, momentum=0.01)
def forward(self, pillars, coors_batch, npoints_per_pillar):
device = pillars.device
# Offset to centroid of points for each pillar - (xc,yc,zc)
offset_pt_center = pillars[:, :, :3] - torch.sum(pillars[:, :, :3], dim=1, keepdim=True) / npoints_per_pillar[:, None, None] # (p1 + p2 + ... + pb, num_points, 3)
# Offset to Pillar center calculation - (xp , yp)
x_offset_pi_center = pillars[:, :, :1] - (coors_batch[:, None, 1:2] * self.vx + self.x_offset)
y_offset_pi_center = pillars[:, :, 1:2] - (coors_batch[:, None, 2:3] * self.vy + self.y_offset)
# Concatenating to get augmented 9 dim vector (x,y,z,r,xc,yc,zc,xp,yp)
features = torch.cat([pillars, offset_pt_center, x_offset_pi_center, y_offset_pi_center], dim=-1)
features[:, :, 0:1] = x_offset_pi_center # tmp
features[:, :, 1:2] = y_offset_pi_center # tmp
# Zero padding if number of points is less than threshold
voxel_ids = torch.arange(0, pillars.size(1)).to(device) # (num_points, )
mask = voxel_ids[:, None] < npoints_per_pillar[None, :] # (num_points, p1 + p2 + ... + pb)
mask = mask.permute(1, 0).contiguous()
features *= mask[:, :, None]
# Pointwise features using simple pointNet - (C,P,N)
features = features.permute(0, 2, 1).contiguous()
features = self.conv(features)
features = self.bn(features)
features = F.relu(features)
# Max pooling to get (C,P)
pooling_features = torch.max(features, dim=-1)[0]
# The features are scattered back to the original pillar locations to give (B,C,H,W), - batch size
batched_canvas = []
bs = coors_batch[-1, 0] + 1
for i in range(bs):
cur_coors_idx = coors_batch[:, 0] == i
cur_coors = coors_batch[cur_coors_idx, :]
cur_features = pooling_features[cur_coors_idx]
canvas = torch.zeros((self.x_l, self.y_l, self.out_channel), dtype=torch.float32, device=device)
canvas[cur_coors[:, 1], cur_coors[:, 2]] = cur_features
canvas = canvas.permute(2, 1, 0).contiguous()
batched_canvas.append(canvas)
batched_canvas = torch.stack(batched_canvas, dim=0) # (bs, in_channel, self.y_l, self.x_l)
return batched_canvas