Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 4th No.30】为 Paddle 新增 paddle.sparse.sum 稀疏 API #51406

Merged
merged 57 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
3ba8bce
add paddle/phi/api/yaml
zrr1999 Feb 25, 2023
2b4ea91
implement cpu kernel
zrr1999 Feb 26, 2023
fe92011
implement cpu grad kernel
zrr1999 Mar 4, 2023
12e8f33
add unitests
zrr1999 Mar 5, 2023
dc1c028
implement gpu kernel
zrr1999 Mar 5, 2023
6e2ceda
fix bug
zrr1999 Mar 6, 2023
60b8780
cuda 2d csr
zrr1999 Mar 6, 2023
6b818e7
cuda csr
zrr1999 Mar 7, 2023
8ab243f
cuda coo
zrr1999 Mar 9, 2023
ad9166c
cuda coo grad
zrr1999 Mar 9, 2023
ef0e89a
unitest
zrr1999 Mar 9, 2023
0fb2707
add 1D unitest
zrr1999 Mar 9, 2023
f5c7765
support keepdim in csr
zrr1999 Mar 9, 2023
fd2fed0
add static graph
zrr1999 Mar 9, 2023
b2137b1
add dtype
zrr1999 Mar 9, 2023
1786e84
support dtype
zrr1999 Mar 15, 2023
260706b
add unitest and fix relative import
zrr1999 Mar 16, 2023
c2844de
fix import
zrr1999 Mar 16, 2023
f1a441e
rebase
zrr1999 Apr 3, 2023
40fb8c2
fix import
zrr1999 Apr 4, 2023
6bb3977
add static unitest
zrr1999 Apr 5, 2023
6cc42be
fix static unitest
zrr1999 Apr 5, 2023
d273ad2
add more static unitests
zrr1999 Apr 5, 2023
8525c6a
remove some old func
zrr1999 Apr 5, 2023
2007c83
fix bug in static unittests for parallel
zrr1999 Apr 6, 2023
bbf4923
remove some unused note
zrr1999 Apr 10, 2023
198856c
update licence date
zrr1999 Apr 10, 2023
6abccfa
use phi::funcs::SetConstant
zrr1999 Apr 10, 2023
98a3d4d
fix bug
zrr1999 Apr 11, 2023
13e16a3
opt for in sum_grad_kernel.cc
zrr1999 Apr 11, 2023
9cecc0f
use SetConstant
zrr1999 Apr 11, 2023
a66e3a4
rename map_indices to indices_map
zrr1999 Apr 11, 2023
3d61c7a
support tensor values
zrr1999 Apr 12, 2023
94edf06
fix bug
zrr1999 Apr 12, 2023
645d82b
fix bug in cpu
zrr1999 Apr 13, 2023
4426cf6
fix bug when keep_dim is false
zrr1999 Apr 13, 2023
907d0e9
ensure the sparse_dim is not less than 1
zrr1999 Apr 13, 2023
80a6be0
fix out_shape bug
zrr1999 Apr 13, 2023
6134ba5
support gpu tensor values
zrr1999 Apr 13, 2023
1378c9f
support same dtype of indices
zrr1999 Apr 13, 2023
aa84e2c
remove unused import
zrr1999 Apr 14, 2023
ee92a93
replace intT to IntT
zrr1999 Apr 14, 2023
90c0314
optimize cuda
zrr1999 Apr 14, 2023
544aa85
accelerate cuda
zrr1999 Apr 15, 2023
80fd0de
use offset
zrr1999 Apr 17, 2023
1a63118
fix dtype unitest bug
zrr1999 Apr 17, 2023
28a5173
add explicit function template specialization
zrr1999 Apr 17, 2023
e6428fe
add explicit function template specialization
zrr1999 Apr 17, 2023
1b2c6b1
optimize code
zrr1999 Apr 18, 2023
4157bd5
fix bug
zrr1999 Apr 18, 2023
f487a74
add dx_values dtype conversion
zrr1999 Apr 20, 2023
b6c482e
test in cpu and gpu
zrr1999 Apr 25, 2023
09411ee
test in cpu and gpu
zrr1999 Apr 25, 2023
27a1cca
fix bug in test
zrr1999 Apr 25, 2023
76942bc
add sum in __all__
zrr1999 May 4, 2023
0ebfb9f
fix doc, test=document_fix
zrr1999 May 8, 2023
b3ce86a
fix sample code, test=document_fix
zrr1999 May 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/sparse_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@
func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}

