-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 4th No.24】为 Paddle 新增 paddle.sparse.is_nan 稀疏 API #51513
Changes from all commits
c8ae296
6aa02f0
d599110
264894d
98d1e1c
b958122
760e099
e16076d
085c7a6
b1edf68
366305e
06abd01
7399da9
daaa4d5
de93190
5e5c86b
929963a
0f13d4e
fb34f98
7a70698
3fb3fa5
2284e65
94f9bc5
0075dcf
f323b0a
9282ed7
33a5ec8
fe595a5
cdeeacd
015f94e
987d523
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
#include "paddle/phi/kernels/abs_kernel.h" | ||
#include "paddle/phi/kernels/activation_kernel.h" | ||
#include "paddle/phi/kernels/cast_kernel.h" | ||
#include "paddle/phi/kernels/isfinite_kernel.h" | ||
#include "paddle/phi/kernels/scale_kernel.h" | ||
#include "paddle/phi/kernels/sparse/empty_kernel.h" | ||
#include "paddle/phi/kernels/trunc_kernel.h" | ||
|
@@ -219,5 +220,44 @@ void CastCsrKernel(const Context& dev_ctx, | |
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void IsnanCooKernel(const Context& dev_ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个用目前的公共组件,宏函数来注册kernel。可以复用减少代码 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zhouwei25 公共组件不能满足这个算子,因为emptylike会创建一个相同类型的输出tensor,而这个算子输出是bool型的,所以这里单独写了个kernel。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
OK |
||
const SparseCooTensor& x, | ||
SparseCooTensor* out) { | ||
*(out->mutable_indices()) = x.indices(); | ||
const DenseTensor& x_values = x.non_zero_elements(); | ||
DenseTensor* out_values = out->mutable_non_zero_elements(); | ||
|
||
phi::MetaTensor meta(out_values); | ||
meta.set_dims(x_values.dims()); | ||
meta.set_dtype(DataType::BOOL); | ||
|
||
phi::IsnanKernel<T, Context>( | ||
dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); | ||
out->SetIndicesDict(x.GetIndicesDict()); | ||
} | ||
|
||
template <typename T, typename Context> | ||
void IsnanCsrKernel(const Context& dev_ctx, | ||
const SparseCsrTensor& x, | ||
SparseCsrTensor* out) { | ||
const DenseTensor& x_crows = x.crows(); | ||
const DenseTensor& x_cols = x.cols(); | ||
const DenseTensor& x_values = x.non_zero_elements(); | ||
DenseTensor* out_crows = out->mutable_crows(); | ||
DenseTensor* out_cols = out->mutable_cols(); | ||
DenseTensor* out_values = out->mutable_non_zero_elements(); | ||
|
||
*out_crows = x_crows; | ||
*out_cols = x_cols; | ||
|
||
phi::MetaTensor meta(out_values); | ||
meta.set_dims(x_values.dims()); | ||
meta.set_dtype(DataType::BOOL); | ||
|
||
phi::IsnanKernel<T, Context>( | ||
dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); | ||
} | ||
|
||
} // namespace sparse | ||
} // namespace phi |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
|
||
|
||
class TestSparseIsnan(unittest.TestCase): | ||
""" | ||
Test the API paddle.sparse.isnan on some sparse tensors. | ||
x: sparse tensor, out: sparse tensor | ||
""" | ||
|
||
def to_sparse(self, x, format): | ||
if format == 'coo': | ||
return x.detach().to_sparse_coo(sparse_dim=x.ndim) | ||
elif format == 'csr': | ||
return x.detach().to_sparse_csr() | ||
|
||
def check_result(self, x_shape, format, data_type="float32"): | ||
raw_inp = np.random.randint(-100, 100, x_shape) | ||
mask = np.random.randint(0, 2, x_shape) | ||
inp_x = (raw_inp * mask).astype(data_type) | ||
inp_x[inp_x > 0] = np.nan | ||
np_out = np.isnan(inp_x[inp_x != 0]) | ||
|
||
dense_x = paddle.to_tensor(inp_x) | ||
sp_x = self.to_sparse(dense_x, format) | ||
sp_out = paddle.sparse.isnan(sp_x) | ||
sp_out_values = sp_out.values().numpy() | ||
|
||
np.testing.assert_allclose(np_out, sp_out_values, rtol=1e-05) | ||
|
||
def test_isnan_shape(self): | ||
self.check_result([20], 'coo') | ||
|
||
self.check_result([4, 5], 'coo') | ||
self.check_result([4, 5], 'csr') | ||
|
||
self.check_result([8, 16, 32], 'coo') | ||
self.check_result([8, 16, 32], 'csr') | ||
|
||
def test_isnan_dtype(self): | ||
self.check_result([4, 5], 'coo', "float32") | ||
self.check_result([4, 5], 'csr', "float32") | ||
|
||
self.check_result([8, 16, 32], 'coo', "float64") | ||
self.check_result([8, 16, 32], 'csr', "float64") | ||
|
||
|
||
class TestStatic(unittest.TestCase): | ||
def test(self): | ||
paddle.enable_static() | ||
|
||
indices = paddle.static.data( | ||
name='indices', shape=[2, 3], dtype='int32' | ||
) | ||
values = paddle.static.data(name='values', shape=[3], dtype='float32') | ||
|
||
dense_shape = [3, 3] | ||
sp_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) | ||
sp_y = paddle.sparse.isnan(sp_x) | ||
out = sp_y.to_dense() | ||
|
||
exe = paddle.static.Executor() | ||
indices_data = [[0, 1, 2], [1, 2, 0]] | ||
values_data = np.array([1.0, float("nan"), 3.0]).astype('float32') | ||
|
||
fetch = exe.run( | ||
feed={'indices': indices_data, 'values': values_data}, | ||
fetch_list=[out], | ||
return_numpy=True, | ||
) | ||
|
||
correct_out = np.array( | ||
[[False, False, False], [False, False, True], [False, False, False]] | ||
).astype('float32') | ||
np.testing.assert_allclose(correct_out, fetch[0], rtol=1e-5) | ||
paddle.disable_static() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个名字不太对?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IsfiniteInferMeta吗?最新提交已修改成unchanged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhouwei25 如果改成unchanged,还是会报错,最新提交又修改回IsfiniteInferMeta
InvalidArgumentError: The type of data we are trying to retrieve (float32) does not match the type of data (bool) currently contained in the container.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就是命名风格有点问题,这个不是IsNanInferMeta吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhouwei25 IsfiniteInferMeta是参考dense tensor那里设计的IsfiniteInferMeta, 考虑到代码可能会冗余就这么直接复用了,老师建议这个地方是需要单独写一个IsNanInferMeta吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PD_REGISTER_INFER_META_FN(isnan, phi::IsfiniteInferMeta);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好复用也可以