Skip to content

Commit

Permalink
【2.0 API】Enhance affine grid operator (#26385)
Browse files Browse the repository at this point in the history
* Enhance affine grid operator:
1. Add cuda kernel
2. Add align corners options
test=develop

* Move new affine_grid api to functional
test=develop

* Add CUDA kernel for affine_grid.
test=develop

* Add more unitest for grid sample API
test=develop
  • Loading branch information
wanghaoshuang authored Aug 25, 2020
1 parent 6f69fbc commit a065a24
Show file tree
Hide file tree
Showing 6 changed files with 534 additions and 26 deletions.
17 changes: 14 additions & 3 deletions paddle/fluid/operators/affine_grid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ using Tensor = framework::Tensor;

template <typename T>
struct Linspace<paddle::platform::CPUDeviceContext, T> {
void operator()(T start, T end, int count, framework::Tensor* numbers,
void operator()(T start, T end, int count, bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx) {
T* number_data = numbers->mutable_data<T>({count}, platform::CPUPlace());
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
for (int i = 0; i < count; ++i) {
number_data[i] = start + (T)i * slice;
}
Expand Down Expand Up @@ -130,6 +135,10 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddAttr<bool>("align_corners",
"(bool, default false) Whether to align the corners of input"
"and ouput.")
.SetDefault(true);
AddAttr<std::vector<int>>(
"output_shape",
"The target output image shape with format [N, C, H, W].")
Expand Down Expand Up @@ -164,10 +173,12 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
C[0] is the coordinates in height axis and C[1] is the coordinates in
width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
Tanspose and reshape C to shape [H * W, 2] and append ones to last
dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
Expand Down
209 changes: 209 additions & 0 deletions paddle/fluid/operators/affine_grid_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/affine_grid_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}

template <typename T>
struct Linspace<paddle::platform::CUDADeviceContext, T> {
void operator()(T start, T end, int count, bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx) {
T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace());
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
auto stream = ctx.cuda_device_context().stream();
int block = 512;
int grid = (count + block - 1) / block;
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, slice, count,
number_data);
}
};

template <typename T>
__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w,
T h_start, T w_start, T h_step, T w_step,
const T* theta, // N, 2, 3
T* output) {
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);

T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);

int theta_offset = n * 6; // 2 * 3;
// affine from (h_coor, w_coor) to (x, y)
output[index * 2] = theta[theta_offset] * h_coor +
theta[theta_offset + 1] * w_coor +
theta[theta_offset + 2];
output[index * 2 + 1] = theta[theta_offset + 3] * h_coor +
theta[theta_offset + 4] * w_coor +
theta[theta_offset + 5];
}
}

template <typename T>
__global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
int out_w, T h_start, T w_start,
T h_step, T w_step,
const T* out_grad, // N, H, W, 2
T* theta_grad) { // N, 2, 3
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);

int theta_offset = n * 6; // 2 * 3;
T out_grad_x = out_grad[index * 2];
atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor);
atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor);
atomicAdd(theta_grad + theta_offset + 2, out_grad_x);

T out_grad_y = out_grad[index * 2 + 1];
atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor);
atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor);
atomicAdd(theta_grad + theta_offset + 5, out_grad_y);
}
}

template <typename T>
class AffineGridOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
auto* output = ctx.Output<Tensor>("Output");
T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());

T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);

h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}

const int count = n * h * w;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = ctx.cuda_device_context().stream();
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
count, n, h, w, h_start, w_start, h_step, w_step,
theta->data<T>(), // N, 2, 3
out_data);
}
};

template <typename T>
class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.cuda_device_context(), theta_grad, static_cast<T>(0));

T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);

h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
VLOG(3) << "count: " << count << "; h_step: " << h_step
<< "; w_step: " << w_step << "; h_start: " << h_start
<< "; w_start: " << w_start;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = ctx.cuda_device_context().stream();
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(
count, n, h, w, h_start, w_start, h_step, w_step,
output_grad->data<T>(), theta_grad_data);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel<float>,
ops::AffineGridOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(affine_grid_grad,
ops::AffineGridGradOpCUDAKernel<float>,
ops::AffineGridGradOpCUDAKernel<double>);
22 changes: 13 additions & 9 deletions paddle/fluid/operators/affine_grid_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,33 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
*/
template <typename DeviceContext, typename T>
struct Linspace {
void operator()(T start, T end, int count, framework::Tensor* numbers,
void operator()(T start, T end, int count, bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx);
};

template <typename DeviceContext, typename T>
inline void GetIdxMap(int n, int h, int w, Tensor* grid,
inline void GetIdxMap(int n, int h, int w, bool align_corners, Tensor* grid,
const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
grid->mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
auto grid_t = EigenTensor<T, 4>::From(*grid);
// Get indexes of height with shape [height, width, 1]
Tensor h_idx;
Linspace<DeviceContext, T> linspace;
linspace((T)-1, (T)1, h, &h_idx, ctx);
linspace((T)-1, (T)1, h, align_corners, &h_idx, ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [height, width, 1]
Tensor w_idx;
linspace((T)-1, (T)1, w, &w_idx, ctx);
linspace((T)-1, (T)1, w, align_corners, &w_idx, ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [height, width, 1]
Tensor ones;
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);

math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &ones, static_cast<T>(1));
auto ones_t = EigenTensor<T, 3>::From(ones);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
Tensor w_idx_map;
Expand All @@ -74,11 +78,9 @@ inline void GetIdxMap(int n, int h, int w, Tensor* grid,
Tensor w_h_one_idx_map;
w_h_one_idx_map.mutable_data<T>({h, w, 3}, ctx.GetPlace());
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);

w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
.broadcast(Array2(h, 1))
.reshape(Array3(h, w, 1));

h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
.broadcast(Array2(w, 1))
.shuffle(Array2(1, 0))
Expand All @@ -97,6 +99,7 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
Expand All @@ -116,7 +119,7 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
Tensor grid;
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
Expand All @@ -140,6 +143,7 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
Expand All @@ -158,7 +162,7 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
ctx.template device_context<DeviceContext>(), theta_grad,
static_cast<T>(0));
Tensor grid;
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
Expand Down
Loading

0 comments on commit a065a24

Please sign in to comment.