Skip to content

Commit

Permalink
[Feat] Add pointnet operators (isl-org#5)
Browse files Browse the repository at this point in the history
* add pointnet op

* [Fix] Delete scripts directory

---------

Signed-off-by: Woodman3 <1025760745@qq.com>
Signed-off-by: RDIO <35186529+Woodman3@users.noreply.github.com>
  • Loading branch information
Woodman3 authored Nov 9, 2024
1 parent 501af0e commit 75cbc86
Show file tree
Hide file tree
Showing 10 changed files with 845 additions and 0 deletions.
12 changes: 12 additions & 0 deletions cpp/open3d/ml/paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ target_sources(open3d_paddle_ops PRIVATE
misc/RoiPoolOps.cpp
)

target_sources(open3d_paddle_ops PRIVATE
pointnet/BallQueryOps.cpp
pointnet/InterpolateOps.cpp
pointnet/SamplingOps.cpp
)

target_sources(open3d_paddle_ops PRIVATE
../contrib/Nms.cpp
)
Expand All @@ -51,6 +57,12 @@ if (BUILD_CUDA_MODULE)
misc/VoxelizeOpKernel.cu
)

target_sources(open3d_paddle_ops PRIVATE
pointnet/BallQueryKernel.cu
pointnet/InterpolateKernel.cu
pointnet/SamplingKernel.cu
)

target_sources(open3d_paddle_ops PRIVATE
../contrib/BallQuery.cu
../contrib/InterpolatePoints.cu
Expand Down
74 changes: 74 additions & 0 deletions cpp/open3d/ml/paddle/pointnet/BallQueryKernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// ----------------------------------------------------------------------------
// - Open3D: www.open3d.org -
// ----------------------------------------------------------------------------
// Copyright (c) 2018-2023 www.open3d.org
// SPDX-License-Identifier: MIT
// ----------------------------------------------------------------------------
//***************************************************************************************/
//
// Based on Pointnet2 Library (MIT License):
// https://github.com/sshaoshuai/Pointnet2.PyPaddle
//
// Copyright (c) 2019 Shaoshuai Shi
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//
//***************************************************************************************/

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "open3d/ml/contrib/BallQuery.cuh"
#include "open3d/ml/contrib/cuda_utils.h"
#include "open3d/ml/paddle/pointnet/BallQueryKernel.h"

using namespace open3d::ml::contrib;

void ball_query_launcher(int b,
int n,
int m,
float radius,
int nsample,
const float *new_xyz,
const float *xyz,
int *idx,
uint64_t stream_id) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)

cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);

cudaError_t err;

dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);

ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample,
new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
44 changes: 44 additions & 0 deletions cpp/open3d/ml/paddle/pointnet/BallQueryKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// ----------------------------------------------------------------------------
// - Open3D: www.open3d.org -
// ----------------------------------------------------------------------------
// Copyright (c) 2018-2023 www.open3d.org
// SPDX-License-Identifier: MIT
// ----------------------------------------------------------------------------
//***************************************************************************************/
//
// Based on Pointnet2 Library (MIT License):
// https://github.com/sshaoshuai/Pointnet2.PyPaddle
//
// Copyright (c) 2019 Shaoshuai Shi
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//
//***************************************************************************************/

#pragma once

void ball_query_launcher(int b,
int n,
int m,
float radius,
int nsample,
const float *xyz,
const float *new_xyz,
int *idx,
uint64_t stream_id);
87 changes: 87 additions & 0 deletions cpp/open3d/ml/paddle/pointnet/BallQueryOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// ----------------------------------------------------------------------------
// - Open3D: www.open3d.org -
// ----------------------------------------------------------------------------
// Copyright (c) 2018-2023 www.open3d.org
// SPDX-License-Identifier: MIT
// ----------------------------------------------------------------------------
//***************************************************************************************/
//
// Based on Pointnet2 Library (MIT License):
// https://github.com/sshaoshuai/Pointnet2.PyPaddle
//
// Copyright (c) 2019 Shaoshuai Shi
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//
//***************************************************************************************/

#include <vector>

