-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add ChamferDistance op for Parrots (#2189)
* Support Parrots extension for op ChamferDistance * Fix lint * Adapt op backward error of IntTensor for parrots * Fix Lint
- Loading branch information
Showing
9 changed files
with
174 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
// Modified from | ||
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp | ||
|
||
#include "pytorch_cpp_helper.hpp" | ||
#include "pytorch_device_registry.hpp" | ||
|
||
void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, | ||
const Tensor dist1, const Tensor dist2, | ||
const Tensor idx1, const Tensor idx2) { | ||
DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2, | ||
idx1, idx2); | ||
} | ||
|
||
void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, | ||
Tensor idx1, Tensor idx2, Tensor graddist1, | ||
Tensor graddist2, Tensor gradxyz1, | ||
Tensor gradxyz2) { | ||
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, idx1, idx2, | ||
graddist1, graddist2, gradxyz1, gradxyz2); | ||
} | ||
|
||
void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, | ||
const Tensor dist1, const Tensor dist2, | ||
const Tensor idx1, const Tensor idx2) { | ||
chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2); | ||
} | ||
|
||
void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, | ||
Tensor idx1, Tensor idx2, Tensor graddist1, | ||
Tensor graddist2, Tensor gradxyz1, | ||
Tensor gradxyz2) { | ||
chamfer_distance_backward_impl(xyz1, xyz2, idx1, idx2, graddist1, graddist2, | ||
gradxyz1, gradxyz2); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include <parrots/compute/aten.hpp> | ||
#include <parrots/extension.hpp> | ||
#include <parrots/foundation/ssattrs.hpp> | ||
|
||
#include "chamfer_distance_pytorch.h" | ||
using namespace parrots; | ||
|
||
#ifdef MMCV_WITH_CUDA | ||
void chamfer_distance_forward_cuda_parrots(CudaContext& ctx, | ||
const SSElement& attr, | ||
const OperatorBase::in_list_t& ins, | ||
OperatorBase::out_list_t& outs) { | ||
auto xyz1 = buildATensor(ctx, ins[0]); | ||
auto xyz2 = buildATensor(ctx, ins[1]); | ||
auto dist1 = buildATensor(ctx, outs[0]); | ||
auto dist2 = buildATensor(ctx, outs[1]); | ||
auto idx1 = buildATensor(ctx, outs[2]); | ||
auto idx2 = buildATensor(ctx, outs[3]); | ||
chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); | ||
} | ||
|
||
void chamfer_distance_backward_cuda_parrots(CudaContext& ctx, | ||
const SSElement& attr, | ||
const OperatorBase::in_list_t& ins, | ||
OperatorBase::out_list_t& outs) { | ||
auto xyz1 = buildATensor(ctx, ins[0]); | ||
auto xyz2 = buildATensor(ctx, ins[1]); | ||
auto idx1 = buildATensor(ctx, ins[2]); | ||
auto idx2 = buildATensor(ctx, ins[3]); | ||
auto graddist1 = buildATensor(ctx, ins[4]); | ||
auto graddist2 = buildATensor(ctx, ins[5]); | ||
auto gradxyz1 = buildATensor(ctx, outs[0]); | ||
auto gradxyz2 = buildATensor(ctx, outs[1]); | ||
chamfer_distance_backward(xyz1, xyz2, idx1, idx2, graddist1, graddist2, | ||
gradxyz1, gradxyz2); | ||
} | ||
|
||
PARROTS_EXTENSION_REGISTER(chamfer_distance_forward) | ||
.input(2) | ||
.output(4) | ||
.apply(chamfer_distance_forward_cuda_parrots) | ||
.done(); | ||
|
||
PARROTS_EXTENSION_REGISTER(chamfer_distance_backward) | ||
.input(6) | ||
.output(2) | ||
.apply(chamfer_distance_backward_cuda_parrots) | ||
.done(); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#ifndef ACTIVE_CHAMFER_DISTANCE_PYTORCH_H | ||
#define ACTIVE_CHAMFER_DISTANCE_PYTORCH_H | ||
#include <torch/extension.h> | ||
using namespace at; | ||
|
||
void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, | ||
const Tensor dist1, const Tensor dist2, | ||
const Tensor idx1, const Tensor idx); | ||
|
||
void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, | ||
Tensor idx1, Tensor idx2, Tensor graddist1, | ||
Tensor graddist2, Tensor gradxyz1, | ||
Tensor gradxyz2); | ||
|
||
#endif // ACTIVE_CHAMFER_DISTANCE_PYTORCH_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters