Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 NO.23】为 Paddle 新增 vander API #51048

Merged
merged 24 commits into from
Mar 31, 2023

Conversation

Li-fAngyU
Copy link
Contributor

@Li-fAngyU Li-fAngyU commented Mar 1, 2023

PR types

New features

PR changes

APIs

Describe

rfc文档链接:PaddlePaddle/community#386
rfc 修复链接:PaddlePaddle/community#464
中文文档:PaddlePaddle/docs#5681

为Paddle新增Vander API

@paddle-bot
Copy link

paddle-bot bot commented Mar 1, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请补充下对应的中文文档

is named after Alexandre-Theophile Vandermonde.

Args:
x (Tensor): The input tensor, it must be 1-D Tensor, and it's data type should be ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'complex64', 'complex128', 如果支持复数的话,单测以及示例代码中均要体现

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

res[0].size, out_ref.size, rtol=1e-05
)
else:
assert res[0] is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

52-57行,为什么cpu/gpu模式下不一样呢?

Copy link
Contributor Author

@Li-fAngyU Li-fAngyU Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是因为在静态图模式下,paddle.empty 的返回值有点区别,我不太清楚是不是bug就做了点处理。

example code:

import paddle
print(paddle.__version__)
import numpy as np
N = 0
places = [paddle.CPUPlace(),paddle.CUDAPlace(0)]
paddle.enable_static()
for place in places:
    with paddle.static.program_guard(paddle.static.Program()):
        out = paddle.empty([3, N])
        exe = paddle.static.Executor(place)
        res = exe.run(fetch_list=[out])
    print('static:', place, res)
# 2.3.2
# static: Place(cpu) [None]
# static: Place(gpu:0) [array([], shape=(3, 0), dtype=float32)]

@Li-fAngyU
Copy link
Contributor Author

Li-fAngyU commented Mar 6, 2023

已补充复数的单测代码,以及示例代码。

注:因静态图下, paddle.empty 不支持创建数据类型为complex 的Tensor,因此单测中没有包括静态图下测试 paddle.vander 是否支持复数类型的代码。

example code:

import paddle
print(paddle.__version__)
places = [paddle.CPUPlace()]
paddle.enable_static()
for place in places:
    with paddle.static.program_guard(paddle.static.Program()):
        out = paddle.empty([3, 0],dtype=paddle.complex64)
        exe = paddle.static.Executor(place)
        res = exe.run(fetch_list=[out])
# 0.0.0
#       5 for place in places:
#       6     with paddle.static.program_guard(paddle.static.Program()):
# ----> 7         out = paddle.empty([3, 0],dtype=paddle.complex64)
#       8         exe = paddle.static.Executor(place)
#       9         res = exe.run(fetch_list=[out])
# TypeError: The data type of 'dtype' in empty must be ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], but received complex64.

@luotao1
Copy link
Contributor

luotao1 commented Mar 7, 2023

因静态图下,paddle.empty 不支持创建数据类型为complex 的Tensor

check_dtype(
dtype,
'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'empty',

1791行加上complex64和complex128就可以了,可以单独提一个PR修一下。

if N < 0:
raise ValueError("N must be non-negative.")

res = paddle.empty([x.shape[0], N], dtype=x.dtype)
Copy link
Contributor

@luotao1 luotao1 Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是因为在静态图模式下,paddle.empty 的返回值有点区别,我不太清楚是不是bug就做了点处理。#51048 (comment)

我能在develop分支复现,也看到你的issue了,但什么情况下N会等于0呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因torch,和numpy的vander都能支持N=0这一极端条件(torch的话是支持空tensor,以前的paddle版本是不支持空tensor,然后我就在单测中多考虑这一情况了)理论上N应该是大于0的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> import torch
>>> x = torch.tensor([1, 2, 3, 5])
>>> torch.vander(x, N=0)
tensor([], size=(4, 0), dtype=torch.int64)
>>> import paddle
>>> paddle.to_tensor([4,0])
Tensor(shape=[2], dtype=int64, place=Place(cpu), stop_gradient=True,
       [4, 0])

Copy link
Contributor Author

@Li-fAngyU Li-fAngyU Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的。

不是很理解 "N=0的时候,能否用paddle.to_tensor?"具体是什么意思, 是指能否用paddle.to_tensor创建空Tensor吗?
(如果用paddle.to_tensor创建空Tensor a的话,a的shape 好像就没法构造成(4,0)的形式了)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不好意思,paddle.to_tensor确实无法构造,看下paddle.rand系列

>>> import paddle
>>> paddle.randint(0, 1, [4, 0])
Tensor(shape=[4, 0], dtype=int64, place=Place(cpu), stop_gradient=True,
       [[],
        [],
        [],
        []])

Copy link
Contributor Author

@Li-fAngyU Li-fAngyU Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,paddle.to_tensor 是可以构建空Tensor的,shape的话直接用reshape就可以了。

>>> a = paddle.to_tensor([]).reshape([3,0])
>>> a
Tensor(shape=[3, 0], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[],
        [],
        []])

