forked from PaddlePaddle/PaddleSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
segformer.py
129 lines (104 loc) · 4.24 KB
/
segformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# The SegFormer code was heavily based on https://github.com/NVlabs/SegFormer
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/NVlabs/SegFormer#license
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from paddleseg.cvlibs import manager
from paddleseg.models import layers
from paddleseg.utils import utils
class MLP(nn.Layer):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose([0, 2, 1])
x = self.proj(x)
return x
@manager.MODELS.add_component
class SegFormer(nn.Layer):
"""
The SegFormer implementation based on PaddlePaddle.
The original article refers to
Xie, Enze, et al. "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers." arXiv preprint arXiv:2105.15203 (2021).
Args:
num_classes (int): The unique number of target classes.
backbone (Paddle.nn.Layer): A backbone network.
embedding_dim (int): The MLP decoder channel dimension.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature.
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""
def __init__(self,
num_classes,
backbone,
embedding_dim,
align_corners=False,
pretrained=None):
super(SegFormer, self).__init__()
self.pretrained = pretrained
self.align_corners = align_corners
self.backbone = backbone
self.num_classes = num_classes
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.backbone.feat_channels
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
self.dropout = nn.Dropout2D(0.1)
self.linear_fuse = layers.ConvBNReLU(
in_channels=embedding_dim * 4,
out_channels=embedding_dim,
kernel_size=1,
bias_attr=False)
self.linear_pred = nn.Conv2D(
embedding_dim, self.num_classes, kernel_size=1)
self.init_weight()
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
def forward(self, x):
feats = self.backbone(x)
c1, c2, c3, c4 = feats
############## MLP decoder on C1-C4 ###########
c1_shape = paddle.shape(c1)
c2_shape = paddle.shape(c2)
c3_shape = paddle.shape(c3)
c4_shape = paddle.shape(c4)
_c4 = self.linear_c4(c4).transpose([0, 2, 1]).reshape(
[0, 0, c4_shape[2], c4_shape[3]])
_c4 = F.interpolate(
_c4,
size=c1_shape[2:],
mode='bilinear',
align_corners=self.align_corners)
_c3 = self.linear_c3(c3).transpose([0, 2, 1]).reshape(
[0, 0, c3_shape[2], c3_shape[3]])
_c3 = F.interpolate(
_c3,
size=c1_shape[2:],
mode='bilinear',
align_corners=self.align_corners)
_c2 = self.linear_c2(c2).transpose([0, 2, 1]).reshape(
[0, 0, c2_shape[2], c2_shape[3]])
_c2 = F.interpolate(
_c2,
size=c1_shape[2:],
mode='bilinear',
align_corners=self.align_corners)
_c1 = self.linear_c1(c1).transpose([0, 2, 1]).reshape(
[0, 0, c1_shape[2], c1_shape[3]])
_c = self.linear_fuse(paddle.concat([_c4, _c3, _c2, _c1], axis=1))
logit = self.dropout(_c)
logit = self.linear_pred(logit)
return [
F.interpolate(
logit,
size=paddle.shape(x)[2:],
mode='bilinear',
align_corners=self.align_corners)
]