Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero-dim] Support 0-d for kthvalue and mode #49340

Merged

Conversation

zoooo0820
Copy link
Contributor

PR types

New features

PR changes

OPs

Describe

This PR is same with #49122 and fix the bug on Windows.
The original PR will met unittest error on Windows and already been reverted. This is because the flag is_runtime is different with Linux.

@paddle-bot
Copy link

paddle-bot bot commented Dec 26, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

zhwesky2010
zhwesky2010 previously approved these changes Dec 29, 2022
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kthvalue和mode的axis应该限制None,0,-1。然后kthvalue应该还有k值的限制为1
输入其他值应该报错,可以和竞品对比下

@zoooo0820 zoooo0820 force-pushed the support_0D_for_kthvalue_and_mode branch from 2ca67bf to e08ba7f Compare January 3, 2023 11:55
@zoooo0820
Copy link
Contributor Author

kthvalue和mode的axis应该限制None,0,-1。然后kthvalue应该还有k值的限制为1 输入其他值应该报错,可以和竞品对比下

kthvalue的部分:

  • k值检查,已在kthvalue_kernel.cc/cu中增加检查k==1,并增加了单测测试
  • axis合法性检查, 已要求axis<=dim_size(此前是axis<dim_size),这样可以在0-d 输入时接收axis=0作为输入。其他维度下axis=dim_size时会有eigen/common.h抛出错误
  • None,-1的接收,目前默认定义与竞品一致(None=-1,默认None)

mode的部分:

  • axis合法性检查, 已要求axis<=dim_size(此前是axis<dim_size),这样可以在0-d 输入时接收axis=0作为输入。其他维度下axis=dim_size时会有eigen/common.h抛出错误
  • None, -1的接收:目前默认与竞品一致,默认-1,不可接收None


out = paddle.kthvalue(x, 1)
out[0].backward()
self.assertEqual(out[0].shape, [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动态图还需要测下反向grad的shape,把具体值也测下吧,因为比较固定

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 292738f into PaddlePaddle:develop Jan 6, 2023
@zoooo0820 zoooo0820 deleted the support_0D_for_kthvalue_and_mode branch January 6, 2023 08:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants