-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconvert.py
36 lines (31 loc) · 1.12 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
class Convert(nn.Module):
def __init__(self, image_size, backbone_output_dim, os, v_dim):
super(Convert, self).__init__()
size = int(image_size / os)
in_dim = size * size * backbone_output_dim
self.linear = nn.Linear(in_dim, v_dim)
self._init_weight()
def forward(self, x):
x = torch.flatten(x, 1)
out = self.linear(x)
return out
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class GAPConvert(nn.Module):
"""docstring for GAPConvert"""
def __init__(self):
super(GAPConvert, self).__init__()
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.avg_pool(x)
out = torch.flatten(x, 1)
return out