Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize sparse convolution #43576

Merged
merged 97 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
bb4db9b
test sparse model
May 9, 2022
441da36
refactor code structure
EsdeathYZH May 29, 2022
c48e076
add native kernel usage
EsdeathYZH May 29, 2022
0a68ba3
add wellford impl
EsdeathYZH Jun 3, 2022
b3248c9
add shmem impl
EsdeathYZH Jun 4, 2022
78349a2
add dispatch logic
EsdeathYZH Jun 4, 2022
98c66f0
add channel_last impl
EsdeathYZH Jun 4, 2022
570dc55
refine the global space init
EsdeathYZH Jun 6, 2022
aaca04a
impl 2d kernel
EsdeathYZH Jun 7, 2022
29ef723
Merge remote-tracking branch 'paddle/develop' into optim_batchnorm1d
EsdeathYZH Jun 11, 2022
74b792b
rm wellford
EsdeathYZH Jun 11, 2022
a0bd5b6
fix backward
EsdeathYZH Jun 11, 2022
2433ebf
add unit test for batchnorm1d
EsdeathYZH Jun 11, 2022
90c27a6
fix bug
EsdeathYZH Jun 11, 2022
91d83e5
impl channel last 2d
EsdeathYZH Jun 11, 2022
6871dbf
refine
EsdeathYZH Jun 15, 2022
0571ecc
fix memory thpt
EsdeathYZH Jun 15, 2022
97493af
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jun 16, 2022
3fc54ad
opt gather
Jun 16, 2022
804ba03
fix threshold
EsdeathYZH Jun 16, 2022
48c6344
fix backward threshold
EsdeathYZH Jun 16, 2022
6785f6f
refine unit test
EsdeathYZH Jun 16, 2022
e46ef54
refine test
EsdeathYZH Jun 16, 2022
938cde3
delete pragma unroll
EsdeathYZH Jun 17, 2022
a24f2aa
opt gather and scatter
Jun 20, 2022
64be38b
opt conv
Jun 20, 2022
c0fce45
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jun 21, 2022
26ca7ab
fix batch csr
Jun 21, 2022
11011c0
remove the unused file
Jun 21, 2022
de91d40
opt SparseMaskCopyKernel
Jun 21, 2022
a137947
merge origin
Jun 21, 2022
3625aae
Merge branch 'opt_conv' of https://github.com/zkh2016/Paddle into opt…
Jun 21, 2022
6a92b32
opt subm
Jun 22, 2022
e96f090
opt subm
Jun 22, 2022
7e272fb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jun 22, 2022
dd8e9ca
merge upstream
Jun 22, 2022
c7eddc5
opt copy rulebook
Jun 22, 2022
dd5e4fd
check cache size
Jun 23, 2022
52367a3
check cache size
Jun 23, 2022
345ebb2
correct alloc out values
Jun 23, 2022
d7bb341
merge origin
Jun 23, 2022
a79206f
Merge remote-tracking branch 'paddle/develop' into optim_batchnorm1d
EsdeathYZH Jun 24, 2022
459fd81
save the rulebook of submanifold conv
Jun 24, 2022
596bfbd
fix backward
Jun 24, 2022
c906bdb
opt conv
Jun 25, 2022
823b5c6
Merge branch 'opt_conv' of https://github.com/zkh2016/Paddle into opt…
Jun 25, 2022
6dc1584
opt conv3d
Jun 26, 2022
8202771
opt scatter
Jun 27, 2022
75df1e2
opt SparseMaskCopy
Jun 27, 2022
2745b0e
coalesced is not performed by default
Jun 27, 2022
ad9c2b6
opt rulebook
Jun 27, 2022
214475b
remove a sync
Jun 27, 2022
13f0b93
gatherV2
Jun 28, 2022
c9929a2
opt gather of backward
Jun 28, 2022
32d1e03
merge upstream
Jun 28, 2022
dab4609
resolve conflict
Jun 28, 2022
f66d0c7
opt groups indexs
Jun 28, 2022
44ad03e
refine code
EsdeathYZH Jun 28, 2022
db9792c
replace sort with remove_copy
Jul 1, 2022
6f9d6ea
fix cache
Jul 1, 2022
c295776
unorder the out index of Conv3D
Jul 3, 2022
bf6fabf
unorder the out index of Conv3D
Jul 3, 2022
b92786d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 4, 2022
e181b1c
add coalesced
Jul 4, 2022
e2bf43a
add coalesced.py
Jul 4, 2022
0aa457f
coalesced before compare result
Jul 5, 2022
749dcce
the key of conv3d is not required
Jul 5, 2022
d38563b
opt code structure
Jul 5, 2022
d06527f
opt gather/scatter code structure
Jul 6, 2022
842acf7
fix pool
Jul 6, 2022
5751987
rename pool_kernel.cc
Jul 6, 2022
7c2fbf5
add new file
Jul 6, 2022
6684d94
for ci
Jul 6, 2022
4346bbb
fix comment
Jul 6, 2022
3187f52
opt code structure
Jul 6, 2022
aa284f4
rename conv_kernel
Jul 6, 2022
57bb27e
Merge remote-tracking branch 'zihang/optim_batchnorm1d' into opt_conv
Jul 7, 2022
0b5ca0e
fix
EsdeathYZH Jul 8, 2022
33ebaf5
rename table_ptr to indices_dict
Jul 8, 2022
d49a06b
Merge remote-tracking branch 'zihang/optim_batchnorm1d' into opt_conv
Jul 8, 2022
123b16c
fix test_sparse_utils
Jul 8, 2022
8658c28
merge upstream
Jul 13, 2022
c154249
Merge branch 'opt_conv' of https://github.com/zkh2016/Paddle into opt…
Jul 13, 2022
af66998
sparse support amp
Jul 13, 2022
8dcc194
Merge branch 'develop' into opt_conv
Jul 13, 2022
aba1e6e
Merge branch 'develop' into opt_conv
Jul 13, 2022
55b30a3
merge upstream
Jul 15, 2022
1864dfc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 18, 2022
3c9ec13
Merge remote-tracking branch 'zkh/opt_conv' into upstream_dev
Jul 18, 2022
91ee01b
resolve conflict
Jul 18, 2022
927e247
resolve conflict
Jul 18, 2022
4480cdf
Merge branch 'opt_conv' of https://github.com/zkh2016/Paddle into opt…
Jul 18, 2022
dbe743b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 19, 2022
a51402a
fix codestyle
Jul 19, 2022
0f9827e
merge upstream
Jul 19, 2022
8c442f6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Jul 25, 2022
ab996e0
supplement the description of key
Jul 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions paddle/phi/api/yaml/sparse_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
backward : add_grad

