-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hackathon 3rd No.22 ] add paddle.incubate.sparse.reshape (#46694)
* add sparse reshape * change the dtype in all test cases to int64 * just one test case * modify comments * Update test_sparse_reshape_op.py * chang the type of "shape" from vector<int64_t> to IntArray * check whether sp_out.to_dense() is the cause of error * print sp_out * Update reshape_kernel.cc * use numpy to generate the equal paddle tensor * just check dense_tensor.numpy() * check cpu and cuda versions * Update test_sparse_reshape_op.py * supply all test cases for cpu forward coo kernel * test forward coo cuda kernel * change configuration of cuda kernel * keep only one test case * test coo cpu kernel (forward and backward) * row major or column major ??? * test cuda coo forward kernel * complete declaration and registration * Update __init__.py * rebuild * retrigger CI * add cudaMalloc and cudaMemcpy in ReshapeCooKernel and change back to row major order in a cuda dense tensor * midify minor error * test only cpu coo forward kernel * add all test cases for coo forward kernel (both cpu and gpu) * test all forward kernels (coo, csr; cpu, gpu) * add all test cases for all kinds of kernels * just retrigger CI * Update sparse_ops.yaml * Update sparse_ops.yaml * Update sparse_ops.yaml * resolve conflicts * Update sparse_ops.yaml * don't specify tensor place * new shape has -1 or 0 in it * Update unary_grad_kernel.h * correct lvalue error * code style * Update sparse_backward.yaml * Update sparse_ops.yaml * Update unary_kernel.h * Update unary.py * Update sparse_backward.yaml * Update unary.py * code style * code style * code style * Update unary.py * specify tensor place explicitly * do not use numpy array * use numpy array in unit test again * modify example code in docstring
- Loading branch information
1 parent
6430790
commit abb3813
Showing
13 changed files
with
713 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" | ||
#include "paddle/phi/kernels/sparse/unary_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" | ||
|
||
namespace phi { | ||
namespace sparse { | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCooGradKernel(const Context& dev_ctx, | ||
const SparseCooTensor& x, | ||
const SparseCooTensor& dout, | ||
SparseCooTensor* dx) { | ||
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCsrGradKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
const SparseCsrTensor& dout, | ||
SparseCsrTensor* dx) { | ||
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(reshape_coo_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCooGradKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} | ||
|
||
PD_REGISTER_KERNEL(reshape_csr_grad, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCsrGradKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/sparse/unary_kernel.h" | ||
|
||
#include "paddle/phi/core/ddim.h" | ||
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" | ||
|
||
namespace phi { | ||
namespace sparse { | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCooKernel(const Context& dev_ctx, | ||
const SparseCooTensor& x, | ||
const phi::IntArray& shape, | ||
SparseCooTensor* out) { | ||
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims | ||
int64_t x_nnz = x.nnz(); | ||
|
||
// Use DDim::reshape to handle -1 and 0 in the argument "shape" | ||
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end()); | ||
phi::DDim out_dims = x.dims().reshape(new_shape); | ||
// get sparse part dimensions of x and out | ||
std::vector<int64_t> x_sparse_part_dims; | ||
std::vector<int64_t> out_sparse_part_dims; | ||
for (int i = 0; i < x.sparse_dim(); ++i) { | ||
x_sparse_part_dims.push_back(x.dims()[i]); | ||
} | ||
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) { | ||
out_sparse_part_dims.push_back(out_dims[i]); | ||
} | ||
DenseTensor out_indices = Empty<int64_t, Context>( | ||
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz}); | ||
DenseTensor out_values(x.values()); | ||
out->SetMember(out_indices, out_values, out_dims, x.coalesced()); | ||
|
||
// compute values of indices | ||
const DenseTensor& x_indices = x.indices(); | ||
const auto* x_indices_data = x_indices.data<int64_t>(); | ||
auto* out_indices_data = out_indices.data<int64_t>(); | ||
|
||
const phi::DDim& x_sparse_part_strides = | ||
phi::stride(phi::make_ddim(x_sparse_part_dims)); | ||
const phi::DDim& out_sparse_part_strides = | ||
phi::stride(phi::make_ddim(out_sparse_part_dims)); | ||
int64_t location = 0; | ||
for (int64_t j = 0; j < x_nnz; ++j) { | ||
location = 0; | ||
for (int i = 0; i < x.sparse_dim(); ++i) { | ||
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i]; | ||
} | ||
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) { | ||
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i]; | ||
location %= out_sparse_part_strides[i]; | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void ReshapeCsrKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
const phi::IntArray& shape, | ||
SparseCsrTensor* out) { | ||
// transform csr format to coo format, and then use coo kernel | ||
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x); | ||
SparseCooTensor out_coo; | ||
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo); | ||
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(reshape_coo, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCooKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} | ||
|
||
PD_REGISTER_KERNEL(reshape_csr, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCsrKernel, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" | ||
#include "paddle/phi/kernels/sparse/unary_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" | ||
|
||
namespace phi { | ||
namespace sparse { | ||
|
||
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc | ||
template <typename T, typename Context> | ||
void ReshapeCooGradKernel(const Context& dev_ctx, | ||
const SparseCooTensor& x, | ||
const SparseCooTensor& dout, | ||
SparseCooTensor* dx) { | ||
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc | ||
template <typename T, typename Context> | ||
void ReshapeCsrGradKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
const SparseCsrTensor& dout, | ||
SparseCsrTensor* dx) { | ||
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx); | ||
phi::IntArray x_shape(phi::vectorize(x.dims())); | ||
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(reshape_coo_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCooGradKernel, | ||
phi::dtype::float16, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} | ||
|
||
PD_REGISTER_KERNEL(reshape_csr_grad, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::sparse::ReshapeCsrGradKernel, | ||
phi::dtype::float16, | ||
float, | ||
double, | ||
int8_t, | ||
uint8_t, | ||
int16_t, | ||
int, | ||
int64_t, | ||
bool) {} |
Oops, something went wrong.