Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse Conv3d gpu backward #40143

Merged
merged 30 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6d4f2fa
fix incorrect dims settings
Feb 17, 2022
ec6eed3
sparse conv3d
Feb 17, 2022
dc8d707
fix out dims
Feb 18, 2022
fa365cb
test performance
Feb 18, 2022
bb1c375
test large shape success
Feb 18, 2022
99c3c41
opt scatter, double performance
Feb 21, 2022
621fae1
test float16
Feb 21, 2022
2832f05
remove profiling code
Feb 21, 2022
c413e96
merge upstream develop
Feb 21, 2022
271eea6
remove pten
Feb 21, 2022
904d664
opt code lines
Feb 22, 2022
2eea16b
correct boundary judgment
Feb 24, 2022
a0c8714
merge upstream
Feb 28, 2022
4798f56
fix:used wrong place
Mar 1, 2022
9d72521
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 2, 2022
838cd41
gpu backward
Mar 3, 2022
f99112d
gpu backward
Mar 3, 2022
8cd00a1
adaptive rocm
Mar 4, 2022
5bf3c4b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 4, 2022
5abe044
fix: d_kernel needs to be initialized to 0; SetContant uses the corre…
Mar 4, 2022
2448bee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 7, 2022
54c6115
Merge branch 'develop' into conv3d_gpu_backward
Mar 7, 2022
4c8a65e
merge upstream
Mar 7, 2022
0f1ccd0
remove invalid empty method
Mar 7, 2022
cc61910
Merge branch 'conv3d_gpu_backward' of https://github.com/zkh2016/Padd…
Mar 7, 2022
f5aaa10
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 7, 2022
441c32e
remove unused function
Mar 7, 2022
ee7396b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 8, 2022
1ebe388
for rocm
Mar 9, 2022
2655d1d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Mar 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/phi/kernels/sparse/convolution_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ std::vector<DenseTensor> Conv3dGrad(const Context& dev_ctx,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups) {
DenseTensor x_grad = phi::Empty<T, Context>(dev_ctx);
DenseTensor kernel_grad = phi::Empty<T, Context>(dev_ctx);
DenseTensor x_grad =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
DenseTensor kernel_grad = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout()));
// TODO(zhangkaihuo): call InferMeta func here
Conv3dGradKernel<T, Context>(dev_ctx,
x,
Expand Down
18 changes: 4 additions & 14 deletions paddle/phi/kernels/sparse/convolution_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ limitations under the License. */
#include "paddle/phi/kernels/empty_kernel.h"

namespace phi {

template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) {
phi::DenseTensor dense_out(
phi::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
{paddle::experimental::CppTypeToDataType<T>::Type(),
{-1},
DataLayout::NCHW});
return dense_out;
}

namespace sparse {

struct Dims4D {
Expand Down Expand Up @@ -149,8 +137,10 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
DenseTensor* rulebook) {
DenseTensor indices = phi::Empty<T, Context>(dev_ctx);
DenseTensor values = phi::Empty<T, Context>(dev_ctx);
DenseTensor indices = phi::Empty<Context>(
dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW));
DenseTensor values =
phi::Empty<Context>(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout()));
SparseCooTensor coo(indices, values, x.dims());
Conv3dKernel<T, Context>(
dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook);
Expand Down
5 changes: 0 additions & 5 deletions paddle/phi/kernels/sparse/cpu/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ void ProductRuleBook(const Context& dev_ctx,
const int64_t non_zero_num = x.nnz();
const auto& non_zero_indices = x.non_zero_indices();
const int* indices_ptr = non_zero_indices.data<int>();
dev_ctx.Alloc(counter_per_kernel,
counter_per_kernel->dtype(),
sizeof(int) * counter_per_kernel->numel());
int* counter_ptr = counter_per_kernel->data<int>();
int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2];
memset(counter_ptr, 0, kernel_size * sizeof(int));
Expand Down Expand Up @@ -138,8 +135,6 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
x.dtype(), {out_non_zero_num, out_channels}, x.layout());
phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta));
dev_ctx.Alloc(
&out_indices, out_indices.dtype(), out_indices.numel() * sizeof(int));
int* out_indices_ptr = out_indices.data<int>();
int i = 0;
for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) {
Expand Down
11 changes: 2 additions & 9 deletions paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"

namespace phi {
Expand Down Expand Up @@ -60,15 +61,8 @@ void Conv3dGradKernel(const Context& dev_ctx,
phi::DenseTensor out_grad_features =
phi::Empty(dev_ctx, std::move(out_grad_features_meta));

dev_ctx.Alloc(
&in_features, in_features.dtype(), sizeof(T) * in_features.numel());
T* in_features_ptr = in_features.data<T>();
dev_ctx.Alloc(
&d_x_features, d_x_features.dtype(), sizeof(T) * d_x_features.numel());
T* d_x_features_ptr = d_x_features.data<T>();
dev_ctx.Alloc(&out_grad_features,
out_grad_features.dtype(),
sizeof(T) * out_grad_features.numel());
T* out_grad_features_ptr = out_grad_features.data<T>();
kernel_grad->Resize(kernel_dims);
dev_ctx.Alloc(
Expand Down Expand Up @@ -156,12 +150,11 @@ void Conv3dGradKernel(const Context& dev_ctx,
} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(sparse_conv_grad,
PD_REGISTER_KERNEL(sparse_conv3d_grad,
CPU,
ALL_LAYOUT,
phi::sparse::Conv3dGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
5 changes: 0 additions & 5 deletions paddle/phi/kernels/sparse/cpu/convolution_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ void Conv3dKernel(const Context& dev_ctx,
phi::Empty(dev_ctx, std::move(in_features_meta));
phi::DenseTensor out_features =
phi::Empty(dev_ctx, std::move(out_features_meta));
dev_ctx.Alloc(&in_features, x.dtype(), sizeof(T) * in_features.numel());
dev_ctx.Alloc(&out_features, x.dtype(), sizeof(T) * out_features.numel());
T* in_features_ptr = in_features.data<T>();
T* out_features_ptr = out_features.data<T>();

Expand Down Expand Up @@ -128,9 +126,6 @@ void Conv3dKernel(const Context& dev_ctx,
}

// 4. scatter
dev_ctx.Alloc(out->mutable_non_zero_elements(),
out->mutable_non_zero_elements()->dtype(),
sizeof(T) * in_features.numel());
T* out_values_ptr = out->mutable_non_zero_elements()->data<T>();
memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels);
Scatter<T>(out_features_ptr,
Expand Down
139 changes: 139 additions & 0 deletions paddle/phi/kernels/sparse/gpu/convolution.cu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <thrust/execution_policy.h>
#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"

namespace phi {
namespace sparse {

// TODO(zhangkaihuo): After the GatherCUDAKernel is migrated to phi, replace
// this kernel with phi::GatherCUDAKernel;
// Vectorization can be used to improve read and write bandwidth
/**
* brief: gather data from params according to indices
* params: the inputs
* indices: the indices you want to gather
* output: the outputs
* index_size: the size of indices
* slice_size: slice size corresponding to each index, here is the channel size
**/
template <typename T, typename IndexT = int>
__global__ void GatherKernel(const T* params,
const IndexT* indices,
T* output,
size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
int64_t params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}

/**
* brief: scatter add
* input: the inputs
* unique_value: refer to UpdateIndexKernel notes
* out_index: the output feature index
* non_zero_num: the number of output features
* rulebook_len: the length of rulebook
* channels: the output channel size
* out: the outputs
**/
template <typename T>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;

int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
}
out[indices_i * channels + channels_i] = sum;
}
}

template <typename Context>
inline int* SortedAndUniqueIndex(const Context& dev_ctx,
const int* rulebook_ptr,
const int len,
DenseTensor* out_index,
DenseTensor* unique_key,
DenseTensor* unique_value) {
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, out_index, kps::IdentityFunctor<int>());
phi::IndexKernel<int, kps::IdentityFunctor<int>>(
dev_ctx, unique_value, kps::IdentityFunctor<int>());

phi::backends::gpu::GpuMemcpyAsync(unique_key->data<int>(),
rulebook_ptr,
sizeof(int) * len,
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToDevice,
#else
cudaMemcpyDeviceToDevice,
#endif
dev_ctx.stream());
// compared with thrust::sort_by_key, thrust::merge_by_key may achieved higher
// performance, but thrust::merge_by_key limited by data size
#ifdef PADDLE_WITH_HIP
thrust::sort_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
unique_key->data<int>(),
unique_key->data<int>() + len,
out_index->data<int>());

// 4. unique
thrust::pair<int*, int*> new_end =
#ifdef PADDLE_WITH_HIP
thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()),
#endif
unique_key->data<int>(),
unique_key->data<int>() + len,
unique_value->data<int>());
return new_end.first;
}

} // namespace sparse
} // namespace phi
Loading