14
14
class RoIAlignRotatedV2Function (Function ):
15
15
16
16
@staticmethod
17
- def symbolic (g , input , rois , spatial_scale , sampling_ratio , pooled_height ,
18
- pooled_width , aligned , clockwise ):
17
+ def symbolic (g , x , rois , spatial_scale , sampling_ratio , pooled_h ,
18
+ pooled_w , aligned , clockwise ):
19
19
return g .op (
20
20
'mmcv::MMCVRoIAlignRotatedV2' ,
21
- input ,
21
+ x ,
22
22
rois ,
23
+ pooled_h = pooled_h ,
24
+ pooled_w = pooled_w ,
23
25
spatial_scale_f = spatial_scale ,
24
26
sampling_ratio_i = sampling_ratio ,
25
- pooled_height = pooled_height ,
26
- pooled_width = pooled_width ,
27
27
aligned_i = aligned ,
28
28
clockwise_i = clockwise )
29
29
30
30
@staticmethod
31
31
def forward (ctx : Any ,
32
- input : torch .Tensor ,
32
+ x : torch .Tensor ,
33
33
rois : torch .Tensor ,
34
+ pooled_h : int ,
35
+ pooled_w : int ,
34
36
spatial_scale : float ,
35
37
sampling_ratio : int ,
36
- pooled_height : int ,
37
- pooled_width : int ,
38
38
aligned : bool = True ,
39
39
clockwise : bool = False ) -> torch .Tensor :
40
- ctx .pooled_height = pooled_height
41
- ctx .pooled_width = pooled_width
40
+ ctx .pooled_h = pooled_h
41
+ ctx .pooled_w = pooled_w
42
42
ctx .spatial_scale = spatial_scale
43
43
ctx .sampling_ratio = sampling_ratio
44
44
ctx .aligned = aligned
45
45
ctx .clockwise = clockwise
46
- ctx .save_for_backward (input , rois )
47
- ctx .feature_size = input .size ()
48
- batch_size , num_channels , data_height , data_width = input .size ()
46
+ ctx .save_for_backward (x , rois )
47
+ ctx .feature_size = x .size ()
48
+ batch_size , num_channels , data_height , data_width = x .size ()
49
49
num_rois = rois .size (0 )
50
50
51
- output = input .new_zeros (num_rois , ctx .pooled_height , ctx .pooled_width ,
51
+ y = x .new_zeros (num_rois , ctx .pooled_h , ctx .pooled_w ,
52
52
num_channels )
53
53
54
54
ext_module .roi_align_rotated_v2_forward (
55
- input ,
55
+ x ,
56
56
rois ,
57
- output ,
57
+ y ,
58
+ pooled_h = ctx .pooled_h ,
59
+ pooled_w = ctx .pooled_w ,
58
60
spatial_scale = ctx .spatial_scale ,
59
61
sampling_ratio = ctx .sampling_ratio ,
60
- pooled_height = ctx .pooled_height ,
61
- pooled_width = ctx .pooled_width ,
62
62
aligned = ctx .aligned ,
63
63
clockwise = ctx .clockwise )
64
- output = output .transpose (2 , 3 ).transpose (1 , 2 ).contiguous ()
65
- return output
64
+ y = y .transpose (2 , 3 ).transpose (1 , 2 ).contiguous ()
65
+ return y
66
66
67
67
@staticmethod
68
68
def backward (ctx : Any , grad_output : torch .Tensor ):
@@ -74,7 +74,7 @@ def backward(ctx: Any, grad_output: torch.Tensor):
74
74
input .size (0 ), input .size (2 ), input .size (3 ), input .size (1 ))
75
75
ext_module .roi_align_rotated_v2_backward (
76
76
input , rois_trans , grad_output_trans , grad_input ,
77
- ctx .pooled_height , ctx .pooled_width , ctx .spatial_scale ,
77
+ ctx .pooled_h , ctx .pooled_w , ctx .spatial_scale ,
78
78
ctx .sampling_ratio , ctx .aligned , ctx .clockwise )
79
79
grad_input = grad_input .permute (0 , 3 , 1 , 2 ).contiguous ()
80
80
@@ -134,31 +134,33 @@ class RoIAlignRotatedV2(nn.Module):
134
134
},
135
135
cls_name = 'RoIAlignRotatedV2' )
136
136
def __init__ (self ,
137
+ pooled_h : int ,
138
+ pooled_w : int ,
137
139
spatial_scale : float ,
138
140
sampling_ratio : int ,
139
- pooled_height : int ,
140
- pooled_width : int ,
141
141
aligned : bool = True ,
142
142
clockwise : bool = False ):
143
143
super ().__init__ ()
144
144
145
- self .pooled_height = int (pooled_height )
146
- self .pooled_width = int (pooled_width )
145
+ self .pooled_h = int (pooled_h )
146
+ self .pooled_w = int (pooled_w )
147
147
self .spatial_scale = float (spatial_scale )
148
148
self .sampling_ratio = int (sampling_ratio )
149
149
self .aligned = aligned
150
150
self .clockwise = clockwise
151
151
152
152
def forward (self , input : torch .Tensor , rois : torch .Tensor ) -> torch .Tensor :
153
- return RoIAlignRotatedV2Function .apply (input , rois , self .spatial_scale ,
153
+ return RoIAlignRotatedV2Function .apply (input , rois ,
154
+ self .pooled_h ,
155
+ self .pooled_w ,
156
+ self .spatial_scale ,
154
157
self .sampling_ratio ,
155
- self .pooled_height ,
156
- self .pooled_width , self .aligned ,
158
+ self .aligned ,
157
159
self .clockwise )
158
160
159
161
def __repr__ (self ):
160
162
s = self .__class__ .__name__
161
- s += f'(pooled_height ={ self .pooled_height } , '
163
+ s += f'(pooled_h ={ self .pooled_h } , '
162
164
s += f'spatial_scale={ self .spatial_scale } , '
163
165
s += f'sampling_ratio={ self .sampling_ratio } , '
164
166
s += f'aligned={ self .aligned } , '
0 commit comments