-
Notifications
You must be signed in to change notification settings - Fork 30
/
conv_modules.py
60 lines (50 loc) · 1.79 KB
/
conv_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import util
from torch import nn
import numpy as np
import torchvision
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
import torch.nn.functional as F
def normalize_imagenet(x):
''' Normalize input images according to ImageNet standards.
Args:
x (tensor): input images
'''
x = x.clone()
x[:, 0] = (x[:, 0] - 0.485) / 0.229
x[:, 1] = (x[:, 1] - 0.456) / 0.224
x[:, 2] = (x[:, 2] - 0.406) / 0.225
return x
def init_weights_normal(m):
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
class Resnet18(nn.Module):
r''' ResNet-18 encoder network for image input.
Args:
c_dim (int): output dimension of the latent embedding
normalize (bool): whether the input images should be normalized
use_linear (bool): whether a final linear layer should be used
'''
def init_weights_normal(m):
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
def __init__(self, c_dim, normalize=True, use_linear=True):
super().__init__()
self.normalize = normalize
self.use_linear = use_linear
self.features = torchvision.models.resnet18(pretrained=True)
self.features.fc = nn.Sequential()
if use_linear:
self.fc = nn.Linear(512, c_dim)
self.fc.apply(init_weights_normal)
elif c_dim == 512:
self.fc = nn.Sequential()
else:
raise ValueError('c_dim must be 512 if use_linear is False')
def forward(self, input):
x = (input + 1) / 2
if self.normalize:
x = normalize_imagenet(x)
net = self.features(x)
out = self.fc(net)
return out