Skip to content

Commit 76a9170

Browse files
committed
add VitSAM
1 parent d714673 commit 76a9170

File tree

6 files changed

+763
-16
lines changed

6 files changed

+763
-16
lines changed

mindcv/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
vgg,
5252
visformer,
5353
vit,
54+
vit_sam,
5455
volo,
5556
xception,
5657
xcit,
@@ -107,6 +108,7 @@
107108
from .vgg import *
108109
from .visformer import *
109110
from .vit import *
111+
from .vit_sam import *
110112
from .volo import *
111113
from .xception import *
112114
from .xcit import *
@@ -165,6 +167,7 @@
165167
__all__.extend(vgg.__all__)
166168
__all__.extend(visformer.__all__)
167169
__all__.extend(vit.__all__)
170+
__all__.extend(vit_sam.__all__)
168171
__all__.extend(volo.__all__)
169172
__all__.extend(["Xception", "xception"])
170173
__all__.extend(xcit.__all__)

mindcv/models/layers/__init__.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
"""layers init"""
2-
from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite
2+
from . import (
3+
activation,
4+
conv_norm_act,
5+
drop_path,
6+
format,
7+
identity,
8+
patch_dropout,
9+
pooling,
10+
selective_kernel,
11+
squeeze_excite,
12+
)
313
from .activation import *
414
from .conv_norm_act import *
515
from .drop_path import *
16+
from .format import *
617
from .identity import *
18+
from .patch_dropout import *
719
from .pooling import *
820
from .selective_kernel import *
921
from .squeeze_excite import *

mindcv/models/layers/format.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from enum import Enum
2+
3+
import mindspore
4+
5+
6+
class Format(str, Enum):
7+
NCHW = 'NCHW'
8+
NHWC = 'NHWC'
9+
NCL = 'NCL'
10+
NLC = 'NLC'
11+
12+
13+
def nchw_to(x: mindspore.Tensor, fmt: Format):
14+
if fmt == Format.NHWC:
15+
x = x.permute(0, 2, 3, 1)
16+
elif fmt == Format.NLC:
17+
x = x.flatten(start_dim=2).transpose((0, 2, 1))
18+
elif fmt == Format.NCL:
19+
x = x.flatten(start_dim=2)
20+
return x
21+
22+
23+
def nhwc_to(x: mindspore.Tensor, fmt: Format):
24+
if fmt == Format.NCHW:
25+
x = x.permute(0, 3, 1, 2)
26+
elif fmt == Format.NLC:
27+
x = x.flatten(start_dim=1, end_dim=2)
28+
elif fmt == Format.NCL:
29+
x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1))
30+
return x

mindcv/models/layers/patch_dropout.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
3+
import mindspore as ms
4+
from mindspore import nn, ops
5+
6+
7+
class PatchDropout(nn.Cell):
8+
"""
9+
https://arxiv.org/abs/2212.00794
10+
"""
11+
def __init__(
12+
self,
13+
prob: float = 0.5,
14+
num_prefix_tokens: int = 1,
15+
ordered: bool = False,
16+
return_indices: bool = False,
17+
):
18+
super().__init__()
19+
assert 0 <= prob < 1.
20+
self.prob = prob
21+
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
22+
self.ordered = ordered
23+
self.return_indices = return_indices
24+
self.sort = ops.Sort()
25+
26+
def forward(self, x):
27+
if not self.training or self.prob == 0.:
28+
if self.return_indices:
29+
return x, None
30+
return x
31+
32+
if self.num_prefix_tokens:
33+
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
34+
else:
35+
prefix_tokens = None
36+
37+
B = x.shape[0]
38+
L = x.shape[1]
39+
num_keep = max(1, int(L * (1. - self.prob)))
40+
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
41+
keep_indices = indices[:, :num_keep]
42+
if self.ordered:
43+
# NOTE does not need to maintain patch order in typical transformer use,
44+
# but possibly useful for debug / visualization
45+
keep_indices, _ = self.sort(keep_indices)
46+
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
47+
x = ops.gather_elements(x, dim=1, index=keep_indices)
48+
49+
if prefix_tokens is not None:
50+
x = ops.concat((prefix_tokens, x), axis=1)
51+
52+
if self.return_indices:
53+
return x, keep_indices
54+
return x