- api : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output : Tensor(out), Tensor(rulebook)
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel :
func : sparse_conv3d{sparse_coo, dense -> sparse_coo, dense}
func : sparse_conv3d{sparse_coo, dense -> sparse_coo, dense, dense}
layout : x
intermediate : rulebook
intermediate: rulebook, counter
backward : conv3d_grad

- api : coo_to_dense
Expand Down Expand Up @@ -132,6 +132,13 @@
layout : x
backward : values_grad

- api: coalesced
args : (Tensor x)
output : Tensor(out)
kernel :
func: coalesced{sparse_coo -> sparse_coo}
layout : x

- api: full_like
args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED)
output : Tensor(out)
Expand Down Expand Up @@ -162,11 +169,11 @@

- api: maxpool
args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output : Tensor(out), Tensor(rulebook)
output : Tensor(out), Tensor(rulebook), Tensor(counter)
kernel :
func : sparse_maxpool{sparse_coo -> sparse_coo, dense}
func : sparse_maxpool{sparse_coo -> sparse_coo, dense, dense}
layout : x
intermediate : rulebook
intermediate : rulebook, counter
backward : sparse_maxpool_grad

- api: mv
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/api/yaml/sparse_bw_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}

- backward_api : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor), Tensor(counter@DenseTensor)
args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(x_grad), Tensor(kernel_grad)
kernel :
func : sparse_conv3d_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
func : sparse_conv3d_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}

- backward_api : coo_to_dense_grad
forward : coo_to_dense(Tensor x) -> Tensor(out)
Expand Down Expand Up @@ -93,11 +93,11 @@
func : softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}

- backward_api : sparse_maxpool_grad
forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor rulebook, Tensor counter, Tensor out, Tensor out_grad, int[] kernel_sizes)
output : Tensor(x_grad)
kernel :
func : sparse_maxpool_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo}
func : sparse_maxpool_grad {sparse_coo, dense, dense, sparse_coo, sparse_coo -> sparse_coo}

- backward_api : sqrt_grad
forward : sqrt(Tensor x) -> Tensor(out)
Expand Down
38 changes: 38 additions & 0 deletions paddle/phi/core/sparse_coo_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,38 @@ class SparseCooTensor : public TensorBase,
/// \brief get the dnese dim
int32_t dense_dim() const;

const std::pair<DenseTensor, std::vector<int>>* table(
const std::string& key) const {
const auto& iter = table_ptr_->find(key);
if (iter == table_ptr_->end()) {
return nullptr;
}
return &iter->second;
}
// DenseTensor* mutable_rulebook() { return &rulebook_; }
void SetTable(const std::string& key,
const std::pair<DenseTensor, std::vector<int>>& table) {
auto ret = table_ptr_->insert({key, table});
if (ret.second == false) {
ret.first->second = table;
}
}

