Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Parrots extension for op ChamferDistance #2189

Merged
merged 4 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions mmcv/ops/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,18 @@ def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:

@staticmethod
@once_differentiable
def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
grad_idx1: Tensor,
grad_idx2: Tensor) -> Tuple[Tensor, Tensor]:
def backward(ctx,
grad_dist1: Tensor,
grad_dist2: Tensor,
grad_idx1=None,
grad_idx2=None) -> Tuple[Tensor, Tensor]:
"""

Args:
grad_dist1 (Tensor): Gradient of chamfer distance
(xyz1 to xyz2) with shape (B, N).
grad_dist2 (Tensor): Gradient of chamfer distance
(xyz2 to xyz1) with shape (B, N).
grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
with shape (B, N), which be used in compute gradient.
grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
with shape (B, N), which be used in compute gradient.

Returns:
Tuple[Tensor, Tensor]:
Expand All @@ -86,9 +84,9 @@ def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).to(device)

ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2,
grad_dist1, grad_dist2, idx1,
idx2)
ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2,
grad_dist1, grad_dist2, grad_xyz1,
grad_xyz2)
return grad_xyz1, grad_xyz2


Expand Down
35 changes: 35 additions & 0 deletions mmcv/ops/csrc/parrots/chamfer_distance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp

#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2,
idx1, idx2);
}

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2,
graddist1, graddist2, gradxyz1, gradxyz2);
}

void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2);
}

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2,
gradxyz1, gradxyz2);
}
51 changes: 51 additions & 0 deletions mmcv/ops/csrc/parrots/chamfer_distance_parrots.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>

#include "chamfer_distance_pytorch.h"
using namespace parrots;

#ifdef MMCV_WITH_CUDA
void chamfer_distance_forward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto xyz1 = buildATensor(ctx, ins[0]);
auto xyz2 = buildATensor(ctx, ins[1]);
auto dist1 = buildATensor(ctx, outs[0]);
auto dist2 = buildATensor(ctx, outs[1]);
auto idx1 = buildATensor(ctx, outs[2]);
auto idx2 = buildATensor(ctx, outs[3]);
chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}

void chamfer_distance_backward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto xyz1 = buildATensor(ctx, ins[0]);
auto xyz2 = buildATensor(ctx, ins[1]);
auto idx1 = buildATensor(ctx, ins[2]);
auto idx2 = buildATensor(ctx, ins[3]);
auto graddist1 = buildATensor(ctx, ins[4]);
auto graddist2 = buildATensor(ctx, ins[5]);
auto gradxyz1 = buildATensor(ctx, outs[0]);
auto gradxyz2 = buildATensor(ctx, outs[1]);
chamfer_distance_backward(xyz1, xyz2, idx1, idx2, graddist1, graddist2,
gradxyz1, gradxyz2);
}

PARROTS_EXTENSION_REGISTER(chamfer_distance_forward)
.input(2)
.output(4)
.apply(chamfer_distance_forward_cuda_parrots)
.done();

PARROTS_EXTENSION_REGISTER(chamfer_distance_backward)
.input(6)
.output(2)
.apply(chamfer_distance_backward_cuda_parrots)
.done();

#endif
16 changes: 16 additions & 0 deletions mmcv/ops/csrc/parrots/chamfer_distance_pytorch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ACTIVE_CHAMFER_DISTANCE_PYTORCH_H
#define ACTIVE_CHAMFER_DISTANCE_PYTORCH_H
#include <torch/extension.h>
using namespace at;

void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx);

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

#endif // ACTIVE_CHAMFER_DISTANCE_PYTORCH_H
37 changes: 37 additions & 0 deletions mmcv/ops/csrc/parrots/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1589,3 +1589,40 @@ Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,

REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);

void ChamferDistanceForwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
const Tensor dist2, const Tensor idx1, const Tensor idx2);

void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2);

void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
ChamferDistanceForwardCUDAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1,
idx2);
};

void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1,
graddist2, gradxyz1, gradxyz2);
};

void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2);

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
chamfer_distance_backward_cuda);
20 changes: 10 additions & 10 deletions mmcv/ops/csrc/pytorch/chamfer_distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
}

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, gradxyz1,
gradxyz2, graddist1, graddist2, idx1, idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2,
graddist1, graddist2, gradxyz1, gradxyz2);
}

void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
Expand All @@ -27,9 +27,9 @@ void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
}

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2, Tensor idx1,
Tensor idx2) {
chamfer_distance_backward_impl(xyz1, xyz2, gradxyz1, gradxyz2, graddist1,
graddist2, idx1, idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2,
gradxyz1, gradxyz2);
}
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void ChamferDistanceForwardCUDAKernelLauncher(
}

void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2) {
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2) {
int batch_size = xyz1.size(0);
int n = xyz1.size(1);
int m = xyz2.size(1);
Expand Down
20 changes: 10 additions & 10 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,8 +1706,8 @@ void ChamferDistanceForwardCUDAKernelLauncher(
const Tensor dist2, const Tensor idx1, const Tensor idx2);

void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2);
const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2,
Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2);

void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
Expand All @@ -1717,21 +1717,21 @@ void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
};

void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, gradxyz1, gradxyz2,
graddist1, graddist2, idx1, idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1,
graddist2, gradxyz1, gradxyz2);
};

void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2);

void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
Expand Down
10 changes: 5 additions & 5 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,9 @@ void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor idx1, const Tensor idx);

void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2, Tensor idx1,
Tensor idx2);
Tensor idx1, Tensor idx2, Tensor graddist1,
Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
Expand Down Expand Up @@ -838,8 +838,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("dist1"), py::arg("dist2"), py::arg("idx1"), py::arg("idx2"));
m.def("chamfer_distance_backward", &chamfer_distance_backward,
"chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"),
py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"),
py::arg("graddist2"), py::arg("idx1"), py::arg("idx2"));
py::arg("idx1"), py::arg("idx2"), py::arg("graddist1"),
py::arg("graddist2"), py::arg("gradxyz1"), py::arg("gradxyz2"));
m.def("prroi_pool_forward", &prroi_pool_forward, "prroi_pool forward",
py::arg("input"), py::arg("rois"), py::arg("output"),
py::arg("pooled_height"), py::arg("pooled_width"),
Expand Down