Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse.coalesce #44256

Merged
merged 11 commits into from
Jul 13, 2022
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/sparse_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,13 @@
layout : x
backward : values_grad

- api: coalesce
args : (Tensor x)
output : Tensor(out)
kernel :
func: coalesce{sparse_coo -> sparse_coo}
layout : x

- api: full_like
args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED)
output : Tensor(out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,16 @@ namespace phi {
namespace sparse {

template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);
void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out);

template <typename T, typename Context>
SparseCooTensor Coalesce(const Context& dev_ctx, const SparseCooTensor& x) {
SparseCooTensor coo;
CoalesceKernel<T, Context>(dev_ctx, x, &coo);
return coo;
}

} // namespace sparse
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/coalesced_kernel.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
Expand All @@ -22,9 +22,9 @@ namespace phi {
namespace sparse {

template <typename T, typename IntT>
void CoalescedCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
void CoalesceCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
Expand Down Expand Up @@ -95,22 +95,22 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx,
}

template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalescedCPUKernel", ([&] {
CoalescedCPUKernel<T, data_t>(dev_ctx, x, out);
x.non_zero_indices().dtype(), "CoalesceCPUKernel", ([&] {
CoalesceCPUKernel<T, data_t>(dev_ctx, x, out);
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(sort,
PD_REGISTER_KERNEL(coalesce,
CPU,
ALL_LAYOUT,
phi::sparse::CoalescedKernel,
phi::sparse::CoalesceKernel,
float,
double,
phi::dtype::float16,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/coalesced_kernel.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"

#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
Expand All @@ -27,9 +27,9 @@ namespace phi {
namespace sparse {

template <typename T, typename IntT>
void CoalescedGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
void CoalesceGPUKernel(const GPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
const DenseTensor& x_indices = x.non_zero_indices();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor out_indices = phi::EmptyLike<IntT>(dev_ctx, x_indices);
Expand All @@ -55,11 +55,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data<IntT>(),
sparse_offsets.data(),
sizeof(IntT) * sparse_dim,
#ifdef PADDLE_WITH_HIP
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
gpuMemcpyHostToDevice,
dev_ctx.stream());

// 1. flatten indices
Expand Down Expand Up @@ -117,11 +113,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
out_indices.data<IntT>(),
sizeof(IntT),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost,
#else
cudaMemcpyDeviceToHost,
#endif
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();

Expand Down Expand Up @@ -161,22 +153,21 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
}

template <typename T, typename Context>
void CoalescedKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
void CoalesceKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "CoalescedGPUKernel", ([&] {
CoalescedGPUKernel<T, data_t>(dev_ctx, x, out);
x.non_zero_indices().dtype(), "CoalesceGPUKernel", ([&] {
CoalesceGPUKernel<T, data_t>(dev_ctx, x, out);
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(sort,
PD_REGISTER_KERNEL(coalesce,
GPU,
ALL_LAYOUT,
phi::sparse::CoalescedKernel,
phi::sparse::CoalesceKernel,
float,
double,
phi::dtype::float16,
Expand Down
6 changes: 2 additions & 4 deletions paddle/phi/kernels/sparse/sparse_utils_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/coalesced_kernel.h"

namespace phi {
namespace sparse {
Expand Down Expand Up @@ -154,9 +153,8 @@ void SparseCooTensorKernel(const Context& dev_ctx,
const DenseTensor& indices,
const IntArray& dense_shape,
SparseCooTensor* out) {
SparseCooTensor before_coalesced(
indices, values, phi::make_ddim(dense_shape.GetData()));
CoalescedKernel<T, Context>(dev_ctx, before_coalesced, out);
*out =
SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData()));
}

} // namespace sparse
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/convolution_kernel.h"

Expand Down Expand Up @@ -207,6 +208,8 @@ void TestConv3dBase(const std::vector<IntT>& indices,
subm,
&d_rulebook);

SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);

ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz());
for (int i = 0; i < correct_out_dims.size(); i++) {
Expand All @@ -217,7 +220,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
dev_ctx_cpu,
DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW));
phi::Copy(dev_ctx_gpu,
d_out.non_zero_indices(),
tmp_d_out.non_zero_indices(),
phi::CPUPlace(),
true,
&h_indices_tensor);
Expand All @@ -231,7 +234,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
phi::EmptyLike<T>(dev_ctx_cpu, d_out.non_zero_elements());

phi::Copy(dev_ctx_gpu,
d_out.non_zero_elements(),
tmp_d_out.non_zero_elements(),
phi::CPUPlace(),
true,
&h_features_tensor);
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h"

Expand Down Expand Up @@ -157,6 +158,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
dilations,
strides,
&d_rulebook);
SparseCooTensor tmp_d_out = sparse::Coalesce<T>(dev_ctx_gpu, d_out);

ASSERT_EQ(correct_out_dims.size(), d_out.dims().size());
ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz());
Expand All @@ -168,7 +170,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
dev_ctx_cpu,
DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW));
phi::Copy(dev_ctx_gpu,
d_out.non_zero_indices(),
tmp_d_out.non_zero_indices(),
phi::CPUPlace(),
true,
&h_indices_tensor);
Expand All @@ -182,7 +184,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
phi::EmptyLike<T>(dev_ctx_cpu, d_out.non_zero_elements());

phi::Copy(dev_ctx_gpu,
d_out.non_zero_elements(),
tmp_d_out.non_zero_elements(),
phi::CPUPlace(),
true,
&h_features_tensor);
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_conv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_conv3d(self):
groups=1,
data_format="NDHWC")
out.backward(out)
out = paddle.incubate.sparse.coalesce(out)
assert np.array_equal(correct_out_values, out.values().numpy())

def test_subm_conv3d(self):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def test_sparse_coo_tensor_sorted(self):
values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.incubate.sparse.sparse_coo_tensor(
indices, values)
sparse_x = paddle.incubate.sparse.coalesce(sparse_x)
indices_sorted = [[0, 1], [1, 0]]
values_sorted = [5.0, 1.0]
assert np.array_equal(indices_sorted,
Expand All @@ -310,6 +311,7 @@ def test_sparse_coo_tensor_sorted(self):
values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.incubate.sparse.sparse_coo_tensor(
indices, values)
sparse_x = paddle.incubate.sparse.coalesce(sparse_x)
values_sorted = [[5.0, 5.0], [1.0, 1.0]]
assert np.array_equal(indices_sorted,
sparse_x.indices().numpy())
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/incubate/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .unary import pow
from .unary import cast
from .unary import neg
from .unary import coalesce

from .binary import mv
from .binary import matmul
Expand Down Expand Up @@ -66,4 +67,5 @@
'subtract',
'multiply',
'divide',
'coalesce',
]
31 changes: 31 additions & 0 deletions python/paddle/incubate/sparse/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,34 @@ def abs(x, name=None):

"""
return _C_ops.final_state_sparse_abs(x)


@dygraph_only
def coalesce(x):
r"""
the coalesced operator include sorted and merge, after coalesced, the indices of x is sorted and unique.

Parameters:
x (Tensor): the input SparseCooTensor.

Returns:
Tensor: return the SparseCooTensor after coalesced.

Examples:
.. code-block:: python

import paddle
from paddle.incubate import sparse
from paddle.fluid.framework import _test_eager_guard

with _test_eager_guard():
indices = [[0, 0, 1], [1, 1, 2]]
values = [1.0, 2.0, 3.0]
sp_x = sparse.sparse_coo_tensor(indices, values)
sp_x = sparse.coalesce(sp_x)
print(sp_x.indices())
#[[0, 1], [1, 2]]
print(sp_x.values())
zkh2016 marked this conversation as resolved.
Show resolved Hide resolved
#[3.0, 3.0]
"""
return _C_ops.final_state_sparse_coalesce(x)