diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 2208d34ecf4cd..85b64e867eef8 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -155,6 +155,17 @@ layout : x backward : expm1_grad +- op : isnan + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : IsfiniteInferMeta + param: [x] + kernel : + func : isnan_coo{sparse_coo -> sparse_coo}, + isnan_csr{sparse_csr -> sparse_csr} + layout : x + - op : leaky_relu args : (Tensor x, float alpha) output : Tensor(out) diff --git a/paddle/phi/kernels/sparse/cpu/mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/mask_kernel.cc index 89c1fc1a9eb23..517c4c078e66a 100644 --- a/paddle/phi/kernels/sparse/cpu/mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/mask_kernel.cc @@ -165,7 +165,8 @@ PD_REGISTER_KERNEL(mask_coo, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index 99ccb878e4254..41c39cf387b06 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -370,7 +370,8 @@ PD_REGISTER_KERNEL(coo_to_dense, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(csr_to_dense, CPU, @@ -383,7 +384,8 @@ PD_REGISTER_KERNEL(csr_to_dense, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(values_coo, CPU, @@ -396,7 +398,8 @@ PD_REGISTER_KERNEL(values_coo, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } @@ -426,7 +429,8 @@ PD_REGISTER_KERNEL(values_csr, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } diff --git a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc index 4bbb97936e6e4..d36439549cf1f 100644 --- a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc @@ -141,3 +141,27 @@ PD_REGISTER_KERNEL(cast_csr, int, int64_t, bool) {} + +PD_REGISTER_KERNEL(isnan_coo, + CPU, + ALL_LAYOUT, + phi::sparse::IsnanCooKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(isnan_csr, + CPU, + ALL_LAYOUT, + phi::sparse::IsnanCsrKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu index d04d96aa72cb8..bae969cf23eb7 100644 --- a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -303,7 +303,8 @@ PD_REGISTER_KERNEL(mask_coo, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index c72a38cd8fd32..2f86a643aa55b 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -580,7 +580,8 @@ PD_REGISTER_KERNEL(coo_to_dense, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(csr_to_dense, GPU, @@ -593,7 +594,8 @@ PD_REGISTER_KERNEL(csr_to_dense, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(values_coo, GPU, @@ -606,7 +608,8 @@ PD_REGISTER_KERNEL(values_coo, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } @@ -621,7 +624,8 @@ PD_REGISTER_KERNEL(values_csr, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu index 98a7248b9845a..ba9c3dbf6e39f 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu @@ -148,3 +148,27 @@ PD_REGISTER_KERNEL(cast_csr, int, int64_t, bool) {} + +PD_REGISTER_KERNEL(isnan_coo, + GPU, + ALL_LAYOUT, + phi::sparse::IsnanCooKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(isnan_csr, + GPU, + ALL_LAYOUT, + phi::sparse::IsnanCsrKernel, + float, + double, + phi::dtype::float16, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 426e580262dbc..06d7a4333640f 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -22,6 +22,7 @@ #include "paddle/phi/kernels/abs_kernel.h" #include "paddle/phi/kernels/activation_kernel.h" #include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/isfinite_kernel.h" #include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/trunc_kernel.h" @@ -219,5 +220,44 @@ void CastCsrKernel(const Context& dev_ctx, } } +template +void IsnanCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + SparseCooTensor* out) { + *(out->mutable_indices()) = x.indices(); + const DenseTensor& x_values = x.non_zero_elements(); + DenseTensor* out_values = out->mutable_non_zero_elements(); + + phi::MetaTensor meta(out_values); + meta.set_dims(x_values.dims()); + meta.set_dtype(DataType::BOOL); + + phi::IsnanKernel( + dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); + out->SetIndicesDict(x.GetIndicesDict()); +} + +template +void IsnanCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + SparseCsrTensor* out) { + const DenseTensor& x_crows = x.crows(); + const DenseTensor& x_cols = x.cols(); + const DenseTensor& x_values = x.non_zero_elements(); + DenseTensor* out_crows = out->mutable_crows(); + DenseTensor* out_cols = out->mutable_cols(); + DenseTensor* out_values = out->mutable_non_zero_elements(); + + *out_crows = x_crows; + *out_cols = x_cols; + + phi::MetaTensor meta(out_values); + meta.set_dims(x_values.dims()); + meta.set_dtype(DataType::BOOL); + + phi::IsnanKernel( + dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc index b7b2986c5e34b..5c4aa7b6b56b6 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc @@ -48,7 +48,8 @@ PD_REGISTER_KERNEL(values_coo_grad, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } @@ -62,7 +63,8 @@ PD_REGISTER_KERNEL(coo_to_dense_grad, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } @@ -91,7 +93,8 @@ PD_REGISTER_KERNEL(values_coo_grad, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } PD_REGISTER_KERNEL(coo_to_dense_grad, @@ -105,7 +108,8 @@ PD_REGISTER_KERNEL(coo_to_dense_grad, int8_t, int16_t, int, - int64_t) { + int64_t, + bool) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } PD_REGISTER_KERNEL(sparse_coo_tensor_grad, diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index 90c52504ecb3e..b219ec07236df 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -52,6 +52,7 @@ DECLARE_SPARSE_UNARY_KERNEL(Sinh) DECLARE_SPARSE_UNARY_KERNEL(Asinh) DECLARE_SPARSE_UNARY_KERNEL(Atanh) DECLARE_SPARSE_UNARY_KERNEL(Relu) +DECLARE_SPARSE_UNARY_KERNEL(Isnan) DECLARE_SPARSE_UNARY_KERNEL(Tanh) DECLARE_SPARSE_UNARY_KERNEL(Square) DECLARE_SPARSE_UNARY_KERNEL(Sqrt) diff --git a/python/paddle/fluid/tests/unittests/test_sparse_isnan_op.py b/python/paddle/fluid/tests/unittests/test_sparse_isnan_op.py new file mode 100644 index 0000000000000..b807e6ba62445 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_isnan_op.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestSparseIsnan(unittest.TestCase): + """ + Test the API paddle.sparse.isnan on some sparse tensors. + x: sparse tensor, out: sparse tensor + """ + + def to_sparse(self, x, format): + if format == 'coo': + return x.detach().to_sparse_coo(sparse_dim=x.ndim) + elif format == 'csr': + return x.detach().to_sparse_csr() + + def check_result(self, x_shape, format, data_type="float32"): + raw_inp = np.random.randint(-100, 100, x_shape) + mask = np.random.randint(0, 2, x_shape) + inp_x = (raw_inp * mask).astype(data_type) + inp_x[inp_x > 0] = np.nan + np_out = np.isnan(inp_x[inp_x != 0]) + + dense_x = paddle.to_tensor(inp_x) + sp_x = self.to_sparse(dense_x, format) + sp_out = paddle.sparse.isnan(sp_x) + sp_out_values = sp_out.values().numpy() + + np.testing.assert_allclose(np_out, sp_out_values, rtol=1e-05) + + def test_isnan_shape(self): + self.check_result([20], 'coo') + + self.check_result([4, 5], 'coo') + self.check_result([4, 5], 'csr') + + self.check_result([8, 16, 32], 'coo') + self.check_result([8, 16, 32], 'csr') + + def test_isnan_dtype(self): + self.check_result([4, 5], 'coo', "float32") + self.check_result([4, 5], 'csr', "float32") + + self.check_result([8, 16, 32], 'coo', "float64") + self.check_result([8, 16, 32], 'csr', "float64") + + +class TestStatic(unittest.TestCase): + def test(self): + paddle.enable_static() + + indices = paddle.static.data( + name='indices', shape=[2, 3], dtype='int32' + ) + values = paddle.static.data(name='values', shape=[3], dtype='float32') + + dense_shape = [3, 3] + sp_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + sp_y = paddle.sparse.isnan(sp_x) + out = sp_y.to_dense() + + exe = paddle.static.Executor() + indices_data = [[0, 1, 2], [1, 2, 0]] + values_data = np.array([1.0, float("nan"), 3.0]).astype('float32') + + fetch = exe.run( + feed={'indices': indices_data, 'values': values_data}, + fetch_list=[out], + return_numpy=True, + ) + + correct_out = np.array( + [[False, False, False], [False, False, True], [False, False, False]] + ).astype('float32') + np.testing.assert_allclose(correct_out, fetch[0], rtol=1e-5) + paddle.disable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/sparse/__init__.py b/python/paddle/sparse/__init__.py index 9ca932ac46b6a..e92a5936c4cfc 100644 --- a/python/paddle/sparse/__init__.py +++ b/python/paddle/sparse/__init__.py @@ -36,6 +36,7 @@ from .unary import expm1 from .unary import transpose from .unary import reshape +from .unary import isnan from .binary import mv from .binary import matmul @@ -83,4 +84,5 @@ 'coalesce', 'is_same_shape', 'reshape', + 'isnan', ] diff --git a/python/paddle/sparse/unary.py b/python/paddle/sparse/unary.py index 23a1aa1c030f0..da1d0b549aa41 100644 --- a/python/paddle/sparse/unary.py +++ b/python/paddle/sparse/unary.py @@ -14,12 +14,13 @@ import numpy as np -from paddle import _C_ops +from paddle import _C_ops, in_dynamic_mode from paddle.fluid.framework import ( convert_np_dtype_to_dtype_, core, dygraph_only, ) +from paddle.fluid.layer_helper import LayerHelper __all__ = [] @@ -700,3 +701,51 @@ def reshape(x, shape, name=None): """ return _C_ops.sparse_reshape(x, shape) + + +def isnan(x, name=None): + """ + + Return whether every element of input tensor is `NaN` or not, requiring x to be a SparseCooTensor or SparseCsrTensor. + + Args: + x (Tensor): The input tensor (SparseCooTensor or SparseCsrTensor), it's data type should be float16, float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Sparse Tensor with the same shape as ``x``, the bool result which shows every element of `x` whether it is `NaN` or not. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + format = "coo" + np_x = np.asarray([[[0., 0], [1., 2.]], [[0., 0], [3., float('nan')]]]) + dense_x = paddle.to_tensor(np_x) + + if format == "coo": + sparse_x = dense_x.to_sparse_coo(len(np_x.shape)) + else: + sparse_x = dense_x.to_sparse_csr() + + sparse_out = paddle.sparse.isnan(sparse_x) + print(sparse_out) + # Tensor(shape=[2, 2, 2], dtype=paddle.bool, place=Place(gpu:0), stop_gradient=True, + # indices=[[0, 0, 1, 1], + # [1, 1, 1, 1], + # [0, 1, 0, 1]], + # values=[False, False, False, True ]) + + """ + if in_dynamic_mode(): + return _C_ops.sparse_isnan(x) + else: + op_type = 'sparse_isnan' + helper = LayerHelper(op_type) + out = helper.create_sparse_variable_for_type_inference(x.dtype) + helper.append_op( + type=op_type, inputs={'x': x}, outputs={'out': out}, attrs={} + ) + return out