Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 NO.23】为 Paddle 新增 vander API #51048

Merged
merged 24 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6bf60e0
first commit of vander api
Li-fAngyU Mar 1, 2023
040457f
fix file name
Li-fAngyU Mar 1, 2023
fc85609
fix code style problem
Li-fAngyU Mar 1, 2023
200ae0d
fix error of vander api at math.py
Li-fAngyU Mar 2, 2023
ce5a955
fix code style problem
Li-fAngyU Mar 2, 2023
e2d1f97
fix code style
Li-fAngyU Mar 2, 2023
b4bea35
code style
Li-fAngyU Mar 2, 2023
68942cb
code style
Li-fAngyU Mar 2, 2023
368adf2
fix test vander api error!
Li-fAngyU Mar 2, 2023
29b84b0
Merge branch 'PaddlePaddle:develop' into api_vander
Li-fAngyU Mar 6, 2023
eba525a
add complex test!
Li-fAngyU Mar 6, 2023
9f4ff05
add complex example code of vander.
Li-fAngyU Mar 6, 2023
9c8659b
fix sample code error
Li-fAngyU Mar 7, 2023
329d3cc
Merge branch 'api_vander' of https://github.com/Li-fAngyU/Paddle into…
Li-fAngyU Mar 16, 2023
f3fc1ba
fix vander api params type.
Li-fAngyU Mar 16, 2023
c81e532
Merge branch 'api_vander' of https://github.com/Li-fAngyU/Paddle into…
Li-fAngyU Mar 17, 2023
407893e
fix error in test_vander
Li-fAngyU Mar 20, 2023
acf59bd
Merge branch 'PaddlePaddle:develop' into api_vander
Li-fAngyU Mar 22, 2023
1f29e36
change paddle.fluid to paddle.static in test_vander file.
Li-fAngyU Mar 22, 2023
7196e17
fix params
Li-fAngyU Mar 23, 2023
03bb29a
Merge branch 'api_vander' of https://github.com/Li-fAngyU/Paddle into…
Li-fAngyU Mar 31, 2023
3507d28
update vander return describe
Li-fAngyU Mar 31, 2023
956a36b
vander examples.
Li-fAngyU Mar 31, 2023
5d748f1
code style
Li-fAngyU Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@
from .tensor.math import frexp # noqa: F401
from .tensor.math import trapezoid # noqa: F401
from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.math import vander # noqa: F401

from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
Expand Down Expand Up @@ -687,4 +688,5 @@
'trapezoid',
'cumulative_trapezoid',
'polar',
'vander',
]
100 changes: 100 additions & 0 deletions python/paddle/fluid/tests/unittests/test_vander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.fluid import core

np.random.seed(10)


def ref_vander(x, N=None, increasing=False):
return np.vander(x, N, increasing)


class TestVanderAPI(unittest.TestCase):
# test paddle.tensor.math.vander

def setUp(self):
self.shape = [5]
self.x = np.random.uniform(-1, 1, self.shape).astype(np.float32)
self.place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def api_case(self, N=None, increasing=False):
paddle.enable_static()
out_ref = ref_vander(self.x, N, increasing)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.shape)
out = paddle.vander(x, N, increasing)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
if N != 0:
np.testing.assert_allclose(res[0], out_ref, rtol=1e-05)
else:
np.testing.assert_allclose(res[0].size, out_ref.size, rtol=1e-05)

paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
out = paddle.vander(x, N, increasing)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-05)
paddle.enable_static()

def test_api(self):
self.api_case()
N = list(range(9))
for n in N:
self.api_case(n)
self.api_case(n, increasing=True)

def test_complex(self):
paddle.disable_static(self.place)
real = np.random.rand(5)
imag = np.random.rand(5)
complex_np = real + 1j * imag
complex_paddle = paddle.complex(
paddle.to_tensor(real), paddle.to_tensor(imag)
)

