diff --git a/mmcv/ops/chamfer_distance.py b/mmcv/ops/chamfer_distance.py index d68eafb47c..1f908a5bbc 100644 --- a/mmcv/ops/chamfer_distance.py +++ b/mmcv/ops/chamfer_distance.py @@ -56,9 +56,11 @@ def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]: @staticmethod @once_differentiable - def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor, - grad_idx1: Tensor, - grad_idx2: Tensor) -> Tuple[Tensor, Tensor]: + def backward(ctx, + grad_dist1: Tensor, + grad_dist2: Tensor, + grad_idx1=None, + grad_idx2=None) -> Tuple[Tensor, Tensor]: """ Args: @@ -66,10 +68,6 @@ def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor, (xyz1 to xyz2) with shape (B, N). grad_dist2 (Tensor): Gradient of chamfer distance (xyz2 to xyz1) with shape (B, N). - grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) - with shape (B, N), which be used in compute gradient. - grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) - with shape (B, N), which be used in compute gradient. Returns: Tuple[Tensor, Tensor]: @@ -86,9 +84,9 @@ def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor, grad_xyz1 = torch.zeros(xyz1.size()).to(device) grad_xyz2 = torch.zeros(xyz2.size()).to(device) - ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2, - grad_dist1, grad_dist2, idx1, - idx2) + ext_module.chamfer_distance_backward(xyz1, xyz2, idx1, idx2, + grad_dist1, grad_dist2, grad_xyz1, + grad_xyz2) return grad_xyz1, grad_xyz2 diff --git a/mmcv/ops/csrc/parrots/chamfer_distance.cpp b/mmcv/ops/csrc/parrots/chamfer_distance.cpp new file mode 100644 index 0000000000..dcff698931 --- /dev/null +++ b/mmcv/ops/csrc/parrots/chamfer_distance.cpp @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2, + idx1, idx2); +} + +void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2, + graddist1, graddist2, gradxyz1, gradxyz2); +} + +void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + +void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2, + gradxyz1, gradxyz2); +} diff --git a/mmcv/ops/csrc/parrots/chamfer_distance_parrots.cpp b/mmcv/ops/csrc/parrots/chamfer_distance_parrots.cpp new file mode 100644 index 0000000000..db8eff1d6f --- /dev/null +++ b/mmcv/ops/csrc/parrots/chamfer_distance_parrots.cpp @@ -0,0 +1,51 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include +#include +#include + +#include "chamfer_distance_pytorch.h" +using namespace parrots; + +#ifdef MMCV_WITH_CUDA +void chamfer_distance_forward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + auto xyz1 = buildATensor(ctx, ins[0]); + auto xyz2 = buildATensor(ctx, ins[1]); + auto dist1 = buildATensor(ctx, outs[0]); + auto dist2 = buildATensor(ctx, outs[1]); + auto idx1 = buildATensor(ctx, outs[2]); + auto idx2 = buildATensor(ctx, outs[3]); + chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + +void chamfer_distance_backward_cuda_parrots(CudaContext& ctx, + const SSElement& attr, + const OperatorBase::in_list_t& ins, + OperatorBase::out_list_t& outs) { + auto xyz1 = buildATensor(ctx, ins[0]); + auto xyz2 = buildATensor(ctx, ins[1]); + auto idx1 = buildATensor(ctx, ins[2]); + auto idx2 = buildATensor(ctx, ins[3]); + auto graddist1 = buildATensor(ctx, ins[4]); + auto graddist2 = buildATensor(ctx, ins[5]); + auto gradxyz1 = buildATensor(ctx, outs[0]); + auto gradxyz2 = buildATensor(ctx, outs[1]); + chamfer_distance_backward(xyz1, xyz2, idx1, idx2, graddist1, graddist2, + gradxyz1, gradxyz2); +} + +PARROTS_EXTENSION_REGISTER(chamfer_distance_forward) + .input(2) + .output(4) + .apply(chamfer_distance_forward_cuda_parrots) + .done(); + +PARROTS_EXTENSION_REGISTER(chamfer_distance_backward) + .input(6) + .output(2) + .apply(chamfer_distance_backward_cuda_parrots) + .done(); + +#endif diff --git a/mmcv/ops/csrc/parrots/chamfer_distance_pytorch.h b/mmcv/ops/csrc/parrots/chamfer_distance_pytorch.h new file mode 100644 index 0000000000..6405526b0c --- /dev/null +++ b/mmcv/ops/csrc/parrots/chamfer_distance_pytorch.h @@ -0,0 +1,16 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ACTIVE_CHAMFER_DISTANCE_PYTORCH_H +#define ACTIVE_CHAMFER_DISTANCE_PYTORCH_H +#include +using namespace at; + +void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx); + +void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2); + +#endif // ACTIVE_CHAMFER_DISTANCE_PYTORCH_H diff --git a/mmcv/ops/csrc/parrots/cudabind.cpp b/mmcv/ops/csrc/parrots/cudabind.cpp index 04c6e36c4a..35c18da321 100644 --- a/mmcv/ops/csrc/parrots/cudabind.cpp +++ b/mmcv/ops/csrc/parrots/cudabind.cpp @@ -1589,3 +1589,40 @@ Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA, diff_iou_rotated_sort_vertices_forward_cuda); + +void ChamferDistanceForwardCUDAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, const Tensor dist1, + const Tensor dist2, const Tensor idx1, const Tensor idx2); + +void ChamferDistanceBackwardCUDAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2, + Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2); + +void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + ChamferDistanceForwardCUDAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1, + idx2); +}; + +void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1, + graddist2, gradxyz1, gradxyz2); +}; + +void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2); + +void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2); + +REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA, + chamfer_distance_forward_cuda); +REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA, + chamfer_distance_backward_cuda); diff --git a/mmcv/ops/csrc/pytorch/chamfer_distance.cpp b/mmcv/ops/csrc/pytorch/chamfer_distance.cpp index 6ea1ba675e..dcff698931 100644 --- a/mmcv/ops/csrc/pytorch/chamfer_distance.cpp +++ b/mmcv/ops/csrc/pytorch/chamfer_distance.cpp @@ -13,11 +13,11 @@ void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, } void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, - Tensor gradxyz1, Tensor gradxyz2, - Tensor graddist1, Tensor graddist2, - Tensor idx1, Tensor idx2) { - DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, gradxyz1, - gradxyz2, graddist1, graddist2, idx1, idx2); + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2, + graddist1, graddist2, gradxyz1, gradxyz2); } void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, @@ -27,9 +27,9 @@ void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, } void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, - Tensor gradxyz1, Tensor gradxyz2, - Tensor graddist1, Tensor graddist2, Tensor idx1, - Tensor idx2) { - chamfer_distance_backward_impl(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, - graddist2, idx1, idx2); + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2, + gradxyz1, gradxyz2); } diff --git a/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu index 980482eb54..6effa29ee7 100644 --- a/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu @@ -33,8 +33,8 @@ void ChamferDistanceForwardCUDAKernelLauncher( } void ChamferDistanceBackwardCUDAKernelLauncher( - const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2, - Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2) { + const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2, + Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2) { int batch_size = xyz1.size(0); int n = xyz1.size(1); int m = xyz2.size(1); diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index 12cf7afdc2..1df35f510d 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1706,8 +1706,8 @@ void ChamferDistanceForwardCUDAKernelLauncher( const Tensor dist2, const Tensor idx1, const Tensor idx2); void ChamferDistanceBackwardCUDAKernelLauncher( - const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2, - Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2); + const Tensor xyz1, const Tensor xyz2, Tensor idx1, Tensor idx2, + Tensor grad_dist1, Tensor grad_dist2, Tensor grad_xyz1, Tensor grad_xyz2); void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2, const Tensor dist1, const Tensor dist2, @@ -1717,11 +1717,11 @@ void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2, }; void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2, - Tensor gradxyz1, Tensor gradxyz2, - Tensor graddist1, Tensor graddist2, - Tensor idx1, Tensor idx2) { - ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, gradxyz1, gradxyz2, - graddist1, graddist2, idx1, idx2); + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2) { + ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, idx1, idx2, graddist1, + graddist2, gradxyz1, gradxyz2); }; void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, @@ -1729,9 +1729,9 @@ void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, const Tensor idx1, const Tensor idx2); void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, - Tensor gradxyz1, Tensor gradxyz2, - Tensor graddist1, Tensor graddist2, - Tensor idx1, Tensor idx2); + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2); REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA, chamfer_distance_forward_cuda); diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index c134090871..6fb7a8a53f 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -419,9 +419,9 @@ void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, const Tensor idx1, const Tensor idx); void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, - Tensor gradxyz1, Tensor gradxyz2, - Tensor graddist1, Tensor graddist2, Tensor idx1, - Tensor idx2); + Tensor idx1, Tensor idx2, Tensor graddist1, + Tensor graddist2, Tensor gradxyz1, + Tensor gradxyz2); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), @@ -838,8 +838,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("dist1"), py::arg("dist2"), py::arg("idx1"), py::arg("idx2")); m.def("chamfer_distance_backward", &chamfer_distance_backward, "chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"), - py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"), - py::arg("graddist2"), py::arg("idx1"), py::arg("idx2")); + py::arg("idx1"), py::arg("idx2"), py::arg("graddist1"), + py::arg("graddist2"), py::arg("gradxyz1"), py::arg("gradxyz2")); m.def("prroi_pool_forward", &prroi_pool_forward, "prroi_pool forward", py::arg("input"), py::arg("rois"), py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),