Skip to content

Commit f27fff7

Browse files
authored
Merge pull request #81 from ason-rob/rc4main
roi_align_rotated_v2
2 parents 2b7def6 + ca26c89 commit f27fff7

File tree

4 files changed

+66
-53
lines changed

4 files changed

+66
-53
lines changed

mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,36 @@
33
using namespace NPU_NAME_SPACE;
44
using namespace std;
55

6-
void roi_align_rotated_v2_forward_npu(const Tensor input, Tensor rois_map,
7-
Tensor output,
6+
void roi_align_rotated_v2_forward_npu(const Tensor x, Tensor rois_map,
7+
Tensor y,
8+
int32_t pooled_h,
9+
int32_t pooled_w,
810
double spatial_scale,
911
int32_t sampling_ratio,
10-
int32_t pooled_height,
11-
int32_t pooled_width,
1212
bool aligned,
1313
bool clockwise) {
14-
at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous();
14+
at::Tensor feature_map = x.permute({0, 2, 3, 1}).contiguous();
1515
at::Tensor rois = rois_map.permute({1, 0}).contiguous();
16-
EXEC_NPU_CMD(aclnnRoiAlignRotatedV2, feature_map, rois, spatial_scale, sampling_ratio, pooled_height, pooled_width, aligned, clockwise, output);
16+
at_npu::native::OpCommand cmd;
17+
cmd.Name("RoiAlignRotated")
18+
.Input(feature_map)
19+
.Input(rois)
20+
.Output(y)
21+
.Attr("pooled_h", static_cast<int64_t>(pooled_h))
22+
.Attr("pooled_w", static_cast<int64_t>(pooled_w))
23+
.Attr("spatial_scale", static_cast<float>(spatial_scale))
24+
.Attr("sampling_ratio", static_cast<int64_t>(sampling_ratio))
25+
.Attr("aligned", aligned)
26+
.Attr("clockwise", clockwise)
27+
.Run();
1728
}
1829

19-
void roi_align_rotated_v2_forward_impl(const Tensor input, Tensor rois,
20-
Tensor output,
30+
void roi_align_rotated_v2_forward_impl(const Tensor x, Tensor rois,
31+
Tensor y,
32+
int32_t pooled_h,
33+
int32_t pooled_w,
2134
double spatial_scale,
2235
int32_t sampling_ratio,
23-
int32_t pooled_height,
24-
int32_t pooled_width,
2536
bool aligned,
2637
bool clockwise);
2738

mmcv/ops/csrc/pytorch/pybind.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y,
209209
int sampling_ratio, int pool_mode, bool aligned);
210210

211211
void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output,
212+
int pooled_h, int pooled_w,
212213
double spatial_scale, int sampling_ratio,
213-
int aligned_height, int aligned_width,
214214
bool aligned, bool clockwise);
215215

216216
void roi_align_rotated_v2_backward(Tensor input, Tensor rois,
@@ -343,9 +343,9 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
343343
bool clockwise);
344344

345345
void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output,
346-
double spatial_scale, int sampling_ratio,
347-
int aligned_height, int aligned_width,
348-
bool aligned, bool clockwise);
346+
int pooled_h, int pooled_w,
347+
double spatial_scale, int sampling_ratio,
348+
bool aligned, bool clockwise);
349349

