-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
132 lines (103 loc) · 3.54 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import functools
from numpy import imag
import torch
from torch import nn
from torchvision.models import vgg19
from utils import adaIN
def wrap_reflection_pad(network):
for name, m in network.named_modules():
if not isinstance(m, nn.Conv2d):
continue
if m.padding[1] == 0:
continue
x_pad = int(m.padding[1])
m.padding = (0, 0)
names = name.split('.')
root = functools.reduce(lambda o, i: getattr(o, i), [network] + names[:-1])
setattr(
root, names[-1],
nn.Sequential(
nn.ReflectionPad2d((x_pad, x_pad, x_pad, x_pad)),
m
)
)
class VGG19_Reflection_Encoder(nn.Module):
BLOCKS = {
'relu1_1': 1 + 1,
'relu2_1': 6 + 1,
'relu3_1': 11 + 1,
'relu4_1': 20 + 1
}
REMOVE = [
#4 # RemoveFirstMaxPool
]
def __init__(self):
super().__init__()
base = vgg19(pretrained=True)
# Remove layers
for i in self.REMOVE:
base.features[i] = nn.Identity()
offset = 0
self.feature_names = list(self.BLOCKS.keys())
self.feature_extractor = nn.Module()
for name, output_num_layer in self.BLOCKS.items():
setattr(
self.feature_extractor,
name,
nn.Sequential(*base.features[offset:output_num_layer])
)
offset = output_num_layer
self.feature_extractor.eval()
self.feature_extractor.requires_grad_(False)
wrap_reflection_pad(self)
def forward(self, x):
output = {}
for name in self.feature_names:
x = getattr(self.feature_extractor, name)(x)
output[name] = x
return output, x
class Reflection_Decoder(nn.Module):
def __init__(self):
super().__init__()
self.base = nn.Sequential(
nn.Conv2d(512, 256, (3, 3), padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(256, 256, (3, 3), padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, (3, 3), padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, (3, 3), padding=1),
nn.ReLU(),
nn.Conv2d(256, 128, (3, 3), padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 128, (3, 3), padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, (3, 3), padding=1),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(64, 64, (3, 3), padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, (3, 3), padding=1),
)
wrap_reflection_pad(self)
def forward(self, x):
return self.base(x)
class AdaINModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = VGG19_Reflection_Encoder()
self.decoder = Reflection_Decoder()
def forward(self, images_content, images_style, alpha=1.0):
if isinstance(alpha, float):
alpha = torch.tensor(alpha)[None, None, None, None]
else:
alpha = alpha[:, None, None, None]
alpha = alpha.to(dtype=images_content.dtype, device=images_content.device)
_, feat_content = self.encoder(images_content)
_, feat_style = self.encoder(images_style)
t = adaIN(feat_content, feat_style)
interpolate_t = alpha * t + (1. - alpha) * feat_content
g_t = self.decoder(interpolate_t)
return g_t