-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathdecoder.py
162 lines (139 loc) · 5.45 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Sequence, List
from segmentation_models_pytorch.base import modules as md
class UnetDecoderBlock(nn.Module):
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion."""
def __init__(
self,
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolation_mode = interpolation_mode
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention1 = md.Attention(
attention_type, in_channels=in_channels + skip_channels
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
def forward(
self,
feature_map: torch.Tensor,
target_height: int,
target_width: int,
skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
feature_map = F.interpolate(
feature_map,
size=(target_height, target_width),
mode=self.interpolation_mode,
)
if skip_connection is not None:
feature_map = torch.cat([feature_map, skip_connection], dim=1)
feature_map = self.attention1(feature_map)
feature_map = self.conv1(feature_map)
feature_map = self.conv2(feature_map)
feature_map = self.attention2(feature_map)
return feature_map
class UnetCenterBlock(nn.Sequential):
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
super().__init__(conv1, conv2)
class UnetDecoder(nn.Module):
"""The decoder part of the U-Net architecture.
Takes encoded features from different stages of the encoder and progressively upsamples them while
combining with skip connections. This helps preserve fine-grained details in the final segmentation.
"""
def __init__(
self,
encoder_channels: Sequence[int],
decoder_channels: Sequence[int],
n_blocks: int = 5,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
add_center_block: bool = False,
interpolation_mode: str = "nearest",
):
super().__init__()
if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)
# remove first skip with same spatial resolution
encoder_channels = encoder_channels[1:]
# reverse channels to start from head of encoder
encoder_channels = encoder_channels[::-1]
# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels
if add_center_block:
self.center = UnetCenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()
# combine decoder keyword arguments
self.blocks = nn.ModuleList()
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = UnetDecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_batchnorm=use_batchnorm,
attention_type=attention_type,
interpolation_mode=interpolation_mode,
)
self.blocks.append(block)
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
spatial_shapes = [feature.shape[2:] for feature in features]
spatial_shapes = spatial_shapes[::-1]
features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder
head = features[0]
skip_connections = features[1:]
x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
# upsample to the next spatial shape
height, width = spatial_shapes[i + 1]
skip_connection = skip_connections[i] if i < len(skip_connections) else None
x = decoder_block(x, height, width, skip_connection=skip_connection)
return x