const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, std::vector<int>>>>&
GetTablePtr() const {
return table_ptr_;
}
void SetTablePtr(
const std::shared_ptr<
std::map<std::string, std::pair<DenseTensor, std::vector<int>>>>&
table_ptr) {
table_ptr_ = table_ptr;
}

// const bool subm() const { return subm_; }
// void SetSubm(const bool subm) { subm_ = subm; }

private:
// save the indices of non zero elements in original dense tensor
DenseTensor non_zero_indices_;
Expand All @@ -165,6 +197,12 @@ class SparseCooTensor : public TensorBase,
bool coalesced_ = false;
// save the number of non zero elements in each batch
DDim dims_;

// for sparse conv
std::shared_ptr<
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
std::map<std::string, std::pair<DenseTensor, std::vector<int>>>>
table_ptr_ = std::make_shared<
std::map<std::string, std::pair<DenseTensor, std::vector<int>>>>();
/* --------------------------- */
/* example: non zero element is scalar */
/* --------------------------- */
Expand Down
121 changes: 109 additions & 12 deletions paddle/phi/kernels/funcs/sparse/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"

#define VecBytes 16

namespace phi {
namespace funcs {
Expand All @@ -28,33 +33,125 @@ namespace sparse {
* channels: the output channel size
* out: the outputs
**/
template <typename T>
template <typename T, int VecSize>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out,
const bool subm = false) {
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;

int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
StoreT sums = {static_cast<T>(0)};
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}
// scatter's index has been grouped in advance
// index_counts record the count of every group
// index_groups save the index of every group
template <typename T, int VecSize>
__global__ void ScatterKernelV2(const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;

StoreT sums = {static_cast<T>(0)};
phi::Load<T, VecSize>(out + indices_i * channels + channels_i * VecSize,
&sums);
for (int it = 0; it < buffer_counts; it++) {
int len = index_counts[indices_i + it * non_zero_num];
const int group_offset = it * kernel_size * non_zero_num;
for (int j = 0; j < len; j++) {
const int out_feature_i =
index_groups[indices_i * kernel_size + j + group_offset];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
}
out[indices_i * channels + channels_i] = sum;
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}

template <typename T>
void ScatterV2(const GPUContext& dev_ctx,
const T* input,
const int* index_counts,
const int* index_groups,
const int non_zero_num,
const int kernel_size,
const int channels,
const int buffer_counts,
T* output) {
const int VecSize = VecBytes / sizeof(T);
if (channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels / VecSize, 1);
ScatterKernelV2<T, VecSize><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, non_zero_num * channels, 1);
ScatterKernelV2<T, 1><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(input,
index_counts,
index_groups,
non_zero_num,
kernel_size,
channels,
buffer_counts,
output);
}
}

Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/sparse/coalesced_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,12 @@ void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);

template <typename T, typename Context>
SparseCooTensor Coalesced(const Context& dev_ctx, const SparseCooTensor& x) {
SparseCooTensor coo;
CoalescedKernel<T, Context>(dev_ctx, x, &coo);
return coo;
}

} // namespace sparse
} // namespace phi
11 changes: 10 additions & 1 deletion paddle/phi/kernels/sparse/convolution_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ template <typename T, typename Context>
void Conv3dGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad);

Expand All @@ -41,27 +44,33 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dGrad(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& kernel,
const SparseCooTensor& out,
const DenseTensor& rulebook,
const DenseTensor& counter,
const SparseCooTensor& out_grad,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const int groups,
const bool subm) {
const bool subm,
const std::string& key) {
SparseCooTensor x_grad;
DenseTensor kernel_grad;

// TODO(zhangkaihuo): call InferMeta func here
Conv3dGradKernel<T, Context>(dev_ctx,
x,
kernel,
out,
rulebook,
counter,
out_grad,
paddings,
dilations,
strides,
groups,
subm,
key,
&x_grad,
&kernel_grad);
return std::make_tuple(x_grad, kernel_grad);
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/sparse/convolution_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ void Conv3dKernel(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
const std::string& key,
SparseCooTensor* out,
DenseTensor* rulebook);
DenseTensor* rulebook,
DenseTensor* counter);

template <typename T, typename Context>
SparseCooTensor Conv3d(const Context& dev_ctx,
Expand All @@ -43,7 +45,9 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
const std::vector<int>& strides,
const int groups,
const bool subm,
DenseTensor* rulebook) {
const std::string& key,
DenseTensor* rulebook,
DenseTensor* counter) {
SparseCooTensor coo;
Conv3dKernel<T, Context>(dev_ctx,
x,
Expand All @@ -53,8 +57,10 @@ SparseCooTensor Conv3d(const Context& dev_ctx,
strides,
groups,
subm,
key,
&coo,
rulebook);
rulebook,
counter);
return coo;
}

Expand Down
Loading