Skip to content

Commit 078ef5b

Browse files
committed
add VitSAM
1 parent 5c87ac5 commit 078ef5b

File tree

6 files changed

+621
-27
lines changed

6 files changed

+621
-27
lines changed

mindcv/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
vgg,
5353
visformer,
5454
vit,
55+
vit_sam,
5556
volo,
5657
xception,
5758
xcit,
@@ -109,6 +110,7 @@
109110
from .vgg import *
110111
from .visformer import *
111112
from .vit import *
113+
from .vit_sam import *
112114
from .volo import *
113115
from .xception import *
114116
from .xcit import *
@@ -168,6 +170,7 @@
168170
__all__.extend(vgg.__all__)
169171
__all__.extend(visformer.__all__)
170172
__all__.extend(vit.__all__)
173+
__all__.extend(vit_sam.__all__)
171174
__all__.extend(volo.__all__)
172175
__all__.extend(["Xception", "xception"])
173176
__all__.extend(xcit.__all__)

mindcv/models/layers/format.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from enum import Enum
2-
from typing import Union
32

43
import mindspore
54

@@ -11,9 +10,6 @@ class Format(str, Enum):
1110
NLC = 'NLC'
1211

1312

14-
FormatT = Union[str, Format]
15-
16-
1713
def nchw_to(x: mindspore.Tensor, fmt: Format):
1814
if fmt == Format.NHWC:
1915
x = x.permute(0, 2, 3, 1)

mindcv/models/layers/patch_dropout.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,41 @@ class PatchDropout(nn.Cell):
88
"""
99
https://arxiv.org/abs/2212.00794
1010
"""
11+
1112
def __init__(
12-
self,
13-
prob: float = 0.5,
14-
num_prefix_tokens: int = 1,
15-
ordered: bool = False,
16-
return_indices: bool = False,
13+
self,
14+
prob: float = 0.5,
15+
num_prefix_tokens: int = 1,
16+
ordered: bool = False,
17+
return_indices: bool = False,
1718
):
1819
super().__init__()
19-
assert 0 <= prob < 1.
20+
assert 0 <= prob < 1.0
2021
self.prob = prob
2122
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
2223
self.ordered = ordered
2324
self.return_indices = return_indices
24-
self.sort = ops.Sort()
2525

26-
def forward(self, x):
27-
if not self.training or self.prob == 0.:
26+
def construct(self, x):
27+
if not self.training or self.prob == 0.0:
2828
if self.return_indices:
2929
return x, None
3030
return x
3131

3232
if self.num_prefix_tokens:
33-
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
33+
prefix_tokens, x = x[:, : self.num_prefix_tokens], x[:, self.num_prefix_tokens :]
3434
else:
3535
prefix_tokens = None
3636

3737
B = x.shape[0]
3838
L = x.shape[1]
39-
num_keep = max(1, int(L * (1. - self.prob)))
40-
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
39+
num_keep = max(1, int(L * (1.0 - self.prob)))
40+
_, indices = ops.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
4141
keep_indices = indices[:, :num_keep]
4242
if self.ordered:
4343
# NOTE does not need to maintain patch order in typical transformer use,
4444
# but possibly useful for debug / visualization
45-
keep_indices, _ = self.sort(keep_indices)
45+
keep_indices, _ = ops.sort(keep_indices)
4646
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
4747
x = ops.gather_elements(x, dim=1, index=keep_indices)
4848

mindcv/models/layers/patch_embed.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class PatchEmbed(nn.Cell):
1818
embed_dim (int): Number of linear projection output channels. Default: 96.
1919
norm_layer (nn.Cell, optional): Normalization layer. Default: None
2020
"""
21+
2122
output_fmt: Format
2223

2324
def __init__(
@@ -37,11 +38,11 @@ def __init__(
3738
self.patch_size = to_2tuple(patch_size)
3839
if image_size is not None:
3940
self.image_size = to_2tuple(image_size)
40-
self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
41-
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
41+
self.grid_size = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
42+
self.num_patches = self.grid_size[0] * self.grid_size[1]
4243
else:
4344
self.image_size = None
44-
self.patches_resolution = None
45+
self.grid_size = None
4546
self.num_patches = None
4647

4748
if output_fmt is not None:
@@ -86,8 +87,8 @@ def construct(self, x: Tensor) -> Tensor:
8687
# FIXME look at relaxing size constraints
8788
x = self.proj(x)
8889
if self.flatten:
89-
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
90-
x = ops.Transpose()(x, (0, 2, 1))
90+
x = ops.reshape(x, (B, self.embed_dim, -1)) # B Ph*Pw C
91+
x = ops.transpose(x, (0, 2, 1))
9192
elif self.output_fmt != "NCHW":
9293
x = nchw_to(x, self.output_fmt)
9394
if self.norm is not None:

mindcv/models/layers/pos_embed.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212

1313
def resample_abs_pos_embed(
14-
posemb,
15-
new_size: List[int],
16-
old_size: Optional[List[int]] = None,
17-
num_prefix_tokens: int = 1,
18-
interpolation: str = 'nearest',
14+
posemb,
15+
new_size: List[int],
16+
old_size: Optional[List[int]] = None,
17+
num_prefix_tokens: int = 1,
18+
interpolation: str = 'nearest',
1919
):
2020
# sort out sizes, assume square if old size not provided
2121
num_pos_tokens = posemb.shape[1]

0 commit comments

Comments
 (0)