4
4
5
5
from mindspore import Tensor , nn , ops
6
6
7
+ from .format import Format , nchw_to
7
8
from .helpers import to_2tuple
8
9
9
10
@@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell):
17
18
embed_dim (int): Number of linear projection output channels. Default: 96.
18
19
norm_layer (nn.Cell, optional): Normalization layer. Default: None
19
20
"""
21
+ output_fmt : Format
20
22
21
23
def __init__ (
22
24
self ,
23
- image_size : int = 224 ,
25
+ image_size : Optional [ int ] = 224 ,
24
26
patch_size : int = 4 ,
25
27
in_chans : int = 3 ,
26
28
embed_dim : int = 96 ,
27
29
norm_layer : Optional [nn .Cell ] = None ,
30
+ flatten : bool = True ,
31
+ output_fmt : Optional [str ] = None ,
32
+ bias : bool = True ,
33
+ strict_img_size : bool = True ,
34
+ dynamic_img_pad : bool = False ,
28
35
) -> None :
29
36
super ().__init__ ()
30
- image_size = to_2tuple (image_size )
31
- patch_size = to_2tuple (patch_size )
32
- patches_resolution = [image_size [0 ] // patch_size [0 ], image_size [1 ] // patch_size [1 ]]
33
- self .image_size = image_size
34
- self .patch_size = patch_size
35
- self .patches_resolution = patches_resolution
36
- self .num_patches = patches_resolution [0 ] * patches_resolution [1 ]
37
-
38
- self .in_chans = in_chans
37
+ self .patch_size = to_2tuple (patch_size )
38
+ if image_size is not None :
39
+ self .image_size = to_2tuple (image_size )
40
+ self .grid_size = tuple ([s // p for s , p in zip (self .image_size , self .patch_size )])
41
+ self .num_patches = self .grid_size [0 ] * self .grid_size [1 ]
42
+ else :
43
+ self .image_size = None
44
+ self .grid_size = None
45
+ self .num_patches = None
46
+
47
+ if output_fmt is not None :
48
+ self .flatten = False
49
+ self .output_fmt = Format (output_fmt )
50
+ else :
51
+ self .flatten = flatten
52
+ self .output_fmt = Format .NCHW
53
+
54
+ self .strict_img_size = strict_img_size
55
+ self .dynamic_img_pad = dynamic_img_pad
39
56
self .embed_dim = embed_dim
40
57
41
58
self .proj = nn .Conv2d (in_channels = in_chans , out_channels = embed_dim , kernel_size = patch_size , stride = patch_size ,
42
- pad_mode = 'pad' , has_bias = True , weight_init = "TruncatedNormal" )
59
+ pad_mode = 'pad' , has_bias = bias , weight_init = "TruncatedNormal" )
43
60
44
61
if norm_layer is not None :
45
62
if isinstance (embed_dim , int ):
@@ -50,11 +67,29 @@ def __init__(
50
67
51
68
def construct (self , x : Tensor ) -> Tensor :
52
69
"""docstring"""
53
- B = x .shape [0 ]
54
- # FIXME look at relaxing size constraints
55
- x = ops .Reshape ()(self .proj (x ), (B , self .embed_dim , - 1 )) # B Ph*Pw C
56
- x = ops .Transpose ()(x , (0 , 2 , 1 ))
70
+ B , C , H , W = x .shape
71
+ if self .image_size is not None :
72
+ if self .strict_img_size :
73
+ if (H , W ) != (self .image_size [0 ], self .image_size [1 ]):
74
+ raise ValueError (f"Input height and width ({ H } ,{ W } ) doesn't match model ({ self .image_size [0 ]} ,"
75
+ f"{ self .image_size [1 ]} )." )
76
+ elif not self .dynamic_img_pad :
77
+ if H % self .patch_size [0 ] != 0 :
78
+ raise ValueError (f"Input height ({ H } ) should be divisible by patch size ({ self .patch_size [0 ]} )." )
79
+ if W % self .patch_size [1 ] != 0 :
80
+ raise ValueError (f"Input width ({ W } ) should be divisible by patch size ({ self .patch_size [1 ]} )." )
81
+ if self .dynamic_img_pad :
82
+ pad_h = (self .patch_size [0 ] - H % self .patch_size [0 ]) % self .patch_size [0 ]
83
+ pad_w = (self .patch_size [1 ] - W % self .patch_size [1 ]) % self .patch_size [1 ]
84
+ x = ops .pad (x , (0 , pad_w , 0 , pad_h ))
57
85
86
+ # FIXME look at relaxing size constraints
87
+ x = self .proj (x )
88
+ if self .flatten :
89
+ x = ops .Reshape ()(x , (B , self .embed_dim , - 1 )) # B Ph*Pw C
90
+ x = ops .Transpose ()(x , (0 , 2 , 1 ))
91
+ elif self .output_fmt != "NCHW" :
92
+ x = nchw_to (x , self .output_fmt )
58
93
if self .norm is not None :
59
94
x = self .norm (x )
60
95
return x
0 commit comments