350350
void roi_align_rotated_v2_backward(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input,
351351
int pooled_height, int pooled_width, double spatial_scale,
@@ -814,8 +814,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
814814
py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise"));
815815
m.def("roi_align_rotated_v2_forward", &roi_align_rotated_v2_forward,
816816
"roi_align_rotated_v2_forward", py::arg("input"), py::arg("rois"),
817-
py::arg("output"), py::arg("spatial_scale"), py::arg("sampling_ratio"),
818-
py::arg("pooled_height"), py::arg("pooled_width"),
817+
py::arg("output"), py::arg("pooled_h"), py::arg("pooled_w"),
818+
py::arg("spatial_scale"), py::arg("sampling_ratio"),
819819
py::arg("aligned"), py::arg("clockwise"));
820820
m.def("roi_align_rotated_v2_backward", &roi_align_rotated_v2_backward,
821821
"roi_align_rotated_v2_backward", py::arg("input"), py::arg("rois"),

mmcv/ops/csrc/pytorch/roi_align_rotated_v2.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@
22
#include "pytorch_cpp_helper.hpp"
33
#include "pytorch_device_registry.hpp"
44

5-
void roi_align_rotated_v2_forward_impl(Tensor input, Tensor rois, Tensor output,
5+
void roi_align_rotated_v2_forward_impl(Tensor x, Tensor rois, Tensor y,
6+
int pooled_h, int pooled_w,
67
double spatial_scale, int sampling_ratio,
7-
int pooled_height, int pooled_width,
88
bool aligned, bool clockwise) {
9-
DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_forward_impl, input, rois, output,
10-
spatial_scale, sampling_ratio, pooled_height, pooled_width,
9+
DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_forward_impl, x, rois, y,
10+
pooled_h, pooled_w, spatial_scale, sampling_ratio,
1111
aligned, clockwise);
1212
}
1313

1414

15-
void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output,
15+
void roi_align_rotated_v2_forward(Tensor x, Tensor rois, Tensor y,
16+
int pooled_h, int pooled_w,
1617
double spatial_scale, int sampling_ratio,
17-
int pooled_height, int pooled_width,
1818
bool aligned, bool clockwise) {
19-
roi_align_rotated_v2_forward_impl(input, rois, output, spatial_scale, sampling_ratio,
20-
pooled_height, pooled_width, aligned, clockwise);
19+
roi_align_rotated_v2_forward_impl(x, rois, y, pooled_h, pooled_w,
20+
spatial_scale, sampling_ratio, aligned, clockwise);
2121
}
2222

2323

mmcv/ops/roi_align_rotated_v2.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,55 @@
1414
class RoIAlignRotatedV2Function(Function):
1515

1616
@staticmethod
17-
def symbolic(g, input, rois, spatial_scale, sampling_ratio, pooled_height,
18-
pooled_width, aligned, clockwise):
17+
def symbolic(g, x, rois, spatial_scale, sampling_ratio, pooled_h,
18+
pooled_w, aligned, clockwise):
1919
return g.op(
2020
'mmcv::MMCVRoIAlignRotatedV2',
21-
input,
21+
x,
2222
rois,
23+
pooled_h=pooled_h,
24+
pooled_w=pooled_w,
2325
spatial_scale_f=spatial_scale,
2426
sampling_ratio_i=sampling_ratio,
25-
pooled_height=pooled_height,
26-
pooled_width=pooled_width,
2727
aligned_i=aligned,
2828
clockwise_i=clockwise)
2929

3030
@staticmethod
3131
def forward(ctx: Any,
32-
input: torch.Tensor,
32+
x: torch.Tensor,
3333
rois: torch.Tensor,
34+
pooled_h: int,
35+
pooled_w: int,
3436
spatial_scale: float,
3537
sampling_ratio: int,
36-
pooled_height: int,
37-
pooled_width: int,
3838
aligned: bool = True,
3939
clockwise: bool = False) -> torch.Tensor:
40-
ctx.pooled_height = pooled_height
41-
ctx.pooled_width = pooled_width
40+
ctx.pooled_h = pooled_h
41+
ctx.pooled_w = pooled_w
4242
ctx.spatial_scale = spatial_scale
4343
ctx.sampling_ratio = sampling_ratio
4444
ctx.aligned = aligned
4545
ctx.clockwise = clockwise
46-
ctx.save_for_backward(input, rois)
47-
ctx.feature_size = input.size()
48-
batch_size, num_channels, data_height, data_width = input.size()
46+
ctx.save_for_backward(x, rois)
47+
ctx.feature_size = x.size()
48+
batch_size, num_channels, data_height, data_width = x.size()
4949
num_rois = rois.size(0)
5050

51-
output = input.new_zeros(num_rois, ctx.pooled_height, ctx.pooled_width,
51+
y = x.new_zeros(num_rois, ctx.pooled_h, ctx.pooled_w,
5252
num_channels)
5353

5454
ext_module.roi_align_rotated_v2_forward(
55-
input,
55+
x,
5656
rois,
57-
output,
57+
y,
58+
pooled_h=ctx.pooled_h,
59+
pooled_w=ctx.pooled_w,
5860
spatial_scale=ctx.spatial_scale,
5961
sampling_ratio=ctx.sampling_ratio,
60-
pooled_height=ctx.pooled_height,
61-
pooled_width=ctx.pooled_width,
6262
aligned=ctx.aligned,
6363
clockwise=ctx.clockwise)
64-
output = output.transpose(2, 3).transpose(1, 2).contiguous()
65-
return output
64+
y = y.transpose(2, 3).transpose(1, 2).contiguous()
65+
return y
6666

6767
@staticmethod
6868
def backward(ctx: Any, grad_output: torch.Tensor):
@@ -74,7 +74,7 @@ def backward(ctx: Any, grad_output: torch.Tensor):
7474
input.size(0), input.size(2), input.size(3), input.size(1))
7575
ext_module.roi_align_rotated_v2_backward(
7676
input, rois_trans, grad_output_trans, grad_input,
77-
ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
77+
ctx.pooled_h, ctx.pooled_w, ctx.spatial_scale,
7878
ctx.sampling_ratio, ctx.aligned, ctx.clockwise)
7979
grad_input = grad_input.permute(0, 3, 1, 2).contiguous()
8080

@@ -134,31 +134,33 @@ class RoIAlignRotatedV2(nn.Module):
134134
},
135135
cls_name='RoIAlignRotatedV2')
136136
def __init__(self,
137+
pooled_h: int,
138+
pooled_w: int,
137139
spatial_scale: float,
138140
sampling_ratio: int,
139-
pooled_height: int,
140-
pooled_width: int,
141141
aligned: bool = True,
142142
clockwise: bool = False):
143143
super().__init__()
144144

145-
self.pooled_height = int(pooled_height)
146-
self.pooled_width = int(pooled_width)
145+
self.pooled_h = int(pooled_h)
146+
self.pooled_w = int(pooled_w)
147147
self.spatial_scale = float(spatial_scale)
148148
self.sampling_ratio = int(sampling_ratio)
149149
self.aligned = aligned
150150
self.clockwise = clockwise
151151

152152
def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
153-
return RoIAlignRotatedV2Function.apply(input, rois, self.spatial_scale,
153+
return RoIAlignRotatedV2Function.apply(input, rois,
154+
self.pooled_h,
155+
self.pooled_w,
156+
self.spatial_scale,
154157
self.sampling_ratio,
155-
self.pooled_height,
156-
self.pooled_width, self.aligned,
158+
self.aligned,
157159
self.clockwise)
158160

159161
def __repr__(self):
160162
s = self.__class__.__name__
161-
s += f'(pooled_height={self.pooled_height}, '
163+
s += f'(pooled_h={self.pooled_h}, '
162164
s += f'spatial_scale={self.spatial_scale}, '
163165
s += f'sampling_ratio={self.sampling_ratio}, '
164166
s += f'aligned={self.aligned}, '

0 commit comments

Comments
 (0)