mindcv/models/layers/patch_embed.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from mindspore import Tensor, nn, ops
66

7+
from .format import Format, nchw_to
78
from .helpers import to_2tuple
89

910

@@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell):
1718
embed_dim (int): Number of linear projection output channels. Default: 96.
1819
norm_layer (nn.Cell, optional): Normalization layer. Default: None
1920
"""
21+
output_fmt: Format
2022

2123
def __init__(
2224
self,
23-
image_size: int = 224,
25+
image_size: Optional[int] = 224,
2426
patch_size: int = 4,
2527
in_chans: int = 3,
2628
embed_dim: int = 96,
2729
norm_layer: Optional[nn.Cell] = None,
30+
flatten: bool = True,
31+
output_fmt: Optional[str] = None,
32+
bias: bool = True,
33+
strict_img_size: bool = True,
34+
dynamic_img_pad: bool = False,
2835
) -> None:
2936
super().__init__()
30-
image_size = to_2tuple(image_size)
31-
patch_size = to_2tuple(patch_size)
32-
patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
33-
self.image_size = image_size
34-
self.patch_size = patch_size
35-
self.patches_resolution = patches_resolution
36-
self.num_patches = patches_resolution[0] * patches_resolution[1]
37-
38-
self.in_chans = in_chans
37+
self.patch_size = to_2tuple(patch_size)
38+
if image_size is not None:
39+
self.image_size = to_2tuple(image_size)
40+
self.grid_size = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
41+
self.num_patches = self.grid_size[0] * self.grid_size[1]
42+
else:
43+
self.image_size = None
44+
self.grid_size = None
45+
self.num_patches = None
46+
47+
if output_fmt is not None:
48+
self.flatten = False
49+
self.output_fmt = Format(output_fmt)
50+
else:
51+
self.flatten = flatten
52+
self.output_fmt = Format.NCHW
53+
54+
self.strict_img_size = strict_img_size
55+
self.dynamic_img_pad = dynamic_img_pad
3956
self.embed_dim = embed_dim
4057

4158
self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
42-
pad_mode='pad', has_bias=True, weight_init="TruncatedNormal")
59+
pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal")
4360

4461
if norm_layer is not None:
4562
if isinstance(embed_dim, int):
@@ -50,11 +67,29 @@ def __init__(
5067

5168
def construct(self, x: Tensor) -> Tensor:
5269
"""docstring"""
53-
B = x.shape[0]
54-
# FIXME look at relaxing size constraints
55-
x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C
56-
x = ops.Transpose()(x, (0, 2, 1))
70+
B, C, H, W = x.shape
71+
if self.image_size is not None:
72+
if self.strict_img_size:
73+
if (H, W) != (self.image_size[0], self.image_size[1]):
74+
raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
75+
f"{self.image_size[1]}).")
76+
elif not self.dynamic_img_pad:
77+
if H % self.patch_size[0] != 0:
78+
raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
79+
if W % self.patch_size[1] != 0:
80+
raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
81+
if self.dynamic_img_pad:
82+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
83+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
84+
x = ops.pad(x, (0, pad_w, 0, pad_h))
5785

86+
# FIXME look at relaxing size constraints
87+
x = self.proj(x)
88+
if self.flatten:
89+
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
90+
x = ops.Transpose()(x, (0, 2, 1))
91+
elif self.output_fmt != "NCHW":
92+
x = nchw_to(x, self.output_fmt)
5893
if self.norm is not None:
5994
x = self.norm(x)
6095
return x

0 commit comments

Comments
 (0)