#include "open3d/ml/paddle/PaddleHelper.h"
#include "open3d/ml/paddle/pointnet/BallQueryKernel.h"
#include "paddle/extension.h"

#ifdef BUILD_CUDA_MODULE

std::vector<paddle::Tensor> BallQuery(paddle::Tensor &xyz,
paddle::Tensor &center,
double radius,
const int64_t nsample) {
int batch_size = xyz.shape()[0];
int pts_num = xyz.shape()[1];
int ball_num = center.shape()[1];

auto place = xyz.place();
paddle::Tensor out =
paddle::full({batch_size, ball_num, nsample}, 0.0f,
paddle::DataType(ToPaddleDtype<int>()), place);

const float *center_data = center.data<float>();
const float *xyz_data = xyz.data<float>();
int *idx = out.data<int>();

ball_query_launcher(batch_size, pts_num, ball_num, radius, nsample,
center_data, xyz_data, idx,
reinterpret_cast<uint64_t>(xyz.stream()));
return {out};
}

std::vector<paddle::DataType> BallQueryInferDtype() {
return {paddle::DataType::FLOAT32};
}

std::vector<std::vector<int64_t>> BallQueryInferShape(
std::vector<int64_t> xyz_shape,
std::vector<int64_t> center_shape,
const int64_t nsample) {
return {{xyz_shape[0], xyz_shape[1], center_shape[1]}};
}

PD_BUILD_OP(open3d_ball_query)
.Inputs({"xyz", "center"})
.Outputs({"out"})
.Attrs({
"radius: double",
"nsample: int64_t",
})
.SetKernelFn(PD_KERNEL(BallQuery))
.SetInferShapeFn(PD_INFER_SHAPE(BallQueryInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(BallQueryInferDtype));

#endif
139 changes: 139 additions & 0 deletions cpp/open3d/ml/paddle/pointnet/InterpolateKernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// ----------------------------------------------------------------------------
// - Open3D: www.open3d.org -
// ----------------------------------------------------------------------------
// Copyright (c) 2018-2023 www.open3d.org
// SPDX-License-Identifier: MIT
// ----------------------------------------------------------------------------
//***************************************************************************************/
//
// Based on Pointnet2 Library (MIT License):
// https://github.com/sshaoshuai/Pointnet2.PyPaddle
//
// Copyright (c) 2019 Shaoshuai Shi
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//
//***************************************************************************************/

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include <vector>

#include "open3d/ml/contrib/InterpolatePoints.cuh"
#include "open3d/ml/contrib/cuda_utils.h"
#include "open3d/ml/paddle/pointnet/InterpolateKernel.h"

using namespace open3d::ml::contrib;

void three_nn_launcher(int b,
int n,
int m,
const float *unknown,
const float *known,
float *dist2,
int *idx,
uint64_t stream_id) {
// unknown: (B, N, 3)
// known: (B, M, 3)
// output:
// dist2: (B, N, 3)
// idx: (B, N, 3)

cudaError_t err;

cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);

dim3 blocks(DIVUP(n, THREADS_PER_BLOCK),
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);

three_nn_kernel<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known,
dist2, idx);

err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}

void three_interpolate_launcher(int b,
int c,
int m,
int n,
const float *points,
const int *idx,
const float *weight,
float *out,
uint64_t stream_id) {
// points: (B, C, M)
// idx: (B, N, 3)
// weight: (B, N, 3)
// output:
// out: (B, C, N)

cudaError_t err;

cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);

dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
three_interpolate_kernel<<<blocks, threads, 0, stream>>>(b, c, m, n, points,
idx, weight, out);

err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}

void three_interpolate_grad_launcher(int b,
int c,
int n,
int m,
const float *grad_out,
const int *idx,
const float *weight,
float *grad_points,
uint64_t stream_id) {
// grad_out: (B, C, N)
// weight: (B, N, 3)
// output:
// grad_points: (B, C, M)

cudaError_t err;

cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);

dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
three_interpolate_grad_kernel<<<blocks, threads, 0, stream>>>(
b, c, n, m, grad_out, idx, weight, grad_points);

err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
Loading

0 comments on commit 75cbc86

Please sign in to comment.