- backward_op : sum_grad
forward : sum(Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray axis={}, bool keepdim=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : sum_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
sum_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : sync_batch_norm_grad
forward : sync_batch_norm_(Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/sparse_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,17 @@
layout : x
backward : subtract_grad

- op : sum
args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : SumInferMeta
kernel :
func : sum_coo{sparse_coo -> sparse_coo},
sum_csr{sparse_csr -> sparse_csr}
data_type : x
backward : sum_grad

- op : sync_batch_norm_
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_layout, bool use_global_stats, bool trainable_statistics)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ PD_REGISTER_KERNEL(sum_grad,
float,
double,
phi::dtype::float16,
int16_t,
int,
int64_t,
phi::dtype::complex<float>,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ PD_REGISTER_KERNEL(sum_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int16_t,
int,
int64_t,
phi::dtype::complex<float>,
Expand Down
219 changes: 219 additions & 0 deletions paddle/phi/kernels/sparse/cpu/sum_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/reduce_sum_grad_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT, typename Context>
void SumCooGradCPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
unsigned int n_dim = axis.size();

const DenseTensor& x_indices = x.indices();
const DenseTensor& dout_indices = dout.indices();
const DenseTensor& dout_values = dout.values();
const auto* dout_indices_data = dout_indices.data<int64_t>();
const auto* dout_values_data = dout_values.data<T>();

DenseTensor* dx_indices = dx->mutable_indices();
DenseTensor* dx_values = dx->mutable_values();
*dx_indices = x_indices;

const auto* dx_indices_data = dx_indices->data<int64_t>();
auto* dx_values_data = dx_values->data<T>();

phi::funcs::SetConstant<Context, T> set_constant;
if (n_dim == 0) {
T value = dout_values.data<T>()[0];
set_constant(dev_ctx, dx_values, value);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}

auto dim = axis[0] < 0 ? x.dims().size() + axis[0] : axis[0];
auto sparse_dim = x.sparse_dim();
if (dim >= sparse_dim) {
dim = dim - sparse_dim + 1;
phi::ReduceSumGradKernel<T, Context>(
dev_ctx, x.values(), dout.values(), {dim}, keep_dim, false, dx_values);
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
// Ensure the sparse_dim is not less than 1.
if (sparse_dim == 1) {
keep_dim = true;
}

int64_t dense_dim = 1;
for (auto i = 1; i < x.values().dims().size(); ++i) {
dense_dim *= x.values().dims()[i];
}

std::map<std::vector<IntT>, int64_t> indices_map;
for (auto j = 0; j < dout_indices.dims()[1]; ++j) {
std::vector<IntT> pos;
for (int i = 0; i < dout_indices.dims()[0]; ++i) {
pos.push_back(dout_indices_data[j + i * dout_indices.dims()[1]]);
}
indices_map[pos] = j;
}

for (auto j = 0; j < dx_indices->dims()[1]; ++j) {
std::vector<IntT> pos;
for (int i = 0; i < dx_indices->dims()[0]; ++i) {
if (i != dim) {
pos.push_back(dx_indices_data[j + i * dx_indices->dims()[1]]);
} else if (keep_dim) {
pos.push_back(0);
}
}
for (int i = 0; i < dense_dim; ++i) {
dx_values_data[i + j * dense_dim] =
dout_values_data[i + indices_map[pos] * dense_dim];
}
}
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
}

template <typename T, typename Context>
void SumCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
unsigned int n_dim = axis.size();

const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& dout_values = dout.values();
const auto* x_crows_data = x_crows.data<int64_t>();

DenseTensor* dx_crows = dx->mutable_crows();
DenseTensor* dx_cols = dx->mutable_cols();
DenseTensor* dx_values = dx->mutable_values();

*dx_crows = x_crows;
*dx_cols = x_cols;

phi::funcs::SetConstant<Context, T> set_constant;
if (n_dim == 0) {
T value = dout_values.data<T>()[0];
set_constant(dev_ctx, dx_values, value);
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
}
return;
}
PADDLE_ENFORCE_EQ(axis[0],
-1,
phi::errors::Unimplemented(
"`axis` of SumCsrKernel only support None or -1 now."
"More number will be supported in the future."));

if (x.dims().size() == 2) {
int value_index = 0;
for (int k = 0; k < x.dims()[0]; ++k) {
if (x_crows_data[k] == x_crows_data[k + 1]) {
continue;
}
T value = dout_values.data<T>()[value_index];
set_constant(dev_ctx, dx_values, value);
value_index += 1;
}
} else {
int dout_value_index = 0;
int dx_value_index = 0;
for (auto batch = 0; batch < x.dims()[0]; ++batch) {
for (auto k = batch * (x.dims()[1] + 1);
k < batch * (x.dims()[1] + 1) + x.dims()[1];
++k) {
if (x_crows_data[k] == x_crows_data[k + 1]) {
continue;
}
T value = dout_values.data<T>()[dout_value_index];
for (auto i = x_crows_data[k]; i < x_crows_data[k + 1]; ++i) {
dx_values->data<T>()[dx_value_index] = value;
dx_value_index++;
}
dout_value_index++;
}
}
}

if (dx_values->dtype() != dx->dtype()) {
*dx_values = phi::Cast<T, Context>(dev_ctx, *dx_values, dx->dtype());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请教下,这里为什么要加一个判断呢?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是因为如果指定了输出 dtype,dxvalues会保持和输出dout的类型一样,但是和dx不同,调用sparse_to_dense就会报错
ValueError: (InvalidArgument) The type of data we are trying to retrieve (float64) does not match the type of data (float32) currently contained in the container.
[Hint: Expected dtype() == phi::CppTypeToDataType::Type(), but received dtype():10 != phi::CppTypeToDataType::Type():11.] (at /Paddle/paddle/phi/core/dense_tensor.cc:163)
这里通过这个把dxvalues转换成和dx相同的类型

}
}

template <typename T, typename Context>
void SumCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
const IntArray& axis,
bool keep_dim,
SparseCooTensor* dx) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "SumCooGradCPUKernel", ([&] {
SumCooGradCPUKernel<T, data_t, Context>(
dev_ctx, x, dout, axis, keep_dim, dx);
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(sum_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SumCooGradKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(sum_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SumCsrGradKernel,
float,
double,
int16_t,
int,
int64_t,
bool) {}
Loading