@@ -8,41 +8,41 @@ class PatchDropout(nn.Cell):
8
8
"""
9
9
https://arxiv.org/abs/2212.00794
10
10
"""
11
+
11
12
def __init__ (
12
- self ,
13
- prob : float = 0.5 ,
14
- num_prefix_tokens : int = 1 ,
15
- ordered : bool = False ,
16
- return_indices : bool = False ,
13
+ self ,
14
+ prob : float = 0.5 ,
15
+ num_prefix_tokens : int = 1 ,
16
+ ordered : bool = False ,
17
+ return_indices : bool = False ,
17
18
):
18
19
super ().__init__ ()
19
- assert 0 <= prob < 1.
20
+ assert 0 <= prob < 1.0
20
21
self .prob = prob
21
22
self .num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
22
23
self .ordered = ordered
23
24
self .return_indices = return_indices
24
- self .sort = ops .Sort ()
25
25
26
- def forward (self , x ):
27
- if not self .training or self .prob == 0. :
26
+ def construct (self , x ):
27
+ if not self .training or self .prob == 0.0 :
28
28
if self .return_indices :
29
29
return x , None
30
30
return x
31
31
32
32
if self .num_prefix_tokens :
33
- prefix_tokens , x = x [:, :self .num_prefix_tokens ], x [:, self .num_prefix_tokens :]
33
+ prefix_tokens , x = x [:, : self .num_prefix_tokens ], x [:, self .num_prefix_tokens :]
34
34
else :
35
35
prefix_tokens = None
36
36
37
37
B = x .shape [0 ]
38
38
L = x .shape [1 ]
39
- num_keep = max (1 , int (L * (1. - self .prob )))
40
- _ , indices = self .sort (ms .Tensor (np .random .rand (B , L )).astype (ms .float32 ))
39
+ num_keep = max (1 , int (L * (1.0 - self .prob )))
40
+ _ , indices = ops .sort (ms .Tensor (np .random .rand (B , L )).astype (ms .float32 ))
41
41
keep_indices = indices [:, :num_keep ]
42
42
if self .ordered :
43
43
# NOTE does not need to maintain patch order in typical transformer use,
44
44
# but possibly useful for debug / visualization
45
- keep_indices , _ = self .sort (keep_indices )
45
+ keep_indices , _ = ops .sort (keep_indices )
46
46
keep_indices = ops .broadcast_to (ops .expand_dims (keep_indices , axis = - 1 ), (- 1 , - 1 , x .shape [2 ]))
47
47
x = ops .gather_elements (x , dim = 1 , index = keep_indices )
48
48
0 commit comments