forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#8 from FDInSky/farthest_point_samplin…
…g_op add farthest point sampling op
- Loading branch information
Showing
7 changed files
with
316 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
class FarthestPointSamplingOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(Tensor)input point cloud dataset with shape (B, N, 3)" | ||
"B is batch size, N is points's nums, 3 is (x,y,z) coordinate"); | ||
AddOutput("Output", | ||
"(Tensor)return sampled points with shape (B, M)" | ||
"B is batch size, M is points's nums"); | ||
AddAttr<int>("sampled_point_num", "sampling points's num") | ||
.SetDefault(0) | ||
.EqualGreaterThan(0); | ||
AddComment( | ||
R"Doc( | ||
Sampling point based on | ||
its max eucliden distance with other points.)Doc"); | ||
} | ||
}; | ||
|
||
class FarthestPointSamplingOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shoud not be null"); | ||
auto x_dims = ctx->GetInputDim("X"); | ||
PADDLE_ENFORCE(x_dims.size() == 3, | ||
"Input(X) of FathestPointSamplingOp should be 3-D Tensor"); | ||
const int m = ctx->Attrs().Get<int>("sampled_point_num"); | ||
ctx->SetOutputDim("Output", {x_dims[0], m}); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
auto input_data_type = ctx.Input<Tensor>("X")->type(); | ||
return framework::OpKernelType(input_data_type, ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(farthest_point_sampling, ops::FarthestPointSamplingOp, | ||
ops::FarthestPointSamplingOpMaker); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/framework/eigen.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T, unsigned int block_size> | ||
__global__ void farthestpointsamplingKernel(int b, int n, int m, | ||
const T *__restrict__ dataset, | ||
T *__restrict__ temp, | ||
int *__restrict__ idxs) { | ||
// 1. add first point | ||
// 2. add the point having farthest distance with first point's | ||
// 3. make second point as first point, repeat 1,2 | ||
if (m <= 0) return; | ||
const int BlockSize = block_size; | ||
__shared__ float dists[BlockSize]; | ||
__shared__ int dists_i[BlockSize]; | ||
const int BufferSize = 3072; | ||
__shared__ float buf[BufferSize * 3]; | ||
|
||
// one block one batch, n points | ||
// one thread one point | ||
for (int i = blockIdx.x; i < b; i += gridDim.x) { | ||
// can select old point as first point randomly | ||
int old = 0; | ||
if (threadIdx.x == 0) idxs[i * m + 0] = old; | ||
|
||
for (int j = threadIdx.x; j < n; j += blockDim.x) { | ||
temp[blockIdx.x * n + j] = 1e38; | ||
} | ||
for (int j = threadIdx.x; j < min(BufferSize, n) * 3; j += blockDim.x) { | ||
buf[j] = dataset[i * n * 3 + j]; | ||
} | ||
// wait all threads do this in the same block | ||
__syncthreads(); | ||
|
||
// out m points | ||
for (int j = 1; j < m; j++) { | ||
// Step 1. | ||
// fatherest distance | ||
int besti = 0; | ||
float best = -1; | ||
// first point in m points | ||
float x1 = dataset[i * n * 3 + old * 3 + 0]; | ||
float y1 = dataset[i * n * 3 + old * 3 + 1]; | ||
float z1 = dataset[i * n * 3 + old * 3 + 2]; | ||
|
||
// Step 2. | ||
// find farthest point of (x1, y1, z1) | ||
for (int k = threadIdx.x; k < n; k += blockDim.x) { | ||
float td = temp[blockIdx.x * n + k]; | ||
float x2, y2, z2; | ||
if (k < BufferSize) { | ||
x2 = buf[k * 3 + 0]; | ||
y2 = buf[k * 3 + 1]; | ||
z2 = buf[k * 3 + 2]; | ||
} else { | ||
x2 = dataset[i * n * 3 + k * 3 + 0]; | ||
y2 = dataset[i * n * 3 + k * 3 + 1]; | ||
z2 = dataset[i * n * 3 + k * 3 + 2]; | ||
} | ||
// compute eucliden distance | ||
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + | ||
(z2 - z1) * (z2 - z1); | ||
float d2 = min(d, td); | ||
if (d2 != td) temp[blockIdx.x * n + k] = d2; | ||
if (d2 > best) { | ||
best = d2; | ||
besti = k; | ||
} | ||
} | ||
|
||
// step 3. | ||
dists[threadIdx.x] = best; | ||
dists_i[threadIdx.x] = besti; | ||
for (int u = 0; (1 << u) < blockDim.x; u++) { | ||
__syncthreads(); | ||
if (threadIdx.x < (blockDim.x >> (u + 1))) { | ||
int i1 = (threadIdx.x * 2) << u; | ||
int i2 = (threadIdx.x * 2 + 1) << u; | ||
if (dists[i1] < dists[i2]) { | ||
dists[i1] = dists[i2]; | ||
dists_i[i1] = dists_i[i2]; | ||
} | ||
} | ||
} | ||
__syncthreads(); | ||
// store the found node index | ||
old = dists_i[0]; | ||
if (threadIdx.x == 0) idxs[i * m + j] = old; | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
class FarthestPointSamplingOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext &ctx) const override { | ||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), | ||
"This kernel only runs on GPU device."); | ||
auto *input = ctx.Input<Tensor>("X"); | ||
auto *output = ctx.Output<Tensor>("Output"); | ||
if (input->numel() == 0) return; | ||
// allocate memory | ||
auto *ptr_out_points_index = output->mutable_data<int>(ctx.GetPlace()); | ||
|
||
// b, n, m | ||
int batch_size = input->dims()[0]; | ||
int in_n_points = input->dims()[1]; | ||
int out_m_points = ctx.Attr<int>("sampled_point_num"); | ||
|
||
const T *ptr_in_points = input->data<T>(); | ||
|
||
Tensor tmp; | ||
auto *ptr_tmp_e = | ||
tmp.mutable_data<T>({batch_size, in_n_points}, ctx.GetPlace()); | ||
|
||
// run fathest point sampling kernel | ||
// P40 have max 512 thread | ||
farthestpointsamplingKernel<T, 512><<<32, 512>>>( | ||
batch_size, in_n_points, out_m_points, ptr_in_points, ptr_tmp_e, | ||
ptr_out_points_index); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL(farthest_point_sampling, | ||
ops::FarthestPointSamplingOpCUDAKernel<float>, | ||
ops::FarthestPointSamplingOpCUDAKernel<double>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
python/paddle/fluid/tests/unittests/test_farthest_point_sampling_op.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
import paddle.fluid.core as core | ||
|
||
|
||
def farthest_point_sampling(xyz, npoint): | ||
B, N, C = xyz.shape | ||
S = npoint | ||
|
||
centroids = np.zeros((B, S)) | ||
distance = np.ones((B, N)) * 1e10 | ||
farthest = 0 | ||
batch_indices = np.arange(B).astype('int32') | ||
for i in range(S): | ||
centroids[:, i] = farthest | ||
centroid = xyz[batch_indices, farthest, :].reshape((B, 1, 3)) | ||
dist = np.sum((xyz - centroid)**2, -1) | ||
mask = dist < distance | ||
distance[mask] = dist[mask] | ||
farthest = np.argmax(distance, -1) | ||
return centroids.astype('int32') | ||
|
||
|
||
class TestFarthestPointSamplingOp(OpTest): | ||
def setUp(self): | ||
self.op_type = 'farthest_point_sampling' | ||
self.config() | ||
x = np.random.randint(1, 100, | ||
(self.x_shape[0] * self.x_shape[1] * | ||
3, )).reshape(self.x_shape).astype(self.x_type) | ||
m = self.sampled_point_num | ||
out_np = farthest_point_sampling(x, m) | ||
self.inputs = {'X': x, } | ||
self.attrs = {'sampled_point_num': m, } | ||
self.outputs = {'Output': out_np, } | ||
|
||
def config(self): | ||
self.x_shape = (1, 512, 3) | ||
self.x_type = 'float32' | ||
self.sampled_point_num = 256 | ||
|
||
def test_check_output(self): | ||
if core.is_compiled_with_cuda(): | ||
place = core.CUDAPlace(0) | ||
self.check_output_with_place(place, atol=1e-3) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |