Skip to content

Commit

Permalink
[Fix] Fix psamask backward api call.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanieeelLiu committed Apr 12, 2023
1 parent 65a7115 commit 1b4a104
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
auto dy_impl = torch_mlu::getMluTensorImpl(dy_tensor);
auto dy_ptr = dy_impl->cnnlMalloc();

mluOpPsamaskForward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask,
dx_tmp_desc.desc(), dx_ptr);
mluOpPsamaskBackward(handle, psa_type, dy_desc.desc(), dy_ptr, h_mask, w_mask,
dx_tmp_desc.desc(), dx_ptr);

dx.copy_(dx_tmp);
}
Expand Down

0 comments on commit 1b4a104

Please sign in to comment.