Skip to content

Commit

Permalink
[Fix] Fix c++/cuda lint failed (open-mmlab#2528)
Browse files Browse the repository at this point in the history
* Fix c/cuda lint failed

* minor refine

* minor refine
  • Loading branch information
HAOCHENYE authored and root committed Jan 30, 2023
1 parent c8cc777 commit 141b6c4
Showing 1 changed file with 62 additions and 82 deletions.
144 changes: 62 additions & 82 deletions mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,73 @@
using namespace NPU_NAME_SPACE;
using namespace std;

void psamask_forward_npu(const int psa_type,
const Tensor x,
Tensor y,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask) {
int64_t psa_type_i64 = psa_type;
int64_t num_i64 = num;
int64_t h_feature_i64 = h_feature;
int64_t w_feature_i64 = w_feature;
int64_t h_mask_i64 = h_mask;
int64_t w_mask_i64 = w_mask;
int64_t half_h_mask_i64 = half_h_mask;
int64_t half_w_mask_i64 = half_w_mask;
OpCommand cmd;
cmd.Name("PSAMask")
.Input(x)
.Output(y)
.Attr("psa_type", psa_type_i64)
.Attr("num", num_i64)
.Attr("h_feature", h_feature_i64)
.Attr("w_feature", w_feature_i64)
.Attr("h_mask", h_mask_i64)
.Attr("w_mask", w_mask_i64)
.Attr("half_h_mask", half_h_mask_i64)
.Attr("half_w_mask", half_w_mask_i64)
.Run();
void psamask_forward_npu(const int psa_type, const Tensor x, Tensor y,
const int num, const int h_feature,
const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask) {
int64_t psa_type_i64 = psa_type;
int64_t num_i64 = num;
int64_t h_feature_i64 = h_feature;
int64_t w_feature_i64 = w_feature;
int64_t h_mask_i64 = h_mask;
int64_t w_mask_i64 = w_mask;
int64_t half_h_mask_i64 = half_h_mask;
int64_t half_w_mask_i64 = half_w_mask;
OpCommand cmd;
cmd.Name("PSAMask")
.Input(x)
.Output(y)
.Attr("psa_type", psa_type_i64)
.Attr("num", num_i64)
.Attr("h_feature", h_feature_i64)
.Attr("w_feature", w_feature_i64)
.Attr("h_mask", h_mask_i64)
.Attr("w_mask", w_mask_i64)
.Attr("half_h_mask", half_h_mask_i64)
.Attr("half_w_mask", half_w_mask_i64)
.Run();
}

void psamask_forward_impl(const int psa_type,
const Tensor x,
Tensor y,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask);
void psamask_forward_impl(const int psa_type, const Tensor x, Tensor y,
const int num, const int h_feature,
const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask);

void psamask_backward_npu(const int psa_type,
const Tensor y_grad,
Tensor x_grad,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask) {
int64_t psa_type_i64 = psa_type;
int64_t num_i64 = num;
int64_t h_feature_i64 = h_feature;
int64_t w_feature_i64 = w_feature;
int64_t h_mask_i64 = h_mask;
int64_t w_mask_i64 = w_mask;
int64_t half_h_mask_i64 = half_h_mask;
int64_t half_w_mask_i64 = half_w_mask;
OpCommand cmd;
cmd.Name("PSAMaskGrad")
.Input(y_grad)
.Output(x_grad)
.Attr("psa_type", psa_type_i64)
.Attr("num", num_i64)
.Attr("h_feature", h_feature_i64)
.Attr("w_feature", w_feature_i64)
.Attr("h_mask", h_mask_i64)
.Attr("w_mask", w_mask_i64)
.Attr("half_h_mask", half_h_mask_i64)
.Attr("half_w_mask", half_w_mask_i64)
.Run();
void psamask_backward_npu(const int psa_type, const Tensor y_grad,
Tensor x_grad, const int num, const int h_feature,
const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask) {
int64_t psa_type_i64 = psa_type;
int64_t num_i64 = num;
int64_t h_feature_i64 = h_feature;
int64_t w_feature_i64 = w_feature;
int64_t h_mask_i64 = h_mask;
int64_t w_mask_i64 = w_mask;
int64_t half_h_mask_i64 = half_h_mask;
int64_t half_w_mask_i64 = half_w_mask;
OpCommand cmd;
cmd.Name("PSAMaskGrad")
.Input(y_grad)
.Output(x_grad)
.Attr("psa_type", psa_type_i64)
.Attr("num", num_i64)
.Attr("h_feature", h_feature_i64)
.Attr("w_feature", w_feature_i64)
.Attr("h_mask", h_mask_i64)
.Attr("w_mask", w_mask_i64)
.Attr("half_h_mask", half_h_mask_i64)
.Attr("half_w_mask", half_w_mask_i64)
.Run();
}

void psamask_backward_impl(const int psa_type,
const Tensor y_grad,
Tensor x_grad,
const int num,
const int h_feature,
const int w_feature,
const int h_mask,
const int w_mask,
const int half_h_mask,
const int half_w_mask);
void psamask_backward_impl(const int psa_type, const Tensor y_grad,
Tensor x_grad, const int num, const int h_feature,
const int w_feature, const int h_mask,
const int w_mask, const int half_h_mask,
const int half_w_mask);

REGISTER_NPU_IMPL(psamask_forward_impl, psamask_forward_npu);
REGISTER_NPU_IMPL(psamask_backward_impl, psamask_backward_npu);

0 comments on commit 141b6c4

Please sign in to comment.