-
Notifications
You must be signed in to change notification settings - Fork 1
/
Model.py
63 lines (48 loc) · 2.42 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F
class UpSample(nn.Sequential):
def __init__(self, skip_input, output_features):
super(UpSample, self).__init__()
self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1)
self.leakyreluA = nn.LeakyReLU(0.2)
self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1)
self.leakyreluB = nn.LeakyReLU(0.2)
def forward(self, x, concat_with):
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
return self.leakyreluB(self.convB(self.leakyreluA(self.convA(torch.cat([up_x, concat_with], dim=1)))))
class Decoder(nn.Module):
def __init__(self, num_features=2208, decoder_width=0.5):
super(Decoder, self).__init__()
features = int(num_features * decoder_width)
self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=1)
self.up1 = UpSample(skip_input=features // 1 + 384, output_features=features // 2)
self.up2 = UpSample(skip_input=features // 2 + 192, output_features=features // 4)
self.up3 = UpSample(skip_input=features // 4 + 96, output_features=features // 8)
self.up4 = UpSample(skip_input=features // 8 + 96, output_features=features // 16)
self.conv3 = nn.Conv2d(features // 16, 1, kernel_size=3, stride=1, padding=1)
def forward(self, features):
x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[
11]
x_d0 = self.conv2(x_block4)
x_d1 = self.up1(x_d0, x_block3)
x_d2 = self.up2(x_d1, x_block2)
x_d3 = self.up3(x_d2, x_block1)
x_d4 = self.up4(x_d3, x_block0)
return self.conv3(x_d4)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
import torchvision.models as models
self.original_model = models.densenet161(pretrained=True)
def forward(self, x):
features = [x]
for k, v in self.original_model.features._modules.items(): features.append(v(features[-1]))
return features
class FullModel(nn.Module):
def __init__(self):
super(FullModel, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.decoder(self.encoder(x))