Skip to content

Commit

Permalink
[Feature] Add ChamferDistance op for Parrots (#2189)
Browse files Browse the repository at this point in the history
* Support Parrots extension for op ChamferDistance

* Fix lint

* Adapt op backward error of IntTensor for parrots

* Fix Lint
  • Loading branch information
CokeDong authored Aug 15, 2022
1 parent 304f184 commit 595c2eb
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 37 deletions.
18 changes: 8 additions & 10 deletions mmcv/ops/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,18 @@ def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:

@staticmethod
@once_differentiable
def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
grad_idx1: Tensor,
grad_idx2: Tensor) -> Tuple[Tensor, Tensor]:
def backward(ctx,
grad_dist1: Tensor,
grad_dist2: Tensor,
grad_idx1=None,
grad_idx2=None) -> Tuple[Tensor, Tensor]:
"""
Args:
grad_dist1 (Tensor): Gradient of chamfer distance
(xyz1 to xyz2) with shape (B, N).
grad_dist2 (Tensor): Gradient of chamfer distance
(xyz2 to xyz1) with shape (B, N).
grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
with shape (B, N), which be used in compute gradient.
grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
with shape (B, N), which be used in compute gradient.
Returns:
Tuple[Tensor, Tensor]:
Expand All @@ -86,9 +84,9 @@ def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).to(device)

ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2,
grad_dist1, grad_dist2, idx1,
idx2)
ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1,
grad_xyz2)
return grad_xyz1, grad_xyz2


Expand Down
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/parrots/chamfer_distance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp

#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2,
idx1, idx2);
}

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2,
graddist1, graddist2, gradxyz1, gradxyz2);
}

void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2);
}

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2,
gradxyz1, gradxyz2);
}
51 changes: 51 additions & 0 deletions mmcv/ops/csrc/parrots/chamfer_distance_parrots.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>

#include "chamfer_distance_pytorch.h"
using namespace parrots;

#ifdef MMCV_WITH_CUDA
void chamfer_distance_forward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto xyz1 = buildATensor(ctx, ins[0]);
auto xyz2 = buildATensor(ctx, ins[1]);
auto dist1 = buildATensor(ctx, outs[0]);
auto dist2 = buildATensor(ctx, outs[1]);
auto idx1 = buildATensor(ctx, outs[2]);
auto idx2 = buildATensor(ctx, outs[3]);
chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}

void chamfer_distance_backward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto xyz1 = buildATensor(ctx, ins[0]);
auto xyz2 = buildATensor(ctx, ins[1]);
auto idx1 = buildATensor(ctx, ins[2]);
auto idx2 = buildATensor(ctx, ins[3]);
auto graddist1 = buildATensor(ctx, ins[4]);
auto graddist2 = buildATensor(ctx, ins[5]);
auto gradxyz1 = buildATensor(ctx, outs[0]);
auto gradxyz2 = buildATensor(ctx, outs[1]);
chamfer_distance_backward(xyz1, xyz2, idx1, idx2, graddist1, graddist2,
gradxyz1, gradxyz2);
}

PARROTS_EXTENSION_REGISTER(chamfer_distance_forward)
.input(2)
.output(4)
.apply(chamfer_distance_forward_cuda_parrots)
.done();

PARROTS_EXTENSION_REGISTER(chamfer_distance_backward)
.input(6)
.output(2)
.apply(chamfer_distance_backward_cuda_parrots)
.done();

#endif
16 changes: 16 additions & 0 deletions mmcv/ops/csrc/parrots/chamfer_distance_pytorch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ACTIVE_CHAMFER_DISTANCE_PYTORCH_H
#define ACTIVE_CHAMFER_DISTANCE_PYTORCH_H
#include <torch/extension.h>
using namespace at;

void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx);

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

#endif // ACTIVE_CHAMFER_DISTANCE_PYTORCH_H
37 changes: 37 additions & 0 deletions mmcv/ops/csrc/parrots/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1589,3 +1589,40 @@ Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,

REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);

void ChamferDistanceForwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
const Tensor dist2, const Tensor idx1, const Tensor idx2);

void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2);

void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
ChamferDistanceForwardCUDAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1,
idx2);
};

void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1,
graddist2, gradxyz1, gradxyz2);
};

void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2);

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
chamfer_distance_backward_cuda);
20 changes: 10 additions & 10 deletions mmcv/ops/csrc/pytorch/chamfer_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
}

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, gradxyz1,
gradxyz2, graddist1, graddist2, idx1, idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2,
graddist1, graddist2, gradxyz1, gradxyz2);
}

void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
Expand All @@ -27,9 +27,9 @@ void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
}

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2, Tensor idx1,
Tensor idx2) {
chamfer_distance_backward_impl(xyz1, xyz2, gradxyz1, gradxyz2, graddist1,
graddist2, idx1, idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2,
gradxyz1, gradxyz2);
}
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void ChamferDistanceForwardCUDAKernelLauncher(
}

void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2) {
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2) {
int batch_size = xyz1.size(0);
int n = xyz1.size(1);
int m = xyz2.size(1);
Expand Down
20 changes: 10 additions & 10 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,8 +1706,8 @@ void ChamferDistanceForwardCUDAKernelLauncher(
const Tensor dist2, const Tensor idx1, const Tensor idx2);

void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2);
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2);

void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
Expand All @@ -1717,21 +1717,21 @@ void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
};

void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, gradxyz1, gradxyz2,
graddist1, graddist2, idx1, idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1,
graddist2, gradxyz1, gradxyz2);
};

void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2);

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
Expand Down
10 changes: 5 additions & 5 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,9 @@ void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor idx1, const Tensor idx);

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2, Tensor idx1,
Tensor idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
Expand Down Expand Up @@ -838,8 +838,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("dist1"), py::arg("dist2"), py::arg("idx1"), py::arg("idx2"));
m.def("chamfer_distance_backward", &chamfer_distance_backward,
"chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"),
py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"),
py::arg("graddist2"), py::arg("idx1"), py::arg("idx2"));
py::arg("idx1"), py::arg("idx2"), py::arg("graddist1"),
py::arg("graddist2"), py::arg("gradxyz1"), py::arg("gradxyz2"));
m.def("prroi_pool_forward", &prroi_pool_forward, "prroi_pool forward",
py::arg("input"), py::arg("rois"), py::arg("output"),
py::arg("pooled_height"), py::arg("pooled_width"),
Expand Down

0 comments on commit 595c2eb

Please sign in to comment.