Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix iou3d in parrots #2054

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions mmcv/ops/csrc/parrots/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,11 +564,11 @@ REGISTER_DEVICE_IMPL(group_points_forward_impl, CUDA,
REGISTER_DEVICE_IMPL(group_points_backward_impl, CUDA,
group_points_backward_cuda);

void IoU3DBoxesIoU3DForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
const Tensor boxes_b,
Tensor ans_iou);
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
const Tensor boxes_b,
Tensor ans_overlap);

void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long* mask,
Expand All @@ -580,11 +580,11 @@ void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes,
int boxes_num,
float nms_overlap_thresh);

void iou3d_boxes_iou3d_forward_cuda(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_iou) {
IoU3DBoxesIoU3DForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b,
ans_iou);
void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap) {
IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
};

void iou3d_nms3d_forward_cuda(const Tensor boxes, unsigned long long* mask,
Expand All @@ -600,9 +600,9 @@ void iou3d_nms3d_normal_forward_cuda(const Tensor boxes,
nms_overlap_thresh);
};

void iou3d_boxes_iou3d_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_iou);
void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap);

void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long* mask,
int boxes_num, float nms_overlap_thresh);
Expand All @@ -611,8 +611,8 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
unsigned long long* mask, int boxes_num,
float nms_overlap_thresh);

REGISTER_DEVICE_IMPL(iou3d_boxes_iou3d_forward_impl, CUDA,
iou3d_boxes_iou3d_forward_cuda);
REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA,
iou3d_boxes_overlap_bev_forward_cuda);
REGISTER_DEVICE_IMPL(iou3d_nms3d_forward_impl, CUDA, iou3d_nms3d_forward_cuda);
REGISTER_DEVICE_IMPL(iou3d_nms3d_normal_forward_impl, CUDA,
iou3d_nms3d_normal_forward_cuda);
Expand Down
16 changes: 9 additions & 7 deletions mmcv/ops/csrc/parrots/iou3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ All Rights Reserved 2019-2020.

const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;

void iou3d_boxes_iou3d_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_iou) {
DISPATCH_DEVICE_IMPL(iou3d_boxes_iou3d_forward_impl, num_a, boxes_a, num_b,
boxes_b, ans_iou);
void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap) {
DISPATCH_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, num_a, boxes_a,
num_b, boxes_b, ans_overlap);
}

void iou3d_nms3d_forward_impl(const Tensor boxes, unsigned long long *mask,
Expand All @@ -32,14 +32,16 @@ void iou3d_nms3d_normal_forward_impl(const Tensor boxes,
nms_overlap_thresh);
}

void iou3d_boxes_iou3d_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_iou) {
void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_overlap) {
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0);

iou3d_boxes_iou3d_forward_impl(num_a, boxes_a, num_b, boxes_b, ans_iou);
iou3d_boxes_overlap_bev_forward_impl(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
}

void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
Expand Down
13 changes: 6 additions & 7 deletions mmcv/ops/csrc/parrots/iou3d_parrots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
using namespace parrots;

#ifdef MMCV_WITH_CUDA
void iou3d_boxes_iou3d_forward_cuda_parrots(CudaContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
void iou3d_boxes_overlap_bev_forward_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
auto boxes_a = buildATensor(ctx, ins[0]);
auto boxes_b = buildATensor(ctx, ins[1]);

auto ans_iou = buildATensor(ctx, outs[0]);

iou3d_boxes_iou3d_forward(boxes_a, boxes_b, ans_iou);
iou3d_boxes_overlap_bev_forward(boxes_a, boxes_b, ans_iou);
}

void iou3d_nms3d_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
Expand Down Expand Up @@ -49,10 +48,10 @@ void iou3d_nms3d_normal_forward_cuda_parrots(CudaContext& ctx,
iou3d_nms3d_normal_forward(boxes, keep, keep_num, nms_overlap_thresh);
}

PARROTS_EXTENSION_REGISTER(iou3d_boxes_iou3d_forward)
PARROTS_EXTENSION_REGISTER(iou3d_boxes_overlap_bev_forward)
.input(2)
.output(1)
.apply(iou3d_boxes_iou3d_forward_cuda_parrots)
.apply(iou3d_boxes_overlap_bev_forward_cuda_parrots)
.done();

PARROTS_EXTENSION_REGISTER(iou3d_nms3d_forward)
Expand Down
3 changes: 2 additions & 1 deletion mmcv/ops/csrc/parrots/iou3d_pytorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#include <torch/extension.h>
using namespace at;

void iou3d_boxes_iou3d_forward(Tensor boxes_a, Tensor boxes_b, Tensor ans_iou);
void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_overlap);

void iou3d_nms3d_forward(Tensor boxes, Tensor keep, Tensor keep_num,
float nms_overlap_thresh);
Expand Down