所以目前要先在N=0的时候,用paddle.to_tensor去替代paddle.empty吗?

Copy link
Contributor

@cuicheng01 cuicheng01 Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Li-fAngyU paddle.to_tensor在静态图下不能构造含有0 dim的tensor,建议还是使用paddle.empty

import paddle
print(paddle.__version__)
import numpy as np
N = 0
#places = [paddle.CPUPlace(),paddle.CUDAPlace(0)]
places = [paddle.CPUPlace()]
paddle.enable_static()
for place in places:
    with paddle.static.program_guard(paddle.static.Program()):
        #out = paddle.empty([3, N])
        out =  paddle.to_tensor([]).reshape([3,N])
        exe = paddle.static.Executor(place)
        res = exe.run(fetch_list=[out])
        print('static:', place, res)
######
λ ffd50c7f717b /Paddle {develop} python3.7 test.py
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
0.0.0
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    out =  paddle.to_tensor([]).reshape([3,N])
  File "/usr/local/python3.7.0/lib/python3.7/site-packages/paddle/tensor/manipulation.py", line 3588, in reshape
    attrs["shape"] = get_attr_shape(shape)
  File "/usr/local/python3.7.0/lib/python3.7/site-packages/paddle/tensor/manipulation.py", line 3571, in get_attr_shape
    % (dim_idx, len(x.shape))
AssertionError: The index of 0 in `shape` must be less than the input tensor X's dimensions. But received shape[1] = 0, X's dimensions = 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cuicheng01 @zhouwei25 讨论,应该是静态图CPUPlace的问题,换成 paddle.randint(0, 1, [4, 0]),输出也是None

static: Place(cpu) [None]

@Li-fAngyU
Copy link
Contributor Author

@cuicheng01

@Li-fAngyU Li-fAngyU reopened this Mar 17, 2023
@Li-fAngyU
Copy link
Contributor Author

@luotao1 请问PR-CR-Build里报错 make[2]: *** No rule to make target '../python/paddle/fluid/tests/unittests/test_vander.py', needed by 'python/build/.timestamp'. Stop. 该怎么解决呢?

@luotao1
Copy link
Contributor

luotao1 commented Mar 20, 2023

该怎么解决呢

已集中整理到 #51195 (comment)

@luotao1
Copy link
Contributor

luotao1 commented Mar 21, 2023

@Li-fAngyU paddle-build失败的问题已修复,可以rerun一下

@Li-fAngyU
Copy link
Contributor Author

收到

luotao1
luotao1 previously approved these changes Mar 24, 2023
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

jeff41404
jeff41404 previously approved these changes Mar 27, 2023
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


x = paddle.to_tensor([1., 2., 3.], dtype="float32")
out = paddle.vander(x)
print(out.numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print输出一定要在后面加 .numpy()吗? 去除后能否正常输出

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了.numpy() 是为了让输出更简洁一点,去除后可以正常输出。

increasing(bool): Order of the powers of the columns. If True, the powers increase from left to right, if False (the default) they are reversed.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Vandermonde matrix with shape: (len(x), N). If increasing is False, the first column is x^(N-1), the second x^(N-2) and so forth. If increasing is True, the columns are x^0, x^1, ..., x^(N-1).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • x^(N-1) 这类带次方的数字写成公式叭, 即写成 :math:xxxxx
  • 返回值请按规范写,需要先描述 API 返回值的类型,然后再描述 API 的返回值及其含义。
  • shape后不加冒号。returns的描述里避免出现冒号,因为冒号前面的会解析为 Returns type

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

@paddle-bot
Copy link

paddle-bot bot commented Mar 31, 2023

很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。
Sorry to inform you that through our discussion, your PR fails to meet the merging standard (Reference: Paddle Custom Operator Design Doc). You can also submit an new one. Thank you.

@Li-fAngyU Li-fAngyU reopened this Mar 31, 2023
@Li-fAngyU Li-fAngyU dismissed stale reviews from jeff41404 and luotao1 via 956a36b March 31, 2023 01:30
Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 7b53923 into PaddlePaddle:develop Mar 31, 2023
@Li-fAngyU Li-fAngyU deleted the api_vander branch July 6, 2023 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants