-
Notifications
You must be signed in to change notification settings - Fork 3
/
detection_head.py
78 lines (67 loc) · 2.88 KB
/
detection_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
#SSD Detection Head - refer paper
class Head(nn.Module):
def __init__(self, in_channel, n_anchors, n_classes):
super().__init__()
self.conv_cls = nn.Conv2d(in_channel, n_anchors*n_classes, 1)
self.conv_reg = nn.Conv2d(in_channel, n_anchors*7, 1)
self.conv_dir_cls = nn.Conv2d(in_channel, n_anchors*2, 1)
# in consitent with mmdet3d
conv_layer_id = 0
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, mean=0, std=0.01)
if conv_layer_id == 0:
prior_prob = 0.01
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
nn.init.constant_(m.bias, bias_init)
else:
nn.init.constant_(m.bias, 0)
conv_layer_id += 1
def forward(self, x):
'''
x: (bs, 384, 248, 216)
return:
bbox_cls_pred: (bs, n_anchors*3, 248, 216)
bbox_pred: (bs, n_anchors*7, 248, 216)
bbox_dir_cls_pred: (bs, n_anchors*2, 248, 216)
'''
bbox_cls_pred = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
bbox_dir_cls_pred = self.conv_dir_cls(x)
return bbox_cls_pred, bbox_pred, bbox_dir_cls_pred
class Neck(nn.Module):
def __init__(self, in_channels, upsample_strides, out_channels):
super().__init__()
assert len(in_channels) == len(upsample_strides)
assert len(upsample_strides) == len(out_channels)
self.decoder_blocks = nn.ModuleList()
for i in range(len(in_channels)):
decoder_block = []
decoder_block.append(nn.ConvTranspose2d(in_channels[i],
out_channels[i],
upsample_strides[i],
stride=upsample_strides[i],
bias=False))
decoder_block.append(nn.BatchNorm2d(out_channels[i], eps=1e-3, momentum=0.01))
decoder_block.append(nn.ReLU(inplace=True))
self.decoder_blocks.append(nn.Sequential(*decoder_block))
# in consitent with mmdet3d
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
'''
x: [(bs, 64, 248, 216), (bs, 128, 124, 108), (bs, 256, 62, 54)]
return: (bs, 384, 248, 216)
'''
outs = []
for i in range(len(self.decoder_blocks)):
xi = self.decoder_blocks[i](x[i]) # (bs, 128, 248, 216)
outs.append(xi)
out = torch.cat(outs, dim=1)
return out