-
Notifications
You must be signed in to change notification settings - Fork 18
/
simba.py
697 lines (586 loc) · 26.1 KB
/
simba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.fft
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import math
import numpy as np
from mamba_ssm import Mamba
from einops import rearrange, repeat, einsum
class EinFFT(nn.Module):
def __init__(self, dim):
super().__init__()
self.hidden_size = dim #768
self.num_blocks = 4
self.block_size = self.hidden_size // self.num_blocks
assert self.hidden_size % self.num_blocks == 0
self.sparsity_threshold = 0.01
self.scale = 0.02
self.complex_weight_1 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
self.complex_weight_2 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
self.complex_bias_1 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, dtype=torch.float32) * self.scale)
self.complex_bias_2 = nn.Parameter(torch.randn(2, self.num_blocks, self.block_size, dtype=torch.float32) * self.scale)
def multiply(self, input, weights):
return torch.einsum('...bd,bdk->...bk', input, weights)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.view(B, N, self.num_blocks, self.block_size )
x = torch.fft.fft2(x, dim=(1,2), norm='ortho') # FFT on N dimension
x_real_1 = F.relu(self.multiply(x.real, self.complex_weight_1[0]) - self.multiply(x.imag, self.complex_weight_1[1]) + self.complex_bias_1[0])
x_imag_1 = F.relu(self.multiply(x.real, self.complex_weight_1[1]) + self.multiply(x.imag, self.complex_weight_1[0]) + self.complex_bias_1[1])
x_real_2 = self.multiply(x_real_1, self.complex_weight_2[0]) - self.multiply(x_imag_1, self.complex_weight_2[1]) + self.complex_bias_2[0]
x_imag_2 = self.multiply(x_real_1, self.complex_weight_2[1]) + self.multiply(x_imag_1, self.complex_weight_2[0]) + self.complex_bias_2[1]
x = torch.stack([x_real_2, x_imag_2], dim=-1).float()
x = F.softshrink(x, lambd=self.sparsity_threshold) if self.sparsity_threshold else x
x = torch.view_as_complex(x)
x = torch.fft.ifft2(x, dim=(1,2), norm="ortho")
# RuntimeError: "fused_dropout" not implemented for 'ComplexFloat'
x = x.to(torch.float32)
x = x.reshape(B, N, C)
return x
# For Fast Implementation use MambaLayer,# This implementation is slow, only for checking GFLOPS and other paramater,
# For more details please refer to https://github.com/johnma2006/mamba-minimal/blob/master/model.py
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state = 64, expand = 2, d_conv = 4, conv_bias = True, bias = False ):
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
super().__init__()
self.d_model = d_model # Model dimension d_model
self.d_state=d_state # SSM state expansion factor
self.d_conv=d_conv # Local convolution width
self.expand=expand # Block expansion factor
self.conv_bias=conv_bias
self.bias=bias
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16)
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=self.bias)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=self.conv_bias,
kernel_size=self.d_conv,
groups=self.d_inner,
padding=self.d_conv - 1,
)
# x_proj takes in `x` and outputs the input-specific Δ, B, C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
# dt_proj projects Δ from dt_rank to d_in
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
A = repeat(torch.arange(1, self.d_state + 1), 'n -> d n', d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=self.bias)
def forward(self, x):
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(b, l, d) = x.shape
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
x = rearrange(x, 'b l d_in -> b d_in l')
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, 'b d_in l -> b l d_in')
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res)
output = self.out_proj(y)
return output
def ssm(self, x):
"""Runs the SSM. See:
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
Args:
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d_in)
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.shape
# Compute ∆ A B C D, the state space parameters.
# A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
# ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
# and is why Mamba is called **selective** state spaces)
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
return y
def selective_scan(self, u, delta, A, B, C, D):
"""Does selective scan algorithm. See:
- Section 2 State Space Models in the Mamba paper [1]
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
This is the classic discrete state space formula:
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
Args:
u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
delta: shape (b, l, d_in)
A: shape (d_in, n)
B: shape (b, l, n)
C: shape (b, l, n)
D: shape (d_in,)
Returns:
output: shape (b, l, d_in)
Official Implementation:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
"""
(b, l, d_in) = u.shape
n = A.shape[1]
# Discretize continuous parameters (A, B)
# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
# "A is the more important term and the performance doesn't change much with the simplification on B"
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
# is additionally hardware-aware (like FlashAttention).
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
y = y + u * D
return y
class MambaLayer(nn.Module):
def __init__(self, dim, d_state=64, d_conv=4, expand=2):
super().__init__()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand # Block expansion factor
)
def forward(self, x):
# print('x',x.shape)
B, L, C = x.shape
x_norm = self.norm(x)
x_mamba = self.mamba(x_norm)
return x_mamba
def rand_bbox(size, lam, scale=1):
W = size[1] // scale
H = size[2] // scale
cut_rat = np.sqrt(1. - lam)
cut_w = np.int_(W * cut_rat)
cut_h = np.int_(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
class PVT2FFN(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.fc2(x)
return x
class FFN(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, in_features)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class ClassBlock(nn.Module):
def __init__(self, dim, mlp_ratio, norm_layer=nn.LayerNorm, cm_type = 'mlp'):
super().__init__()
# self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = MambaLayer(dim) #MambaBlock(d_model=dim)
if cm_type == 'EinFFT':
self.mlp = EinFFT(dim)
else:
self.mlp = FFN(dim, int(dim * mlp_ratio))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
cls_embed = x[:, :1]
cls_embed = cls_embed + self.attn(x[:, :1])
cls_embed = cls_embed + self.mlp(self.norm2(cls_embed), H, W)
return torch.cat([cls_embed, x[:, 1:]], dim=1)
class Block_mamba(nn.Module):
def __init__(self,
dim,
mlp_ratio,
drop_path=0.,
norm_layer=nn.LayerNorm,
sr_ratio=1,
cm_type = 'mlp'
):
super().__init__()
# self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = MambaLayer(dim) #MambaBlock(d_model=dim)
if cm_type == 'EinFFT':
self.mlp = EinFFT(dim)
else:
self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(x))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class DownSamples(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.norm = nn.LayerNorm(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class Stem(nn.Module):
def __init__(self, in_channels, stem_hidden_dim, out_channels):
super().__init__()
hidden_dim = stem_hidden_dim
self.conv = nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
padding=3, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
)
self.proj = nn.Conv2d(hidden_dim,
out_channels,
kernel_size=3,
stride=2,
padding=1)
self.norm = nn.LayerNorm(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.conv(x)
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class SiMBA(nn.Module):
def __init__(self,
in_chans=3,
num_classes=1000,
stem_hidden_dim = 32,
embed_dims=[64, 128, 320, 448],
mlp_ratios=[8, 8, 4, 4],
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
sr_ratios=[4, 2, 1, 1],
num_stages=4,
token_label=True,
**kwargs
):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
alpha=5#
for i in range(num_stages):
if i == 0:
patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
else:
patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])
block = nn.ModuleList([Block_mamba(
dim = embed_dims[i],
mlp_ratio = mlp_ratios[i],
drop_path=dpr[cur + j],
norm_layer=norm_layer,
sr_ratio = sr_ratios[i],
cm_type='mlp') # Change this to run EinFFT based Channel Mixer, cm_type='EinFFT'
for j in range(depths[i])])
norm = norm_layer(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
post_layers = ['ca']
self.post_network = nn.ModuleList([
ClassBlock(
dim = embed_dims[-1],
mlp_ratio = mlp_ratios[-1],
norm_layer=norm_layer,
cm_type='mlp') # Change this to run EinFFT based Channel Mixer, cm_type='EinFFT'
for _ in range(len(post_layers))
])
# classification head
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
##################################### token_label #####################################
self.return_dense = token_label
self.mix_token = token_label
self.beta = 1.0
self.pooling_scale = 8
if self.return_dense:
self.aux_head = nn.Linear(
embed_dims[-1],
num_classes) if num_classes > 0 else nn.Identity()
##################################### token_label #####################################
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward_cls(self, x, H, W):
B, N, C = x.shape
cls_tokens = x.mean(dim=1, keepdim=True)
x = torch.cat((cls_tokens, x), dim=1)
for block in self.post_network:
x = block(x, H, W)
return x
def forward_features(self, x):
B = x.shape[0]
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
if i != self.num_stages - 1:
norm = getattr(self, f"norm{i + 1}")
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.forward_cls(x, H, W)[:, 0]
norm = getattr(self, f"norm{self.num_stages}")
x = norm(x)
return x
def forward(self, x):
if not self.return_dense:
x = self.forward_features(x)
x = self.head(x)
return x
else:
x, H, W = self.forward_embeddings(x)
# mix token, see token labeling for details.
if self.mix_token and self.training:
lam = np.random.beta(self.beta, self.beta)
patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[
2] // self.pooling_scale
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
temp_x = x.clone()
sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\
self.pooling_scale*bbx2,self.pooling_scale*bby2
temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
x = temp_x
else:
bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
x = self.forward_tokens(x, H, W)
x_cls = self.head(x[:, 0])
x_aux = self.aux_head(
x[:, 1:]
) # generate classes in all feature tokens, see token labeling
if not self.training:
return x_cls + 0.5 * x_aux.max(1)[0]
if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
temp_x = x_aux.clone()
temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
x_aux = temp_x
x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
def forward_tokens(self, x, H, W):
B = x.shape[0]
x = x.view(B, -1, x.size(-1))
for i in range(self.num_stages):
if i != 0:
patch_embed = getattr(self, f"patch_embed{i + 1}")
x, H, W = patch_embed(x)
block = getattr(self, f"block{i + 1}")
for blk in block:
x = blk(x, H, W)
if i != self.num_stages - 1:
norm = getattr(self, f"norm{i + 1}")
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.forward_cls(x, H, W)
norm = getattr(self, f"norm{self.num_stages}")
x = norm(x)
return x
def forward_embeddings(self, x):
patch_embed = getattr(self, f"patch_embed{0 + 1}")
x, H, W = patch_embed(x)
x = x.view(x.size(0), H, W, -1)
return x, H, W
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
@register_model
def simba_s(pretrained=False, **kwargs):
model = SiMBA(
stem_hidden_dim = 32,
embed_dims = [64, 128, 320, 448],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, eps=1e-6),
depths = [3, 4, 6, 3],
sr_ratios = [4, 2, 1, 1],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def simba_b(pretrained=False, **kwargs):
model = SiMBA(
stem_hidden_dim = 64,
embed_dims = [64, 128, 320, 512],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, eps=1e-6),
depths = [3, 4, 12, 3],
sr_ratios = [4, 2, 1, 1],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def simba_l(pretrained=False, **kwargs):
model = SiMBA(
stem_hidden_dim = 64,
embed_dims = [96, 192, 384, 512],
mlp_ratios = [8, 8, 4, 4],
norm_layer = partial(nn.LayerNorm, eps=1e-6),
depths = [3, 6, 18, 3],
sr_ratios = [4, 2, 1, 1],
**kwargs)
model.default_cfg = _cfg()
return model