From 016c03e75ce8d45ad5c1795a3c4259cb2cd12bf4 Mon Sep 17 00:00:00 2001 From: Prantl Date: Tue, 16 Feb 2021 17:29:47 +0100 Subject: [PATCH] PointRCNN ops (#3021) * pointnet++ and roipooling operators added, only torch gpu supported * style * moved ATen import in .cu files * ops builded for cuda module only * added newlines and namespace, removed npy files * removed security risk of urllib * removed security risk of urllib #2 * added license header * corrected link in header * style Co-authored-by: Yixing Lao --- cpp/open3d/ml/contrib/RoIPoolKernel.cu | 328 +++++++++++++++++ cpp/open3d/ml/contrib/RoIPoolKernel.h | 50 +++ cpp/open3d/ml/pytorch/CMakeLists.txt | 10 + cpp/open3d/ml/pytorch/misc/RoIPoolOps.cpp | 70 ++++ .../ml/pytorch/pointnet/BallQueryKernel.cu | 110 ++++++ .../ml/pytorch/pointnet/BallQueryKernel.h | 37 ++ .../ml/pytorch/pointnet/BallQueryOps.cpp | 62 ++++ .../ml/pytorch/pointnet/GroupPointsKernel.cu | 147 ++++++++ .../ml/pytorch/pointnet/GroupPointsKernel.h | 46 +++ .../ml/pytorch/pointnet/GroupPointsOps.cpp | 85 +++++ .../ml/pytorch/pointnet/InterpolateKernel.cu | 246 +++++++++++++ .../ml/pytorch/pointnet/InterpolateKernel.h | 54 +++ .../ml/pytorch/pointnet/InterpolateOps.cpp | 131 +++++++ .../ml/pytorch/pointnet/SamplingKernel.cu | 342 ++++++++++++++++++ .../ml/pytorch/pointnet/SamplingKernel.h | 47 +++ .../ml/pytorch/pointnet/SamplingOps.cpp | 113 ++++++ cpp/open3d/ml/pytorch/pointnet/cuda_utils.h | 40 ++ python/test/ml_ops/mltest.py | 15 + python/test/ml_ops/test_gathering.py | 57 +++ python/test/ml_ops/test_group_pts.py | 57 +++ python/test/ml_ops/test_query_pts.py | 60 +++ python/test/ml_ops/test_roi_pool.py | 65 ++++ python/test/ml_ops/test_sampling.py | 55 +++ python/test/ml_ops/test_three_interp.py | 60 +++ python/test/ml_ops/test_three_nn.py | 61 ++++ 25 files changed, 2348 insertions(+) create mode 100644 cpp/open3d/ml/contrib/RoIPoolKernel.cu create mode 100644 cpp/open3d/ml/contrib/RoIPoolKernel.h create mode 100644 cpp/open3d/ml/pytorch/misc/RoIPoolOps.cpp create mode 100644 cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.cu create mode 100644 cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.h create mode 100644 cpp/open3d/ml/pytorch/pointnet/BallQueryOps.cpp create mode 100644 cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.cu create mode 100644 cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.h create mode 100644 cpp/open3d/ml/pytorch/pointnet/GroupPointsOps.cpp create mode 100644 cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.cu create mode 100644 cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.h create mode 100644 cpp/open3d/ml/pytorch/pointnet/InterpolateOps.cpp create mode 100644 cpp/open3d/ml/pytorch/pointnet/SamplingKernel.cu create mode 100644 cpp/open3d/ml/pytorch/pointnet/SamplingKernel.h create mode 100644 cpp/open3d/ml/pytorch/pointnet/SamplingOps.cpp create mode 100644 cpp/open3d/ml/pytorch/pointnet/cuda_utils.h create mode 100644 python/test/ml_ops/test_gathering.py create mode 100644 python/test/ml_ops/test_group_pts.py create mode 100644 python/test/ml_ops/test_query_pts.py create mode 100644 python/test/ml_ops/test_roi_pool.py create mode 100644 python/test/ml_ops/test_sampling.py create mode 100644 python/test/ml_ops/test_three_interp.py create mode 100644 python/test/ml_ops/test_three_nn.py diff --git a/cpp/open3d/ml/contrib/RoIPoolKernel.cu b/cpp/open3d/ml/contrib/RoIPoolKernel.cu new file mode 100644 index 00000000000..b801c13d334 --- /dev/null +++ b/cpp/open3d/ml/contrib/RoIPoolKernel.cu @@ -0,0 +1,328 @@ +//***************************************************************************************/ +// +// Based on PointRCNN Library (MIT License): +// https://github.com/sshaoshuai/PointRCNN +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include + +#include "open3d/ml/contrib/RoIPoolKernel.h" + +namespace open3d { +namespace ml { +namespace contrib { + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +__device__ inline int pt_in_box3d(float x, + float y, + float z, + float cx, + float bottom_y, + float cz, + float h, + float w, + float l, + float angle, + float max_dis) { + float x_rot, z_rot, cosa, sina, cy; + int in_flag; + cy = bottom_y - h / 2.0; + if ((fabsf(x - cx) > max_dis) || (fabsf(y - cy) > h / 2.0) || + (fabsf(z - cz) > max_dis)) { + return 0; + } + cosa = cos(angle); + sina = sin(angle); + x_rot = (x - cx) * cosa + (z - cz) * (-sina); + z_rot = (x - cx) * sina + (z - cz) * cosa; + + in_flag = (x_rot >= -l / 2.0) & (x_rot <= l / 2.0) & (z_rot >= -w / 2.0) & + (z_rot <= w / 2.0); + return in_flag; +} + +__global__ void roipool3d_forward(int batch_size, + int pts_num, + int boxes_num, + int feature_in_len, + int sampled_pts_num, + const float *xyz, + const float *boxes3d, + const float *pts_feature, + float *pooled_features, + int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + + int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (boxes_idx >= boxes_num) { + return; + } + + for (int i = 0; i < batch_size; i++) { + int cnt = 0; + for (int k = 0; k < pts_num; k++) { + int pt_offset = i * pts_num * 3 + k * 3; + int box_offset = i * boxes_num * 7 + boxes_idx * 7; + + int cur_in_flag = pt_in_box3d( + xyz[pt_offset], xyz[pt_offset + 1], xyz[pt_offset + 2], + boxes3d[box_offset], boxes3d[box_offset + 1], + boxes3d[box_offset + 2], boxes3d[box_offset + 3], + boxes3d[box_offset + 4], boxes3d[box_offset + 5], + boxes3d[box_offset + 6], 10.0); + if (cur_in_flag) { + if (cnt < sampled_pts_num) { + int feature_out_offset = + i * boxes_num * sampled_pts_num * + (3 + feature_in_len) + + boxes_idx * sampled_pts_num * (3 + feature_in_len) + + cnt * (3 + feature_in_len); + + int feature_in_offset = + i * pts_num * feature_in_len + k * feature_in_len; + + // copy xyz + for (int j = 0; j < 3; j++) + pooled_features[feature_out_offset + j] = + xyz[pt_offset + j]; + + // copy feature + for (int j = 0; j < feature_in_len; j++) + pooled_features[feature_out_offset + 3 + j] = + pts_feature[feature_in_offset + j]; + + cnt++; + } else + break; + } + } + + if (cnt == 0) { + pooled_empty_flag[i * boxes_num + boxes_idx] = 1; + } else if (cnt < sampled_pts_num) { + // duplicate same points for sampling + for (int k = cnt; k < sampled_pts_num; k++) { + int duplicate_idx = k % cnt; + int src_offset = + i * boxes_num * sampled_pts_num * (3 + feature_in_len) + + boxes_idx * sampled_pts_num * (3 + feature_in_len) + + duplicate_idx * (3 + feature_in_len); + int dst_offset = + i * boxes_num * sampled_pts_num * (3 + feature_in_len) + + boxes_idx * sampled_pts_num * (3 + feature_in_len) + + k * (3 + feature_in_len); + for (int j = 0; j < 3 + feature_in_len; j++) + pooled_features[dst_offset + j] = + pooled_features[src_offset + j]; + } + } + } +} + +__global__ void assign_pts_to_box3d(int batch_size, + int pts_num, + int boxes_num, + const float *xyz, + const float *boxes3d, + int *pts_assign) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means + // background points + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int box_idx = blockIdx.y; + int bs_idx = blockIdx.z; + + if (pt_idx >= pts_num || box_idx >= boxes_num || bs_idx >= batch_size) { + return; + } + int assign_idx = + bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx; + pts_assign[assign_idx] = 0; + + int box_offset = bs_idx * boxes_num * 7 + box_idx * 7; + int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3; + + int cur_in_flag = + pt_in_box3d(xyz[pt_offset], xyz[pt_offset + 1], xyz[pt_offset + 2], + boxes3d[box_offset], boxes3d[box_offset + 1], + boxes3d[box_offset + 2], boxes3d[box_offset + 3], + boxes3d[box_offset + 4], boxes3d[box_offset + 5], + boxes3d[box_offset + 6], 10.0); + + pts_assign[assign_idx] = cur_in_flag; + // printf("bs=%d, pt=%d, in=%d\n", bs_idx, pt_idx, pts_assign[bs_idx * + // pts_num + pt_idx]); +} + +__global__ void get_pooled_idx(int batch_size, + int pts_num, + int boxes_num, + int sampled_pts_num, + const int *pts_assign, + int *pts_idx, + int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params pts_feature: (B, N, C) + // params pts_assign: (B, N) + // params pts_idx: (B, M, 512) + // params pooled_empty_flag: (B, M) + + int boxes_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (boxes_idx >= boxes_num) { + return; + } + + int bs_idx = blockIdx.y; + + int cnt = 0; + for (int k = 0; k < pts_num; k++) { + if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + + boxes_idx]) { + if (cnt < sampled_pts_num) { + pts_idx[bs_idx * boxes_num * sampled_pts_num + + boxes_idx * sampled_pts_num + cnt] = k; + cnt++; + } else + break; + } + } + + if (cnt == 0) { + pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1; + } else if (cnt < sampled_pts_num) { + // duplicate same points for sampling + for (int k = cnt; k < sampled_pts_num; k++) { + int duplicate_idx = k % cnt; + int base_offset = bs_idx * boxes_num * sampled_pts_num + + boxes_idx * sampled_pts_num; + pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx]; + } + } +} + +__global__ void roipool3d_forward(int batch_size, + int pts_num, + int boxes_num, + int feature_in_len, + int sampled_pts_num, + const float *xyz, + const int *pts_idx, + const float *pts_feature, + float *pooled_features, + int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params pts_idx: (B, M, 512) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + + int sample_pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + int box_idx = blockIdx.y; + int bs_idx = blockIdx.z; + + if (sample_pt_idx >= sampled_pts_num || box_idx >= boxes_num || + bs_idx >= batch_size) { + return; + } + + if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) { + return; + } + + int temp_idx = bs_idx * boxes_num * sampled_pts_num + + box_idx * sampled_pts_num + sample_pt_idx; + int src_pt_idx = pts_idx[temp_idx]; + int dst_feature_offset = temp_idx * (3 + feature_in_len); + + for (int j = 0; j < 3; j++) + pooled_features[dst_feature_offset + j] = + xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j]; + + int src_feature_offset = + bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len; + for (int j = 0; j < feature_in_len; j++) + pooled_features[dst_feature_offset + 3 + j] = + pts_feature[src_feature_offset + j]; +} + +void roipool3dLauncher(int batch_size, + int pts_num, + int boxes_num, + int feature_in_len, + int sampled_pts_num, + const float *xyz, + const float *boxes3d, + const float *pts_feature, + float *pooled_features, + int *pooled_empty_flag) { + // printf("batch_size=%d, pts_num=%d, boxes_num=%d\n", batch_size, pts_num, + // boxes_num); + int *pts_assign = NULL; + cudaMalloc(&pts_assign, batch_size * pts_num * boxes_num * + sizeof(int)); // (batch_size, N, M) + // cudaMemset(&pts_assign, -1, batch_size * pts_num * boxes_num * + // sizeof(int)); + + dim3 blocks(DIVUP(pts_num, THREADS_PER_BLOCK), boxes_num, + batch_size); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + assign_pts_to_box3d<<>>(batch_size, pts_num, boxes_num, + xyz, boxes3d, pts_assign); + + int *pts_idx = NULL; + cudaMalloc(&pts_idx, + batch_size * boxes_num * sampled_pts_num * + sizeof(int)); // (batch_size, M, sampled_pts_num) + + dim3 blocks2(DIVUP(boxes_num, THREADS_PER_BLOCK), + batch_size); // blockIdx.x(col), blockIdx.y(row) + get_pooled_idx<<>>(batch_size, pts_num, boxes_num, + sampled_pts_num, pts_assign, pts_idx, + pooled_empty_flag); + + dim3 blocks_pool(DIVUP(sampled_pts_num, THREADS_PER_BLOCK), boxes_num, + batch_size); + roipool3d_forward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + xyz, pts_idx, pts_feature, pooled_features, pooled_empty_flag); + + cudaFree(pts_assign); + cudaFree(pts_idx); + +#ifdef DEBUG + cudaDeviceSynchronize(); // for using printf in kernel function +#endif +} + +} // namespace contrib +} // namespace ml +} // namespace open3d diff --git a/cpp/open3d/ml/contrib/RoIPoolKernel.h b/cpp/open3d/ml/contrib/RoIPoolKernel.h new file mode 100644 index 00000000000..ca35e9a2411 --- /dev/null +++ b/cpp/open3d/ml/contrib/RoIPoolKernel.h @@ -0,0 +1,50 @@ +//***************************************************************************************/ +// +// Based on PointRCNN Library (MIT License): +// https://github.com/sshaoshuai/PointRCNN +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +namespace open3d { +namespace ml { +namespace contrib { + +#ifdef BUILD_CUDA_MODULE + +void roipool3dLauncher(int batch_size, + int pts_num, + int boxes_num, + int feature_in_len, + int sampled_pts_num, + const float *xyz, + const float *boxes3d, + const float *pts_feature, + float *pooled_features, + int *pooled_empty_flag); +#endif + +} // namespace contrib +} // namespace ml +} // namespace open3d diff --git a/cpp/open3d/ml/pytorch/CMakeLists.txt b/cpp/open3d/ml/pytorch/CMakeLists.txt index 815412d9cbb..dad821c2fb2 100644 --- a/cpp/open3d/ml/pytorch/CMakeLists.txt +++ b/cpp/open3d/ml/pytorch/CMakeLists.txt @@ -38,6 +38,11 @@ set(TORCH_OPS_SOURCES "misc/VoxelizeOps.cpp" "misc/VoxelizeOpKernel.cpp" "misc/NmsOps.cpp" + "misc/RoIPoolOps.cpp" + "pointnet/BallQueryOps.cpp" + "pointnet/GroupPointsOps.cpp" + "pointnet/InterpolateOps.cpp" + "pointnet/SamplingOps.cpp" "../contrib/Nms.cpp" ) @@ -54,6 +59,11 @@ set(TORCH_OPS_CUDA_SOURCES "misc/VoxelizeOpKernel.cu" "../impl/continuous_conv/ContinuousConvCUDAKernels.cu" "../contrib/Nms.cu" + "../contrib/RoIPoolKernel.cu" + "pointnet/BallQueryKernel.cu" + "pointnet/GroupPointsKernel.cu" + "pointnet/InterpolateKernel.cu" + "pointnet/SamplingKernel.cu" ) if(BUILD_CUDA_MODULE) diff --git a/cpp/open3d/ml/pytorch/misc/RoIPoolOps.cpp b/cpp/open3d/ml/pytorch/misc/RoIPoolOps.cpp new file mode 100644 index 00000000000..e87561aa70d --- /dev/null +++ b/cpp/open3d/ml/pytorch/misc/RoIPoolOps.cpp @@ -0,0 +1,70 @@ +//***************************************************************************************/ +// +// Based on PointRCNN Library (MIT License): +// https://github.com/sshaoshuai/PointRCNN +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include "open3d/ml/contrib/RoIPoolKernel.h" +#include "open3d/ml/pytorch/TorchHelper.h" + +#ifdef BUILD_CUDA_MODULE +std::tuple roipool3d( + torch::Tensor xyz, + torch::Tensor boxes3d, + torch::Tensor pts_feature, + const int64_t sampled_pts_num) { + int batch_size = xyz.size(0); + int pts_num = xyz.size(1); + int boxes_num = boxes3d.size(1); + int feature_in_len = pts_feature.size(2); + + auto device = xyz.device(); + torch::Tensor features = torch::zeros( + {batch_size, boxes_num, sampled_pts_num, 3 + feature_in_len}, + torch::dtype(ToTorchDtype()).device(device)); + + torch::Tensor empty_flag = + torch::zeros({batch_size, boxes_num}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *xyz_data = xyz.data(); + const float *boxes3d_data = boxes3d.data(); + const float *pts_feature_data = pts_feature.data(); + float *pooled_features_data = features.data(); + int *pooled_empty_flag_data = empty_flag.data(); + + open3d::ml::contrib::roipool3dLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + xyz_data, boxes3d_data, pts_feature_data, pooled_features_data, + pooled_empty_flag_data); + + return std::tuple(features, empty_flag); +} + +static auto registry = torch::RegisterOperators( + "open3d::roipool3d(Tensor xyz, Tensor boxes3d," + "Tensor pts_feature, int sampled_pts_num)" + " -> (Tensor features, Tensor flags)", + &roipool3d); +#endif diff --git a/cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.cu b/cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.cu new file mode 100644 index 00000000000..fb29e8bf66e --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.cu @@ -0,0 +1,110 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include +#include + +#include "ATen/cuda/CUDAContext.h" +#include "open3d/ml/pytorch/pointnet/BallQueryKernel.h" +#include "open3d/ml/pytorch/pointnet/cuda_utils.h" + +__global__ void ball_query_kernel(int b, + int n, + int m, + float radius, + int nsample, + const float *__restrict__ new_xyz, + const float *__restrict__ xyz, + int *__restrict__ idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + + float radius2 = radius * radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 < radius2) { + if (cnt == 0) { + for (int l = 0; l < nsample; ++l) { + idx[l] = k; + } + } + idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } +} + +void ball_query_launcher(int b, + int n, + int m, + float radius, + int nsample, + const float *new_xyz, + const float *xyz, + int *idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + ball_query_kernel<<>>(b, n, m, radius, nsample, + new_xyz, xyz, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.h b/cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.h new file mode 100644 index 00000000000..c76683b9db3 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/BallQueryKernel.h @@ -0,0 +1,37 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void ball_query_launcher(int b, + int n, + int m, + float radius, + int nsample, + const float *xyz, + const float *new_xyz, + int *idx); diff --git a/cpp/open3d/ml/pytorch/pointnet/BallQueryOps.cpp b/cpp/open3d/ml/pytorch/pointnet/BallQueryOps.cpp new file mode 100644 index 00000000000..32490b4ab9f --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/BallQueryOps.cpp @@ -0,0 +1,62 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include + +#include "open3d/ml/pytorch/TorchHelper.h" +#include "open3d/ml/pytorch/pointnet/BallQueryKernel.h" +#include "torch/script.h" + +#ifdef BUILD_CUDA_MODULE +torch::Tensor ball_query(torch::Tensor xyz, + torch::Tensor center, + double radius, + const int64_t nsample) { + int batch_size = xyz.size(0); + int pts_num = xyz.size(1); + int ball_num = center.size(1); + + auto device = xyz.device(); + torch::Tensor out = + torch::zeros({batch_size, ball_num, nsample}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *center_data = center.data(); + const float *xyz_data = xyz.data(); + int *idx = out.data(); + + ball_query_launcher(batch_size, pts_num, ball_num, radius, nsample, + center_data, xyz_data, idx); + return out; +} + +static auto registry = torch::RegisterOperators( + "open3d::ball_query(Tensor xyz, Tensor center," + "float radius, int nsample)" + " -> Tensor out", + &ball_query); +#endif diff --git a/cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.cu b/cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.cu new file mode 100644 index 00000000000..7eb4228368d --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.cu @@ -0,0 +1,147 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include + +#include "ATen/cuda/CUDAContext.h" +#include "open3d/ml/pytorch/pointnet/GroupPointsKernel.h" +#include "open3d/ml/pytorch/pointnet/cuda_utils.h" + +__global__ void group_points_grad_kernel(int b, + int c, + int n, + int npoints, + int nsample, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0], grad_out[0]); +} + +void group_points_grad_launcher(int b, + int c, + int n, + int npoints, + int nsample, + const float *grad_out, + const int *idx, + float *grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_grad_kernel<<>>( + b, c, n, npoints, nsample, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void group_points_kernel(int b, + int c, + int n, + int npoints, + int nsample, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + +void group_points_launcher(int b, + int c, + int n, + int npoints, + int nsample, + const float *points, + const int *idx, + float *out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_kernel<<>>( + b, c, n, npoints, nsample, points, idx, out); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.h b/cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.h new file mode 100644 index 00000000000..9ccff8b7367 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/GroupPointsKernel.h @@ -0,0 +1,46 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void group_points_launcher(int b, + int c, + int n, + int npoints, + int nsample, + const float *points, + const int *idx, + float *out); + +void group_points_grad_launcher(int b, + int c, + int n, + int npoints, + int nsample, + const float *grad_out, + const int *idx, + float *grad_points); diff --git a/cpp/open3d/ml/pytorch/pointnet/GroupPointsOps.cpp b/cpp/open3d/ml/pytorch/pointnet/GroupPointsOps.cpp new file mode 100644 index 00000000000..830a6d9683c --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/GroupPointsOps.cpp @@ -0,0 +1,85 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include "open3d/ml/pytorch/TorchHelper.h" +#include "open3d/ml/pytorch/pointnet/GroupPointsKernel.h" +#include "torch/script.h" + +#ifdef BUILD_CUDA_MODULE +torch::Tensor group_points_grad(torch::Tensor grad_out, + torch::Tensor idx, + const int64_t N) { + int batch_size = grad_out.size(0); + int C = grad_out.size(1); + int feature_size = grad_out.size(2); + int sample_size = grad_out.size(3); + + auto device = grad_out.device(); + torch::Tensor out = + torch::zeros({batch_size, C, N}, + torch::dtype(ToTorchDtype()).device(device)); + + float *grad_points = out.data(); + const int *idx_data = idx.data(); + const float *grad_out_data = grad_out.data(); + + group_points_grad_launcher(batch_size, C, N, feature_size, sample_size, + grad_out_data, idx_data, grad_points); + return out; +} + +torch::Tensor group_points(torch::Tensor points, torch::Tensor idx) { + int batch_size = idx.size(0); + int feature_size = idx.size(1); + int sample_size = idx.size(2); + int C = points.size(1); + int N = points.size(2); + + auto device = points.device(); + torch::Tensor out = + torch::zeros({batch_size, C, feature_size, sample_size}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *points_data = points.data(); + const int *idx_data = idx.data(); + float *out_data = out.data(); + + group_points_launcher(batch_size, C, N, feature_size, sample_size, + points_data, idx_data, out_data); + return out; +} + +static auto registry = torch::RegisterOperators( + "open3d::group_points(Tensor points, Tensor idx)" + " -> Tensor out", + &group_points); + +static auto registry_grad = torch::RegisterOperators( + "open3d::group_points_grad(Tensor grad_out, Tensor idx, int N)" + " -> Tensor out", + &group_points_grad); +#endif diff --git a/cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.cu b/cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.cu new file mode 100644 index 00000000000..38ccd8290a6 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.cu @@ -0,0 +1,246 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include +#include + +#include + +#include "ATen/cuda/CUDAContext.h" +#include "open3d/ml/pytorch/pointnet/InterpolateKernel.h" +#include "open3d/ml/pytorch/pointnet/cuda_utils.h" + +__global__ void three_nn_kernel(int b, + int n, + int m, + const float *__restrict__ unknown, + const float *__restrict__ known, + float *__restrict__ dist2, + int *__restrict__ idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= n) return; + + unknown += bs_idx * n * 3 + pt_idx * 3; + known += bs_idx * m * 3; + dist2 += bs_idx * n * 3 + pt_idx * 3; + idx += bs_idx * n * 3 + pt_idx * 3; + + float ux = unknown[0]; + float uy = unknown[1]; + float uz = unknown[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = + (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; + besti3 = besti2; + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + } else if (d < best2) { + best3 = best2; + besti3 = besti2; + best2 = d; + besti2 = k; + } else if (d < best3) { + best3 = d; + besti3 = k; + } + } + dist2[0] = best1; + dist2[1] = best2; + dist2[2] = best3; + idx[0] = besti1; + idx[1] = besti2; + idx[2] = besti3; +} + +void three_nn_launcher(int b, + int n, + int m, + const float *unknown, + const float *known, + float *dist2, + int *idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + three_nn_kernel<<>>(b, n, m, unknown, known, + dist2, idx); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void three_interpolate_kernel(int b, + int c, + int m, + int n, + const float *__restrict__ points, + const int *__restrict__ idx, + const float *__restrict__ weight, + float *__restrict__ out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + weight += bs_idx * n * 3 + pt_idx * 3; + points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + out += bs_idx * c * n + c_idx * n; + + out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + + weight[2] * points[idx[2]]; +} + +void three_interpolate_launcher(int b, + int c, + int m, + int n, + const float *points, + const int *idx, + const float *weight, + float *out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_kernel<<>>(b, c, m, n, points, + idx, weight, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void three_interpolate_grad_kernel( + int b, + int c, + int n, + int m, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + const float *__restrict__ weight, + float *__restrict__ grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + grad_out += bs_idx * c * n + c_idx * n + pt_idx; + weight += bs_idx * n * 3 + pt_idx * 3; + grad_points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + + atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); + atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); + atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); +} + +void three_interpolate_grad_launcher(int b, + int c, + int n, + int m, + const float *grad_out, + const int *idx, + const float *weight, + float *grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_grad_kernel<<>>( + b, c, n, m, grad_out, idx, weight, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.h b/cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.h new file mode 100644 index 00000000000..bd52fa5f26d --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/InterpolateKernel.h @@ -0,0 +1,54 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void three_nn_launcher(int b, + int n, + int m, + const float *unknown, + const float *known, + float *dist2, + int *idx); + +void three_interpolate_launcher(int b, + int c, + int m, + int n, + const float *points, + const int *idx, + const float *weight, + float *out); + +void three_interpolate_grad_launcher(int b, + int c, + int n, + int m, + const float *grad_out, + const int *idx, + const float *weight, + float *grad_points); diff --git a/cpp/open3d/ml/pytorch/pointnet/InterpolateOps.cpp b/cpp/open3d/ml/pytorch/pointnet/InterpolateOps.cpp new file mode 100644 index 00000000000..9a193663957 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/InterpolateOps.cpp @@ -0,0 +1,131 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "open3d/ml/pytorch/TorchHelper.h" +#include "open3d/ml/pytorch/pointnet/InterpolateKernel.h" +#include "torch/script.h" + +#ifdef BUILD_CUDA_MODULE +std::tuple three_nn(torch::Tensor query_pts, + torch::Tensor data_pts) { + int batch_size = query_pts.size(0); + int pts_num_out = query_pts.size(1); + int pts_num_in = data_pts.size(1); + + auto device = data_pts.device(); + torch::Tensor out_idx = + torch::zeros({batch_size, pts_num_out, 3}, + torch::dtype(ToTorchDtype()).device(device)); + + torch::Tensor out_dist2 = + torch::zeros({batch_size, pts_num_out, 3}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *pts_out = query_pts.data(); + const float *pts_in = data_pts.data(); + float *dist2 = out_dist2.data(); + int *idx = out_idx.data(); + + three_nn_launcher(batch_size, pts_num_out, pts_num_in, pts_out, pts_in, + dist2, idx); + + return std::tuple(out_dist2, out_idx); +} + +torch::Tensor three_interpolate(torch::Tensor points, + torch::Tensor idx, + torch::Tensor weights) { + int batch_size = points.size(0); + int C = points.size(1); + int M = points.size(2); + int N = idx.size(1); + + auto device = points.device(); + torch::Tensor out = + torch::zeros({batch_size, C, N}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *points_data = points.data(); + const float *weights_data = weights.data(); + const int *idx_data = idx.data(); + float *out_data = out.data(); + + three_interpolate_launcher(batch_size, C, M, N, points_data, idx_data, + weights_data, out_data); + + return out; +} + +torch::Tensor three_interpolate_grad(torch::Tensor grad_out, + torch::Tensor idx, + torch::Tensor weights, + const int64_t M) { + int batch_size = grad_out.size(0); + int C = grad_out.size(1); + int N = grad_out.size(2); + + auto device = grad_out.device(); + torch::Tensor out = + torch::zeros({batch_size, C, M}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *grad_out_data = grad_out.data(); + const float *weights_data = weights.data(); + const int *idx_data = idx.data(); + + float *out_data = out.data(); + + three_interpolate_grad_launcher(batch_size, C, N, M, grad_out_data, + idx_data, weights_data, out_data); + + return out; +} + +static auto registry_nn = torch::RegisterOperators( + "open3d::three_nn(Tensor query_pts, Tensor data_pts)" + " -> (Tensor dist, Tensor idx)", + &three_nn); + +static auto registry = torch::RegisterOperators( + "open3d::three_interpolate(Tensor points," + "Tensor idx, Tensor weights)" + " -> Tensor out", + &three_interpolate); + +static auto registry_grad = torch::RegisterOperators( + "open3d::three_interpolate_grad(Tensor points," + "Tensor idx, Tensor weights, int N)" + " -> Tensor out", + &three_interpolate_grad); +#endif diff --git a/cpp/open3d/ml/pytorch/pointnet/SamplingKernel.cu b/cpp/open3d/ml/pytorch/pointnet/SamplingKernel.cu new file mode 100644 index 00000000000..7fb2fccada8 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/SamplingKernel.cu @@ -0,0 +1,342 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include +#include + +#include + +#include "ATen/cuda/CUDAContext.h" +#include "open3d/ml/pytorch/pointnet/SamplingKernel.h" +#include "open3d/ml/pytorch/pointnet/cuda_utils.h" + +__global__ void gather_points_kernel(int b, + int c, + int n, + int m, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, M) + // output: + // out: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + points += bs_idx * c * n + c_idx * n; + out[0] = points[idx[0]]; +} + +void gather_points_launcher(int b, + int c, + int n, + int npoints, + const float *points, + const int *idx, + float *out) { + // points: (B, C, N) + // idx: (B, npoints) + // output: + // out: (B, C, npoints) + + auto stream = at::cuda::getCurrentCUDAStream(); + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_kernel<<>>(b, c, n, npoints, + points, idx, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void gather_points_grad_kernel(int b, + int c, + int n, + int m, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + // grad_out: (B, C, M) + // idx: (B, M) + // output: + // grad_points: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + grad_out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + grad_points += bs_idx * c * n + c_idx * n; + + atomicAdd(grad_points + idx[0], grad_out[0]); +} + +void gather_points_grad_launcher(int b, + int c, + int n, + int npoints, + const float *grad_out, + const int *idx, + float *grad_points) { + // grad_out: (B, C, npoints) + // idx: (B, npoints) + // output: + // grad_points: (B, C, N) + + auto stream = at::cuda::getCurrentCUDAStream(); + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, + b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_grad_kernel<<>>( + b, c, n, npoints, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__device__ void __update(float *__restrict__ dists, + int *__restrict__ dists_i, + int idx1, + int idx2) { + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +template +__global__ void furthest_point_sampling_kernel( + int b, + int n, + int m, + const float *__restrict__ dataset, + float *__restrict__ temp, + int *__restrict__ idxs) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + // if (mag <= 1e-3) + // continue; + + float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + + (z2 - z1) * (z2 - z1); + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 1024) { + if (tid < 512) { + __update(dists, dists_i, tid, tid + 512); + } + __syncthreads(); + } + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) idxs[j] = old; + } +} + +void furthest_point_sampling_launcher( + int b, int n, int m, const float *dataset, float *temp, int *idxs) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + cudaError_t err; + + auto stream = at::cuda::getCurrentCUDAStream(); + + unsigned int n_threads = opt_n_threads(n); + + switch (n_threads) { + case 1024: + furthest_point_sampling_kernel<1024> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 512: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 256: + furthest_point_sampling_kernel<256> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 128: + furthest_point_sampling_kernel<128> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 64: + furthest_point_sampling_kernel<64> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 32: + furthest_point_sampling_kernel<32> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 16: + furthest_point_sampling_kernel<16> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 8: + furthest_point_sampling_kernel<8> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 4: + furthest_point_sampling_kernel<4> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 2: + furthest_point_sampling_kernel<2> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 1: + furthest_point_sampling_kernel<1> + <<>>(b, n, m, dataset, temp, idxs); + break; + default: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + } + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/cpp/open3d/ml/pytorch/pointnet/SamplingKernel.h b/cpp/open3d/ml/pytorch/pointnet/SamplingKernel.h new file mode 100644 index 00000000000..81ef723bf55 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/SamplingKernel.h @@ -0,0 +1,47 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +void gather_points_launcher(int b, + int c, + int n, + int npoints, + const float *points, + const int *idx, + float *out); + +void gather_points_grad_launcher(int b, + int c, + int n, + int npoints, + const float *grad_out, + const int *idx, + float *grad_points); + +void furthest_point_sampling_launcher( + int b, int n, int m, const float *dataset, float *temp, int *idxs); diff --git a/cpp/open3d/ml/pytorch/pointnet/SamplingOps.cpp b/cpp/open3d/ml/pytorch/pointnet/SamplingOps.cpp new file mode 100644 index 00000000000..67fb2dededf --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/SamplingOps.cpp @@ -0,0 +1,113 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#include + +#include "open3d/ml/pytorch/TorchHelper.h" +#include "open3d/ml/pytorch/pointnet/SamplingKernel.h" +#include "torch/script.h" + +#ifdef BUILD_CUDA_MODULE +torch::Tensor gather_points(torch::Tensor points, torch::Tensor idx) { + int batch_size = idx.size(0); + int idx_size = idx.size(1); + int group_size = points.size(1); + int feature_size = points.size(2); + + auto device = points.device(); + torch::Tensor out = + torch::zeros({batch_size, group_size, idx_size}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *points_data = points.data(); + const int *idx_data = idx.data(); + float *out_data = out.data(); + + gather_points_launcher(batch_size, group_size, feature_size, idx_size, + points_data, idx_data, out_data); + return out; +} + +torch::Tensor gather_points_grad(torch::Tensor grad_out, + torch::Tensor idx, + const int64_t C, + const int64_t N) { + int batch_size = idx.size(0); + int idx_size = idx.size(1); + + auto device = grad_out.device(); + torch::Tensor out = + torch::zeros({batch_size, C, N}, + torch::dtype(ToTorchDtype()).device(device)); + + const float *grad_out_data = grad_out.data(); + const int *idx_data = idx.data(); + float *out_data = out.data(); + + gather_points_grad_launcher(batch_size, C, N, idx_size, grad_out_data, + idx_data, out_data); + return out; +} + +torch::Tensor furthest_point_sampling(torch::Tensor points, + const int64_t sample_size) { + int batch_size = points.size(0); + int pts_size = points.size(1); + + auto device = points.device(); + torch::Tensor out = + torch::zeros({batch_size, sample_size}, + torch::dtype(ToTorchDtype()).device(device)); + torch::Tensor temp = + torch::full({batch_size, pts_size}, 1e10, + torch::dtype(ToTorchDtype()).device(device)); + + const float *points_data = points.data(); + float *temp_data = temp.data(); + int *out_data = out.data(); + + furthest_point_sampling_launcher(batch_size, pts_size, sample_size, + points_data, temp_data, out_data); + + return out; +} + +static auto registry_fp = torch::RegisterOperators( + "open3d::furthest_point_sampling(Tensor points, int sample_siz)" + " -> Tensor out", + &furthest_point_sampling); + +static auto registry = torch::RegisterOperators( + "open3d::gather_points(Tensor points, Tensor idx)" + " -> Tensor out", + &gather_points); + +static auto registry_grad = torch::RegisterOperators( + "open3d::gather_points_grad(Tensor grad_out, Tensor idx, int C, int N)" + " -> Tensor out", + &gather_points_grad); +#endif diff --git a/cpp/open3d/ml/pytorch/pointnet/cuda_utils.h b/cpp/open3d/ml/pytorch/pointnet/cuda_utils.h new file mode 100644 index 00000000000..8a3256d69b9 --- /dev/null +++ b/cpp/open3d/ml/pytorch/pointnet/cuda_utils.h @@ -0,0 +1,40 @@ +//***************************************************************************************/ +// +// Based on Pointnet2 Library (MIT License): +// https://github.com/sshaoshuai/Pointnet2.PyTorch +// +// Copyright (c) 2019 Shaoshuai Shi +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. +// +//***************************************************************************************/ + +#pragma once + +#include + +#define TOTAL_THREADS 1024 +#define THREADS_PER_BLOCK 256 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} diff --git a/python/test/ml_ops/mltest.py b/python/test/ml_ops/mltest.py index 14a486c097d..1e9cd2f8e2a 100644 --- a/python/test/ml_ops/mltest.py +++ b/python/test/ml_ops/mltest.py @@ -4,6 +4,8 @@ from collections import namedtuple import importlib from types import SimpleNamespace +import urllib.request +import io # skip all tests if the ml ops were not built default_marks = [ @@ -192,3 +194,16 @@ def run_op_grad(ml, device_name, check_device, fn, x, y_attr_name, v for k, v in _ml_modules.items() if v.module.__name__ == 'tensorflow' ]), ) + + +def fetch_numpy(url): + # prevents security issue + if url.lower().startswith('http'): + req = urllib.request.Request(url) + else: + raise ValueError from None + + with urllib.request.urlopen(req) as response: #nosec + np_file = response.read() + return np.load(io.BytesIO(np_file)) + return None diff --git a/python/test/ml_ops/test_gathering.py b/python/test/ml_ops/test_gathering.py new file mode 100644 index 00000000000..b57389243da --- /dev/null +++ b/python/test/ml_ops/test_gathering.py @@ -0,0 +1,57 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_gathering(ml): + + values0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/gathering/values0.npy' + ) + values1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/gathering/values1.npy' + ) + + ans = mltest.run_op(ml, ml.device, True, ml.ops.gather_points, values0, + values1) + + expected = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/gathering/out.npy' + ) + np.testing.assert_equal(ans, expected) diff --git a/python/test/ml_ops/test_group_pts.py b/python/test/ml_ops/test_group_pts.py new file mode 100644 index 00000000000..586fe3dec88 --- /dev/null +++ b/python/test/ml_ops/test_group_pts.py @@ -0,0 +1,57 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_group_pts(ml): + + values0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/group_pts/values0.npy' + ) + values1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/group_pts/values1.npy' + ) + + ans = mltest.run_op(ml, ml.device, True, ml.ops.group_points, values0, + values1) + + expected = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/group_pts/out.npy' + ) + np.testing.assert_equal(ans, expected) diff --git a/python/test/ml_ops/test_query_pts.py b/python/test/ml_ops/test_query_pts.py new file mode 100644 index 00000000000..5d26bf106ae --- /dev/null +++ b/python/test/ml_ops/test_query_pts.py @@ -0,0 +1,60 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_query_pts(ml): + + values0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/query_pts/values0.npy' + ) + values1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/query_pts/values1.npy' + ) + + sample = 16 + radius = 0.1 + + ans = mltest.run_op(ml, ml.device, True, ml.ops.ball_query, values0, + values1, radius, sample) + + expected = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/query_pts/out.npy' + ) + np.testing.assert_equal(ans, expected) diff --git a/python/test/ml_ops/test_roi_pool.py b/python/test/ml_ops/test_roi_pool.py new file mode 100644 index 00000000000..620e57f5269 --- /dev/null +++ b/python/test/ml_ops/test_roi_pool.py @@ -0,0 +1,65 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_roi_pool(ml): + + values0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/roi_pool/values0.npy' + ) + values1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/roi_pool/values1.npy' + ) + values2 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/roi_pool/values2.npy' + ) + sampled_pts_num = 512 + + ans0, ans1 = mltest.run_op(ml, ml.device, True, ml.ops.roipool3d, values0, + values1, values2, sampled_pts_num) + + expected0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/roi_pool/out0.npy' + ) + expected1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/roi_pool/out1.npy' + ) + np.testing.assert_equal(ans0, expected0) + np.testing.assert_equal(ans1, expected1) diff --git a/python/test/ml_ops/test_sampling.py b/python/test/ml_ops/test_sampling.py new file mode 100644 index 00000000000..12b35dd0372 --- /dev/null +++ b/python/test/ml_ops/test_sampling.py @@ -0,0 +1,55 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_furthest_point_sampling(ml): + + values = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/sampling/values.npy' + ) + samples = 4096 + + ans = mltest.run_op(ml, ml.device, True, ml.ops.furthest_point_sampling, + values, samples) + + expected = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/sampling/out.npy' + ) + np.testing.assert_equal(ans, expected) diff --git a/python/test/ml_ops/test_three_interp.py b/python/test/ml_ops/test_three_interp.py new file mode 100644 index 00000000000..69d39f99cda --- /dev/null +++ b/python/test/ml_ops/test_three_interp.py @@ -0,0 +1,60 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_three_interp(ml): + + values0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_interp/values0.npy' + ) + values1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_interp/values1.npy' + ) + values2 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_interp/values2.npy' + ) + + ans = mltest.run_op(ml, ml.device, True, ml.ops.three_interpolate, values0, + values1, values2) + + expected = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_interp/out.npy' + ) + np.testing.assert_equal(ans, expected) diff --git a/python/test/ml_ops/test_three_nn.py b/python/test/ml_ops/test_three_nn.py new file mode 100644 index 00000000000..4913ec843ed --- /dev/null +++ b/python/test/ml_ops/test_three_nn.py @@ -0,0 +1,61 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2020 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import pytest +import mltest + +# Skip all tests if the ml ops were not built. +pytestmark = mltest.default_marks + +ml_torch_gpu_only = pytest.mark.parametrize('ml', [ + v for k, v in mltest._ml_modules.items() + if mltest.is_gpu_device_name(v.device) and v.module.__name__ == 'torch' +]) + + +@ml_torch_gpu_only +def test_three_nn(ml): + + values0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_nn/values0.npy' + ) + values1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_nn/values1.npy' + ) + + ans0, ans1 = mltest.run_op(ml, ml.device, True, ml.ops.three_nn, values0, + values1) + + expected0 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_nn/out0.npy' + ) + expected1 = mltest.fetch_numpy( + 'https://storage.googleapis.com/isl-datasets/open3d-dev/test/ml_ops/data/three_nn/out1.npy' + ) + np.testing.assert_equal(ans0, expected0) + np.testing.assert_equal(ans1, expected1)