Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick 2.4] add sparse api transpose/reshape/is_same_shape #47076

Merged
merged 5 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,15 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_is_same_shape(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
return ToPyObject(self->tensor.shape() == other.shape());
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -1983,6 +1992,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_same_shape",
(PyCFunction)(void (*)(void))tensor_method_is_same_shape,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"to_sparse_csr",
(PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/api/yaml/sparse_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@
func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : reshape_grad
forward : reshape(Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : scale_grad
forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, float scale)
Expand Down Expand Up @@ -385,6 +396,17 @@
kernel :
func : coo_to_dense { sparse_coo -> dense }

- backward_op : transpose_grad
forward : transpose(Tensor x, int[] perm) -> Tensor(out)
args : (Tensor out_grad, int[] perm)
output : Tensor(x_grad)
infer_meta :
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_coo_grad {sparse_coo -> sparse_coo},
transpose_csr_grad {sparse_csr -> sparse_csr}

- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/api/yaml/sparse_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,26 @@
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad

- op : transpose
args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
param: [ x, perm ]
kernel :
func : transpose_coo{sparse_coo -> sparse_coo},
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad

- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
2 changes: 1 addition & 1 deletion paddle/phi/core/sparse_coo_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class SparseCooTensor : public TensorBase,
[0, 0, 0, 0]]
dims_ = (4, 4)
non_zero_elements_ = [[0, 1, 0, 0], [0, 0, 4, 0]]
non_zero_indices_ = [0, 2],
non_zero_indices_ = [[0, 2], [1, 2]]
*/
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/core/sparse_csr_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 6]]
dims_ = (4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6]
non_zero_elements_ = [1, 2, 3, 4, 5, 6]
non_zero_crows_ = [0, 1, 3, 4, 6]
non_zero_cols_ = [1, 0, 3, 2, 1, 3]
*/
Expand All @@ -228,7 +228,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 0]]]
dims_ = (2, 4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5]
non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5]
non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5]
non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1]
*/
Expand Down
73 changes: 73 additions & 0 deletions paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
117 changes: 117 additions & 0 deletions paddle/phi/kernels/sparse/cpu/reshape_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();

// Use DDim::reshape to handle -1 and 0 in the argument "shape"
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());

// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();

const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t location = 0;
for (int64_t j = 0; j < x_nnz; ++j) {
location = 0;
for (int i = 0; i < x.sparse_dim(); ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}

template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}

PD_REGISTER_KERNEL(coo_to_csr,
CPU,
Expand All @@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}

PD_REGISTER_KERNEL(dense_to_csr,
CPU,
Expand Down
Loading