Skip to content

Commit

Permalink
change the knn and three nn code
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyuan64 committed Oct 29, 2024
1 parent d5085da commit 8f5a5b9
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
Expand Down
3 changes: 2 additions & 1 deletion mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
at::Tensor target = unknown.contiguous();

bool is_from_knn = false;
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
int nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
}

void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
Expand Down
13 changes: 3 additions & 10 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,10 @@ def forward(ctx,
N = xyz.shape[1]

if xyz.device.type == 'npu':
dist = center_xyz.new_zeros((B, npoint, N)).float()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
idx = center_xyz.new_zeros((B, npoint, k)).int()
ext_module.knn_forward(
xyz,
center_xyz,
torch.Tensor([]).npu(),
dist,
b=B,
n=N,
m=npoint,
nsample=k)
dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True)
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
zeros_idx = torch.zeros(
xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu()
idx.where(dist2 >= 1e10, zeros_idx)
Expand Down
6 changes: 3 additions & 3 deletions mmcv/ops/three_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def forward(ctx: Any, target: torch.Tensor,
if dtype_ == torch.float16:
target = target.float()
source = source.float()
dist = target.new_empty(B, N, m)
dist2 = target.new_empty(B, N, 3)
idx = target.new_empty(B, N, 3, dtype=torch.int32)
ext_module.three_nn_forward(
target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m)
dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True)
target, source, dist2, idx, b=B, n=N, m=m)
dist2 = torch.sqrt(dist2)
if dtype_ == torch.float16:
dist2 = dist2.half()
Expand Down

0 comments on commit 8f5a5b9

Please sign in to comment.