-
Notifications
You must be signed in to change notification settings - Fork 5
/
swin_transformer.py
787 lines (658 loc) · 31.9 KB
/
swin_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement Transformer Class for Swin Transformer V2
"""
from types import TracebackType
import paddle
from paddle.framework import dtype
import paddle.nn as nn
from droppath import DropPath
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
Use this layer to avoid if condition in some forward methods
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class PatchEmbedding(nn.Layer):
"""Patch Embeddings
Apply patch embeddings on input images. Embeddings is implemented using a Conv2D op.
Attributes:
image_size: int, input image size, default: 224
patch_size: int, size of patch, default: 4
in_channels: int, input image channels, default: 3
embed_dim: int, embedding dimension, default: 96
"""
def __init__(self, image_size=224, patch_size=4, in_channels=3, embed_dim=96):
super().__init__()
image_size = (image_size, image_size) # TODO: add to_2tuple
patch_size = (patch_size, patch_size)
patches_resolution = [image_size[0]//patch_size[0], image_size[1]//patch_size[1]]
self.image_size = image_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_channels = in_channels
self.embed_dim = embed_dim
self.patch_embed = nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size)
w_attr, b_attr = self._init_weights_layernorm()
self.norm = nn.LayerNorm(embed_dim,
weight_attr=w_attr,
bias_attr=b_attr)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x):
x = self.patch_embed(x) # [batch, embed_dim, h, w] h,w = patch_resolution
x = x.flatten(start_axis=2, stop_axis=-1) # [batch, embed_dim, h*w] h*w = num_patches
x = x.transpose([0, 2, 1]) # [batch, h*w, embed_dim]
x = self.norm(x) # [batch, num_patches, embed_dim]
return x
class PatchMerging(nn.Layer):
""" Patch Merging class
Merge multiple patch into one path and keep the out dim.
Spefically, merge adjacent 2x2 patches(dim=C) into 1 patch.
The concat dim 4*C is rescaled to 2*C
Attributes:
input_resolution: tuple of ints, the size of input
dim: dimension of single patch
reduction: nn.Linear which maps 4C to 2C dim
norm: nn.LayerNorm, applied after linear layer.
"""
def __init__(self, input_resolution, dim):
super(PatchMerging, self).__init__()
self.input_resolution = input_resolution
self.dim = dim
w_attr_1, b_attr_1 = self._init_weights()
self.reduction = nn.Linear(4 * dim,
2 * dim,
weight_attr=w_attr_1,
bias_attr=False)
w_attr_2, b_attr_2 = self._init_weights_layernorm()
self.norm = nn.LayerNorm(4*dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x):
h, w = self.input_resolution
b, _, c = x.shape
x = x.reshape([b, h, w, c])
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = paddle.concat([x0, x1, x2, x3], -1) #[B, H/2, W/2, 4*C]
x = x.reshape([b, -1, 4*c]) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x)
return x
class Mlp(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
act: GELU
dropout1: dropout after fc1
dropout2: dropout after fc2
"""
def __init__(self, in_features, hidden_features, dropout):
super(Mlp, self).__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(in_features,
hidden_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(hidden_features,
in_features,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Mlp_Relu(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
act: RELU
dropout1: dropout after fc1
dropout2: dropout after fc2
"""
def __init__(self, in_features, hidden_features, out_features, dropout):
super(Mlp_Relu, self).__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(in_features,
hidden_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(hidden_features,
out_features,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.ReLU()
self.dropout = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class WindowAttention(nn.Layer):
"""Window based multihead attention, with relative position bias.
Both shifted window and non-shifted window are supported.
Attributes:
dim: int, input dimension (channels)
window_size: int, height and width of the window
num_heads: int, number of attention heads
qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
attention_dropout: float, dropout of attention
dropout: float, dropout for output
"""
def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attention_dropout=0.,
dropout=0.):
super(WindowAttention, self).__init__()
self.window_size = window_size
self.num_heads = num_heads
self.dim = dim
self.dim_head = dim // num_heads
self.scale = qk_scale or self.dim_head ** -0.5
self.relative_position_bias_table = paddle.create_parameter(
shape=[(2 * window_size[0] -1) * (2 * window_size[1] - 1), num_heads],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
# relative position index for each token inside window
coords_h = paddle.arange(self.window_size[0])
coords_w = paddle.arange(self.window_size[1])
coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w]
coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w]
# 2, window_h * window_w, window_h * window_h
relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)
# winwod_h*window_w, window_h*window_w, 2
relative_coords = relative_coords.transpose([1, 2, 0])
## Swin-T v1
# relative_coords[:, :, 0] += self.window_size[0] - 1
# relative_coords[:, :, 1] += self.window_size[1] - 1
# relative_coords[:, :, 0] *= 2* self.window_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # [window_size * window_size, window_size*window_size]
# self.register_buffer("relative_position_index", relative_position_index)
## Swin-T v2, log-spaced coordinates, Eq.(4)
log_relative_position_index = paddle.multiply(relative_coords.cast(dtype='float32').sign(),
paddle.log((relative_coords.cast(dtype='float32').abs()+1)))
self.register_buffer("log_relative_position_index", log_relative_position_index)
## Swin-T v2, small meta network, Eq.(3)
self.cpb = Mlp_Relu(in_features=2, # delta x, delta y
hidden_features=512, # TODO: hidden dims
out_features=self.num_heads,
dropout=dropout)
w_attr_1, b_attr_1 = self._init_weights()
self.qkv = nn.Linear(dim,
dim * 3,
weight_attr=w_attr_1,
bias_attr=b_attr_1 if qkv_bias else False)
self.attn_dropout = nn.Dropout(attention_dropout)
w_attr_2, b_attr_2 = self._init_weights()
self.proj = nn.Linear(dim,
dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
# Swin-T v2, Scaled cosine attention
self.tau = paddle.create_parameter(
shape = [num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1]],
dtype='float32',
default_initializer=paddle.nn.initializer.Constant(1))
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def transpose_multihead(self, x):
new_shape = x.shape[:-1] + [self.num_heads, self.dim_head]
x = x.reshape(new_shape)
x = x.transpose([0, 2, 1, 3])
return x
def get_relative_pos_bias_from_pos_index(self):
# relative_position_bias_table is a ParamBase object
# https://github.com/PaddlePaddle/Paddle/blob/067f558c59b34dd6d8626aad73e9943cf7f5960f/python/paddle/fluid/framework.py#L5727
table = self.relative_position_bias_table # N x num_heads
# index is a tensor
index = self.relative_position_index.reshape([-1]) # window_h*window_w * window_h*window_w
# NOTE: paddle does NOT support indexing Tensor by a Tensor
relative_position_bias = paddle.index_select(x=table, index=index)
return relative_position_bias
def get_continuous_relative_position_bias(self):
# The continuous position bias approach adopts a small meta network on the relative coordinates
continuous_relative_position_bias = self.cpb(self.log_relative_position_index)
return continuous_relative_position_bias
def forward(self, x, mask=None):
qkv = self.qkv(x).chunk(3, axis=-1) # {list:3}
q, k, v = map(self.transpose_multihead, qkv) # [bs*num_window=1*64,4,49,32] -> [bs*num_window=1*16,8,49,32]-> [bs*num_window=1*4,16,49,32]->[bs*num_window=1*1,32,49,32]
# Swin-T v2, Scaled cosine attention
qk = paddle.matmul(q, k, transpose_y=True) # [bs*num_window=1*64,num_heads=4,49,49] -> [bs*num_window=1*16,num_heads=8,49,49] -> [bs*num_window=1*4,num_heads=16,49,49] -> [bs*num_window=1*1,num_heads=32,49,49]
q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3)
k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3)
attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6)
attn = attn/paddle.clip(self.tau.unsqueeze(0), min=0.01)
## Swin-T v1
# relative_position_bias = self.get_relative_pos_bias_from_pos_index() #[2401,num_heads=4]->[2401,8]->[2401,16]->[2401,32]
## Swin-T v2
relative_position_bias = self.get_continuous_relative_position_bias()
relative_position_bias = relative_position_bias.reshape(
[self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1]) # [49,49,num_heads=4]->[49,49,8]->[49,49,16]->[49,49,32]
# nH, window_h*window_w, window_h*window_w
relative_position_bias = relative_position_bias.transpose([2, 0, 1]) # [bs*num_window=1*64,49,49]->[1*16,49,49]->[1*4,49,49]->[1*1,49,49]
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.reshape(
[x.shape[0] // nW, nW, self.num_heads, x.shape[1], x.shape[1]])
attn += mask.unsqueeze(1).unsqueeze(0)
attn = attn.reshape([-1, self.num_heads, x.shape[1], x.shape[1]])
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_dropout(attn) # [bs*num_window=1*64,4,49,49]->[1*16,8,49,49]->[1*4,16,49,49]->[1*1,32,49,49]
z = paddle.matmul(attn, v) # [bs*num_window=1*64,4,49,32]->[1*16,8,49,32]->[1*4,16,49,32]->[1*1,32,49,32]
z = z.transpose([0, 2, 1, 3])
new_shape = z.shape[:-2] + [self.dim]
z = z.reshape(new_shape)
z = self.proj(z)
z = self.proj_dropout(z) # [512,49,96]->[128,49,192]->[32,49,384]->[8,49,768]
return z
def windows_partition(x, window_size):
""" partite windows into window_size x window_size
Args:
x: Tensor, shape=[b, h, w, c]
window_size: int, window size
Returns:
x: Tensor, shape=[num_windows*b, window_size, window_size, c]
"""
B, H, W, C = x.shape
x = x.reshape([B, H//window_size, window_size, W//window_size, window_size, C]) # [bs,num_window,window_size,num_window,window_size,C]
x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,num_window,window_size,window_Size,C]
x = x.reshape([-1, window_size, window_size, C]) #(bs*num_windows,window_size, window_size, C)
return x
def windows_reverse(windows, window_size, H, W):
""" Window reverse
Args:
windows: (n_windows * B, window_size, window_size, C)
window_size: (int) window size
H: (int) height of image
W: (int) width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1]) # [bs,num_window,num_window,window_size,window_Size,C]
x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,window_size,num_window,window_size,C]
x = x.reshape([B, H, W, -1]) #(bs,num_windows*window_size, num_windows*window_size, C)
return x
class SwinTransformerBlock(nn.Layer):
"""Swin transformer block
Contains window multi head self attention, droppath, mlp, norm and residual.
Attributes:
dim: int, input dimension (channels)
input_resolution: int, input resoultion
num_heads: int, number of attention heads
window_size: int, window size, default: 7
shift_size: int, shift size for SW-MSA, default: 0
mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4.
qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
dropout: float, dropout for output, default: 0.
attention_dropout: float, dropout of attention, default: 0.
droppath: float, drop path rate, default: 0.
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0., extra_norm=False,
attention_dropout=0., droppath=0.):
super(SwinTransformerBlock, self).__init__()
self.dim = dim
self.extra_norm = extra_norm # Swin-T v2, introduce a LN unit on the main branch every 6 layers
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
w_attr_1, b_attr_1 = self._init_weights_layernorm()
self.norm1 = nn.LayerNorm(dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
self.attn = WindowAttention(dim,
window_size=(self.window_size, self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attention_dropout=attention_dropout,
dropout=dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else None
w_attr_2, b_attr_2 = self._init_weights_layernorm()
self.norm2 = nn.LayerNorm(dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim*mlp_ratio),
dropout=dropout)
if extra_norm:
# Swin-T v2, introduce a LN unit on the main branch every 6 layers
w_attr_3, b_attr_3 = self._init_weights_layernorm()
self.norm3 = nn.LayerNorm(dim,
weight_attr=w_attr_3,
bias_attr=b_attr_3)
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = paddle.zeros((1, H, W, 1))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = windows_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size))
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = paddle.where(attn_mask != 0,
paddle.ones_like(attn_mask) * float(-100.0),
attn_mask)
attn_mask = paddle.where(attn_mask == 0,
paddle.zeros_like(attn_mask),
attn_mask)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
h = x
# x = self.norm1(x) # Swin-T v1, pre-norm
new_shape = [B, H, W, C]
x = x.reshape(new_shape) # [bs,H,W,C]
if self.shift_size > 0:
shifted_x = paddle.roll(x,
shifts=(-self.shift_size, -self.shift_size),
axis=(1, 2)) # [bs,H,W,C]
else:
shifted_x = x
x_windows = windows_partition(shifted_x, self.window_size) # [bs*num_windows,7,7,C]
x_windows = x_windows.reshape([-1, self.window_size * self.window_size, C]) # [bs*num_windows,7*7,C]
attn_windows = self.attn(x_windows, mask=self.attn_mask) # [bs*num_windows,7*7,C]
attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) # [bs*num_windows,7,7,C]
shifted_x = windows_reverse(attn_windows, self.window_size, H, W) # [bs,H,W,C]
# reverse cyclic shift
if self.shift_size > 0:
x = paddle.roll(shifted_x,
shifts=(self.shift_size, self.shift_size),
axis=(1, 2))
else:
x = shifted_x
x = x.reshape([B, H*W, C]) # [bs,H*W,C]
x = self.norm1(x) # Swin-T v2, post-norm
if self.drop_path is not None:
x = h + self.drop_path(x)
else:
x = h + x
h = x # [bs,H*W,C]
# x = self.norm2(x) # Swin-T v1, pre-norm
x = self.mlp(x) # [bs,H*W,C]
x = self.norm2(x) # Swin-T v2, post-norm
if self.drop_path is not None:
x = h + self.drop_path(x)
else:
x = h + x
if self.extra_norm: # Swin-T v2
x = self.norm3(x)
return x
class SwinTransformerStage(nn.Layer):
"""Stage layers for swin transformer
Stage layers contains a number of Transformer blocks and an optional
patch merging layer, patch merging is not applied after last stage
Attributes:
dim: int, embedding dimension
input_resolution: tuple, input resoliution
depth: list, num of blocks in each stage
blocks: nn.LayerList, contains SwinTransformerBlocks for one stage
downsample: PatchMerging, patch merging layer, none if last stage
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0.,
attention_dropout=0., droppath=0., downsample=None, sum_depth=None):
super(SwinTransformerStage, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.LayerList()
for i in range(depth):
self.blocks.append(
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
extra_norm = sum_depth!=None and (i+sum_depth+1)%6==0, # Swin-T v2
qkv_bias=qkv_bias, qk_scale=qk_scale,
dropout=dropout, attention_dropout=attention_dropout,
droppath=droppath[i] if isinstance(droppath, list) else droppath))
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim)
else:
self.downsample = None
def forward(self, x):
for block in self.blocks:
x = block(x) # [bs,56*56,96] -> [bs,28*28,96*2] -> [bs,14*14,96*4] -> [bs,7*7,96*8]
if self.downsample is not None:
x = self.downsample(x) # [bs,28*28,96*2] -> [bs,14*14,96*4] -> [bs,7*7,96*8]
return x
class SwinTransformer(nn.Layer):
"""SwinTransformer class
Attributes:
num_classes: int, num of image classes
num_stages: int, num of stages contains patch merging and Swin blocks
depths: list of int, num of Swin blocks in each stage
num_heads: int, num of heads in attention module
embed_dim: int, output dimension of patch embedding
num_features: int, output dimension of whole network before classifier
mlp_ratio: float, hidden dimension of mlp layer is mlp_ratio * mlp input dim
qkv_bias: bool, if True, set qkv layers have bias enabled
qk_scale: float, scale factor for qk.
ape: bool, if True, set to use absolute positional embeddings
window_size: int, size of patch window for inputs
dropout: float, dropout rate for linear layer
dropout_attn: float, dropout rate for attention
patch_embedding: PatchEmbedding, patch embedding instance
patch_resolution: tuple, number of patches in row and column
position_dropout: nn.Dropout, dropout op for position embedding
stages: SwinTransformerStage, stage instances.
norm: nn.LayerNorm, norm layer applied after transformer
avgpool: nn.AveragePool2D, pooling layer before classifer
fc: nn.Linear, classifier op.
"""
def __init__(self,
image_size=224,
patch_size=4,
in_channels=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
dropout=0.,
attention_dropout=0.,
droppath=0.,
ape=False,
extra_norm=False):
super(SwinTransformer, self).__init__()
self.num_classes = num_classes
self.num_stages = len(depths)
self.embed_dim = embed_dim
self.num_features = int(self.embed_dim * 2 ** (self.num_stages - 1))
self.mlp_ratio = mlp_ratio
self.ape = ape
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim)
num_patches = self.patch_embedding.num_patches
self.patches_resolution = self.patch_embedding.patches_resolution
if self.ape:
self.absolute_positional_embedding = paddle.nn.ParameterList([
paddle.create_parameter(
shape=[1, num_patches, self.embed_dim], dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))])
self.position_dropout = nn.Dropout(dropout)
depth_decay = [x.item() for x in paddle.linspace(0, droppath, sum(depths))]
self.stages = nn.LayerList()
for stage_idx in range(self.num_stages):
stage = SwinTransformerStage(
dim=int(self.embed_dim * 2 ** stage_idx),
input_resolution=(
self.patches_resolution[0] // (2 ** stage_idx),
self.patches_resolution[1] // (2 ** stage_idx)),
depth=depths[stage_idx],
sum_depth=sum(depths[:stage_idx]) if extra_norm else None, # Swin-T v2
num_heads=num_heads[stage_idx],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=dropout,
attention_dropout=attention_dropout,
droppath=depth_decay[
sum(depths[:stage_idx]):sum(depths[:stage_idx+1])],
downsample=PatchMerging if (
stage_idx < self.num_stages-1) else None,
)
self.stages.append(stage)
w_attr_1, b_attr_1 = self._init_weights_layernorm()
self.norm = nn.LayerNorm(self.num_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
self.avgpool = nn.AdaptiveAvgPool1D(1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc = nn.Linear(self.num_features,
self.num_classes,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward_features(self, x):
x = self.patch_embedding(x) # [bs,H*W,96]
if self.ape:
x = x + self.absolute_positional_embedding
x = self.position_dropout(x) # [bs,H*W,96]
for stage in self.stages:
x = stage(x) # [bs,784,192],[bs,196,384],[bs,49,768],[bs,49,768]
x = self.norm(x) # [bs,49,768]
x = x.transpose([0, 2, 1])
x = self.avgpool(x) # [bs,768,1]
x = x.flatten(1) # [bs,768]
return x
def forward(self, x):
x = self.forward_features(x) # [bs,768]
x = self.fc(x) # [bs,1000]
return x
def build_swin(config):
model = SwinTransformer(
image_size=config.DATA.IMAGE_SIZE,
patch_size=config.MODEL.TRANS.PATCH_SIZE,
in_channels=config.MODEL.TRANS.IN_CHANNELS,
embed_dim=config.MODEL.TRANS.EMBED_DIM,
num_classes=config.MODEL.NUM_CLASSES,
depths=config.MODEL.TRANS.STAGE_DEPTHS,
num_heads=config.MODEL.TRANS.NUM_HEADS,
mlp_ratio=config.MODEL.TRANS.MLP_RATIO,
qkv_bias=config.MODEL.TRANS.QKV_BIAS,
qk_scale=config.MODEL.TRANS.QK_SCALE,
ape=config.MODEL.TRANS.APE,
window_size=config.MODEL.TRANS.WINDOW_SIZE,
dropout=config.MODEL.DROPOUT,
attention_dropout=config.MODEL.ATTENTION_DROPOUT,
droppath=config.MODEL.DROP_PATH,
extra_norm=config.MODEL.TRANS.EXTRA_NORM)
return model
if __name__ == '__main__':
from main_single_gpu import get_arguments
from config import get_config
from config import update_config
arguments = get_arguments()
config = get_config()
config = update_config(config, arguments)
model = build_swin(config)
image = paddle.randn([1, 3, 224, 224])
output = model(image)
print(output.shape)