def test_api_case(N, increasing=False):
for n in N:
res_np = np.vander(complex_np, n, increasing)
res_paddle = paddle.vander(complex_paddle, n, increasing)
np.testing.assert_allclose(
res_paddle.numpy(), res_np, rtol=1e-05
)

N = [0, 1, 2, 3, 4]
test_api_case(N)
test_api_case(N, increasing=True)
paddle.enable_static()

def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
self.assertRaises(TypeError, paddle.vander, 1)
x = paddle.static.data('X', [10, 12], 'int32')
self.assertRaises(ValueError, paddle.vander, x)
x1 = paddle.static.data('X1', [10], 'int32')
self.assertRaises(ValueError, paddle.vander, x1, n=-1)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@
from .math import cumulative_trapezoid # noqa: F401
from .math import sigmoid # noqa: F401
from .math import sigmoid_ # noqa: F401
from .math import vander # noqa: F401

from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
Expand Down Expand Up @@ -538,6 +539,7 @@
'polar',
'sigmoid',
'sigmoid_',
'vander',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
76 changes: 76 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5333,3 +5333,79 @@ def cumulative_trapezoid(y, x=None, dx=None, axis=-1, name=None):
# [3.50000000, 8. ]])
"""
return _trapezoid(y, x, dx, axis, mode='cumsum')


def vander(x, n=None, increasing=False, name=None):
"""
Generate a Vandermonde matrix.

The columns of the output matrix are powers of the input vector. Order of the powers is
determined by the increasing Boolean parameter. Specifically, when the increment is
"false", the ith output column is a step-up in the order of the elements of the input
vector to the N - i - 1 power. Such a matrix with a geometric progression in each row
is named after Alexandre-Theophile Vandermonde.

Args:
x (Tensor): The input tensor, it must be 1-D Tensor, and it's data type should be ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'complex64', 'complex128', 如果支持复数的话,单测以及示例代码中均要体现

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

n (int): Number of columns in the output. If n is not specified, a square array is returned (n = len(x)).
increasing(bool): Order of the powers of the columns. If True, the powers increase from left to right, if False (the default) they are reversed.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor, A vandermonde matrix with shape (len(x), N). If increasing is False, the first column is :math:`x^{(N-1)}`, the second :math:`x^{(N-2)}` and so forth.
If increasing is True, the columns are :math:`x^0`, :math:`x^1`, ..., :math:`x^{(N-1)}`.

Examples:
.. code-block:: python

import paddle
x = paddle.to_tensor([1., 2., 3.], dtype="float32")
out = paddle.vander(x)
print(out.numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print输出一定要在后面加 .numpy()吗? 去除后能否正常输出

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了.numpy() 是为了让输出更简洁一点,去除后可以正常输出。

# [[1., 1., 1.],
# [4., 2., 1.],
# [9., 3., 1.]]
out1 = paddle.vander(x,2)
print(out1.numpy())
# [[1., 1.],
# [2., 1.],
# [3., 1.]]
out2 = paddle.vander(x, increasing = True)
print(out2.numpy())
# [[1., 1., 1.],
# [1., 2., 4.],
# [1., 3., 9.]]
real = paddle.to_tensor([2., 4.])
imag = paddle.to_tensor([1., 3.])
complex = paddle.complex(real, imag)
out3 = paddle.vander(complex)
print(out3.numpy())
# [[2.+1.j, 1.+0.j],
# [4.+3.j, 1.+0.j]]
"""
check_variable_and_dtype(
x,
'x',
['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'],
'vander',
)
if x.dim() != 1:
raise ValueError(
"The input of x is expected to be a 1-D Tensor."
"But now the dims of Input(X) is %d." % x.dim()
)

if n is None:
n = x.shape[0]

if n < 0:
raise ValueError("N must be non-negative.")

res = paddle.empty([x.shape[0], n], dtype=x.dtype)

if n > 0:
res[:, 0] = paddle.to_tensor([1], dtype=x.dtype)
if n > 1:
res[:, 1:] = x[:, None]
res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1)
res = res[:, ::-1] if not increasing else res
return res